diff --git a/.agents/skills/cuda-attention-kernel-patterns/SKILL.md b/.agents/skills/cuda-attention-kernel-patterns/SKILL.md new file mode 100644 index 0000000000000..5325a1bf22bdc --- /dev/null +++ b/.agents/skills/cuda-attention-kernel-patterns/SKILL.md @@ -0,0 +1,237 @@ +--- +name: cuda-attention-kernel-patterns +description: Patterns and pitfalls for the ONNX domain Attention operator (opset 23/24) CUDA implementation. Use when modifying the dispatch cascade in core/providers/cuda/llm/attention.cc, writing mask/bias CUDA kernels, debugging attention test routing, or adding features to the ONNX Attention op. NOT for contrib domain MultiHeadAttention/GroupQueryAttention. +--- + +# ONNX Domain Attention (Opset 23/24) CUDA Patterns + +Reusable knowledge from ONNX Attention CUDA development in ORT. + +> **Scope**: This skill covers the **ONNX domain** `Attention` operator (opset 23/24) +> implemented at `core/providers/cuda/llm/attention.cc`. This is **separate from** the +> contrib domain `MultiHeadAttention` / `GroupQueryAttention` at `contrib_ops/cuda/bert/`. +> They share some underlying kernels (CUTLASS FMHA, Flash Attention) and infrastructure +> (`attention_softmax.h`) but have **different dispatch logic, parameter structs, and eligibility checks**. +> +> - **Shared infrastructure**: CUTLASS FMHA kernel, Flash kernel, unified unfused kernel +> (`unfused_attention.cu`), `attention_softmax.h`, `attention_impl.cu` (contrib only) +> - **ONNX-specific**: Dispatch cascade in `attention.cc`, `ConvertAttnMaskToBias`, +> `mask_filter_value` cap, parameter bridge to contrib structs, `attention_mask_impl.cu` +> - **Contrib-specific**: Own dispatch in contrib MHA/GQA ops, uses `contrib::AttentionParameters` +> directly, has XQA kernel, past-present buffer sharing + +## 1. Runner Dispatch Cascade + +CUDA attention dispatches in priority order: **Flash → MEA (Memory Efficient) → Unified Unfused Attention**. + +``` +// onnxruntime/core/providers/cuda/llm/attention.cc — ComputeInternal() +Flash eligible? → RunFlashAttention() + ↓ no +MEA eligible? → RunMemoryEfficientAttention() + ↓ no +Unified Unfused → RunUnfusedAttention() + (handles both MHA and GQA via reshape-Q trick) +``` + +**Flash eligibility**: fp16/bf16 only, SM≥8.0 (Ampere+), `head_size == v_head_size`, `head_size <= 256`, no `output_qk`, `attn_mask == nullptr`. Uses `mha_fwd` / `mha_fwd_kvcache`. + +**MEA eligibility**: SM50+/53+/80+ by dtype, `head_size <= 1024` and divisible by 8, no `output_qk`. Decode requires `head_size == v_head_size` (for `LaunchConcatNewToPastKV`). Bias stride must satisfy `total_sequence_length % 4 == 0`. GQA with FP32 is excluded (LaunchUngroup only has fp16/bf16 instantiations). Supports `softcap + attn_mask` — CUTLASS applies softcap before bias in kernel tiles, matching ONNX spec ordering (onnx/onnx#7865). + +**Unified Unfused Attention**: Always available as the final fallback. Handles both MHA (`num_heads == kv_num_heads`, group=1) and GQA (`num_heads != kv_num_heads`, group>1) via a reshape-Q trick with stride-based cuBLAS batched GEMM (no K/V head replication). Uses FP32 QK scratch for precision. Supports all features: +- softcap + attn_mask (spec-correct ordering) +- output_qk (kQK mode: copies raw QK before softcap/mask mutations) +- past_key + past_value with `head_size != v_head_size` (separate K/V concat) +- causal masking, nonpad_kv_seqlen, all dtypes (fp16/bf16/fp32) + +## 2. CUTLASS kLog2e Overflow + +CUTLASS `iterative_softmax` multiplies all attention scores by `kLog2e ≈ 1.4427` internally (for `exp2f` instead of `expf`). For float/bf16: + +``` +mask_filter_value = std::numeric_limits::lowest() ≈ -3.40e+38 +-3.40e+38 × 1.4427 ≈ -4.91e+38 → overflows fp32 → -inf +``` + +When all values become `-inf`, CUTLASS's special-case path produces `s_prime=0` → `1/s_prime=inf` → `0 × inf = NaN`. + +**Fix**: Cap `mask_filter_value` to `-1.0e+30f` in `ConvertAttnMaskToBias`. This value is safe: `1e30 × 1.4427 ≈ 1.4e30 << FLT_MAX`, and `exp(-1e30) ≈ 0` (effectively masked). + +**fp16 is NOT affected**: `lowest() = -65504`, and `-65504 × 1.4427 ≈ -94500` stays within fp32 range. + +This cap is ONLY applied in MEA paths. The unfused path uses `lowest()` directly (its softmax subtracts max first, avoiding overflow). + +**Subtlety**: When bias is present (`kSupportsBias=true`), CUTLASS pre-applies `p.scale` to QK (line 858) and uses `scaling=1.0f` in the softmax loop (line 981). So the full `kLog2e` multiplier hits the bias-dominated values — the overflow is head_size-independent. Without bias, `scaling = p.scale * kLog2e = kLog2e/sqrt(head_size)`, which is much smaller. + +## 3. Bias Alignment + +CUTLASS FMHA requires the attention bias row stride to satisfy minimum alignment. The bias has shape `[B, H, S, T]` where `T = total_sequence_length` is the row stride. + +```cpp +constexpr int min_bias_align = 4; // elements, not bytes +if (parameters.total_sequence_length % min_bias_align != 0) { + mea_eligible = false; // fall through to unfused +} +``` + +**Impact on tests**: If a test uses `total_sequence_length` not divisible by 4 (e.g., past=5 + new=6 = 11), MEA is rejected and unfused handles it. To test MEA with bias, ensure `total_sequence_length % 4 == 0`. + +## 4. Softcap Ordering + +ONNX spec ordering (onnx/onnx#7865): `QK → scale → softcap → add mask/bias → softmax` + +- **MEA (CUTLASS)**: Fuses softcap before bias in kernel tile loop (`kernel_forward.h`). Matches spec ordering. +- **Flash**: Handles softcap natively in `mha_fwd`/`mha_fwd_kvcache` but rejects `attn_mask`, so ordering with mask is moot. +- **Unfused**: Handles spec-correct ordering in the fused softmax kernel: `QK → scale → softcap → add bias → softmax`. + +All three paths apply softcap BEFORE mask/bias. If softcap were applied after masking, `tanh(-inf/sc) = -sc` (finite), leaking probability to masked positions. + +The unfused path does: `QK → scale → softcap → add bias → softmax` (all fused in `UnfusedSoftmaxKernel`). + +## 5. Grid-Stride Loops for CUDA Kernels + +Always cap grid size to prevent exceeding `gridDim.x` limits, and use grid-stride loops for large workloads: + +```cpp +constexpr int64_t kMaxGridDimX = 65535; +int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); +int64_t blocks = (total + threads - 1) / threads; +unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); + +MyKernel<<>>(...); + +// Inside the kernel: +for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += static_cast(gridDim.x) * blockDim.x) { + // work +} +``` + +**Never** cast `int64_t` block count directly to `unsigned int` without capping — it silently truncates. + +Always call `CUDA_CALL(cudaGetLastError())` after kernel launches in standalone helper functions. This is the established pattern in the file (see `ConcatPastToPresent`, `PastPresentBufferShare`). + +## 6. Fully-Masked Batches + +All-false bool masks or `seqlens_k=0` produce NaN in CUTLASS MEA. + +**Additive-bias path** (bool mask converted to bias): Fixed by capping `mask_filter_value` to `-1e+30f` (see section 2). CUTLASS then naturally computes uniform softmax → mean(V). + +**Nonpad path** (`seqlens_k=0`): CUTLASS skips all K/V positions → `s_prime=0` → NaN. Fixed by `ZeroOutputForFullyMaskedBatches` kernel which zeros output for batches where `seqlens_k[b] == 0`. Note: this produces zeros, not mean(V) — a cross-EP consistency TODO exists. + +**CPU/Unfused behavior**: `mask_filter_value = lowest()` (not `-inf`). All masked values are equal → `softmax(equal) = 1/N` → output = mean(V). This is the spec reference. + +## 7. Test Runner Targeting + +Use `ScopedEnvironmentVariables` to force specific CUDA runners: + +```cpp +// Force MEA (disable Flash) +ScopedEnvironmentVariables scoped_env({ + {"ORT_DISABLE_FLASH_ATTENTION", "1"}, +}); + +// Force Unfused (disable both Flash and MEA) +ScopedEnvironmentVariables scoped_env({ + {"ORT_DISABLE_FLASH_ATTENTION", "1"}, + {"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", "1"}, +}); +``` + +**Always verify which runner a test actually hits.** A test designed for MEA may silently fall to unfused if: +- `total_sequence_length % 4 != 0` (bias alignment) +- `head_size != v_head_size` (decode path) +- fp32 dtype with GQA (LaunchUngroup fp16/bf16 only) +- fp32 dtype on SM < 80 + +Enable verbose logging to confirm: `LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using ..."`. + +## 8. Cross-EP Consistency + +CPU is the spec reference implementation. CUDA outputs should match CPU for all valid inputs. + +- CPU uses `mask_filter_value = std::numeric_limits::lowest()` (finite, not `-inf`) +- CPU softmax: subtract-max-first → works correctly with extreme finite values +- CPU handles fully-masked batches naturally (uniform softmax → mean(V)) + +Run tests with `disable_cpu=false` to always validate against CPU. The C++ test framework (`RunTest4D`) supports `disable_cpu`, `disable_cuda`, `disable_dml` flags. + +## 9. File Locations + +### ONNX Domain (this op's code) + +| File | Purpose | +|------|---------| +| `core/providers/cuda/llm/attention.cc` | ONNX Attention CUDA dispatch: Flash/MEA/Unfused cascade, `ConvertAttnMaskToBias`, parameter setup | +| `core/providers/cuda/llm/attention_mask_impl.cu` | ONNX-specific mask/bias CUDA kernels: bool→bias, nonpad→seqlens_k, ZeroOutput, bias composition | +| `core/providers/cuda/llm/attention_mask_impl.h` | Declarations for ONNX mask/bias kernels | +| `core/providers/cpu/llm/attention.cc` | CPU reference implementation (ONNX domain) | +| `core/providers/cpu/llm/attention_helper.h` | ONNX parameter validation and shape computation | +| `test/providers/cpu/llm/attention_op_test.cc` | C++ attention tests (all EPs) | +| `test/python/transformers/test_onnx_attention/test_mha.py` | Python parity tests | +| `test/python/transformers/test_onnx_attention/common.py` | Python test utilities and reference `attention_ref()` | + +### Shared Infrastructure (used by both ONNX and contrib ops) + +| File | Purpose | +|------|---------| +| `contrib_ops/cuda/bert/unfused_attention.cu` | Unified unfused attention: QK GEMM (FP32), fused softmax kernel (scale+softcap+bias+causal), V GEMM. Handles MHA and GQA. | +| `contrib_ops/cuda/bert/unfused_attention.h` | `UnfusedAttentionParams`, `LaunchUnfusedAttention`, workspace size | +| `contrib_ops/cuda/bert/attention_impl.cu` | Legacy unfused `QkvToContext` (contrib MHA only). Also `ApplySoftcap`, `ConcatPastToPresent` | +| `contrib_ops/cuda/bert/attention_softmax.h` | CUDA softmax kernels (`ComputeSoftmax`, `ComputeSoftmaxWithRawMask`) — used by legacy contrib path | +| `contrib_ops/cuda/bert/cutlass_fmha/` | CUTLASS FMHA (Memory Efficient Attention) kernels | +| `contrib_ops/cuda/bert/flash_attention/` | Flash Attention kernels | + +### Contrib Domain (separate ops, NOT covered by this skill) + +| File | Purpose | +|------|---------| +| `contrib_ops/cuda/bert/multihead_attention.cu` | Contrib `MultiHeadAttention` — own dispatch, uses `contrib::AttentionParameters` directly | +| `contrib_ops/cuda/bert/group_query_attention.cu` | Contrib `GroupQueryAttention` — has XQA kernel, past-present buffer sharing | + +## 10. Parameter Bridge (ONNX → Contrib) + +The ONNX Attention op uses `attention_helper::AttentionParameters` (in `core/providers/cpu/llm/attention_parameters.h`). The unified unfused kernel (`LaunchUnfusedAttention`) uses its own `UnfusedAttentionParams` struct populated directly from ONNX parameters in `RunUnfusedAttention`. + +The contrib `QkvToContext` function (used by contrib MHA, NOT by ONNX Attention) uses `contrib::AttentionParameters`. ONNX Attention does **not** bridge to `contrib::AttentionParameters` — it routes through the unified unfused kernel instead. + +## 11. Causal Alignment + +The ONNX spec defines two causal alignment modes based on where query positions sit in the full attention matrix: + +- **Upper-left**: `q_i` attends to `kv[0..i]`. Query positions start at 0 in the full matrix. +- **Lower-right**: `q_i` attends to `kv[kv_len - q_len + i..kv_len - 1]`. Query positions are at the end. + +**ONNX spec rule**: `is_causal=1` always means upper-left in the full matrix. When `past_key` provides context, `past_sequence_length` shifts the query start position forward — the resulting `[S_q × total_kv]` sub-matrix effectively has lower-right alignment. + +### Per-kernel behavior + +| Kernel | Alignment | Mechanism | +|--------|-----------|-----------| +| **Flash** | Lower-right only | `is_causal` flag → `seqlen_k - seqlen_q` offset in kernel. No top-left option. | +| **MEA (CUTLASS)** | Both | `causal_from_top_left` flag in `MemoryEfficientAttentionParams`. `true` → `CausalFromTopLeft` (offset=0). `false` → `CausalFromBottomRight` (offset = num_keys - num_queries). | +| **Unfused** | Both | `past_kv_length` param. `0` → upper-left. `total_kv - S_q` → lower-right. | + +### Dispatch logic in attention.cc + +```cpp +// Flash cannot do upper-left → guarded by causal_cross_no_past +bool causal_cross_no_past = parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; + +// Flash: skip when causal_cross_no_past (no top-left support) +// MEA: NOT skipped — handles it via causal_from_top_left = (past_sequence_length == 0) +// Unfused: always correct via past_kv_length = parameters.past_sequence_length +``` + +### When S_q == S_kv + +Upper-left and lower-right produce **identical** results when `S_q == S_kv` (the offset is 0 either way). The alignment distinction only matters for cross-attention shapes (`S_q != S_kv`). + +### TensorScatter decode (opset 24 external KV cache) + +TensorScatter manages KV cache externally — `past_key` is nullptr but K/V already contain the full sequence. Per the ONNX spec, `is_causal` with `S_q != S_kv` and no `past_key` means upper-left (q[0] sees only kv[0]), which is **not meaningful for decode**. + +**Correct pattern**: TensorScatter decode must use `is_causal=0` and rely on `nonpad_kv_seqlen` to bound the active KV range. Models using `is_causal=1` with TensorScatter decode have a spec-invalid combination. diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 5ff8c4f5d2ad2..5db08bf2dc579 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.26.0 +1.27.0 diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 39985b23da3cc..494d5588c2d03 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -93,6 +93,12 @@ endif() if(HAS_CAST_FUNCTION_TYPE) target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type") endif() +# pybind11 3.0 headers trigger -Wmaybe-uninitialized in GCC's flow analysis +# of property accessor lambdas. Suppress it for this target only. +# See https://github.com/microsoft/onnxruntime/issues/25681 +if(HAS_MAYBE_UNINITIALIZED) + target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-maybe-uninitialized") +endif() # We export symbols using linker and the compiler does not know anything about it # There is a problem with classes that have pybind types as members. diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 6779fd60bcd0a..9e96c3ca16105 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -50,8 +50,6 @@ CMake creates a target to this project - $(BuildDate) - $(BuildTime) $([System.DateTime]::UtcNow.ToString(yyyyMMdd)) $([System.DateTime]::UtcNow.ToString(hhmm)) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 28b282b25f8f6..bc4602856b3bf 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -680,7 +680,8 @@ The **OpSet Version** column uses the following notation: |||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| |||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[23, 24]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -882,7 +883,8 @@ The **OpSet Version** column uses the following notation: |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |RMSNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**|23+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| |||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| |RandomNormal|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |RandomNormalLike|*in* input:**T1**
*out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(double), tensor(float), tensor(float16)| @@ -914,7 +916,8 @@ The **OpSet Version** column uses the following notation: |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|25+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|||[23, 24]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[19, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[14, 18]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| diff --git a/docs/python/README.rst b/docs/python/README.rst index 0e03575236613..e8190c584fb62 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime { // & outputs); // ::SetOutputs(std::vector& outputs) { } template -inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { - // Graph takes ownership of `initializer` - // On error the ownership is not transferred. +inline void GraphImpl::AddInitializer(const std::string& name, const Value& initializer, bool data_is_external) { + // Graph copies the OrtValue internally. Caller retains ownership of initializer. ThrowOnError(GetModelEditorApi().AddInitializerToGraph(this->p_, name.c_str(), initializer, data_is_external)); - initializer.release(); } template diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 6957caf427f59..1c679620af4b1 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.26.0'; +export const version = '1.27.0'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 21a9a5e3aceba..0c63997a2b8fa 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.26.0", + "version": "1.27.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.26.0", + "version": "1.27.0", "license": "MIT", "devDependencies": { "globby": "^15.0.0", diff --git a/js/common/package.json b/js/common/package.json index 6746115c97f76..c038faeb1b6cd 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.26.0", + "version": "1.27.0", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 6957caf427f59..1c679620af4b1 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.26.0'; +export const version = '1.27.0'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index d7e9867590377..fc25230249bc8 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.26.0", + "version": "1.27.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.26.0", + "version": "1.27.0", "hasInstallScript": true, "license": "MIT", "os": [ @@ -30,7 +30,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.26.0", + "version": "1.27.0", "license": "MIT", "devDependencies": { "globby": "^15.0.0", diff --git a/js/node/package.json b/js/node/package.json index 613d227b048fd..c0403665c1a90 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -11,7 +11,7 @@ 6 ] }, - "version": "1.26.0", + "version": "1.27.0", "dependencies": { "adm-zip": "^0.5.16", "global-agent": "^4.1.3", diff --git a/js/node/script/install-metadata-versions.js b/js/node/script/install-metadata-versions.js index f4238d1981b45..81c5671cac9a3 100644 --- a/js/node/script/install-metadata-versions.js +++ b/js/node/script/install-metadata-versions.js @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -module.exports = { nuget: [{ feed: 'nuget', version: '1.26.0' }] }; +module.exports = { nuget: [{ feed: 'nuget', version: '1.27.0' }] }; diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index 73d6e2a65f274..907e9cf72b59c 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -4757,7 +4757,9 @@ "license": "MIT" }, "node_modules/brace-expansion": { - "version": "1.1.11", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "license": "MIT", "dependencies": { "balanced-match": "^1.0.0", @@ -5763,7 +5765,9 @@ } }, "node_modules/detox/node_modules/brace-expansion": { - "version": "2.0.1", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz", + "integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 6957caf427f59..1c679620af4b1 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.26.0'; +export const version = '1.27.0'; diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index 06e97afaac587..4410cc4816a66 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-react-native", - "version": "1.26.0", + "version": "1.27.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "onnxruntime-react-native", - "version": "1.26.0", + "version": "1.27.0", "license": "MIT", "dependencies": { "onnxruntime-common": "file:../common" @@ -30,7 +30,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.26.0", + "version": "1.27.0", "license": "MIT", "devDependencies": { "globby": "^15.0.0", diff --git a/js/react_native/package.json b/js/react_native/package.json index b518adf14b327..b25f85e9cf38b 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -37,7 +37,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.26.0", + "version": "1.27.0", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 6957caf427f59..1c679620af4b1 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.26.0'; +export const version = '1.27.0'; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 364254d35f58e..9b07364eb033a 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.26.0", + "version": "1.27.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.26.0", + "version": "1.27.0", "license": "MIT", "dependencies": { "flatbuffers": "^25.1.24", @@ -50,7 +50,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.26.0", + "version": "1.27.0", "license": "MIT", "devDependencies": { "globby": "^15.0.0", diff --git a/js/web/package.json b/js/web/package.json index 2b3db3ffc7036..dfbb96c9d0370 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -7,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.26.0", + "version": "1.27.0", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^25.1.24", diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index d6d1e383641c2..2cfbd6ed4c92a 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -1068,9 +1068,9 @@ } }, "node_modules/nanoid": { - "version": "3.3.8", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz", - "integrity": "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==", + "version": "3.3.12", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.12.tgz", + "integrity": "sha512-ZB9RH/39qpq5Vu6Y+NmUaFhQR6pp+M2Xt76XBnEwDaGcVAqhlvxrl3B2bKS5D3NH3QR76v3aSrKaF/Kiy7lEtQ==", "funding": [ { "type": "github", @@ -1105,9 +1105,9 @@ } }, "node_modules/postcss": { - "version": "8.5.3", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.3.tgz", - "integrity": "sha512-dle9A3yYxlBSrt8Fu+IpjGT8SY8hN0mlaA6GY8t0P5PjIOZemULz/E2Bnm/2dcUOena75OTNkHI76uZBNUUq3A==", + "version": "8.5.13", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.13.tgz", + "integrity": "sha512-qif0+jGGZoLWdHey3UFHHWP0H7Gbmsk8T5VEqyYFbWqPr1XqvLGBbk/sl8V5exGmcYJklJOhOQq1pV9IcsiFag==", "funding": [ { "type": "opencollective", @@ -1124,7 +1124,7 @@ ], "license": "MIT", "dependencies": { - "nanoid": "^3.3.8", + "nanoid": "^3.3.11", "picocolors": "^1.1.1", "source-map-js": "^1.2.1" }, diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 9f3349a163e91..16b33baaf1b7f 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -10,7 +10,7 @@ import contextlib -__version__ = "1.26.0" +__version__ = "1.27.0" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index f316a0dfdf91c..5b7624d11c6fd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -33,6 +33,7 @@ struct AttentionParameters { bool broadcast_attn_bias_dim_1 = false; float mask_filter_value = 0.0f; float scale = 0.0f; + float softcap = 0.0f; bool use_tf32 = false; bool is_output_bnsh = false; // whether the output format is BNSH AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 98f92b79e6ec6..60f2d05446da1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -205,11 +205,11 @@ struct GroupQueryAttentionData { void* xqa_buffer = nullptr; size_t xqa_buffer_bytes = 0; - // Unfused fallback buffers (see LaunchGqaUnfusedAttention in gqa_unfused_attention.h): + // Unfused fallback buffers (see LaunchUnfusedAttention in unfused_attention.h): // unfused_q_bnsh : [B, N_q, S_q, H] (Q transposed from BSNH to BNSH) // unfused_y_bnsh : [B, N_q, S_q, H_v] (output BNSH, transposed to BSNH before leaving op) // unfused_workspace: FP32 QK scratch + T softmax scratch (sized by - // GetGqaUnfusedAttentionWorkspaceSize) + // GetUnfusedAttentionWorkspaceSize) T* unfused_q_bnsh = nullptr; T* unfused_y_bnsh = nullptr; void* unfused_workspace = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 29bb4fba6a09a..aedb370d38367 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -176,7 +176,14 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromBottomRight; + // ONNX spec: is_causal means upper-left alignment (q_i attends to kv[0..i]). + // When past_sequence_length > 0 (decode with KV cache), positions shift → lower-right. + // causal_from_top_left=true: past_seq==0, use CausalFromTopLeft (offset=0). + // causal_from_top_left=false: past_seq>0 or S_q==S_kv, use CausalFromBottomRight + // (offset = num_keys - num_queries, which is 0 when square). + p.custom_mask_type = params.causal_from_top_left + ? Attention::CausalFromTopLeft + : Attention::CausalFromBottomRight; } // We use max_sequence_length to calculate KV stride diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index ace598489a226..a961be051a16a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -13,6 +13,13 @@ namespace cuda { constexpr int kEfficientAttentionMaxHeadSize = 1024; +// CUTLASS online softmax multiplies attention scores by kLog2e (≈1.4427). +// For float/bf16, |lowest() × kLog2e| > FLT_MAX, overflowing to -inf and +// causing s_prime=0 → NaN for fully-masked batches. Cap to prevent this. +// -1e+30 is safe: 1e30 × 1.4427 ≈ 1.4e30 << FLT_MAX ≈ 3.4e38, and +// exp(-1e30) ≈ 0 (effectively masked). For fp16 lowest()=-65504 > -1e30, no-op. +constexpr float kCutlassSafeMaskFilterValue = -1.0e+30f; + struct MemoryEfficientAttentionParams { int32_t sm = 50; bool is_half = false; @@ -27,6 +34,12 @@ struct MemoryEfficientAttentionParams { int32_t v_head_size = 0; int32_t local_window_size = -1; bool causal = false; + // When true, causal masking uses upper-left alignment (q_i attends to kv[0..i]). + // When false (default), uses lower-right alignment (q_i attends to kv[kv_len-q_len+i..kv_len-1]). + // ONNX Attention spec requires upper-left for cross-attention without past (S_q != S_kv, past=0). + // Lower-right is correct for decode with KV cache (past > 0). + // For square matrices (S_q == S_kv), both alignments produce identical results. + bool causal_from_top_left = false; bool use_smooth_softmax = false; bool broadcast_attn_bias_dim_0 = false; bool broadcast_attn_bias_dim_1 = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 5f21f3cd34e8f..dfecc2b810a04 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -14,7 +14,7 @@ #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/xqa/xqa_loader.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cpu/utils/debug_macros.h" @@ -513,7 +513,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // GQA-capable unfused fallback (issue #28195). // Activates when Flash / MEA / XQA are all ineligible and KV is not quantized. // Supports any head_size (FP32 QK accumulation), GQA, sliding window, softcap. - // See LaunchGqaUnfusedAttention in contrib_ops/cuda/bert/gqa_unfused_attention.h. + // See LaunchUnfusedAttention in contrib_ops/cuda/bert/unfused_attention.h. // --------------------------------------------------------------------- IAllocatorUniquePtr unfused_scratch; if (!data.use_xqa && !data.use_flash_attention && !data.use_memory_efficient_attention && @@ -538,7 +538,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const SafeInt q_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H * sizeof(T)); const SafeInt y_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H_v * sizeof(T)); const SafeInt ws_bytes = SafeInt( - onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + onnxruntime::contrib::cuda::GetUnfusedAttentionWorkspaceSize( static_cast(B), static_cast(N_q), static_cast(S_q), static_cast(S_kv))); const SafeInt workspace_offset = q_bnsh_bytes + y_bnsh_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ebb6a0b0da215..70c58e6b8f764 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -38,7 +38,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh" @@ -1095,7 +1095,7 @@ Status UnfusedGqaAttention( } // Step 3: run unfused attention with FP32 QK accumulation. - GqaUnfusedAttentionParams p; + UnfusedAttentionParams p; p.batch_size = batch_size; p.num_heads = num_heads; p.kv_num_heads = kv_num_heads; @@ -1113,18 +1113,20 @@ Status UnfusedGqaAttention( p.broadcast_attn_bias_dim_1 = false; p.is_causal = parameters.is_unidirectional; p.local_window_size = parameters.local_window_size; // -1 disables + p.past_kv_length = parameters.total_sequence_length - parameters.sequence_length; p.scale = scale; p.softcap = parameters.softcap; p.seqlens_k = data.total_seq_lens; - ORT_RETURN_IF_ERROR((LaunchGqaUnfusedAttention( + ORT_RETURN_IF_ERROR((LaunchUnfusedAttention( device_prop, cublas, stream, p, data.unfused_q_bnsh, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), /*attn_bias=*/nullptr, data.unfused_y_bnsh, - data.unfused_workspace))); + data.unfused_workspace, + /*output_qk=*/nullptr))); // Step 4: transpose output BNSH → BSNH into data.output. // Use p.v_head_size (== head_size per ORT_ENFORCE) for semantic correctness. diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu similarity index 77% rename from onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu rename to onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu index 8aac549aeba01..a0c9d4666cae3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// GQA-capable unfused CUDA attention kernel. See header for contract. +// Unified unfused CUDA attention kernel. See header for contract. +#include #include #include "core/providers/cuda/cu_inc/cub.cuh" #include @@ -13,7 +14,7 @@ #include "core/providers/cuda/cuda_type_conversion.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" using onnxruntime::cuda::OrtToCudaType; @@ -38,10 +39,37 @@ __device__ __forceinline__ float ToFloat<__half>(__half v) { return __half2float template <> __device__ __forceinline__ float ToFloat<__nv_bfloat16>(__nv_bfloat16 v) { return __bfloat162float(v); } +// Device helper: convert float to T. +template +__device__ __forceinline__ T FromFloat(float v); +template <> +__device__ __forceinline__ float FromFloat(float v) { return v; } +template <> +__device__ __forceinline__ __half FromFloat<__half>(float v) { return __float2half(v); } +template <> +__device__ __forceinline__ __nv_bfloat16 FromFloat<__nv_bfloat16>(float v) { return __float2bfloat16(v); } + inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total_kv) { return SafeInt(batch_size) * num_heads * q_seq * total_kv; } +// --------------------------------------------------------------------------- +// CopyQK kernel: copies FP32 QK scratch to T output with scale applied. +// output_qk[i] = T(qk_fp32[i] * scale) for i in [0, total_elements). +// --------------------------------------------------------------------------- +template +__global__ void ScaledCopyQkKernel( + const float* __restrict__ qk_fp32, + T* __restrict__ output_qk, + const float scale, + const int64_t total_elements) { + for (int64_t idx = static_cast(blockIdx.x) * TPB + threadIdx.x; + idx < total_elements; + idx += static_cast(gridDim.x) * TPB) { + output_qk[idx] = FromFloat(qk_fp32[idx] * scale); + } +} + // --------------------------------------------------------------------------- // Softmax kernel: reads FP32 QK scores, writes T softmax output. // @@ -56,7 +84,7 @@ inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total // total_kv_length. Handles fully-masked rows by emitting zeros (no NaN). // --------------------------------------------------------------------------- template -__global__ void GqaUnfusedSoftmaxKernel( +__global__ void UnfusedSoftmaxKernel( const int q_sequence_length, const int total_kv_length, const int num_heads, // N_q @@ -68,6 +96,7 @@ __global__ void GqaUnfusedSoftmaxKernel( const int* __restrict__ seqlens_k, const bool is_causal, const int local_window_size, + const int past_kv_length, const float scale, const float softcap, T* __restrict__ softmax_out) { @@ -82,12 +111,13 @@ __global__ void GqaUnfusedSoftmaxKernel( if (v < kv_end) kv_end = v; if (v < 0) kv_end = 0; } - // past (number of KV positions before the current query tokens) must be - // per-batch when seqlens_k is provided, since different batches can have - // different amounts of valid past context. Using the global total_kv_length - // would over-estimate past for short batches and shift the sliding-window - // start past kv_end, producing an all-masked (zero) row. - const int past = kv_end - q_sequence_length; + // past_kv_length is the number of KV positions that precede the current query + // tokens. For upper-left causal alignment (ONNX Attention with no past), + // this is 0. For lower-right alignment (decode with past), this is + // total_kv_length - q_sequence_length. + // When seqlens_k varies per batch (GQA sliding window), derive per-batch + // so the window cutoff stays within the valid range for shorter batches. + const int past = (seqlens_k != nullptr) ? (kv_end - q_sequence_length) : past_kv_length; const int q_pos = past + q_in_head; int end = kv_end; @@ -191,16 +221,16 @@ __global__ void GqaUnfusedSoftmaxKernel( } template -void LaunchGqaUnfusedSoftmax( +void LaunchUnfusedSoftmax( cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const float* qk_in, const T* attn_bias, T* softmax_out) { const dim3 grid(params.num_heads * params.q_sequence_length, params.batch_size, 1); const bool has_bias = (attn_bias != nullptr); constexpr int TPB = 256; - GqaUnfusedSoftmaxKernel<<>>( + UnfusedSoftmaxKernel<<>>( params.q_sequence_length, params.total_kv_length, params.num_heads, @@ -212,6 +242,7 @@ void LaunchGqaUnfusedSoftmax( params.seqlens_k, params.is_causal, params.local_window_size, + params.past_kv_length, params.scale, params.softcap, softmax_out); @@ -250,7 +281,7 @@ template common::Status LaunchQkGemmFp32( const cudaDeviceProp& /*device_prop*/, cublasHandle_t cublas, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, float* qk_out) { @@ -292,7 +323,7 @@ common::Status LaunchQkGemmFp32( CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention QK GEMM failed: ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "UnfusedAttention QK GEMM failed: ", status); } return common::Status::OK(); } @@ -312,7 +343,7 @@ common::Status LaunchQkGemmFp32( template common::Status LaunchAttnVGemm( cublasHandle_t cublas, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* softmax_out, const T* value, T* output) { @@ -347,7 +378,7 @@ common::Status LaunchAttnVGemm( CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention AV GEMM failed: ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "UnfusedAttention AV GEMM failed: ", status); } return common::Status::OK(); } @@ -357,10 +388,10 @@ common::Status LaunchAttnVGemm( // --------------------------------------------------------------------------- // Public API // --------------------------------------------------------------------------- -size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, - int num_heads, - int q_sequence_length, - int total_kv_length) { +size_t GetUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length) { const size_t elems = QkElementCount(batch_size, num_heads, q_sequence_length, total_kv_length); // FP32 QK scratch + T softmax scratch. We always allocate sizeof(float) per // element for the T scratch too (upper bound); caller can cast appropriately. @@ -370,26 +401,27 @@ size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, } template -common::Status LaunchGqaUnfusedAttention( +common::Status LaunchUnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t cublas, cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, const T* value, const T* attn_bias, T* output, - void* workspace) { + void* workspace, + T* output_qk) { ORT_RETURN_IF_NOT(params.batch_size > 0 && params.num_heads > 0 && params.kv_num_heads > 0 && params.head_size > 0 && params.v_head_size > 0 && params.q_sequence_length > 0 && params.total_kv_length > 0 && params.max_kv_length >= params.total_kv_length, - "GqaUnfusedAttention: invalid params."); + "UnfusedAttention: invalid params."); ORT_RETURN_IF_NOT(params.num_heads % params.kv_num_heads == 0, - "GqaUnfusedAttention: num_heads (", params.num_heads, + "UnfusedAttention: num_heads (", params.num_heads, ") must be a multiple of kv_num_heads (", params.kv_num_heads, ")."); - ORT_RETURN_IF(workspace == nullptr, "GqaUnfusedAttention: workspace is null."); + ORT_RETURN_IF(workspace == nullptr, "UnfusedAttention: workspace is null."); const size_t elems = QkElementCount(params.batch_size, params.num_heads, params.q_sequence_length, params.total_kv_length); @@ -400,7 +432,21 @@ common::Status LaunchGqaUnfusedAttention( ORT_RETURN_IF_ERROR((LaunchQkGemmFp32(device_prop, cublas, params, query, key, qk_fp32))); - LaunchGqaUnfusedSoftmax(stream, params, qk_fp32, attn_bias, softmax_T); + // Copy scaled QK to output_qk BEFORE softcap/mask/softmax. + // output_qk[i] = T(qk_fp32[i] * scale) — this is "kQK" mode (scale * Q @ K^T). + // Note: When seqlens_k is provided, positions [seqlens_k[b], total_kv) in output_qk + // may contain stale KV cache data. Consumers of output_qk should only read positions + // [0, seqlens_k[b]) for batch b. + if (output_qk != nullptr) { + const int64_t total = static_cast(elems); + constexpr int kTPB = 256; + constexpr int kMaxBlocks = 65535; + const int blocks = static_cast(std::min(static_cast(kMaxBlocks), (total + kTPB - 1) / kTPB)); + ScaledCopyQkKernel<<>>(qk_fp32, output_qk, params.scale, total); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + } + + LaunchUnfusedSoftmax(stream, params, qk_fp32, attn_bias, softmax_T); CUDA_RETURN_IF_ERROR(cudaGetLastError()); ORT_RETURN_IF_ERROR((LaunchAttnVGemm(cublas, params, softmax_T, value, output))); @@ -409,18 +455,18 @@ common::Status LaunchGqaUnfusedAttention( } // Explicit template instantiations. -template common::Status LaunchGqaUnfusedAttention<__half>( +template common::Status LaunchUnfusedAttention<__half>( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const __half*, const __half*, const __half*, - const __half*, __half*, void*); -template common::Status LaunchGqaUnfusedAttention<__nv_bfloat16>( + const UnfusedAttentionParams&, const __half*, const __half*, const __half*, + const __half*, __half*, void*, __half*); +template common::Status LaunchUnfusedAttention<__nv_bfloat16>( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const __nv_bfloat16*, const __nv_bfloat16*, - const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, void*); -template common::Status LaunchGqaUnfusedAttention( + const UnfusedAttentionParams&, const __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, void*, __nv_bfloat16*); +template common::Status LaunchUnfusedAttention( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const float*, const float*, const float*, - const float*, float*, void*); + const UnfusedAttentionParams&, const float*, const float*, const float*, + const float*, float*, void*, float*); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.h similarity index 77% rename from onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h rename to onnxruntime/contrib_ops/cuda/bert/unfused_attention.h index 84d645cd2b349..8fb3a18ac7570 100644 --- a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.h @@ -13,7 +13,7 @@ namespace contrib { namespace cuda { // ============================================================================ -// GQA Unfused Attention (CUDA fallback for large head_size / fp16 overflow) +// Unified Unfused Attention (CUDA fallback for large head_size / fp16 overflow) // ============================================================================ // // Purpose: @@ -38,18 +38,20 @@ namespace cuda { // - scale is applied to raw QK (before softcap / bias). // - softcap (> 0) is applied after scale: x = softcap * tanh(x / softcap). // - attn_bias (if non-null) is added after softcap (additive mask). -// - causal: k > (past + q) is -inf where past = total_kv - S_q. +// - causal: k > (past_kv_length + q) is -inf. +// When past_kv_length=0 (no past), gives upper-left alignment: q_i attends to kv[0..i]. +// When past_kv_length=total_kv-S_q (decode with past), gives lower-right alignment. // - local_window_size (>= 0): k < (past + q) - local_window_size is -inf. // local_window_size == -1 disables the sliding-window mask. // // The new kernel is suitable only as a fallback when Flash / MEA are ineligible -// (head_size > 256, past_key present with mask, GQA with MHA-only unfused, etc). +// (head_size > 256, past_key present with mask, etc). // The QK GEMM runs with CUBLAS_COMPUTE_32F and writes a FP32 scratch to avoid // fp16 overflow. // // ============================================================================ -struct GqaUnfusedAttentionParams { +struct UnfusedAttentionParams { int batch_size = 0; int num_heads = 0; // N_q int kv_num_heads = 0; // N_kv (num_heads % kv_num_heads == 0) @@ -68,6 +70,7 @@ struct GqaUnfusedAttentionParams { bool is_causal = false; int local_window_size = -1; // -1 disables sliding window + int past_kv_length = 0; // number of past KV positions (for causal alignment) float scale = 1.0f; float softcap = 0.0f; // 0 disables @@ -77,27 +80,30 @@ struct GqaUnfusedAttentionParams { }; // Returns required scratch size in bytes. Caller must allocate -// GetGqaUnfusedAttentionWorkspaceSize(...) bytes and pass as workspace. -size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, - int num_heads, - int q_sequence_length, - int total_kv_length); +// GetUnfusedAttentionWorkspaceSize(...) bytes and pass as workspace. +size_t GetUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length); // Compute: Y = softmax(scale * Q * K^T [softcap, causal, window, bias, seqlens_k]) * V. // All pointers are on device. Q/K/V/output are in type T (fp16/bf16/float). // attn_bias (if present) is in type T. +// output_qk (optional): when non-null, writes scale * Q @ K^T (FP32→T) before softcap/mask/softmax. +// Shape: [B, N_q, S_q, total_kv]. Caller allocates. template -common::Status LaunchGqaUnfusedAttention( +common::Status LaunchUnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t cublas, cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, const T* value, const T* attn_bias, T* output, - void* workspace); + void* workspace, + T* output_qk = nullptr); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 755bd0c60452f..95a5c7c17bc3a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -726,7 +726,7 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) parameters.qkv_format_ = Q_K_V_BSNH; // Check if we can use flash attention - if (CanApplyFlashAttention(nullptr, parameters, context)) { + if (CanApplyFlashAttention(parameters, context)) { // FlashAttention supports Q_K_V_BSNH format directly return ApplyFlashAttention(&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr, nullptr, nullptr, nullptr, parameters, context, nullptr); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 58c7376895661..c288a82994e98 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -574,11 +574,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return Status::OK(); } -bool CanApplyFlashAttention(const Tensor* bias, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && - bias == nullptr && context.HasFeature(wgpu::FeatureName::Subgroups) && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index fc2843f6ea908..980ddc3a5373b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -191,8 +191,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr); -bool CanApplyFlashAttention(const Tensor* bias, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); // Split packed QKV with Q/K rotary embedding and copy KV cache fusion Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index fd72f751ee810..cdf88c2f225e8 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -257,7 +257,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(nullptr, temp_params, context); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context); } if (parameters.is_packed_qkv_ && do_rotary_) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index ed43e9b3653b0..2890afae02ab9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -104,7 +104,40 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* output_qk = context.Output(3, output_qk_shape); if (output_qk == nullptr && // Flash attention does not output QK scores - CanApplyFlashAttention(bias, parameters, context)) { + CanApplyFlashAttention(parameters, context)) { + if (bias != nullptr) { + // Apply bias and transpose Q from BSD to BNSH before FlashAttention + TensorShapeVector q_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + Tensor Q = context.CreateGPUTensor(query->DataType(), TensorShape(q_dims)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q)); + + WebgpuAttentionParameters params_bnsh(parameters); + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { + // Cross-attention: K/V are already BNSH, only Q needs bias+transpose + params_bnsh.qkv_format_ = Q_K_V_BNSH; + return ApplyFlashAttention(&Q, key, value, attention_bias, output, past_key, present_key, past_value, + present_value, params_bnsh, context); + } + + // Self-attention: K/V also need bias+transpose + TensorShapeVector k_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + Tensor K = context.CreateGPUTensor(key->DataType(), TensorShape(k_dims)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, bias, parameters.hidden_size_, &K)); + + TensorShapeVector v_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + Tensor V = context.CreateGPUTensor(value->DataType(), TensorShape(v_dims)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V)); + + params_bnsh.qkv_format_ = Q_K_V_BNSH; + return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, past_key, present_key, past_value, + present_value, params_bnsh, context); + } return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index cdc0f1ded3e45..a14bf26e7c438 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -293,13 +293,31 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te const bool has_weight_idx = weight_index > 0 || has_weight_idx_indirect; SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; mul_program.SetWorkgroupSize(work_group_size); - mul_program.SetDispatchGroupSize( - (N + tile_size_b - 1) / tile_size_b, - (M + tile_size_a - 1) / tile_size_a, 1); + uint32_t dispatch_x = (N + tile_size_b - 1) / tile_size_b; + uint32_t num_m_tiles = (M + tile_size_a - 1) / tile_size_a; + uint32_t dispatch_y = num_m_tiles; + // For large M on Intel Xe, cap dispatch_y so each workgroup processes multiple + // M-tiles sequentially, reducing scheduling overhead. + if (M > 2048 && context.AdapterInfo().vendor == std::string_view{"intel"}) { + // Each XeCore has 4 XVE x 8 SIMD-32 hardware threads = 32 subgroups. + uint32_t hw_subgroups = 0; + if (context.AdapterInfo().architecture == std::string_view{"xe-3lpg"}) { + hw_subgroups = 384; // 12 XeCore x 32 + } else if (context.AdapterInfo().architecture == std::string_view{"xe-2lpg"}) { + hw_subgroups = 256; // 8 XeCore x 32 + } + if (hw_subgroups > 0) { + constexpr uint32_t kOccupancyFactor = 16; // empirically tuned on Xe2/Xe3 devices + uint32_t target_wgs = hw_subgroups * kOccupancyFactor / (work_group_size / 32); + dispatch_y = std::min(dispatch_y, (target_wgs + dispatch_x - 1) / dispatch_x); + } + } + uint32_t m_tiles_per_wg = (num_m_tiles + dispatch_y - 1) / dispatch_y; + mul_program.SetDispatchGroupSize(dispatch_x, dispatch_y, 1); mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kU32Components : 2 * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({{M}, {N}, {K}, {zero_blocks_per_col}, {weight_index}}) + .AddUniformVariables({{M}, {N}, {K}, {zero_blocks_per_col}, {weight_index}, {m_tiles_per_wg}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}) .CacheHint(nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect); if (has_zero_points) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index 810bda950b169..f4b0d10262de5 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -32,7 +32,8 @@ class SubgroupMatrixMatMulNBitsProgram final : public Program; - var sg_mat_c1: subgroup_matrix_result; - var sg_mat_c2: subgroup_matrix_result; - var sg_mat_c3: subgroup_matrix_result; - for (var k_idx: u32 = 0; k_idx < uniforms.K; k_idx += kTileK) { - // Load Phase - dequant_b_to_tile(global_base_b, k_idx, local_idx / 4, local_idx % 4); - workgroupBarrier(); - - for (var sg_mat_k_idx: u32 = 0; sg_mat_k_idx < kTileK; sg_mat_k_idx += kSgMatK) - { - // Load A from global memory (prepacked layout). - // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride - var sg_mat_a0: subgroup_matrix_left = - subgroupMatrixLoad>( - &input_a, sg_mat_offset_a, false, kSgMatK); - sg_mat_offset_a += kSgMatSizeLeft; - - // Load B from shared local memory. - // tile_b [kTileN, kTileK] is stored as column major. - var sg_mat_b0: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx, true, kTileK); - var sg_mat_b1: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + kSgMatStrideN, true, kTileK); - var sg_mat_b2: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + 2 * kSgMatStrideN, true, kTileK); - var sg_mat_b3: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + 3 * kSgMatStrideN, true, kTileK); - - // Compute Phase - // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate - sg_mat_c0 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b0, sg_mat_c0); - sg_mat_c1 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b1, sg_mat_c1); - sg_mat_c2 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b2, sg_mat_c2); - sg_mat_c3 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b3, sg_mat_c3); + let num_tiles_m = (uniforms.M + kTileM - 1) / kTileM; + + // Zero-initialized accumulator template (used to reset per M-tile iteration). + var sg_mat_zero: subgroup_matrix_result; + + // Sequential M-loop: each workgroup processes a contiguous block of M-tiles. + let m_start = workgroup_id.y * uniforms.m_tiles_per_wg; + let m_end = min(m_start + uniforms.m_tiles_per_wg, num_tiles_m); + for (var m_tile: u32 = m_start; m_tile < m_end; m_tile++) { + let global_base_a = m_tile * kTileM; + let sg_mat_idx = (m_tile * kSgMatCountM + sg_idx) * sg_mat_count_k; + + var sg_mat_offset_a = sg_mat_idx * kSgMatSizeLeft; + + var sg_mat_c0 = sg_mat_zero; + var sg_mat_c1 = sg_mat_zero; + var sg_mat_c2 = sg_mat_zero; + var sg_mat_c3 = sg_mat_zero; + for (var k_idx: u32 = 0; k_idx < uniforms.K; k_idx += kTileK) { + // Load Phase + dequant_b_to_tile(global_base_b, k_idx, local_idx / 4, local_idx % 4); + workgroupBarrier(); + + for (var sg_mat_k_idx: u32 = 0; sg_mat_k_idx < kTileK; sg_mat_k_idx += kSgMatK) + { + // Load A from global memory (prepacked layout). + // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride + var sg_mat_a0: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a, false, kSgMatK); + sg_mat_offset_a += kSgMatSizeLeft; + + // Load B from shared local memory. + // tile_b [kTileN, kTileK] is stored as column major. + var sg_mat_b0: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx, true, kTileK); + var sg_mat_b1: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + kSgMatStrideN, true, kTileK); + var sg_mat_b2: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + 2 * kSgMatStrideN, true, kTileK); + var sg_mat_b3: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + 3 * kSgMatStrideN, true, kTileK); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + sg_mat_c0 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b0, sg_mat_c0); + sg_mat_c1 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b1, sg_mat_c1); + sg_mat_c2 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b2, sg_mat_c2); + sg_mat_c3 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b3, sg_mat_c3); + } + workgroupBarrier(); } - workgroupBarrier(); - } - // Write out + // Write out #if has_bias - // Store results to scratch workgroup memory, then add bias and write to output. - // scratch layout: [kTileM, kTileN] row-major - let scratch_m_base = sg_idx * kSgMatM; - subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); - workgroupBarrier(); - - // 256 threads write 64x64 = 4096 elements. Each thread handles 16 elements. - // Thread mapping: m = local_idx / 4, n_base = (local_idx % 4) * 16 - let out_m = local_idx / 4; - let out_n_base = (local_idx % 4) * 16; - let global_m = global_base_a + out_m; - if (global_m < uniforms.M) { - let global_n_base = global_base_b + out_n_base; - let scratch_base = out_m * kTileN + out_n_base; - let out_base = global_m * uniforms.N + global_n_base; + // Store results to scratch workgroup memory, then add bias and write to output. + // scratch layout: [kTileM, kTileN] row-major + let scratch_m_base = sg_idx * kSgMatM; + subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); + workgroupBarrier(); + + // 256 threads write 64x64 = 4096 elements. Each thread handles 16 elements. + // Thread mapping: m = local_idx / 4, n_base = (local_idx % 4) * 16 + let out_m = local_idx / 4; + let out_n_base = (local_idx % 4) * 16; + let global_m = global_base_a + out_m; + if (global_m < uniforms.M) { + let global_n_base = global_base_b + out_n_base; + let scratch_base = out_m * kTileN + out_n_base; + let out_base = global_m * uniforms.N + global_n_base; #if has_weight_idx_indirect - let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; + let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; #elif has_weight_idx - let bias_offset = uniforms.weight_idx * uniforms.N; + let bias_offset = uniforms.weight_idx * uniforms.N; #else - const bias_offset: u32 = 0; + const bias_offset: u32 = 0; #endif - for (var i: u32 = 0; i < 16; i++) { - if (global_n_base + i < uniforms.N) { - let val = output_element_t(scratch[scratch_base + i]) - + bias[bias_offset + global_n_base + i]; - output.setByOffset(out_base + i, val); + for (var i: u32 = 0; i < 16; i++) { + if (global_n_base + i < uniforms.N) { + let val = output_element_t(scratch[scratch_base + i]) + + bias[bias_offset + global_n_base + i]; + output.setByOffset(out_base + i, val); + } } } - } #else - let sg_mat_offset_c = global_base_a * uniforms.N + global_base_b + sg_idx * kSgMatM * uniforms.N; - subgroupMatrixStore(&output, sg_mat_offset_c, sg_mat_c0, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + kSgMatN, sg_mat_c1, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + 2 * kSgMatN, sg_mat_c2, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + 3 * kSgMatN, sg_mat_c3, false, uniforms.N); + let sg_mat_offset_c = global_base_a * uniforms.N + global_base_b + sg_idx * kSgMatM * uniforms.N; + subgroupMatrixStore(&output, sg_mat_offset_c, sg_mat_c0, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + kSgMatN, sg_mat_c1, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + 2 * kSgMatN, sg_mat_c2, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + 3 * kSgMatN, sg_mat_c3, false, uniforms.N); #endif + } // end M-tile loop } // MAIN diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 3e928afcf6c80..360726d780a17 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1569,7 +1569,7 @@ Status GetExtDataFromTensorProto(const Env& env, if constexpr (endian::native != endian::little) { auto allocator = CPUAllocator::DefaultInstance(); - auto deleter = [&allocator](uint8_t* ptr) { allocator->Free(ptr); }; + auto deleter = [allocator](uint8_t* ptr) { allocator->Free(ptr); }; std::unique_ptr native_data{reinterpret_cast(allocator->Alloc(static_cast(raw_data_safe_len))), deleter}; size_t element_size = onnxruntime::utils::GetElementSizeOfTensor(static_cast(tensor_proto.data_type())); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 1346b976461ce..4487ccf62c0a2 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4177,12 +4177,43 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType(); TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer); - auto tensor = Tensor(type, tensor_shape, user_provided_tensor_buffer, - OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); - constexpr const bool use_tensor_buffer_false = false; - auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false); - **existing_entry = std::move(new_tensor_proto); + // Convert data from little endian before assigning it to tensor. + // It would have been better to byteswap it right after loading from file, + // but at that moment information about tensor element size was not available. + if constexpr (endian::native != endian::little) { + size_t element_size = onnxruntime::utils::GetElementSizeOfTensor( + static_cast(old_initializer.data_type())); + + // If element size is unknown, set it to 1 to disable byteswapping + if (element_size < 1) element_size = 1; + + auto allocator = CPUAllocator::DefaultInstance(); + + auto deleter = [allocator](uint8_t* ptr) { allocator->Free(ptr); }; + std::unique_ptr native_data{ + reinterpret_cast(allocator->Alloc(tensor_byte_size)), deleter}; + + auto src_span = gsl::make_span( + reinterpret_cast(user_provided_tensor_buffer), tensor_byte_size); + auto dst_span = gsl::make_span( + reinterpret_cast(native_data.get()), tensor_byte_size); + + ORT_RETURN_IF_ERROR(onnxruntime::utils::ReadLittleEndian(element_size, src_span, dst_span)); + + auto tensor = Tensor{type, tensor_shape, native_data.release(), allocator}; + + constexpr const bool use_tensor_buffer_false = false; + auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false); + **existing_entry = std::move(new_tensor_proto); + } else { + auto tensor = Tensor(type, tensor_shape, user_provided_tensor_buffer, + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); + + constexpr const bool use_tensor_buffer_false = false; + auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false); + **existing_entry = std::move(new_tensor_proto); + } } } @@ -6757,12 +6788,12 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati } }; - auto add_initializers = [this](const std::unordered_map>& initializers, + auto add_initializers = [this](const std::unordered_map& initializers, bool is_external) { for (auto& name_and_ortvalue : initializers) { // convert from OrtValue to TensorProto const std::string& name = name_and_ortvalue.first; - OrtValue& v = *name_and_ortvalue.second; + const OrtValue& v = name_and_ortvalue.second; ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); const Tensor& t = v.Get(); @@ -6783,7 +6814,7 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati offset, t.SizeInBytes(), tensor_proto); // add OrtValue to ortvalue_initializers_ to keep it alive and to store the deleter if provided. - ortvalue_initializers_.emplace(name, std::move(v)); + ortvalue_initializers_.emplace(name, v); } else { onnxruntime::utils::SetRawDataInTensorProto(tensor_proto, t.DataRaw(), t.SizeInBytes()); } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 2c0f6d6174303..6fbd687545ab2 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -81,6 +81,7 @@ struct ModelEditorValueInfo : public OrtValueInfo { "OrtModelEditorApi does not support querying if a OrtValueInfo is defined in an outer scope."); } + bool owned_ = false; // true after ownership transferred to a graph std::string name; std::unique_ptr type_info; }; @@ -154,6 +155,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the parent graph for OrtNode"); } + bool owned_ = false; // true after ownership transferred to a graph size_t id = 0; std::string operator_name; std::string domain_name; @@ -235,8 +237,9 @@ struct ModelEditorGraph : public OrtGraph { onnxruntime::InlinedVector> inputs; onnxruntime::InlinedVector> outputs; - std::unordered_map> initializers; - std::unordered_map> external_initializers; + std::unordered_map initializers; + std::unordered_map external_initializers; + bool owned_ = false; // true after ownership transferred to a model std::vector> nodes; std::string name = "ModelEditorGraph"; std::filesystem::path model_path; diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 167952356ff58..f88ce56fe36fa 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/reshape_fusion.h" @@ -486,6 +488,16 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { return false; } + // The fused shape is taken verbatim from the inferred output shape of the last reshape + // (we ensured tensor_shape.Size() != -1 above, so dims are concrete). If any dim is + // literally 0, fusing into a single Reshape is unsafe: ONNX Reshape with the default + // allowzero=0 would reinterpret the 0 as "copy from input", producing the wrong shape. + // Setting allowzero=1 would fix it but requires opset >= 14, which we cannot assume + // here (this transformer accepts Reshape opset 5+). Bail out conservatively. + if (std::any_of(shape_value.begin(), shape_value.end(), [](int64_t d) { return d == 0; })) { + return false; + } + const std::string& name = contiguous_reshapes[0].get().Name(); ONNX_NAMESPACE::TensorProto shape_initializer_proto; shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape")); diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 4ddb5c7e78037..935fb3172cc14 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -1334,6 +1334,19 @@ BitShift::BitShift(const OpKernelInfo& info) : OpKernel(info) { ORT_THROW("Invalid direction value of '", direction, "'. Valid values are 'LEFT' or 'RIGHT'."); } +// Shifting by >= the bit width of an unsigned type is undefined behavior in C++. +// On x86, 64-bit shifts mask the shift amount to 6 bits, so shift by 64 acts like shift by 0. +// Guard against this by returning 0 when the shift amount >= the bit width. +template +inline T SafeShiftLeft(T value, T shift) { + return shift >= sizeof(T) * 8 ? T{0} : value << shift; +} + +template +inline T SafeShiftRight(T value, T shift) { + return shift >= sizeof(T) * 8 ? T{0} : value >> shift; +} + template Status BitShift::Compute(OpKernelContext* context) const { ProcessBroadcastSpanFuncs funcs{ @@ -1345,11 +1358,11 @@ Status BitShift::Compute(OpKernelContext* context) const { ptrdiff_t i = 0; if (shift_left) { for (const auto& input : input1.array()) { - output[i++] = input0 << input; + output[i++] = SafeShiftLeft(input0, input); } } else { for (const auto& input : input1.array()) { - output[i++] = input0 >> input; + output[i++] = SafeShiftRight(input0, input); } } }, @@ -1361,11 +1374,11 @@ Status BitShift::Compute(OpKernelContext* context) const { ptrdiff_t i = 0; if (shift_left) { for (const auto& input : input0.array()) { - output[i++] = input << input1; + output[i++] = SafeShiftLeft(input, input1); } } else { for (const auto& input : input0.array()) { - output[i++] = input >> input1; + output[i++] = SafeShiftRight(input, input1); } } }, @@ -1380,11 +1393,11 @@ Status BitShift::Compute(OpKernelContext* context) const { auto cur_out = output.begin(), end_out = output.end(); if (shift_left) { for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) { - *cur_out = *cur0 << *cur1; + *cur_out = SafeShiftLeft(*cur0, *cur1); } } else { for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) { - *cur_out = *cur0 >> *cur1; + *cur_out = SafeShiftRight(*cur0, *cur1); } } diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index ded4813276b1d..e10e896a62d18 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -462,6 +462,12 @@ class UpsampleBase { }; case ROUND_PREFER_CEIL: return [](float x_original, bool) { + // for half way cases prefer ceil + // std::round rounds away from zero which is correct for positive .5 values + // but for negative .5 values (e.g., -0.5) it rounds to -1 instead of 0 (ceil) + if (x_original == static_cast(x_original) - 0.5f) { + return static_cast(std::ceil(x_original)); + } return static_cast(std::round(x_original)); }; case FLOOR: diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8b139c2d5514f..bc0a250a90493 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -840,6 +840,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); @@ -1071,9 +1073,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, float, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, double, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, MLFloat16, Round); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); @@ -1190,14 +1192,14 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, E class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Min); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, bool, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, float, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, double, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, MLFloat16, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Greater); @@ -1403,30 +1405,30 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Div); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Identity); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, RNN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, RNN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, RNN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, GRU); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, GRU); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, GRU); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Identity); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, RNN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, RNN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, RNN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1486,11 +1488,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, // Opset 18 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, Split); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Pad); @@ -1568,6 +1572,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, bool, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Equal); // Opset 20 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); @@ -1577,6 +1589,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, 21, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMin); + // Opset 21. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, double, Cast); @@ -1672,34 +1699,40 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RNN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RNN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RNN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Round); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, bool, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, float, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, double, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, MLFloat16, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, BFloat16, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, bool, Cast); #if !defined(DISABLE_FLOAT8_TYPES) -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E4M3FN, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E5M2, Cast); #endif #if !defined(DISABLE_FLOAT4_TYPES) -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float4E2M1x2, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float4E2M1x2, Cast); #endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint8_t, float, DequantizeLinear); @@ -1738,7 +1771,7 @@ class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E4M3FN, MLFloat16, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E5M2, MLFloat16, QuantizeLinear); #endif -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Reshape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization); @@ -1769,6 +1802,26 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, T class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Unsqueeze); // Opset 25. +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, float, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, double, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, MLFloat16, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, BFloat16, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, bool, Cast); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E4M3FN, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E5M2, Cast); +#endif +#if !defined(DISABLE_FLOAT4_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float4E2M1x2, Cast); +#endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, ConstantOfShape); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint8_t, float, DequantizeLinear); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int8_t, float, DequantizeLinear); @@ -1806,6 +1859,7 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E4M3FN, MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E5M2, MLFloat16, QuantizeLinear); #endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Size); @@ -2044,6 +2098,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2110,15 +2166,15 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2279,9 +2335,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2393,14 +2449,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2607,9 +2663,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2617,6 +2670,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2772,6 +2830,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, @@ -2781,6 +2847,21 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // Opset 21 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2876,34 +2957,40 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif #if !defined(DISABLE_FLOAT4_TYPES) - BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2942,7 +3029,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2973,6 +3060,26 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 25 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -3010,6 +3117,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 228729745b65b..15f9dcbf8e7f2 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1,17 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cpu/llm/attention.h" #include "core/providers/cpu/llm/attention_helper.h" #include "core/providers/cuda/llm/attention.h" #include "core/providers/cuda/llm/attention_mask_impl.h" -#include "contrib_ops/cuda/bert/attention_data.h" +// attention_impl.h provides Transpose_BNSH_to_BSNH / Transpose_BSNH_to_BNSH used +// by the transpose helpers. #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_kv_cache.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "core/providers/cuda/cuda_type_conversion.h" @@ -155,7 +158,12 @@ Status Attention::ConvertAttnMaskToBias( int64_t num_elements = attn_mask->Shape().Size(); converted_mask_buffer = GetScratchBuffer( num_elements * sizeof(NativeCudaT), GetComputeStream(context)); - float mask_filter_value = static_cast(std::numeric_limits::lowest()); + // CUTLASS online softmax multiplies attention scores by kLog2e (≈1.4427). + // For float/bf16, |lowest() × kLog2e| > FLT_MAX, overflowing to -inf and + // causing s_prime=0 → NaN for fully-masked batches. Cap to prevent this. + // See kCutlassSafeMaskFilterValue in memory_efficient_attention.h for details. + float mask_filter_value = std::max(static_cast(std::numeric_limits::lowest()), + ::onnxruntime::contrib::cuda::kCutlassSafeMaskFilterValue); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), reinterpret_cast(converted_mask_buffer.get()), @@ -189,7 +197,7 @@ Status Attention::ConvertAttnMaskToBias( // Path 1: nonpad_kv_seqlen (opset 24 external cache) -> mha_fwd_kvcache // Path 2: past_key + past_value (internal cache decode) -> mha_fwd_kvcache // - No mask support (attn_mask rejected at eligibility) -// - 4D BNSH: transposes Q/K/V to BSNH before kernel +// - 4D BNSH: transposes Q to BSNH; new K/V to BSNH for concat (cache stays BNSH) // Path 3: no past, no mask (prompt) -> mha_fwd // Eligibility: fp16/bf16, head_size==v_head_size, no output_qk, attn_mask==nullptr // Note: softcap is passed to the Flash kernel natively. softmax_precision is @@ -334,10 +342,10 @@ Status Attention::RunFlashAttention( ORT_ENFORCE(present_key != nullptr && present_value != nullptr, "present_key/value outputs are required when past_key is provided."); - // TODO(titaiwang): Consolidate preprocessing (RoPE, mask conversion, KV cache concat) into a + // TODO(titaiwang): Consolidate preprocessing (transpose, KV cache concat) into a // single fused kernel like GQA's LaunchUnpackRoPEAppend. Current decode path uses 4-6 kernel - // launches; a fused approach would reduce to ~2, saving ~21μs launch overhead and ~256KB - // intermediate buffer traffic per decode step. + // launches; a fused approach would reduce to ~2, saving launch overhead and intermediate + // buffer traffic per decode step. // Concat past + new KV directly into present buffers using a single fused kernel. // This replaces the old pattern of memset + strided cudaMemcpy2DAsync + Flash's @@ -476,7 +484,7 @@ Status Attention::RunFlashAttention( cuda_stream, device_prop.maxThreadsPerBlock)); } - // --- Populate present_key/value (BNSH) from K/V (BSNH) --- + // --- Populate present_key/value (BNSH) from K/V (BSNH or BNSH) --- // Skip for decode path where mha_fwd_kvcache already populated present buffers. if (!present_kv_already_populated) { if (present_key != nullptr && is_bsnh) { @@ -528,13 +536,15 @@ Status Attention::RunFlashAttention( // ============================================================================ // // Memory Efficient Attention (cutlass FMHA) dispatch paths: -// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode -// Path 2: no past, with mask (prompt) -> standard MEA with additive bias -// Path 3: no past, no mask (prompt) -> standard MEA +// Path 1: Decode with past KV cache -> LaunchConcatNewToPastKV then standard MEA +// Path 2: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode +// Path 3: Prompt with mask -> standard MEA with additive bias +// Path 4: Prompt without mask -> standard MEA // Eligibility: see has_memory_efficient_attention() (SM50+/53+/80+ by dtype, -// head_size <= 1024), plus: no output_qk, no past_key (decode excluded), -// bias stride alignment. -// Note: softcap is forwarded to the MEA kernel via p.softcap. softmax_precision +// head_size <= 1024, head_size divisible by 8), plus: no output_qk, bias stride alignment. +// Note: softcap is forwarded to the MEA kernel via p.softcap. CUTLASS applies +// softcap before bias (fused in kernel tiles), matching ONNX spec ordering +// (onnx/onnx#7865): QK → softcap → mask/bias → softmax. softmax_precision // is inherently satisfied (cutlass FMHA accumulates softmax in FP32). // template @@ -546,8 +556,6 @@ Status Attention::RunMemoryEfficientAttention( Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const { #if USE_MEMORY_EFFICIENT_ATTENTION - ORT_UNUSED_PARAMETER(past_key); - ORT_UNUSED_PARAMETER(past_value); auto& device_prop = GetDeviceProp(); auto cuda_stream = Stream(context); const bool is_bsnh = parameters.transpose_output; @@ -582,6 +590,120 @@ Status Attention::RunMemoryEfficientAttention( out_data = out_bsnh_buffer.get(); } + bool present_kv_already_populated = false; + // Track the effective layout of k_data/v_data. Initially matches input layout, + // but changes to BNSH (false) after decode concat into present buffers. + bool kv_is_bsnh = is_bsnh; + + // Scratch buffers for decode concat output when present_key/value are optional. + // Declared at function scope so they outlive the decode block (k_data/v_data may point here). + IAllocatorUniquePtr present_k_scratch; + IAllocatorUniquePtr present_v_scratch; + + // --- Decode path: concat past + new K/V → present buffers (BNSH) --- + // nonpad_kv_seqlen and past_key are mutually exclusive (enforced at validation), + // so the decode path only needs the internal-cache (past_key/present_key) flow. + if (past_key != nullptr) { + ORT_RETURN_IF_NOT(past_value != nullptr, "past_key requires past_value."); + ORT_RETURN_IF_NOT(nonpad_kv_seqlen == nullptr, + "nonpad_kv_seqlen and past_key are mutually exclusive (internal vs external cache)."); + // This mirrors the eligibility check in ComputeInternal — must stay in sync. + ORT_RETURN_IF_NOT(parameters.head_size == parameters.v_head_size, + "MEA decode (past_key) requires head_size == v_head_size for LaunchConcatNewToPastKV."); + + using NativeCudaT = typename OrtToCudaType::type; + + // Allocate scratch buffers for concat output when present_key/value are not requested. + // The concat kernel needs a destination buffer regardless of whether the caller wants present outputs. + T* present_k_data = nullptr; + T* present_v_data = nullptr; + + SafeInt present_k_bytes = SafeInt(parameters.batch_size) * parameters.kv_num_heads * + parameters.total_sequence_length * parameters.head_size * sizeof(T); + SafeInt present_v_bytes = SafeInt(parameters.batch_size) * parameters.kv_num_heads * + parameters.total_sequence_length * parameters.v_head_size * sizeof(T); + + if (present_key != nullptr) { + present_k_data = present_key->MutableData(); + } else { + present_k_scratch = GetScratchBuffer(present_k_bytes, GetComputeStream(context)); + present_k_data = static_cast(present_k_scratch.get()); + } + if (present_value != nullptr) { + present_v_data = present_value->MutableData(); + } else { + present_v_scratch = GetScratchBuffer(present_v_bytes, GetComputeStream(context)); + present_v_data = static_cast(present_v_scratch.get()); + } + + // Step 1: Uniform past sequence lengths for the concat kernel. + // ONNX past_key has shape [B, H, past_seq, head_size] — all batches share + // the same past_seq dimension. Bool masks do NOT change where tokens are stored; + // they change which tokens are attended to (via additive bias, handled below). + auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, + parameters.batch_size, cuda_stream, + device_prop.maxThreadsPerBlock)); + + // Step 2: Transpose K/V to BSNH if input is 4D BNSH (concat kernel reads new as BSNH). + const T* k_new_bsnh = K->Data(); + const T* v_new_bsnh = V->Data(); + IAllocatorUniquePtr k_bsnh_buffer; + IAllocatorUniquePtr v_bsnh_buffer; + if (!is_bsnh) { + size_t k_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.head_size; + size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.v_head_size; + k_bsnh_buffer = GetScratchBuffer(k_bytes, GetComputeStream(context)); + v_bsnh_buffer = GetScratchBuffer(v_bytes, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), k_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), v_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + k_new_bsnh = static_cast(k_bsnh_buffer.get()); + v_new_bsnh = static_cast(v_bsnh_buffer.get()); + } + + // Step 3: Fused concat: past_key + new_key → present_key (and same for values). + // One kernel copies past data from [0, past_seq) and new data from BSNH layout + // into present buffer at [past_seq, past_seq + kv_seq), all in BNSH. + // No memset needed: uniform past_seq_lens means every position in the present + // buffer is written by the concat kernel. Padding positions in past_key are copied + // as-is; the attention mask (additive bias) handles correctness at the attention level. + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + parameters.batch_size, + parameters.kv_num_heads, + parameters.head_size, + parameters.kv_sequence_length, + parameters.past_sequence_length, + parameters.total_sequence_length, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), + /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_k_data), + reinterpret_cast(present_v_data), + cuda_stream, + device_prop.maxThreadsPerBlock, + /*past_only=*/false)); + + // Point MEA's K/V inputs at the concatenated buffers (BNSH). + k_data = present_k_data; + v_data = present_v_data; + kv_is_bsnh = false; + present_kv_already_populated = true; + } + // GQA head expansion: MEA requires matching num_heads for Q/K/V. // When q_num_heads != kv_num_heads, expand K/V via LaunchUngroup. const bool is_gqa = parameters.q_num_heads != parameters.kv_num_heads; @@ -622,7 +744,7 @@ Status Attention::RunMemoryEfficientAttention( reinterpret_cast(v_data), parameters.total_sequence_length, parameters.total_sequence_length, - is_bsnh, + kv_is_bsnh, cuda_stream, device_prop.maxThreadsPerBlock)); @@ -631,8 +753,8 @@ Status Attention::RunMemoryEfficientAttention( } } - // Note: MEA with past_key/value is handled by the unfused fallback. - // The cascade in ComputeInternal ensures past_key == nullptr when we reach here. + // Note: When past_key is present (decode), k_data/v_data already point to present + // buffers (BNSH) after LaunchConcatNewToPastKV above, so MEA sees the full cache. // Handle attention mask → attention_bias conversion IAllocatorUniquePtr converted_mask_buffer; @@ -642,7 +764,8 @@ Status Attention::RunMemoryEfficientAttention( if (nonpad_kv_seqlen != nullptr) { // Convert nonpad_kv_seqlen to seqlens_k for custom right padding. - // MEA expects actual token count (not count-1), so use FlashSeqlensK variant. + // MEA expects seqlens_k as actual token count, so use FlashSeqlensK variant + // (which converts int64→int32 without subtracting 1). auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), @@ -665,7 +788,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -674,6 +797,15 @@ Status Attention::RunMemoryEfficientAttention( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_causal; + // ONNX spec: is_causal means upper-left alignment in the full attention matrix. + // When past_sequence_length == 0 and S_q != S_kv (cross-attention without KV cache), + // queries start at absolute position 0, so causal mask is upper-left. + // When past_sequence_length > 0 (decode with KV cache), queries start at position + // past_seq, so causal mask is effectively lower-right on the [S_q x total_kv] sub-matrix. + // NOTE: For external KV cache (TensorScatter), nonpad_kv_seqlen provides per-batch + // actual lengths and seqlens_k handles the masking — the causal_from_top_left flag + // is only consulted when params.causal is true, so it's correct here. + p.causal_from_top_left = (parameters.past_sequence_length == 0); p.scale = parameters.scale; p.softcap = parameters.softcap; p.seqlen_k_ptr = seqlens_k_buffer.get(); @@ -700,8 +832,12 @@ Status Attention::RunMemoryEfficientAttention( onnxruntime::contrib::cuda::run_memory_efficient_attention(p); // On the MEA (CUTLASS) path (used for both MHA and GQA when nonpad_kv_seqlen is provided), - // zero out output for fully-masked batches to produce zeros (matching Flash behavior). + // zero out output for fully-masked batches to prevent NaN. // CUTLASS epilogue computes 1/s_prime where s_prime=0 for seqlens_k=0, producing NaN. + // TODO(titaiwang): ZeroOutputForFullyMaskedBatches outputs zeros for fully-masked + // batches (seqlens_k=0), which diverges from CPU/Unfused behavior (uniform mean of V). + // For cross-EP consistency, replace with LaunchMeanOfVForFullyMaskedBatches that + // computes mean(V[b,n,:,h]) for each masked batch. See issue #27516. { using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t elements_per_batch = static_cast(parameters.q_sequence_length) * @@ -716,9 +852,10 @@ Status Attention::RunMemoryEfficientAttention( } } // Standard MEA path: float attention bias, bool mask (converted to bias), or no mask. - // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value) - // which correctly handles all-false masks (uniform softmax weights) unlike the - // custom_right_padding seqlens approach which would produce NaN. + // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value). + // For fully-masked batches (all-false bool mask), ConvertAttnMaskToBias uses a capped + // mask_filter_value (-1e+30) that stays finite through CUTLASS's kLog2e multiplication, + // producing correct uniform softmax → mean(V) output. else { if (attn_mask != nullptr) { ORT_RETURN_IF_ERROR(ConvertAttnMaskToBias(context, attn_mask, cuda_stream, @@ -731,7 +868,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -740,6 +877,8 @@ Status Attention::RunMemoryEfficientAttention( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_causal; + // Causal alignment: same logic as above — upper-left when no past. + p.causal_from_top_left = (parameters.past_sequence_length == 0); p.scale = parameters.scale; p.softcap = parameters.softcap; p.broadcast_attn_bias_dim_0 = broadcast_bias_dim_0; @@ -773,30 +912,33 @@ Status Attention::RunMemoryEfficientAttention( cuda_stream, device_prop.maxThreadsPerBlock)); } - // Populate present_key/present_value (BNSH) if requested - if (present_key != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.head_size, - K->Data(), present_key->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_key != nullptr && !is_bsnh) { - // 4D BNSH prompt: K is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_key->MutableData(), K->Data(), - K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); - } - if (present_value != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.v_head_size, - V->Data(), present_value->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_value != nullptr && !is_bsnh) { - // 4D BNSH prompt: V is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_value->MutableData(), V->Data(), - V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + // Populate present_key/present_value (BNSH) if requested. + // Skip for decode path where LaunchConcatNewToPastKV already populated present buffers. + if (!present_kv_already_populated) { + if (present_key != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), present_key->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_key != nullptr && !is_bsnh) { + // 4D BNSH prompt: K is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_key->MutableData(), K->Data(), + K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + if (present_value != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), present_value->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_value != nullptr && !is_bsnh) { + // 4D BNSH prompt: V is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_value->MutableData(), V->Data(), + V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } } return Status::OK(); @@ -819,250 +961,30 @@ Status Attention::RunMemoryEfficientAttention( } // ============================================================================ -// RunUnfusedAttention: Delegates to MHA's QkvToContext (unfused GEMM+softmax+GEMM) -// ============================================================================ -// -// Unfused Attention dispatch paths: -// Universal fallback via MHA's QkvToContext. -// Path 1: nonpad_kv_seqlen only -> converts to attention_bias [B, q_seq, total_seq] -// Path 2: nonpad_kv_seqlen + attn_mask -> composes both into attention_bias [B, q_seq, total_seq] -// (nonpad bias + mask bias added element-wise with cyclic broadcasting) -// Path 3: all other cases -> passes mask/bias directly -// Supports: all dtypes (fp16/bf16/fp32), all mask types (bool/float/none), all head sizes -// Not supported: softcap (rejected at fallback), output_qk modes beyond kNone/kQK -// Limitation: MHA only (q_num_heads must equal kv_num_heads) -// -template -Status Attention::RunUnfusedAttention( - OpKernelContext* context, - const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, - const Tensor* nonpad_kv_seqlen, - Tensor* Y, Tensor* present_key, Tensor* present_value, - Tensor* output_qk, - const attention_helper::AttentionParameters& parameters) const { - using CudaT = typename ToCudaType::MappedType; - // OrtToCudaType maps BFloat16 → __nv_bfloat16 (native HW type), matching kernel instantiations. - using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; - auto& device_prop = GetDeviceProp(); - auto cuda_stream = Stream(context); - auto ort_stream = GetOrtStream(context); - - // Bridge to contrib::AttentionParameters for the MHA unfused path - onnxruntime::contrib::AttentionParameters contribop_parameters; - - if (!parameters.transpose_output) { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; - contribop_parameters.is_output_bnsh = true; - } else { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; - contribop_parameters.is_output_bnsh = false; - } - - contribop_parameters.batch_size = parameters.batch_size; - contribop_parameters.sequence_length = parameters.q_sequence_length; - contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; - contribop_parameters.past_sequence_length = parameters.past_sequence_length; - contribop_parameters.total_sequence_length = parameters.total_sequence_length; - contribop_parameters.max_sequence_length = parameters.total_sequence_length; - contribop_parameters.input_hidden_size = 0; - contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; - contribop_parameters.head_size = parameters.head_size; - contribop_parameters.v_head_size = parameters.v_head_size; - contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; - contribop_parameters.num_heads = parameters.q_num_heads; - contribop_parameters.rotary_dim = 0; - contribop_parameters.num_splits = 1; - contribop_parameters.beam_width = 1; - contribop_parameters.is_unidirectional = parameters.is_causal; - contribop_parameters.past_present_share_buffer = false; - contribop_parameters.is_packed_qkv = false; - contribop_parameters.do_rotary = false; - contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; - contribop_parameters.mask_filter_value = static_cast(std::numeric_limits::lowest()); - contribop_parameters.scale = parameters.scale; - contribop_parameters.use_tf32 = UseTF32(); - - // Determine broadcast flags for attention_bias - if (attn_mask != nullptr) { - size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); - auto attn_mask_dims = attn_mask->Shape().GetDims(); - if (attn_mask_dims_size == 2) { - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask_dims_size == 3) { - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[0] == 1; - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; - } - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = false; - } - - // Construct AttentionData - onnxruntime::contrib::cuda::AttentionData data; - data.query = reinterpret_cast(Q->Data()); - data.key = reinterpret_cast(K->Data()); - data.value = reinterpret_cast(V->Data()); - data.mask_index = nullptr; - data.mask_index_dims = gsl::span(); - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - data.output = reinterpret_cast(Y->MutableData()); - data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); - if (output_qk != nullptr) { - data.output_qk = reinterpret_cast(output_qk->MutableData()); - } - data.bias = nullptr; - - // Handle attention mask / nonpad_kv_seqlen → attention_bias - IAllocatorUniquePtr converted_mask_buffer; - IAllocatorUniquePtr mask_bias_buffer; // temp buffer for mask→bias when composing - if (nonpad_kv_seqlen != nullptr) { - // Convert nonpad_kv_seqlen to additive attention bias: [B, q_seq, total_seq] - int64_t bias_elements = static_cast(parameters.batch_size) * - parameters.q_sequence_length * - parameters.total_sequence_length; - converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias( - nonpad_kv_seqlen->Data(), - reinterpret_cast(converted_mask_buffer.get()), - parameters.batch_size, - parameters.q_sequence_length, - parameters.total_sequence_length, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - - // When attn_mask is also present, compose it into the nonpad bias additively. - // The nonpad bias is [B, q, t]; the mask is added with cyclic broadcasting - // (e.g. a 2D [q, t] mask repeats over the batch dimension). - // Only 2D masks and 4D masks with head_dim=1 are supported — per-head masks - // (3D [H,q,t] or 4D [B,H>1,q,t]) cannot be composed into a [B,q,t] buffer. - if (attn_mask != nullptr) { - const auto& mask_shape = attn_mask->Shape(); - int mask_dims = static_cast(mask_shape.NumDimensions()); - ORT_ENFORCE(mask_dims == 2 || (mask_dims == 4 && mask_shape[1] == 1), - "nonpad_kv_seqlen + attn_mask composition in unfused path only supports " - "2D masks [q, t] and 4D masks with head_dim=1 [B, 1, q, t]. " - "Got mask shape: ", - mask_shape); - - int64_t mask_elements = mask_shape.Size(); - const NativeCudaT* mask_bias_ptr = nullptr; - - if (attn_mask->IsDataType()) { - // Convert bool mask to additive bias in a temp buffer, then add in-place. - mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( - attn_mask->Data(), - reinterpret_cast(mask_bias_buffer.get()), - mask_elements, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - mask_bias_ptr = reinterpret_cast(mask_bias_buffer.get()); - } else { - // Float mask is already in additive bias format. - mask_bias_ptr = reinterpret_cast(attn_mask->Data()); - } - - // Add mask bias into nonpad bias with cyclic broadcasting. - // 2D mask [q, t]: mask_elements = q*t, repeats for each batch → correct. - // 4D mask [B, 1, q, t]: mask_elements = B*q*t = bias_elements → direct add. - ORT_RETURN_IF_ERROR(LaunchAddBiasInPlace( - reinterpret_cast(converted_mask_buffer.get()), - mask_bias_ptr, - bias_elements, - mask_elements, - cuda_stream, - device_prop.maxThreadsPerBlock)); - } - - data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); - // Composed bias is [B, q_seq, total_seq] → broadcasts over heads but not batch. - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask != nullptr) { - if (attn_mask->IsDataType()) { - int64_t num_elements = attn_mask->Shape().Size(); - converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( - attn_mask->Data(), - reinterpret_cast(converted_mask_buffer.get()), - num_elements, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); - } else { - data.attention_bias = reinterpret_cast(attn_mask->Data()); - } - } - - data.qkv_format = contribop_parameters.qkv_format; - data.use_flash_attention = false; - data.use_memory_efficient_attention = false; - data.fused_runner = nullptr; - data.fused_cross_attention_kernel = nullptr; - data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; - - // Allocate workspace - const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); - size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( - sizeof(T), - contribop_parameters.batch_size, - contribop_parameters.num_heads, - contribop_parameters.head_size, - contribop_parameters.v_head_size, - contribop_parameters.sequence_length, - contribop_parameters.kv_sequence_length, - contribop_parameters.total_sequence_length, - nullptr, false, false, false, false, false, - no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); - - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.workspace_bytes = workspace_bytes; - - cublasHandle_t cublas = GetCublasHandle(context); - cudnnHandle_t cudnn = GetCudnnHandle(context); - - // Note: unfused attention produces valid finite output (mean-of-V via uniform softmax) - // for fully-masked batches, so ZeroOutput is not needed here. Only MEA requires - // ZeroOutput to prevent NaN from the CUTLASS epilogue's 1/s_prime division. - return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data); -} - -// ============================================================================ -// RunGqaUnfusedAttention: GQA-capable unfused path + large-head fp16/bf16 fix +// RunUnfusedAttention: Unified unfused path for both MHA and GQA // ============================================================================ // -// Routes to LaunchGqaUnfusedAttention from contrib_ops/cuda/bert/gqa_unfused_attention.h. +// Routes to LaunchUnfusedAttention from contrib_ops/cuda/bert/unfused_attention.h. // // Handles: +// - MHA as a degenerate case (group_size=1, no head expansion needed). // - GQA natively (no K/V head replication; reshape-Q trick inside kernel). // - fp16/bf16 with large head_size via FP32 QK scratch (fixes issue #28195: // unfused attention producing NaN when head_dim > 256 at scale=1.0). // - Different Q/K sequence lengths, past_key+past_value, nonpad_kv_seqlen. // - attn_mask (bool/float, 2D/3D/4D), causal, softcap. // -// Not supported here (caller rejects upstream): -// - output_qk: only MHA unfused emits QK, so this path requires output_qk==nullptr. +// Not supported (returns NOT_IMPLEMENTED upstream): +// - qk_matmul_output_mode beyond kNone/kQK (kQKMask, kQKSoftCap, kQKSoftMax). // ============================================================================ template -Status Attention::RunGqaUnfusedAttention( +Status Attention::RunUnfusedAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const { using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; auto& device_prop = GetDeviceProp(); @@ -1108,9 +1030,6 @@ Status Attention::RunGqaUnfusedAttention( ORT_ENFORCE(past_value != nullptr, "past_key requires past_value."); ORT_ENFORCE(present_key != nullptr && present_value != nullptr, "present_key/value outputs are required when past_key is provided."); - // LaunchConcatNewToPastKV uses a single head_size for both K and V caches. - ORT_RETURN_IF(H != H_v, - "RunGqaUnfusedAttention: past_key with H != H_v not supported"); auto past_seqlens_buffer = GetScratchBuffer(B, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, B, @@ -1134,17 +1053,51 @@ Status Attention::RunGqaUnfusedAttention( v_new_bsnh = static_cast(v_bnsh_buffer.get()); } - ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( - B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, - /*is_bsnh=*/false, - past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, - reinterpret_cast(past_key->Data()), - reinterpret_cast(past_value->Data()), - reinterpret_cast(k_new_bsnh), - reinterpret_cast(v_new_bsnh), - reinterpret_cast(present_key->MutableData()), - reinterpret_cast(present_value->MutableData()), - cuda_stream, max_threads, /*past_only=*/false)); + if (H == H_v) { + // K and V have the same head_size -- single concat call handles both. + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_key->MutableData()), + reinterpret_cast(present_value->MutableData()), + cuda_stream, max_threads, /*past_only=*/false)); + } else { + // H != H_v: LaunchConcatNewToPastKV uses a single head_size for both K and V + // (grid Z=0 for K, Z=1 for V with the same block dims). We must call it + // twice with different head_size values -- once for K (head_size=H) and once + // for V (head_size=H_v). Each call duplicates K data into V params (or vice + // versa) so both Z indices write to the same buffer harmlessly. + // + // Trade-off: each call does 2× GPU work (both Z slices execute). This is + // acceptable because H!=H_v decode through MEA is rare, and modifying the + // shared kernel (contrib_ops/cuda/bert/attention_kv_cache.cu) to support + // nullptr outputs or K-only/V-only modes would risk breaking GQA callers. + auto* pk = reinterpret_cast(past_key->Data()); + auto* pv = reinterpret_cast(past_value->Data()); + auto* nk = reinterpret_cast(k_new_bsnh); + auto* nv = reinterpret_cast(v_new_bsnh); + auto* out_k = reinterpret_cast(present_key->MutableData()); + auto* out_v = reinterpret_cast(present_value->MutableData()); + // Concat K with head_size=H (V params duplicate K data -- harmless) + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + pk, pk, nk, nk, out_k, out_k, + cuda_stream, max_threads, /*past_only=*/false)); + // Concat V with head_size=H_v (K params duplicate V data -- harmless) + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H_v, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + pv, pv, nv, nv, out_v, out_v, + cuda_stream, max_threads, /*past_only=*/false)); + } k_cache = reinterpret_cast(present_key->MutableData()); v_cache = reinterpret_cast(present_value->MutableData()); present_already_populated = true; @@ -1214,12 +1167,12 @@ Status Attention::RunGqaUnfusedAttention( } // -------- Allocate kernel workspace ----------------------------------------- - const size_t ws_bytes = onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + const size_t ws_bytes = onnxruntime::contrib::cuda::GetUnfusedAttentionWorkspaceSize( B, N_q, S_q, total_kv); auto ws_buffer = GetScratchBuffer(ws_bytes, GetComputeStream(context)); // -------- Call the kernel --------------------------------------------------- - onnxruntime::contrib::cuda::GqaUnfusedAttentionParams p; + onnxruntime::contrib::cuda::UnfusedAttentionParams p; p.batch_size = B; p.num_heads = N_q; p.kv_num_heads = N_kv; @@ -1232,13 +1185,19 @@ Status Attention::RunGqaUnfusedAttention( p.broadcast_attn_bias_dim_1 = bcast1; p.is_causal = parameters.is_causal; p.local_window_size = -1; // ONNX Attention (opset 23/24) does not expose sliding window. + p.past_kv_length = parameters.past_sequence_length; p.scale = parameters.scale; p.softcap = parameters.softcap; p.seqlens_k = seqlens_k_ptr; - ORT_RETURN_IF_ERROR((onnxruntime::contrib::cuda::LaunchGqaUnfusedAttention( + NativeCudaT* output_qk_data = (output_qk != nullptr) + ? reinterpret_cast(output_qk->MutableData()) + : nullptr; + + ORT_RETURN_IF_ERROR((onnxruntime::contrib::cuda::LaunchUnfusedAttention( device_prop, GetCublasHandle(context), cuda_stream, - p, q_bnsh, k_cache, v_cache, attn_bias_data, out_bnsh, ws_buffer.get()))); + p, q_bnsh, k_cache, v_cache, attn_bias_data, out_bnsh, ws_buffer.get(), + output_qk_data))); // -------- Transpose output BNSH -> BSNH if input was 3D -------------------- if (is_bsnh && out_bnsh_buffer != nullptr) { @@ -1279,10 +1238,10 @@ Status Attention::RunGqaUnfusedAttention( // ============================================================================ // ComputeInternal: Dispatch to appropriate attention kernel // ============================================================================ -// MHA path (q_num_heads == kv_num_heads): uses direct kernel dispatch cascade -// flash → memory efficient → unfused -// GQA path (q_num_heads != kv_num_heads): uses flash (handles GQA natively), MEA -// (with head expansion via LaunchUngroup, fp16/bf16 only), or GQA unfused fallback. +// Dispatch cascade: Flash → MEA (Memory Efficient) → Unified Unfused Attention. +// The unified unfused kernel handles both MHA (num_heads == kv_num_heads) and +// GQA (num_heads != kv_num_heads) via a reshape-Q trick (no K/V head replication). +// MEA uses head expansion via LaunchUngroup (fp16/bf16 only) for GQA. // ============================================================================ template Status Attention::ComputeInternal(OpKernelContext* context) const { @@ -1331,12 +1290,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Flash: strictly requires BSNH — Q is transposed BNSH→BSNH before calling mha_fwd*. // K/V passed as BNSH to mha_fwd_kvcache (it handles both layouts). // MEA: accepts both BSNH and BNSH natively via is_kv_bsnh flag. Q transposed to BSNH. - // Unfused: accepts both via QkvToContext's qkv_format (Q_K_V_BSNH or Q_K_V_BNSH). + // Unfused: accepts both BSNH and BNSH (transposes if needed). // // nonpad_kv_seqlen + attn_mask routing: // Flash: cannot handle this combo (no bias param when seqlens_k is used) → excluded. // MEA: supports both (custom_right_padding for seqlens + additive attn_bias for mask). - // Unfused: nonpad → attention_bias; mask composed additively when both present. + // Unfused: nonpad → seqlens_k; mask → attention_bias; both handled independently in softmax kernel. #if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION const bool has_output_qk = (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone); #endif @@ -1347,6 +1306,39 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // softmax_precision=0 (default) is also fine since higher precision is always // acceptable per the ONNX spec. + // Flash Attention uses lower-right (bottom-right) causal alignment with no option for + // upper-left. The ONNX spec requires upper-left alignment when there is no past context: + // query[0] attends only to key[0]. The difference only manifests when S_q != S_kv + // (cross-attention shape) with no past. Skip Flash for this case; MEA handles it correctly + // via the causal_from_top_left flag, and Unified Unfused uses past_kv_length=0. + // Defined here for visibility — only Flash needs this guard (MEA/Unfused handle upper-left natively). + const bool causal_cross_no_past = parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; + + // Reject causal + TensorScatter decode (S_q < S_kv without past_key). + // Per ONNX spec, is_causal without past_key means upper-left alignment: q[i] attends + // only to kv[0..i]. For decode with external cache (S_q=1, S_kv=cache_size), this means + // q[0] sees only kv[0] — not meaningful for autoregressive generation. + // + // Why is_causal=0 is correct for external cache decode: + // - With S_q=1, there's only one query position at the end of the sequence + // - All KV positions are in the "past" relative to this query — nothing to mask + // - nonpad_kv_seqlen already bounds attention to valid cache positions + // + // For external cache prompt (S_q == S_kv), is_causal=1 works correctly (square matrix, + // upper-left == lower-right). For chunked prefill (S_q > 1 but S_q < S_kv), use an + // explicit attn_mask instead of is_causal. + if (causal_cross_no_past && nonpad_kv_seqlen != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Causal attention with TensorScatter (nonpad_kv_seqlen) and S_q != S_kv without " + "past_key is not supported. Per ONNX spec, is_causal without past_key produces " + "upper-left alignment where q[i] only attends to kv[0..i], which for decode (S_q=1) " + "means q[0] sees only kv[0]. Use is_causal=0 for TensorScatter decode; the KV bounds " + "are already enforced by nonpad_kv_seqlen without needing a causal mask. For chunked " + "prefill with external cache, use an explicit attn_mask instead."); + } + #if USE_FLASH_ATTENTION { auto& device_prop = GetDeviceProp(); @@ -1357,16 +1349,16 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.q_num_heads, parameters.kv_num_heads) && parameters.head_size == parameters.v_head_size && !has_output_qk && - // Flash does not support attention masks (no bias parameter in mha_fwd/mha_fwd_kvcache). - // Bool attn_mask + past_key is rejected because Flash uses paged KV cache semantics - // that produce spec-divergent present_kv layout for partial masks (e.g. [T,T,T,F]). - // Unfused handles bool+past_key spec-correctly via standard ConcatPastToPresent. - // TODO(titaiwang): GQA + bool attn_mask + past_key currently has no runner (Flash - // rejected here, unfused doesn't support GQA, MEA blocked by past_key != nullptr). - // Once PR #27851 merges (MEA supports past_key), this gap will be covered. + !causal_cross_no_past && + // Flash does not support attention masks — reject when attn_mask is present. attn_mask == nullptr; if (flash_eligible) { + LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using Flash Attention" + << " (batch=" << parameters.batch_size + << ", q_seq=" << parameters.q_sequence_length + << ", total_seq=" << parameters.total_sequence_length + << ", past=" << (past_key != nullptr ? "yes" : "no") << ")"; return RunFlashAttention(context, Q, K, V, past_key, past_value, nonpad_kv_seqlen, Y, present_key, present_value, parameters); } @@ -1383,7 +1375,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.v_head_size) && !has_output_qk && - past_key == nullptr && + // MEA requires head_size == v_head_size in two internal paths: + // - LaunchConcatNewToPastKV (decode with past_key) + // - LaunchUngroup (GQA head expansion) + // Fall back to unfused attention when they differ. + (!is_gqa || parameters.head_size == parameters.v_head_size) && + (past_key == nullptr || parameters.head_size == parameters.v_head_size) && // GQA+MEA requires LaunchUngroup which only has fp16/bf16 instantiations. // FP32 GQA must fall through to the unfused path. !(is_gqa && std::is_same::value); @@ -1408,65 +1405,43 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } if (mea_eligible) { + LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using Memory Efficient Attention" + << " (batch=" << parameters.batch_size + << ", q_seq=" << parameters.q_sequence_length + << ", total_seq=" << parameters.total_sequence_length + << ", past=" << (past_key != nullptr ? "yes" : "no") + << ", mask=" << (attn_mask != nullptr ? "yes" : "no") << ")"; return RunMemoryEfficientAttention(context, Q, K, V, attn_mask, past_key, past_value, nonpad_kv_seqlen, Y, present_key, present_value, parameters); } } #endif - // TODO(titaiwang): Support additional output_qk modes beyond kNone and kQK. - // Currently only unfused handles output_qk, and only kNone/kQK modes. + // Fallback: unified unfused attention + // Routes ALL cases to LaunchUnfusedAttention, which handles: + // - GQA natively (reshape-Q trick inside kernel, no K/V head replication) + // - MHA as a degenerate case (group_size=1) + // - fp16/bf16 with large head_size via FP32 QK scratch + // - softcap, attn_mask, causal, past_key+past_value, nonpad_kv_seqlen + // - output_qk (kQK mode: scale * Q @ K^T, before softcap/mask/softmax) + // - past_key with H != H_v (separate concat calls for K and V) + + // Guard: unified kernel only supports kNone and kQK output modes. + // Other modes (kQKMask, kQKSoftCap, kQKSoftMax) expect QK values captured at + // different pipeline stages that the unified kernel does not implement. if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qk_matmul_output_mode other than kNone and kQK is not supported yet " - "in Attention op (CUDA)."); - } - - // GQA-capable unfused fallback (issue #28195). - // Routes through LaunchGqaUnfusedAttention when: - // - GQA (q_num_heads != kv_num_heads) — the MHA unfused runner cannot handle this. - // - fp16/bf16 with head_size > 128 — raw Q*K^T can overflow fp16 storage even - // though cuBLAS accumulates in FP32; the new kernel writes QK to an FP32 scratch. - // The overflow threshold depends on the distribution of Q/K values and scale. - // head_size=256 at scale=1/sqrt(256)=0.0625 is borderline; head_size=512 at - // scale=1.0 (Gemma 4) definitely overflows. We use 128 as a conservative - // threshold since all fused kernels already handle head_size <= 128 anyway. - // This kernel supports softcap. It does not support output_qk, so we only enter it - // when qk_matmul_output_mode_ == kNone. - const bool is_half_or_bf16 = std::is_same::value || std::is_same::value; - const bool needs_fp32_qk_scratch = is_half_or_bf16 && parameters.head_size > 128; - if ((is_gqa || needs_fp32_qk_scratch) && - qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone) { - LOGS_DEFAULT(VERBOSE) << "Attention: using GQA unfused fallback (is_gqa=" << is_gqa - << ", needs_fp32_qk_scratch=" << needs_fp32_qk_scratch - << ", head_size=" << parameters.head_size - << ", softcap=" << parameters.softcap << ")"; - return RunGqaUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, - nonpad_kv_seqlen, Y, present_key, present_value, parameters); - } - - if (is_gqa) { - // qk_matmul_output_mode != kNone reaches here; the unfused MHA runner cannot handle GQA. - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "ONNX Attention with GQA (q_num_heads != kv_num_heads) and output_qk is not " - "supported by the unfused runner."); - } - - // Fallback: unfused MHA attention (legacy runner). - // Softcap is not implemented in the legacy unfused path — it requires Flash or MEA - // (or the new GQA unfused path above, which supports softcap for fp16/bf16/fp32). - // NOTE: keep this guard even if future PRs add softcap to more fused paths — this - // legacy unfused runner does NOT apply softcap and would silently produce wrong results. - if (parameters.softcap > 0.0f) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "softcap requires flash attention or memory efficient attention, " - "but neither is eligible for this configuration. Check dtype (fp16/bf16 required for Flash), " - "head_size constraints, and past_key compatibility."); + "Only kNone and kQK output modes are supported in unified unfused attention. Mode: ", + static_cast(qk_matmul_output_mode_)); } + LOGS_DEFAULT(VERBOSE) << "Attention: using unified unfused path (is_gqa=" << is_gqa + << ", head_size=" << parameters.head_size + << ", softcap=" << parameters.softcap << ")"; return RunUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, - nonpad_kv_seqlen, Y, present_key, present_value, output_qk, parameters); + nonpad_kv_seqlen, Y, present_key, present_value, + output_qk, parameters); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h index 2acbf3b2ed829..f11503f154a30 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.h +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -31,27 +31,18 @@ class Attention final : public CudaKernel { Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const; - Status RunUnfusedAttention( - OpKernelContext* context, - const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, - const Tensor* nonpad_kv_seqlen, - Tensor* Y, Tensor* present_key, Tensor* present_value, - Tensor* output_qk, - const attention_helper::AttentionParameters& parameters) const; - - // GQA-capable unfused fallback. Handles: + // Unified unfused fallback. Handles: // - GQA (q_num_heads != kv_num_heads) without K/V head replication. // - fp16/bf16 with large head_size (FP32 QK accumulation, fixes #28195). // - past_key+past_value, attn_mask (bool/float), nonpad_kv_seqlen. - // Does not support: output_qk - // (output_qk modes other than kNone are rejected upstream). - Status RunGqaUnfusedAttention( + // - output_qk (kQK mode: scale * Q @ K^T, before softcap/mask/softmax). + Status RunUnfusedAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const; Status ConvertAttnMaskToBias( diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu index 4ab3990b2f85d..2ba7f2e1a9836 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu @@ -89,107 +89,6 @@ Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( return CUDA_CALL(cudaGetLastError()); } -// CUDA kernel to convert nonpad_kv_seqlen to an additive attention bias. -// Generates (batch_size, q_seq_len, total_seq_len) output where: -// position t < nonpad_kv_seqlen[b] → 0.0 (attend) -// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) -template -__global__ void ConvertNonpadKvSeqlenToAttentionBiasKernel( - const int64_t* __restrict__ nonpad_kv_seqlen, - T* __restrict__ attention_bias, - const int batch_size, - const int q_seq_len, - const int total_seq_len, - const float mask_filter_value) { - int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; - for (; idx < total; idx += static_cast(gridDim.x) * blockDim.x) { - int b = static_cast(idx / (static_cast(q_seq_len) * total_seq_len)); - int t = static_cast(idx % total_seq_len); - int64_t valid_len = nonpad_kv_seqlen[b]; - CUDA_KERNEL_ASSERT(valid_len >= 0 && valid_len <= static_cast(total_seq_len)); - valid_len = max(static_cast(0), min(valid_len, static_cast(total_seq_len))); - attention_bias[idx] = (t < static_cast(valid_len)) ? T(0.0f) : T(mask_filter_value); - } -} - -template -Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t* nonpad_kv_seqlen, - T* attention_bias, - int batch_size, - int q_seq_len, - int total_seq_len, - float mask_filter_value, - cudaStream_t stream, - int max_threads_per_block) { - int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; - if (total == 0) { - return Status::OK(); - } - - int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); - int64_t blocks = (total + threads - 1) / threads; - constexpr int64_t kMaxGridDimX = 65535; - unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); - - ConvertNonpadKvSeqlenToAttentionBiasKernel<<>>( - nonpad_kv_seqlen, attention_bias, batch_size, q_seq_len, total_seq_len, mask_filter_value); - - return CUDA_CALL(cudaGetLastError()); -} - -template Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t*, float*, int, int, int, float, cudaStream_t, int); -template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__half>( - const int64_t*, __half*, int, int, int, float, cudaStream_t, int); -template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__nv_bfloat16>( - const int64_t*, __nv_bfloat16*, int, int, int, float, cudaStream_t, int); - -// Add an addend bias into an existing bias buffer using cyclic broadcasting. -// Used to compose nonpad_kv_seqlen bias [B, q, t] with an attn_mask bias that -// is smaller or equal (e.g. 2D [q, t] cyclic-broadcasts over batch dimension). -template -__global__ void AddBiasInPlaceKernel( - T* __restrict__ bias, - const T* __restrict__ addend, - int64_t total_elements, - int64_t addend_elements) { - for (int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - idx < total_elements; - idx += static_cast(gridDim.x) * blockDim.x) { - float sum = static_cast(bias[idx]) + static_cast(addend[idx % addend_elements]); - bias[idx] = T(sum); - } -} - -template -Status LaunchAddBiasInPlace( - T* bias, - const T* addend, - int64_t total_elements, - int64_t addend_elements, - cudaStream_t stream, - int max_threads_per_block) { - if (total_elements == 0 || addend_elements == 0) { - return Status::OK(); - } - - int threads = static_cast(std::min(static_cast(max_threads_per_block), total_elements)); - int64_t blocks = (total_elements + threads - 1) / threads; - constexpr int64_t kMaxGridDimX = 65535; - unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); - - AddBiasInPlaceKernel<<>>( - bias, addend, total_elements, addend_elements); - - return CUDA_CALL(cudaGetLastError()); -} - -template Status LaunchAddBiasInPlace(float*, const float*, int64_t, int64_t, cudaStream_t, int); -template Status LaunchAddBiasInPlace<__half>(__half*, const __half*, int64_t, int64_t, cudaStream_t, int); -template Status LaunchAddBiasInPlace<__nv_bfloat16>(__nv_bfloat16*, const __nv_bfloat16*, int64_t, int64_t, cudaStream_t, int); - // Zero output elements for batches where seqlens_k == 0 (fully masked). // CUTLASS MEA epilogue computes 1/s_prime where s_prime=0 → NaN for fully-masked // batches. The unfused path produces uniform softmax weights (finite mask_filter_value, diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h index 1ada783e9d64d..d2cb4dbbd25ae 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h @@ -31,34 +31,6 @@ Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( cudaStream_t stream, int max_threads_per_block); -// Convert nonpad_kv_seqlen to an additive attention bias for the MHA unfused path. -// Generates a (batch_size, q_seq_len, total_seq_len) tensor where: -// position t < nonpad_kv_seqlen[b] → 0.0 (attend) -// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) -template -Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t* nonpad_kv_seqlen, - T* attention_bias, - int batch_size, - int q_seq_len, - int total_seq_len, - float mask_filter_value, - cudaStream_t stream, - int max_threads_per_block); - -// Additively compose an addend bias into an existing bias buffer in-place. -// Supports cyclic broadcasting: addend of size [q, t] is repeated over batch -// to compose with a bias of size [B, q, t]. When both have the same number -// of elements (e.g. 4D mask [B, 1, q, t]), it performs a direct element-wise add. -template -Status LaunchAddBiasInPlace( - T* bias, - const T* addend, - int64_t total_elements, - int64_t addend_elements, - cudaStream_t stream, - int max_threads_per_block); - // Zero output elements for batches where seqlens_k == 0 (fully masked). // Used in the MEA path only: CUTLASS epilogue computes 1/s_prime where s_prime=0, // producing NaN for fully-masked batches. This kernel overwrites those NaN outputs diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index e4faa50d7acbc..babbb4b3ba672 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -579,8 +579,10 @@ Status LessOrEqual::ComputeInternal(OpKernelContext* context) const { return this->CompareMethod(context, &ImplT2_LessOrEqual); } -BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 13) -BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 13, bool) +BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD(Equal, 13, 18) +BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 13, 18, bool) +BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 19) +BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 19, bool) BINARY_OP_REGISTER_VERSIONED_UZILHFD(Equal, 11, 12) BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 11, 12, bool) BINARY_OP_REGISTER_VERSIONED_OIL(Equal, 7, 10) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 86a1b0f5b6102..a54b96da6c174 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -249,7 +249,8 @@ UNARY_OP_HFDX(Erf, 13) UNARY_OP_BWUZCSILHFDX(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) -UNARY_OP_HFD(Round, 11) +UNARY_OP_VERSIONED_HFD(Round, 11, 21) +UNARY_OP_HFD(Round, 22) UNARY_OP_HFD(Cos, 7) UNARY_OP_HFD(Sin, 7) diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 127cfcc557fd5..a0a2f377d0c80 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -36,6 +36,16 @@ namespace cuda { (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); +#define REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(name, T, begin, end) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + begin, end, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ + name); + #define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) @@ -876,13 +886,27 @@ REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16) REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float) REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int64_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int8_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, uint8_t, 17, 18) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, MLFloat16, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, float, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, double, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, int32_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, int64_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, int8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, uint8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, MLFloat16, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, float, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, double, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, int32_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, int64_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, int8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, uint8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, MLFloat16, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, float, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, double, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, int32_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, int64_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, int8_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, uint8_t, 20) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, MLFloat16, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, float, 17, 18) @@ -890,13 +914,27 @@ REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, double, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, BFloat16, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, float, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int64_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int8_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, uint8_t, 17, 18) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, MLFloat16, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, float, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, double, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, int32_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, int64_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, int8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, uint8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, MLFloat16, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, float, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, double, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, int32_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, int64_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, int8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, uint8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, MLFloat16, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, float, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, double, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, int32_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, int64_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, int8_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, uint8_t, 20) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, MLFloat16, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, float, 17, 18) diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index ed8be63679707..236aa5022fa80 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -24,11 +24,25 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, RNN_Input_Index::sequence_lens), \ RNN); +#define REGISTER_KERNEL_VERSIONED_TYPED_14(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + RNN, \ + kOnnxDomain, \ + 14, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, RNN_Input_Index::sequence_lens), \ + RNN); + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ RNN, \ kOnnxDomain, \ - 14, \ + 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -41,6 +55,10 @@ REGISTER_KERNEL_VERSIONED_TYPED(float); REGISTER_KERNEL_VERSIONED_TYPED(double); REGISTER_KERNEL_VERSIONED_TYPED(MLFloat16); +REGISTER_KERNEL_VERSIONED_TYPED_14(float); +REGISTER_KERNEL_VERSIONED_TYPED_14(double); +REGISTER_KERNEL_VERSIONED_TYPED_14(MLFloat16); + REGISTER_KERNEL_TYPED(float); REGISTER_KERNEL_TYPED(double); REGISTER_KERNEL_TYPED(MLFloat16); diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cc b/onnxruntime/core/providers/cuda/tensor/cast_op.cc index 8f5c9202c1dba..2ed08e25d02d2 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cc @@ -90,10 +90,20 @@ const std::vector& CastOpTypeConstraints() { .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", CastOpTypeConstraints()), \ Cast); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 23, 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", CastOpTypeConstraints()), \ + Cast); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Cast, \ kOnnxDomain, \ - 23, \ + 25, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -389,11 +399,23 @@ SPECIALIZE_IMPL(BFloat16) .TypeConstraint("T2", CastOpTypeConstraints()), \ Cast); -#define REGISTER_KERNEL_TYPED_23(T, OutputTypeConstraints) \ +#define REGISTER_KERNEL_TYPED_23_TO_24(T, OutputTypeConstraints) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 23, 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", OutputTypeConstraints), \ + Cast); + +#define REGISTER_KERNEL_TYPED_25(T, OutputTypeConstraints) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Cast, \ kOnnxDomain, \ - 23, \ + 25, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -403,18 +425,20 @@ SPECIALIZE_IMPL(BFloat16) #if !defined(DISABLE_FLOAT8_TYPES) -#define SPECIALIZE_IMPL_19_TO_23(T) \ - REGISTER_KERNEL_TYPED_19_TO_22(T) \ - REGISTER_KERNEL_TYPED_23(T, CastOpTypeConstraints()) \ +#define SPECIALIZE_IMPL_19_TO_25(T) \ + REGISTER_KERNEL_TYPED_19_TO_22(T) \ + REGISTER_KERNEL_TYPED_23_TO_24(T, CastOpTypeConstraints()) \ + REGISTER_KERNEL_TYPED_25(T, CastOpTypeConstraints()) \ template Status Cast::ComputeInternal(OpKernelContext* context) const; -SPECIALIZE_IMPL_19_TO_23(Float8E4M3FN) -SPECIALIZE_IMPL_19_TO_23(Float8E5M2) +SPECIALIZE_IMPL_19_TO_25(Float8E4M3FN) +SPECIALIZE_IMPL_19_TO_25(Float8E5M2) #endif #if !defined(DISABLE_FLOAT4_TYPES) -REGISTER_KERNEL_TYPED_23(Float4E2M1x2, {DataTypeImpl::GetTensorType()}) +REGISTER_KERNEL_TYPED_23_TO_24(Float4E2M1x2, {DataTypeImpl::GetTensorType()}) +REGISTER_KERNEL_TYPED_25(Float4E2M1x2, {DataTypeImpl::GetTensorType()}) template Status Cast::ComputeInternal(OpKernelContext* context) const; #endif diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index 36ee05e1e2b01..7bf3da4197ba9 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -89,7 +89,19 @@ std::unique_ptr FuncReshape( ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, - 23, + 25, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 23, 24, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()) diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 6e0586e772334..d4559363ce68b 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -31,6 +31,12 @@ struct NearestPixel_ROUND_PREFER_FLOOR { struct NearestPixel_ROUND_PREFER_CEIL { __device__ __forceinline__ int operator()(float x_original, bool) const { + // for half way cases prefer ceil + // roundf rounds away from zero which is correct for positive .5 values + // but for negative .5 values (e.g., -0.5) it rounds to -1 instead of 0 (ceil) + if (x_original == static_cast(x_original) - 0.5f) { + return static_cast(_Ceil(x_original)); + } return static_cast(roundf(x_original)); } }; diff --git a/onnxruntime/core/session/model_editor_api.h b/onnxruntime/core/session/model_editor_api.h index be6da18de2a64..fdd574bb91f34 100644 --- a/onnxruntime/core/session/model_editor_api.h +++ b/onnxruntime/core/session/model_editor_api.h @@ -33,7 +33,8 @@ ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); -ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, + _In_ const OrtValue* ort_value, bool data_is_external); ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc index 487d5c818f9bc..91cd66f2d2191 100644 --- a/onnxruntime/core/session/model_editor_c_api.cc +++ b/onnxruntime/core/session/model_editor_c_api.cc @@ -3,8 +3,11 @@ #if !defined(ORT_MINIMAL_BUILD) +#include #include +#include "core/common/inlined_containers.h" + #include "core/framework/error_code_helper.h" #include "core/framework/ort_value.h" #include "core/framework/onnxruntime_typeinfo.h" @@ -105,6 +108,14 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateGraph, _Outptr_ OrtGraph** graph) { ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { API_IMPL_BEGIN + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + + if (inputs == nullptr && inputs_len != 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot be null when inputs_len is non-zero"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -112,7 +123,27 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, "Invalid OrtGraph variant for use in the OrtModelEditorApi"); } + // Check for duplicate pointers in the input array to prevent double-free + onnxruntime::InlinedHashSet seen; + for (size_t i = 0; i < inputs_len; ++i) { + if (inputs[i] == nullptr) { + continue; + } + if (!seen.insert(inputs[i]).second) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Duplicate OrtValueInfo pointer found in inputs array. " + "Each OrtValueInfo can only appear once."); + } + onnxruntime::ModelEditorValueInfo* vi = onnxruntime::ModelEditorValueInfo::ToInternal(inputs[i]); + if (vi != nullptr && vi->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This OrtValueInfo has already been added to a graph. " + "Each OrtValueInfo can only be added once."); + } + } + graph->inputs.clear(); + graph->inputs.reserve(inputs_len); for (size_t i = 0; i < inputs_len; ++i) { if (inputs[i] == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); @@ -125,6 +156,7 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, } graph->inputs.push_back(std::unique_ptr(input)); // take ownership + input->owned_ = true; inputs[i] = nullptr; } @@ -135,6 +167,14 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph, _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { API_IMPL_BEGIN + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + + if (outputs == nullptr && outputs_len != 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot be null when outputs_len is non-zero"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -142,7 +182,27 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph "Invalid OrtGraph variant for use in the OrtModelEditorApi"); } + // Check for duplicate pointers in the output array to prevent double-free + onnxruntime::InlinedHashSet seen; + for (size_t i = 0; i < outputs_len; ++i) { + if (outputs[i] == nullptr) { + continue; + } + if (!seen.insert(outputs[i]).second) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Duplicate OrtValueInfo pointer found in outputs array. " + "Each OrtValueInfo can only appear once."); + } + onnxruntime::ModelEditorValueInfo* vi = onnxruntime::ModelEditorValueInfo::ToInternal(outputs[i]); + if (vi != nullptr && vi->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This OrtValueInfo has already been added to a graph. " + "Each OrtValueInfo can only be added once."); + } + } + graph->outputs.clear(); + graph->outputs.reserve(outputs_len); for (size_t i = 0; i < outputs_len; ++i) { if (outputs[i] == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); @@ -155,6 +215,7 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph } graph->outputs.push_back(std::unique_ptr(output)); // take ownership + output->owned_ = true; outputs[i] = nullptr; } @@ -163,8 +224,20 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph } ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* ort_graph, _In_ const char* name, - _Inout_ OrtValue* tensor, bool data_is_external) { + _In_ const OrtValue* ort_value, bool data_is_external) { API_IMPL_BEGIN + if (name == nullptr || *name == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name cannot be null or empty string"); + } + + if (ort_value == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ort_value cannot be null"); + } + + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -172,19 +245,25 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* ort "Invalid OrtGraph variant for use in the OrtModelEditorApi"); } - if (!tensor->IsTensor()) { + if (!ort_value->IsTensor()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported."); } - if (!tensor->IsAllocated()) { + if (!ort_value->IsAllocated()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Tensor must be allocated."); } - const auto& t = tensor->Get(); + const auto& t = ort_value->Get(); if (t.Location().device.Type() != OrtDevice::CPU) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only CPU based tensors are currently supported."); } + // Reject duplicate name in either map + if (graph->initializers.count(name) != 0 || graph->external_initializers.count(name) != 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "An initializer with this name has already been added to the graph."); + } + if (data_is_external) { // enforce that an external initializer is not used if the data size is < 128 bytes. // the reason for this is to avoid potential shape inferencing errors if this initializer is providing an @@ -195,18 +274,26 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* ort "External initializer should only be used for data >= 128 bytes. " "Please use CreateTensorAsOrtValue instead."); } - - graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership - } else { - graph->initializers[name] = std::unique_ptr(tensor); // take ownership } + auto& m = data_is_external ? graph->external_initializers : graph->initializers; + auto [it, inserted] = m.emplace(name, *ort_value); + ORT_ENFORCE(inserted, "Unexpected duplicate name after validation. This is a bug."); + return nullptr; API_IMPL_END } ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* ort_graph, _Inout_ OrtNode* ort_node) { API_IMPL_BEGIN + if (ort_node == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "node cannot be null"); + } + + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -221,8 +308,19 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* ort_graph, "Invalid OrtNode variant for use in the OrtModelEditorApi"); } + // Reject if this node has already been added to a graph (prevents double-free) + if (node->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This node has already been added to a graph. " + "Each OrtNode can only be added once."); + } + node->id = graph->nodes.size(); - graph->nodes.push_back(std::unique_ptr(node)); // take ownership + if (graph->nodes.size() == graph->nodes.capacity()) { + graph->nodes.reserve(std::max(graph->nodes.capacity() * 2, size_t{1})); + } + graph->nodes.emplace_back(node); + node->owned_ = true; return nullptr; API_IMPL_END } @@ -246,11 +344,36 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModel, ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph) { API_IMPL_BEGIN + if (model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "model cannot be null"); + } + if (graph == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); } + // Reject if model already has a graph (prevents double-free/UAF) + if (model->graph != nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Model already has a graph. Each OrtModel can only have one graph."); + } + + // Reject if this graph has already been added to a model (prevents double-free across models) + onnxruntime::ModelEditorGraph* me_graph = onnxruntime::ModelEditorGraph::ToInternal(graph); + if (me_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtGraph variant for use in the OrtModelEditorApi"); + } + + if (me_graph->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This graph has already been added to a model. " + "Each OrtGraph can only be added once."); + } + model->graph = std::unique_ptr(graph); // take ownership + me_graph->owned_ = true; + return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 3f28529e7a847..2ac95d6e36466 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2752,14 +2752,35 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes } ORT_API(void, OrtApis::ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info) { + if (value_info != nullptr) { + if (auto* me = onnxruntime::ModelEditorValueInfo::ToInternal(value_info); + me != nullptr && me->owned_) { + assert(false && "Releasing an OrtValueInfo that is owned by a graph"); + return; + } + } delete value_info; } ORT_API(void, OrtApis::ReleaseNode, _Frees_ptr_opt_ OrtNode* node) { + if (node != nullptr) { + if (auto* me = onnxruntime::ModelEditorNode::ToInternal(node); + me != nullptr && me->owned_) { + assert(false && "Releasing an OrtNode that is owned by a graph"); + return; + } + } delete node; } ORT_API(void, OrtApis::ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph) { + if (graph != nullptr) { + if (auto* me = onnxruntime::ModelEditorGraph::ToInternal(graph); + me != nullptr && me->owned_) { + assert(false && "Releasing an OrtGraph that is owned by a model"); + return; + } + } delete graph; } @@ -4387,7 +4408,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_26 = { +static constexpr OrtApi ort_api_1_to_27 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -4894,6 +4915,7 @@ static constexpr OrtApi ort_api_1_to_26 = { &OrtApis::KernelInfoGetAttributeArray_string, &OrtApis::SetPerSessionThreadPoolCallbacks, // End of Version 25 - DO NOT MODIFY ABOVE (see above text for more information) + // End of Version 26 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -4932,13 +4954,14 @@ static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of versio static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_string) / sizeof(void*) == 417, "Size of version 25 API cannot change"); +// no additions in version 26 // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.26.0", +static_assert(std::string_view(ORT_VERSION) == "1.27.0", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // -// 2. If there were any APIs added to ort_api_1_to_26 above: +// 2. If there were any APIs added to ort_api_1_to_X above: // a. Add the 'End of version #' markers (pattern above should be obvious) // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version // @@ -4950,7 +4973,7 @@ static_assert(std::string_view(ORT_VERSION) == "1.26.0", ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_26; + return &ort_api_1_to_27; fprintf(stderr, "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index cc3bc2ef28c4f..0a1ba0462f9bf 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -13,7 +13,6 @@ import onnx import onnxruntime -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data from .fusions import ReplaceUpsampleWithResize @@ -88,6 +87,13 @@ def quant_pre_process( model = save_and_reload_model_with_shape_infer(model) if not skip_symbolic_shape: + try: + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference # noqa: PLC0415 + except ImportError as e: + raise ImportError( + "sympy is required for symbolic shape inference in quantization preprocessing. " + "Install with: 'pip install sympy' or pass skip_symbolic_shape=True to quant_pre_process()." + ) from e logger.info("Performing symbolic shape inference...") model = SymbolicShapeInference.infer_shapes( model, diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index a00cddf18870e..25bd35e479bd2 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations import itertools import logging @@ -9,6 +10,10 @@ import sys from collections import deque from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from shape_infer_helper import SymbolicShapeInferenceHelper from float16 import convert_float_to_float16 from onnx import ( @@ -23,7 +28,6 @@ save_model, ) from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data -from shape_infer_helper import SymbolicShapeInferenceHelper logger = logging.getLogger(__name__) @@ -51,6 +55,8 @@ def disable_shape_inference(self): def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False): # noqa: B006 if self.enable_shape_infer: if self.shape_infer_helper is None or update: + from shape_infer_helper import SymbolicShapeInferenceHelper # noqa: PLC0415 + self.shape_infer_helper = SymbolicShapeInferenceHelper(self.model) try: @@ -764,6 +770,8 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): if use_symbolic_shape_infer: # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) # are not recognized by onnx shape inference. + from shape_infer_helper import SymbolicShapeInferenceHelper # noqa: PLC0415 + shape_infer_helper = SymbolicShapeInferenceHelper(model) try: model_with_shape = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False) diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py index f4d65d05ad0c8..5651c3cddba72 100644 --- a/onnxruntime/python/tools/transformers/shape_infer_helper.py +++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py @@ -14,13 +14,31 @@ else: sys.path.append(os.path.join(file_path, "..")) -from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy # noqa: E402 +try: + from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy + + _symbolic_shape_infer_available = True + _symbolic_shape_infer_import_error: ImportError | None = None +except ImportError as exc: + SymbolicShapeInference = object # type: ignore[assignment,misc] + get_shape_from_type_proto = None # type: ignore[assignment] + sympy = None # type: ignore[assignment] + _symbolic_shape_infer_available = False + _symbolic_shape_infer_import_error = exc logger = logging.getLogger(__name__) class SymbolicShapeInferenceHelper(SymbolicShapeInference): def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False): + if not _symbolic_shape_infer_available: + err = _symbolic_shape_infer_import_error + cause = ( + "missing 'sympy' (install with: pip install sympy)" + if err is not None and "sympy" in str(err) + else f"failed to import symbolic_shape_infer: {err!r}" + ) + raise ImportError(f"SymbolicShapeInferenceHelper is unavailable — {cause}") from err super().__init__(int_max, auto_merge, guess_output_rank, verbose) self.model_ = model self.all_shapes_inferred_: bool = False diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 950355742193c..0779bd4d4ec09 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4713,6 +4713,81 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraph) { } } +// Regression test: FuseContiguousReshapes must not collapse a chain of Reshapes +// when the inferred output shape contains a literal 0 dim. Doing so would create +// a single Reshape whose shape data contains 0 and (because allowzero defaults +// to 0) be misinterpreted as "copy from input dim", silently producing wrong shape. +// See https://github.com/microsoft/onnxruntime/issues/28348. +TEST_F(GraphTransformationTests, ReshapeFusionContiguousReshapesWithZeroDim) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 21; + Model model("ReshapeFusionContiguousReshapesWithZeroDim", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), *logger_); + auto& graph = model.MainGraph(); + + // X: float[0, 6, 2] (zero-sized first dim, fully concrete) + TypeProto x_type; + x_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(0); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(6); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + TypeProto y_type; + y_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + + auto& X = graph.GetOrCreateNodeArg("X", &x_type); + auto& mid = graph.GetOrCreateNodeArg("mid", &y_type); + auto& Y = graph.GetOrCreateNodeArg("Y", &y_type); + + // shape1 = [3, 2, -1] -> mid shape (3, 2, 0) + ONNX_NAMESPACE::TensorProto shape1_proto; + shape1_proto.set_name("shape1"); + shape1_proto.set_data_type(TensorProto_DataType_INT64); + shape1_proto.add_dims(3); + for (int64_t v : {3, 2, -1}) shape1_proto.add_int64_data(v); + graph.AddInitializedTensor(shape1_proto); + + // shape2 = [0, 0, 3] with allowzero=1 -> Y shape (0, 0, 3) + ONNX_NAMESPACE::TensorProto shape2_proto; + shape2_proto.set_name("shape2"); + shape2_proto.set_data_type(TensorProto_DataType_INT64); + shape2_proto.add_dims(3); + for (int64_t v : {0, 0, 3}) shape2_proto.add_int64_data(v); + graph.AddInitializedTensor(shape2_proto); + + auto& shape1 = graph.GetOrCreateNodeArg("shape1", nullptr); + auto& shape2 = graph.GetOrCreateNodeArg("shape2", nullptr); + + graph.AddNode("reshape1", "Reshape", "first reshape", {&X, &shape1}, {&mid}); + auto& reshape2 = graph.AddNode("reshape2", "Reshape", "second reshape (allowzero=1)", + {&mid, &shape2}, {&Y}); + reshape2.AddAttribute("allowzero", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::map op_to_count_before = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count_before["Reshape"], 2); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // Fusion must NOT collapse the two reshapes, otherwise the resulting single + // Reshape would (mis)compute output shape (0, 6, 3) instead of (0, 0, 3). + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Reshape"], 2); + + // Y's inferred shape must remain (0, 0, 3). + const auto* y_shape = graph.GetNodeArg("Y")->Shape(); + ASSERT_NE(y_shape, nullptr); + ASSERT_EQ(y_shape->dim_size(), 3); + EXPECT_EQ(y_shape->dim(0).dim_value(), 0); + EXPECT_EQ(y_shape->dim(1).dim_value(), 0); + EXPECT_EQ(y_shape->dim(2).dim_value(), 3); +} + TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_slice1.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 0cf95141b7a6c..40c45db2dfd66 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -8,6 +8,8 @@ #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/scoped_env_vars.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace test { @@ -91,8 +93,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, y, false, 0, 3e-5f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, present_key); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, present_value); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, qk_matmul_output); } else if (tensor_type == TensorType::kFloat16) { @@ -120,8 +126,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, ToFloat16(y), false, 0, 3e-3f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, ToFloat16(present_key)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, ToFloat16(present_value)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, ToFloat16(qk_matmul_output)); } else { @@ -149,8 +159,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, FloatsToBFloat16s(y), false, 0, 3e-3f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, FloatsToBFloat16s(present_key)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, FloatsToBFloat16s(present_value)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, FloatsToBFloat16s(qk_matmul_output)); } @@ -516,11 +530,10 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) { // Regression guard: all-false bool mask in decode mode (past_sequence_length > 0). // Guards against a bug where fully-masked batches produce NaN or incorrect output. -// Expected behavior: uniform softmax over past KV values produces Y = mean-of-V. -// With past_v = [10,20,30,40] and [20,40,60,80] per head, and all positions masked out, -// softmax(all -inf + constant mask_filter_value) → uniform weights → Y = {25, 50}. -// This test originally came from upstream/main and validates that both CPU and CUDA -// (unfused path) handle the all-false mask case identically. +// Expected behavior: uniform softmax over all KV values produces Y = mean-of-V. +// On CUDA, MEA decode handles this config (total_seq=4, 4-aligned). The capped +// mask_filter_value (-1e+30) in ConvertAttnMaskToBias prevents CUTLASS overflow, +// producing correct uniform softmax → mean(V). TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { int batch_size = 1; int q_num_heads = 2; @@ -609,8 +622,9 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { ); } -// Unfused decode path with fp16 and all-true bool attention mask. -// Flash rejects attn_mask (requires attn_mask==nullptr), so CUDA routes to unfused. +// Decode path with fp16 and all-true bool attention mask. +// Flash rejects attn_mask (requires attn_mask==nullptr). MEA handles decode with +// bool mask via additive bias (past_key concat + ConvertAttnMaskToBias). // head_size=64. Uniform keys make output analytically verifiable: // all attention scores are equal, so softmax is uniform over all positions. TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { @@ -695,8 +709,8 @@ TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { // Decode with partial bool mask [T,T,T,F]: the new token is masked out. // With mask [T,T,T,F] past_seq=3 total=4: only positions 0,1,2 are attended (past only). -// Flash is ineligible (bool+past_key rejected), so CUDA uses unfused which handles this -// spec-correctly via standard ConcatPastToPresent + element-wise mask application. +// Flash is ineligible (bool+past_key rejected). MEA handles decode with bool mask +// via additive bias (past_key concat + ConvertAttnMaskToBias). // Y = uniform mean over the 3 attended past values (Q=K=constant → uniform softmax). // CPU always runs; CUDA runs when SM 5.3+ is available. TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { @@ -781,7 +795,8 @@ TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { // Multi-batch decode with per-batch partial bool masks. // batch_size=2: batch 0 [T,T,T,F,F,F] (3 leading trues), batch 1 [T,T,T,T,T,T] (all true). -// Flash is ineligible (bool+past_key rejected), CUDA uses unfused. +// Flash is ineligible (bool+past_key rejected). MEA rejected by CUTLASS bias alignment +// (total_seq=6, 6%4≠0), so CUDA falls through to unfused. // Unfused applies standard ConcatPastToPresent (new token at position past_sequence_length=5 // for all batches) and element-wise mask in softmax. // Runs on both CPU and CUDA to verify cross-EP consistency. @@ -988,9 +1003,8 @@ TEST(AttentionTest, Attention4DSoftCap) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - // disable_cuda: head_size(8) != v_head_size(10) blocks Flash, past_key blocks MEA, - // unfused path doesn't support softcap. Needs test with head_size == v_head_size and no past. - false, true, true // disable_cpu, disable_cuda, disable_dml + // head_size(8) != v_head_size(10) blocks Flash and MEA decode; falls to unfused which now supports softcap. + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1018,9 +1032,8 @@ TEST(AttentionTest, Attention4DSoftCapFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - // disable_cuda: head_size(8) != v_head_size(10) blocks Flash, past_key blocks MEA, - // unfused path doesn't support softcap. Needs test with head_size == v_head_size and no past. - false, true, true // disable_cpu, disable_cuda, disable_dml + // head_size(8) != v_head_size(10) blocks Flash and MEA decode; falls to unfused which now supports softcap. + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1160,7 +1173,6 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { ); } -// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausal) { int batch_size = 2; // Q.shape[0] int q_num_heads = 3; // Q.shape[1] @@ -1250,7 +1262,6 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasicFloat16) { ); } -// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausalBasicDifferentSequenceLength) { int batch_size = 2; // Q.shape[0] int q_num_heads = 1; // Q.shape[1] @@ -2308,10 +2319,10 @@ TEST(AttentionTest, Attention_NonPadKVSeqLen_WithFloatAttnMask_MultiBatch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused attention with FP32 QK accumulation for large head_size (> 128). -// This exercises the RunGqaUnfusedAttention path in attention.cc which uses +// Unfused attention with FP32 QK accumulation for large head_size (> 128). +// This exercises the RunUnfusedAttention path in attention.cc which uses // an FP32 scratch buffer for QK matmul to prevent overflow in fp16. -TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_FP16) { +TEST(AttentionTest, Attention_Unfused_LargeHeadSize_FP16) { if (!HasCudaEnvironment(530)) { return; // fp16 requires SM 5.3+ } @@ -2371,9 +2382,9 @@ TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused attention with causal mask and large head_size. -// Verifies that is_causal works correctly in the unfused GQA path. -TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_Causal_FP16) { +// Unfused attention with causal mask and large head_size. +// Verifies that is_causal works correctly in the unfused path. +TEST(AttentionTest, Attention_Unfused_LargeHeadSize_Causal_FP16) { if (!HasCudaEnvironment(530)) { return; // fp16 requires SM 5.3+ } @@ -2440,8 +2451,8 @@ TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_Causal_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with past_key + attn_mask: exercises concat + bias path together. -TEST(AttentionTest, Attention_GqaUnfused_PastKey_AttnMask_FP16) { +// Unfused with past_key + attn_mask: exercises concat + bias path together. +TEST(AttentionTest, Attention_Unfused_PastKey_AttnMask_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2519,8 +2530,8 @@ TEST(AttentionTest, Attention_GqaUnfused_PastKey_AttnMask_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with softcap + attn_mask: verifies the softcap + bias interaction. -TEST(AttentionTest, Attention_GqaUnfused_Softcap_AttnMask_FP16) { +// Unfused with softcap + attn_mask: verifies the softcap + bias interaction. +TEST(AttentionTest, Attention_Unfused_Softcap_AttnMask_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2572,8 +2583,8 @@ TEST(AttentionTest, Attention_GqaUnfused_Softcap_AttnMask_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with BSNH (3D) input: previous tests all use 4D BNSH input. -TEST(AttentionTest, Attention_GqaUnfused_BSNH_FP16) { +// Unfused with BSNH (3D) input: previous tests all use 4D BNSH input. +TEST(AttentionTest, Attention_Unfused_BSNH_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2622,8 +2633,8 @@ TEST(AttentionTest, Attention_GqaUnfused_BSNH_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with fp32: exercises the float template instantiation. -TEST(AttentionTest, Attention_GqaUnfused_FP32) { +// Unfused with fp32: exercises the float template instantiation. +TEST(AttentionTest, Attention_Unfused_FP32) { if (!HasCudaEnvironment(0)) { return; } @@ -2673,5 +2684,296 @@ TEST(AttentionTest, Attention_GqaUnfused_FP32) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// Test MEA decode path by disabling Flash Attention. +// Uses the same Attention4DDefaultBasic data (head_size == v_head_size, fp16 with past_key) +// but forces MEA runner via environment variable. +TEST(AttentionTest, Attention4DMEADecodeFloat16) { + int batch_size = 2; + int q_num_heads = 3; + int q_sequence_length = 4; + int head_size = 8; + int kv_sequence_length = 6; + int kv_num_heads = 3; + int v_head_size = 8; + int past_sequence_length = 5; + + // Simple test data: one-hot Q/K/V to make expected output predictable + size_t q_size = batch_size * q_num_heads * q_sequence_length * head_size; + size_t k_size = batch_size * kv_num_heads * kv_sequence_length * head_size; + size_t v_size = batch_size * kv_num_heads * kv_sequence_length * v_head_size; + + std::vector q(q_size, 0.0f); + q[0] = 1.0f; // first element of first query is 1 + std::vector k(k_size, 0.0f); + k[0] = 1.0f; // first element of first key is 1 + std::vector v(v_size, 0.0f); + v[0] = 1.0f; // first element of first value is 1 + + // Expected output matches Attention4DDefaultBasic (same data, same math regardless of runner) + std::vector y = {0.221683f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + + // Force MEA by disabling Flash Attention + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDisableFlashAttention, "1"}}}; + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, + y, std::vector(), std::vector(), std::vector(), + true, false, true // disable_cpu, disable_cuda=false (test CUDA MEA), disable_dml + ); +} + +// Regression test for output_qk + softcap: verifies that qk_matmul_output_mode=0 (kQK) +// returns RAW Q*K logits (before softcap), not softcapped values. +// This test would FAIL if CopyQK were moved after ApplySoftcap: +// - Correct (CopyQK before softcap): output_qk = 2.0 (raw dot product) +// - Wrong (CopyQK after softcap): output_qk = tanh(2.0) ≈ 0.964 (clamped by softcap=1.0) +// Uses constant Q=1, K=1 with head_size=4 so QK = scale * dot(Q,K) = 0.5 * 4 = 2.0. +// v_head_size(6) != head_size(4) blocks Flash Attention and MEA decode, forcing unfused path. +TEST(AttentionTest, Attention4DSoftCapOutputQkRawLogits) { + int batch_size = 1; + int q_num_heads = 2; + int q_sequence_length = 2; + int head_size = 4; + int kv_sequence_length = 3; + int kv_num_heads = 2; + int v_head_size = 6; + int past_sequence_length = 0; + int total_sequence_length = past_sequence_length + kv_sequence_length; + + // Constant Q and K: all 1.0 + // QK = scale * dot(Q[i], K[j]) = (1/sqrt(4)) * 4 = 2.0 for all (i,j) pairs + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 1.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 1.0f); + + // V: position j gets value (j+1)*0.1 across all v_head_size dims + std::vector v(batch_size * kv_num_heads * kv_sequence_length * v_head_size); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < v_head_size; h++) { + v[(n * kv_sequence_length + s) * v_head_size + h] = val; + } + } + } + + // Expected output_qk: raw QK logits = 2.0 for all entries + // Shape: [batch, q_num_heads, q_seq, total_seq] = [1, 2, 2, 3] = 12 values + std::vector expected_qk(batch_size * q_num_heads * q_sequence_length * total_sequence_length, 2.0f); + + // Expected Y: softcap(2.0) ≈ 0.964 for all QK → uniform softmax → Y = mean(V) = 0.2 + // Shape: [batch, q_num_heads, q_seq, v_head_size] = [1, 2, 2, 6] = 24 values + std::vector ys(batch_size * q_num_heads * q_sequence_length * v_head_size, 0.2f); + + // present_key = K (no past), present_value = V (no past) + // These must be provided so the OpTester has all 4 outputs for correct index mapping. + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, 0, std::numeric_limits::quiet_NaN(), 1.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode=kQK, scale=default, softcap=1.0 + ys, k, v, expected_qk, + false, false, true // disable_cpu, disable_cuda, disable_dml — runs on both CPU and CUDA unfused (v_head_size != head_size blocks Flash/MEA) + ); +} + +// ============================================================================ +// Causal alignment tests: verify upper-left (no past) vs lower-right (with past) +// These are CUDA-only tests that validate the causal masking fix. +// ============================================================================ + +// Test: Causal + cross-attention (S_q=3, S_kv=5, no past) +// ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. +// V is identity-like so output directly reveals which KV positions were attended. +// Exercises MEA (fp32, head_size divisible by 4) or Unfused kernel on CUDA. +TEST(AttentionTest, Attention4DCausalCrossAttentionUpperLeft) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 3; + int head_size = 4; + int kv_sequence_length = 5; + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, + 0.4f, 0.8f, 0.1f, 0.6f, + 0.7f, 0.3f, 0.9f, 0.5f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f, + 0.3f, 0.2f, 0.1f, 0.4f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 0.5f, 0.5f, 0.5f, 0.5f}; + // Upper-left causal (scale=0.5): q0→v[0]=[1,0,0,0], q1→softmax([0.47,0.375])@v[0:2], q2→softmax([0.6,0.48,0.495])@v[0:3] + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.523732f, 0.476268f, 0.000000f, 0.000000f, + 0.358777f, 0.318207f, 0.323016f, 0.000000f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + +// Test: Causal + cross-attention (S_q=3, S_kv=5, no past) with head_size=8. +// ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. +// head_size=8 targets the MEA path (below Flash minimum of 32) but validates +// correctness regardless of which kernel handles it. head_size=8 satisfies +// MEA's head_size%8==0 requirement, so this exercises MEA's CausalFromTopLeft +// path (via causal_from_top_left=true when past_seq==0). +// V is identity-like so output directly reveals which KV positions were attended. +TEST(AttentionTest, Attention4DCausalCrossAttentionUpperLeftSmallHead) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 3; + int head_size = 8; + int kv_sequence_length = 5; + int kv_num_heads = 1; + int v_head_size = 8; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, 0.8f, 0.4f, 0.6f, 0.1f, + 0.4f, 0.8f, 0.1f, 0.6f, 0.3f, 0.7f, 0.2f, 0.9f, + 0.7f, 0.3f, 0.9f, 0.5f, 0.1f, 0.6f, 0.4f, 0.8f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, 0.1f, 0.3f, 0.5f, 0.7f, + 0.1f, 0.3f, 0.5f, 0.7f, 0.9f, 0.2f, 0.4f, 0.6f, + 0.9f, 0.1f, 0.2f, 0.3f, 0.4f, 0.8f, 0.7f, 0.5f, + 0.5f, 0.6f, 0.7f, 0.8f, 0.2f, 0.4f, 0.3f, 0.1f, + 0.3f, 0.2f, 0.1f, 0.4f, 0.6f, 0.5f, 0.8f, 0.9f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}; + // Upper-left causal (scale=1/sqrt(8)): q0→v[0], q1→softmax(scaled_scores[0:2])@v[0:2], q2→softmax(scaled_scores[0:3])@v[0:3] + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.511488f, 0.488512f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.344711f, 0.305668f, 0.349621f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} +// Lower-right alignment: q0 at absolute position 4 attends to all 5 KV positions. +// Exercises Unfused or MEA decode path on CUDA. +TEST(AttentionTest, Attention4DCausalDecodeWithPastLowerRight) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 1; + int head_size = 4; + int kv_sequence_length = 1; // new KV tokens + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 4; // total = 4 + 1 = 5 + + // clang-format off + std::vector q = {0.7f, 0.3f, 0.9f, 0.5f}; + std::vector k = {0.3f, 0.2f, 0.1f, 0.4f}; // new key + std::vector v = {0.5f, 0.5f, 0.5f, 0.5f}; // new value + std::vector past_key = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f}; + std::vector past_value = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + // Lower-right: q0 at pos 4 sees all 5 positions. scores=[0.6,0.48,0.495,0.78,0.28]*scale=0.5 already applied + std::vector y = {0.289363f, 0.265357f, 0.268203f, 0.331229f}; + // present = concat(past, new) in BNSH layout + std::vector present_key = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f, + 0.3f, 0.2f, 0.1f, 0.4f}; + std::vector present_value = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 0.5f, 0.5f, 0.5f, 0.5f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), past_key, past_value, + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + +// Test: Causal + square (S_q=S_kv=4, no past) +// Upper-left == lower-right for square matrices. Verifies correctness on both paths. +// Exercises MEA or Unfused kernel depending on GPU capability. +TEST(AttentionTest, Attention4DCausalSquareNoPast) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 4; + int head_size = 4; + int kv_sequence_length = 4; + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, + 0.4f, 0.8f, 0.1f, 0.6f, + 0.7f, 0.3f, 0.9f, 0.5f, + 0.2f, 0.6f, 0.4f, 0.8f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + // Both alignments give identical result for square (no past). + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.523732f, 0.476268f, 0.000000f, 0.000000f, + 0.358777f, 0.318207f, 0.323016f, 0.000000f, + 0.265821f, 0.240525f, 0.196925f, 0.296730f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 48a18210face7..11a4b373c53f1 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -3674,6 +3674,127 @@ TEST(MathOpTest, Equal_string) { test.Run(); } +#ifdef USE_CUDA +// Opset 19 tests for numeric types (CUDA EP) +TEST(MathOpTest, Equal_19_bool) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {false, true, false, true}); + test.AddInput("B", dims, {false, false, true, true}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_int32) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_int64) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_float) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1.0f, 0.0f, -1.0f, -1.0f}); + test.AddInput("B", dims, {1.0f, 1.0f, 2.0f, -1.0f}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_double) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1.0, 0.0, -1.0, -1.0}); + test.AddInput("B", dims, {1.0, 1.0, 2.0, -1.0}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_float16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(-1.0f), MLFloat16(-1.0f)}); + test.AddInput("B", dims, {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(-1.0f)}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_broadcastAB) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + test.AddInput("A", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0}); + test.AddInput("B", {2}, {1, 1}); + test.AddOutput("C", {4, 2}, {true, false, false, false, true, true, false, false}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + #if defined(USE_DNNL) TEST(MathOpTest, Equal_bfloat16) { #ifdef USE_DNNL @@ -4420,6 +4541,62 @@ TEST(BitShiftOpTest, BroadcastXRight_Uint8) { test.Run(); } +// Test that shift amounts >= bit width produce 0 (not undefined behavior). +// DirectML EP has the same hardware-level shift masking behavior, so skip these tests for DML. +TEST(BitShiftOpTest, RightShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {4}, {1000, 255, 1, 42}); + test.AddInput("Y", {4}, {64, 64, 64, 64}); + test.AddOutput("Z", {4}, {0, 0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, LeftShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {4}, {1000, 255, 1, 42}); + test.AddInput("Y", {4}, {64, 64, 64, 64}); + test.AddOutput("Z", {4}, {0, 0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, RightShiftByBitWidth_Uint32) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {3}, {16, 4, 1}); + test.AddInput("Y", {3}, {32, 32, 32}); + test.AddOutput("Z", {3}, {0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, RightShiftByMoreThanBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {2}, {1000, 42}); + test.AddInput("Y", {2}, {65, 128}); + test.AddOutput("Z", {2}, {0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, ScalarRightShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {1}, {1000}); + test.AddInput("Y", {3}, {64, 65, 128}); + test.AddOutput("Z", {3}, {0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, ScalarLeftShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {3}, {1000, 255, 42}); + test.AddInput("Y", {1}, {64}); + test.AddOutput("Z", {3}, {0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + TEST(MathOpTest, BitwiseAnd) { OpTester test("BitwiseAnd", 18); std::vector dims{3}; diff --git a/onnxruntime/test/providers/cpu/math/round_test.cc b/onnxruntime/test/providers/cpu/math/round_test.cc index 5df14ac079a63..48f96fe4f8494 100644 --- a/onnxruntime/test/providers/cpu/math/round_test.cc +++ b/onnxruntime/test/providers/cpu/math/round_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" #include "core/framework/data_types.h" #include "core/util/math.h" @@ -30,5 +31,53 @@ TEST(RoundTest, SimpleTestFloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +#ifdef USE_CUDA +// Opset 22 tests +TEST(RoundTest, Round22_Float) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {0.9f, 2.5f, 2.3f, 1.5f, -4.5f}); + test.AddOutput("y", {5}, {1.0f, 2.0f, 2.0f, 2.0f, -4.0f}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoundTest, Round22_Double) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {0.9, 2.5, 2.3, 1.5, -4.5}); + test.AddOutput("y", {5}, {1.0, 2.0, 2.0, 2.0, -4.0}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoundTest, Round22_Float16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {MLFloat16(0.9f), MLFloat16(2.5f), MLFloat16(2.3f), MLFloat16(1.5f), MLFloat16(-4.5f)}); + test.AddOutput("y", {5}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(-4.0f)}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 52e8b55cb3b98..79617dc16e1f5 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -6318,5 +6318,209 @@ TEST(ReductionOpTest, ReduceSumSquare_NoopWithAxesNotProvided_ElementwiseSquare) test.ConfigEp(DefaultCpuExecutionProvider()).RunWithConfig(); } +// Opset 20 tests for ReduceMax and ReduceMin on CUDA. +// Verifies CUDA kernel registration at opset 20 works for all supported types. +#if defined(USE_CUDA) + +TEST(ReductionOpTest, ReduceMax_float_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4.0f, 8.0f, 12.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_double_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4.0, 8.0, 12.0}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_half_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + FloatsToMLFloat16s({1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f})); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, FloatsToMLFloat16s({4.0f, 8.0f, 12.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_int32_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4, 8, 12}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_int64_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4, 8, 12}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_float_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1.0f, 5.0f, 9.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_double_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1.0, 5.0, 9.0}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_half_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + FloatsToMLFloat16s({1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f})); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, FloatsToMLFloat16s({1.0f, 5.0f, 9.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_int32_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_int64_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_int8_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_uint8_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test ReduceMax at opset 20 with keepdims=0 on CUDA +TEST(ReductionOpTest, ReduceMax_float_Opset20_NoKeepdims_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3}, {4.0f, 8.0f, 12.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test ReduceMin at opset 20 with keepdims=0 on CUDA +TEST(ReductionOpTest, ReduceMin_float_Opset20_NoKeepdims_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3}, {1.0f, 5.0f, 9.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +#endif // defined(USE_CUDA) + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index 0dcf4f597d9c8..49bc10935c2c8 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -986,6 +986,69 @@ TEST(RNNTest, RNN_forward_sequence_lens_with_zero) { test.ConfigEp(std::move(cpu)).RunWithConfig(); } +TEST(RNNTest, RNN_ForwardDefaultActivations_OpSet22_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + // Simple forward RNN at opset 22 to verify CUDA registration. + int64_t seq_length = 2; + int batch_size = 1; + int64_t input_size = 2; + int64_t hidden_size = 3; + int num_directions = 1; + + std::vector X_data = {1.f, 2.f, 3.f, 4.f}; + std::vector X_dims = {seq_length, batch_size, input_size}; + + std::vector W_data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f}; + std::vector W_dims = {num_directions, hidden_size, input_size}; + + std::vector R_data(num_directions * hidden_size * hidden_size, 0.1f); + std::vector R_dims = {num_directions, hidden_size, hidden_size}; + + // Y = tanh(X * W^T + H_prev * R^T), H_prev = 0 + // time_step 0: X=[1,2], W^T cols=[0.1,0.3,0.5; 0.2,0.4,0.6] + // h0 = tanh([0.1*1+0.2*2, 0.3*1+0.4*2, 0.5*1+0.6*2]) = tanh([0.5, 1.1, 1.7]) + float h0_0 = std::tanh(0.5f); + float h0_1 = std::tanh(1.1f); + float h0_2 = std::tanh(1.7f); + + // time_step 1: X=[3,4], h_prev = h0 + // h1 = tanh(X * W^T + h0 * R^T) + // X * W^T = [0.1*3+0.2*4, 0.3*3+0.4*4, 0.5*3+0.6*4] = [1.1, 2.5, 3.9] + // h0 * R^T (R=0.1 everywhere) = [0.1*(h0_0+h0_1+h0_2), ...] (same for each) + float h0_sum = h0_0 + h0_1 + h0_2; + float h1_0 = std::tanh(1.1f + 0.1f * h0_sum); + float h1_1 = std::tanh(2.5f + 0.1f * h0_sum); + float h1_2 = std::tanh(3.9f + 0.1f * h0_sum); + + std::vector Y_data = {h0_0, h0_1, h0_2, h1_0, h1_1, h1_2}; + std::vector Y_dims = {seq_length, num_directions, batch_size, hidden_size}; + + std::vector Y_h_data = {h1_0, h1_1, h1_2}; + std::vector Y_h_dims = {num_directions, batch_size, hidden_size}; + + OpTester test("RNN", 22); + test.AddShapeToTensorData(); + + test.AddAttribute>("activations", {"Tanh"}); + test.AddAttribute("direction", string("forward")); + test.AddAttribute("hidden_size", hidden_size); + + test.AddInput("X", X_dims, X_data); + test.AddInput("W", W_dims, W_data, true); + test.AddInput("R", R_dims, R_data, true); + + test.AddOutput("Y", Y_dims, Y_data); + test.AddOutput("Y_h", Y_h_dims, Y_h_data); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // Test reverse RNN with all-zero sequence_lens and non-zero initial_h. // The bug: reverse direction with sequence_lens=0 would return initial_h instead of zero-filling. TEST(RNNTest, RNN_reverse_sequence_lens_all_zero) { diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index d5b6630668000..4481cf36554cd 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "boost/mp11.hpp" @@ -2635,7 +2636,8 @@ TEST(CastOpTest, Float8E4M3FNToInt2x4_OddShape) { template void CastOpTestFloatFloat4(std::vector shape, std::vector float_data, - bool is_fp4_input = false) { + bool is_fp4_input = false, + int opset = 23) { int num_pairs = static_cast(float_data.size()) / 2; int num_fp4_elements = static_cast((float_data.size() + 1) / 2); bool is_odd_count = (float_data.size() % 2 != 0); @@ -2653,7 +2655,7 @@ void CastOpTestFloatFloat4(std::vector shape, if (!is_fp4_input) { TestCastOp(gsl::make_span(float_data), gsl::make_span(fp4_data), shape, - OpTester::ExpectResult::kExpectSuccess, "", 23, Saturate::None, true); + OpTester::ExpectResult::kExpectSuccess, "", opset, Saturate::None, true); } else { std::vector casted_back_float; @@ -2668,7 +2670,7 @@ void CastOpTestFloatFloat4(std::vector shape, } TestCastOp(gsl::make_span(fp4_data), gsl::make_span(casted_back_float), shape, - OpTester::ExpectResult::kExpectSuccess, "", 23, Saturate::None, true); + OpTester::ExpectResult::kExpectSuccess, "", opset, Saturate::None, true); } } @@ -2732,8 +2734,185 @@ TEST(CastOpTest, Float4E2M1x2ToFloat) { } } +// Opset 25 tests for Float4 types on CUDA +TEST(CastOpTest, FloatToFloat4E2M1x2_Opset25) { + CastOpTestFloatFloat4({2, 2, 2}, + {std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 7.f, -7.f, + 0.5f, -0.5f, + std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()}, + false, 25); + + CastOpTestFloatFloat4({1, 3, 1}, + {0.256f, 0.987f, 43.8f}, + false, 25); +} + +TEST(CastOpTest, Float4E2M1x2ToFloat_Opset25) { + CastOpTestFloatFloat4({2, 2, 2}, + {0.5f, 7.34f, + 1.f, 1.5f, + 2.f, 3.f, + 4.f, 6.f}, + true, 25); + + CastOpTestFloatFloat4({1, 3, 1}, + {0.256f, 0.987f, 43.8f}, + true, 25); +} + #endif +// Opset 25 tests for standard types on CUDA. +// Verifies CUDA Cast kernel registration at opset 25 works for common type conversions. +#if defined(USE_CUDA) + +TEST(CastOpTest, StandardTypes_Opset25_Cuda) { + const std::vector shape{2, 3}; + + // float -> double + { + const std::vector input = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + const std::vector expected = {1.0, 2.5, -3.0, 0.0, 100.0, -0.5}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // double -> float + { + const std::vector input = {1.0, 2.5, -3.0, 0.0, 100.0, -0.5}; + const std::vector expected = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // float -> int32_t + { + const std::vector input = {1.0f, 2.9f, -3.0f, 0.0f, 100.0f, -0.5f}; + const std::vector expected = {1, 2, -3, 0, 100, 0}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // int32_t -> float + { + const std::vector input = {1, 2, -3, 0, 100, -7}; + const std::vector expected = {1.0f, 2.0f, -3.0f, 0.0f, 100.0f, -7.0f}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // float -> MLFloat16 + if (HasCudaEnvironment(530)) { + const std::vector input = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + const std::vector expected = CastedValues(gsl::make_span(input)); + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // MLFloat16 -> float + if (HasCudaEnvironment(530)) { + const std::vector input = CastedValues( + gsl::make_span(std::vector{1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f})); + const std::vector expected = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // BFloat16 -> float + if (HasCudaEnvironment(800)) { + const std::vector input = CastedValues( + gsl::make_span(std::vector{1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f})); + const std::vector expected = CastedValues(gsl::make_span(input)); + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // bool -> float + { + const bool input[] = {true, false, true, true, false, false}; + const gsl::span input_span(input); + const std::vector expected = {1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f}; + TestCastOp(input_span, gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } +} + +#if !defined(DISABLE_FLOAT8_TYPES) + +TEST(CastOpTest, Float8_Opset25_Cuda) { + constexpr int min_cuda_architecture = 11080; + if (!HasCudaEnvironment(min_cuda_architecture)) { + return; + } + + const std::vector shape{2, 2, 2}; + const std::vector float_input = {NAN, -1.f, 0.0391877927f, 0.296140194f, + -0.120196559f, 5.0f, + -std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + // Float8E4M3FN: float -> Float8E4M3FN at opset 25 + { + std::vector output; + output.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + output.emplace_back(Float8E4M3FN(float_input[i], true)); + } + TestCastOp(gsl::make_span(float_input), gsl::make_span(output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::True, /*cuda_only=*/true); + } + + // Float8E5M2: float -> Float8E5M2 at opset 25 + { + std::vector output; + output.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + output.emplace_back(Float8E5M2(float_input[i], true)); + } + TestCastOp(gsl::make_span(float_input), gsl::make_span(output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::True, /*cuda_only=*/true); + } + + // Float8E4M3FN -> float at opset 25 + { + std::vector input; + input.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + input.emplace_back(Float8E4M3FN(float_input[i], true)); + } + std::vector expected; + expected.reserve(input.size()); + for (const auto& v : input) { + expected.push_back(v.ToFloat()); + } + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // Float8E5M2 -> float at opset 25 + { + std::vector input; + input.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + input.emplace_back(Float8E5M2(float_input[i], true)); + } + std::vector expected; + expected.reserve(input.size()); + for (const auto& v : input) { + expected.push_back(v.ToFloat()); + } + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) + +#endif // defined(USE_CUDA) + // Regression tests for sub-byte same-type cast (CopyCpuTensor heap overflow fix). // When src and dst types are the same, Cast::Compute calls CopyCpuTensor which must // use SizeInBytes() (not shape.Size() * DataType()->Size()) for the memcpy byte count. @@ -2835,7 +3014,7 @@ TEST(CastOpTest, UInt2x4ToUInt2x4_LargeShape) { // Direct CopyCpuTensor test with guaranteed distinct buffers to exercise the memcpy path. // This bypasses the MayInplace optimization that can alias input/output in OpTester. // Uses guard bytes after the valid buffer region to detect overflow deterministically -// without relying on ASan — the pre-fix code would overwrite these sentinel bytes. +// without relying on ASan; the pre-fix code would overwrite these sentinel bytes. TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) { constexpr uint8_t kGuardByte = 0xCD; constexpr size_t kGuardSize = 64; diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 3129476b1b505..7de16b00dafe3 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -1431,6 +1431,79 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } +// Test round_prefer_ceil with half_pixel coordinate transformation. +// Exercises non-integer scale (26->64) where round_prefer_ceil selects +// source pixels at fractional boundaries. +TEST(ResizeOpTest, ResizeOpNearestUpSample_RoundPreferCeil_HalfPixel) { + OpTester test("Resize", 13); + + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 1.0f, 64.0f / 26.0f}; + + test.AddAttribute("mode", "nearest"); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + test.AddAttribute("nearest_mode", "round_prefer_ceil"); + + constexpr int64_t N = 1, C = 1, H = 1, W = 26; + std::vector X(26); + for (int i = 0; i < 26; i++) X[i] = static_cast(i); + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales); + + std::vector Y = { + 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, + 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 5.0f, 6.0f, + 6.0f, 7.0f, 7.0f, 7.0f, 8.0f, 8.0f, 9.0f, 9.0f, + 9.0f, 10.0f, 10.0f, 11.0f, 11.0f, 11.0f, 12.0f, 12.0f, + 13.0f, 13.0f, 14.0f, 14.0f, 14.0f, 15.0f, 15.0f, 16.0f, + 16.0f, 16.0f, 17.0f, 17.0f, 18.0f, 18.0f, 18.0f, 19.0f, + 19.0f, 20.0f, 20.0f, 20.0f, 21.0f, 21.0f, 22.0f, 22.0f, + 22.0f, 23.0f, 23.0f, 24.0f, 24.0f, 24.0f, 25.0f, 25.0f}; + + test.AddOutput("Y", {N, C, H, 64}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); +} + +// Test round_prefer_ceil with half_pixel for a small upsample (2x2 -> 7x8). +// Verifies that at positive .5 boundaries, ceiling is preferred. +TEST(ResizeOpTest, ResizeOpNearestUpSample_RoundPreferCeil_HalfPixel_2x2to7x8) { + OpTester test("Resize", 13); + + std::vector roi{}; + std::vector scales{}; + std::vector sizes{1, 1, 7, 8}; + + test.AddAttribute("mode", "nearest"); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + test.AddAttribute("nearest_mode", "round_prefer_ceil"); + + constexpr int64_t N = 1, C = 1, H = 2, W = 2; + std::vector X = {1.0f, 2.0f, 3.0f, 4.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("", {0}, scales); + test.AddInput("sizes", {4}, sizes); + + // half_pixel: x_orig = (x_resized + 0.5) / scale - 0.5 + // H scale = 7/2 = 3.5, W scale = 8/2 = 4.0 + // H coords: i=0: -0.357, i=1: -0.071, i=2: 0.214, i=3: 0.5, i=4: 0.786, i=5: 1.071, i=6: 1.357 + // round_prefer_ceil at 0.5 -> ceil(0.5) = 1 + // W coords: i=0: -0.375, i=1: -0.125, i=2: 0.125, i=3: 0.375, i=4: 0.625, i=5: 0.875, i=6: 1.125, i=7: 1.375 + std::vector Y = {1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f}; + + test.AddOutput("Y", {N, C, sizes[2], sizes[3]}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); +} + TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { diff --git a/onnxruntime/test/python/quantization/test_quant_preprocess.py b/onnxruntime/test/python/quantization/test_quant_preprocess.py index c93f081072f35..f00fb4a05b6d8 100644 --- a/onnxruntime/test/python/quantization/test_quant_preprocess.py +++ b/onnxruntime/test/python/quantization/test_quant_preprocess.py @@ -5,6 +5,7 @@ # license information. # -------------------------------------------------------------------------- +import sys import tempfile import unittest from pathlib import Path @@ -158,5 +159,83 @@ def test_clip_version_conversion(self): assert preprocessed_model.opset_import[0].version >= 11 +class TestSkipSymbolicShape(unittest.TestCase): + """Verify that skip_symbolic_shape=True avoids importing sympy.""" + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory(prefix="ort.quant_preprocess_skip_sympy_") + self.temp_path = Path(self.temp_dir.name) + + def tearDown(self): + self.temp_dir.cleanup() + + def build_simple_model(self): + """Build a minimal identity model for testing.""" + input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 4]) + output_tensor = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 4]) + identity_node = onnx.helper.make_node("Identity", ["input"], ["output"]) + graph = onnx.helper.make_graph([identity_node], "simple_graph", [input_tensor], [output_tensor]) + opset_imports = [onnx.helper.make_opsetid("", 13)] + return onnx.helper.make_model(graph, opset_imports=opset_imports) + + def test_skip_symbolic_shape_does_not_require_sympy(self): + """ + When skip_symbolic_shape=True, quant_pre_process must not attempt to + import onnxruntime.tools.symbolic_shape_infer (which requires sympy). + We verify this by installing a meta_path finder that raises + ModuleNotFoundError for those modules — guaranteeing any fresh import + attempt fails — and asserting the call succeeds without ever loading + them. + """ + + class _BlockSympyAndSymbolicFinder: + blocked_prefixes = ("sympy",) + blocked_substrings = ("symbolic_shape_infer",) + + def find_spec(self, fullname, path=None, target=None): + if fullname == "sympy" or fullname.startswith("sympy."): + raise ModuleNotFoundError(f"blocked by test: {fullname}") + if "symbolic_shape_infer" in fullname: + raise ModuleNotFoundError(f"blocked by test: {fullname}") + return None + + model = self.build_simple_model() + input_path = self.temp_path / "simple_model.onnx" + output_path = self.temp_path / "out_model.onnx" + onnx.save_model(model, str(input_path)) + + saved = {} + for key in list(sys.modules.keys()): + if key == "sympy" or key.startswith("sympy.") or "symbolic_shape_infer" in key: + saved[key] = sys.modules.pop(key) + + blocker = _BlockSympyAndSymbolicFinder() + sys.meta_path.insert(0, blocker) + try: + quant_pre_process( + input_model=str(input_path), + output_model_path=str(output_path), + skip_optimization=True, + skip_onnx_shape=True, + skip_symbolic_shape=True, + ) + + for mod_name in list(sys.modules): + self.assertFalse( + mod_name == "sympy" or mod_name.startswith("sympy."), + f"sympy was imported despite skip_symbolic_shape=True: {mod_name}", + ) + self.assertNotIn( + "symbolic_shape_infer", + mod_name, + f"symbolic_shape_infer was imported despite skip_symbolic_shape=True: {mod_name}", + ) + finally: + sys.meta_path.remove(blocker) + sys.modules.update(saved) + + self.assertTrue(output_path.exists(), "Output model should be created even without sympy") + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 48640fa38aca2..1ab38fb1ea0f9 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -11,7 +11,7 @@ # ------------------------------------------------------------------------- """ -Shared utilities for ONNX Attention op (opset 23) tests. +Shared utilities for ONNX Attention op (opset 23/24) tests. Contains configuration, ONNX graph builders, reference implementation, and parity check helpers used by both GQA and MHA test modules. @@ -38,9 +38,6 @@ # Reduces number of tests to run for faster pipeline checks pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" -# Number of values per parameter (compared to pipeline mode) -param_count = int(os.getenv("PARAM_COUNT", "3")) if not pipeline_mode else 2 - # When quick build is used, flash attention only supports head_size=128 quick_build = ", quick-build=" in get_build_info() @@ -71,14 +68,6 @@ torch.int8: TensorProto.INT8, } -TORCH_DTYPE_MAP = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "int8": torch.int8, - "int4": torch.uint8, -} - @dataclass class AttentionConfig: @@ -88,6 +77,7 @@ class AttentionConfig: q_num_heads: int kv_num_heads: int head_size: int + v_head_size: int = 0 # 0 means same as head_size; set explicitly for asymmetric Q/V head sizes is_causal: int = 0 past_kv_sequence_length: int = 0 softcap: float = 0.0 @@ -115,7 +105,7 @@ def create_attention_node_and_io( """ Create ONNX Attention op node and I/O definitions for testing. - ONNX Attention op (opset 23) inputs: + ONNX Attention op (opset 23/24) inputs: - 0: Q (query) - required - 1: K (key) - required - 2: V (value) - required @@ -135,6 +125,9 @@ def create_attention_node_and_io( else: # Prompt (no past KV cache) present_kv_seqlen = config.kv_sequence_length + # Effective v_head_size: defaults to head_size when not explicitly set + effective_v_head_size = config.v_head_size or config.head_size + if not config.kv_cache_type: config.kv_cache_type = { TensorProto.FLOAT16: "float16", @@ -168,7 +161,7 @@ def create_attention_node_and_io( while inputs and inputs[-1] == "": inputs.pop() - # ONNX Attention op attributes (opset 23) + # ONNX Attention op attributes (opset 23/24) node = helper.make_node( op_type="Attention", inputs=inputs, @@ -199,13 +192,14 @@ def create_attention_node_and_io( helper.make_tensor_value_info( "value", ort_type, - [config.batch_size, config.kv_num_heads, config.kv_sequence_length, config.head_size], + [config.batch_size, config.kv_num_heads, config.kv_sequence_length, effective_v_head_size], ), ] else: # 3D inputs: [batch, seq_len, hidden_size] q_hidden_size = config.q_num_heads * config.head_size kv_hidden_size = config.kv_num_heads * config.head_size + v_hidden_size = config.kv_num_heads * effective_v_head_size graph_input = [ helper.make_tensor_value_info( "query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size] @@ -214,7 +208,7 @@ def create_attention_node_and_io( "key", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] ), helper.make_tensor_value_info( - "value", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] + "value", ort_type, [config.batch_size, config.kv_sequence_length, v_hidden_size] ), ] @@ -263,10 +257,11 @@ def create_attention_node_and_io( # Shape: [batch, num_heads, past_seq_len, head_size] (4D BNSH format) if is_past: past_k_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, config.head_size] + past_v_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, effective_v_head_size] graph_input.extend( [ helper.make_tensor_value_info("past_key", cache_ort_type, past_k_shape), - helper.make_tensor_value_info("past_value", cache_ort_type, past_k_shape), + helper.make_tensor_value_info("past_value", cache_ort_type, past_v_shape), ] ) @@ -276,16 +271,17 @@ def create_attention_node_and_io( # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + output_v_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, effective_v_head_size] if config.use_4d_bnsh: - output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size] + output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size] else: - output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * config.head_size] + output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * effective_v_head_size] graph_output = [ helper.make_tensor_value_info("output", ort_type, output_shape), helper.make_tensor_value_info("present_key", cache_ort_type, output_k_shape), - helper.make_tensor_value_info("present_value", cache_ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", cache_ort_type, output_v_shape), ] if output_qk > 0: @@ -447,24 +443,26 @@ def attention_prompt_func( bind_tensor(io_binding, "nonpad_kv_seqlen", nonpad_kv_seqlen, device, TensorProto.INT64) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape for prompt (no past) present_seqlen = config.kv_sequence_length - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] # Determine dtype for cache tensors cache_dtype = out_dtype @@ -473,8 +471,8 @@ def attention_prompt_func( else: cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) @@ -565,28 +563,30 @@ def attention_past_func( bind_tensor(io_binding, "past_value", past_v_sliced, device, cache_ort_type) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape (past + new) present_seqlen = total_seq_len - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] cache_dtype = out_dtype - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) @@ -645,6 +645,9 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k) / math.sqrt(q.shape[-1]) + # Corrected ordering per onnx/onnx#7865: QK → softcap → add bias/mask → softmax + # Softcap must be applied before mask so that -inf mask values are not + # squashed to finite -softcap, which would leak probability to masked positions. if softcap > 0: scores = (scores / softcap).tanh() * softcap diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index c4e3c1b19e85e..55f07666e8c6f 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -98,16 +98,19 @@ def parity_check_gqa_prompt( ) v = torch.randn_like(k) * std - # --- Create attn_mask as boolean padding mask (simulating seqlens_k) --- + # --- Create attn_mask matching the ONNX model's expected shape --- attn_mask = None key_padding_mask = None if config.has_attn_mask: + total_seq = config.past_kv_sequence_length + config.kv_sequence_length + # 2D mask shape: [q_seq, total_seq] per ONNX spec (matches create_attention_graph_prompt) attn_mask = torch.ones( - config.batch_size, - config.kv_sequence_length, + config.q_sequence_length, + total_seq, device=device, dtype=torch.bool, ) + # key_padding_mask for PyTorch reference: [batch, kv_seq] key_padding_mask = torch.ones( config.batch_size, config.kv_sequence_length, @@ -115,6 +118,17 @@ def parity_check_gqa_prompt( dtype=torch.bool, ) + # --- Create nonpad_kv_seqlen tensor if needed (opset 24+) --- + nonpad_kv_seqlen = None + if config.has_nonpad_kv_seqlen: + # Each batch element has the full kv_sequence_length as valid (no padding) + nonpad_kv_seqlen = torch.full( + (config.batch_size,), + config.kv_sequence_length, + device=device, + dtype=torch.int64, + ) + # --- PyTorch Reference Path --- out_ref, _ = attention_ref( q=q, @@ -138,6 +152,7 @@ def parity_check_gqa_prompt( ep=ep, device=device, ort_type=ort_type, + nonpad_kv_seqlen=nonpad_kv_seqlen, ) if i == 0: first_out = out.clone() @@ -271,7 +286,7 @@ def parity_check_gqa_past( key_padding_mask = None if config.has_attn_mask: attn_mask = torch.ones( - config.batch_size, + config.q_sequence_length, total_seq_len, device=device, dtype=torch.bool, @@ -441,7 +456,7 @@ def parity_check_gqa_prompt_with_padding( ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -568,7 +583,7 @@ def parity_check_gqa_past_with_padding( ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_past_func( + out, _present_k, _present_v = attention_past_func( q=q, past_k=past_k, past_v=past_v, @@ -708,6 +723,9 @@ def gqa_prompt_padding_test_cases(): # Guard case: batch_size=4 != q_seq_len=1 (decode). This catches the original bug # where 2D mask was [batch, total_seq] instead of [q_seq, total_seq]. + # NOTE: is_causal=0 because per ONNX spec, is_causal with S_q!=S_kv and no past_key + # gives upper-left alignment (q[0] sees only kv[0]), which is not meaningful for decode. + # KV bounds are enforced by the attention mask instead. for mask_dims in mask_dims_options: config = AttentionConfig( batch_size=4, @@ -717,7 +735,7 @@ def gqa_prompt_padding_test_cases(): q_num_heads=8, kv_num_heads=2, head_size=128, - is_causal=1, + is_causal=0, has_attn_mask=True, attn_mask_dims=mask_dims, ) @@ -730,7 +748,9 @@ def gqa_past_padding_test_cases(): Generate test cases for ONNX Attention op GQA path with boolean padding masks in decoding phase. """ batches = [2] - seqs = [(1, 32)] + # past=31 + new=1 = total_seq=32, which satisfies MEA's bias alignment + # requirement (total_seq % 4 == 0) when attn_mask is present. + seqs = [(1, 31)] heads = [(8, 2)] h_sizes = [128] mask_dims_options = [2, 3, 4] @@ -863,22 +883,37 @@ def test_gqa_prompt_memory_efficient(self, name, config): # flash attention. -# TODO(titaiwang): Re-enable once PR #27851 merges (MEA supports past_key for GQA). -# Flash now rejects attn_mask (requires attn_mask==nullptr). GQA + bool mask + past_key -# has no runner until MEA supports past_key. See issue #27885. -@unittest.skip( - "Flash now rejects attn_mask. GQA + bool mask + past_key has no runner " - "until PR #27851 (MEA with past_key). See issue #27885." -) -@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") -@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) -class TestONNXAttentionPaddingMaskGQA(unittest.TestCase): +@unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQABF16(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention using BFloat16.""" + + @parameterized.expand(gqa_past_test_cases()) + def test_gqa_past_memory_efficient_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config.kv_cache_type = "bfloat16" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionPaddingMaskMEAGQA(unittest.TestCase): """ Test ONNX Attention op (opset 23) GQA path with boolean padding masks. - SKIPPED: Flash now requires attn_mask == nullptr. GQA + bool attn_mask + - past_key currently has no runner (Flash rejected, unfused doesn't support GQA, - MEA blocked by past_key != nullptr). Will be re-enabled when PR #27851 lands. + GQA + bool attn_mask + past_key uses the MEA decode path (Flash requires + attn_mask == nullptr). MEA handles bool masks via additive bias conversion. These tests verify that the boolean attn_mask is correctly converted to sequence lengths on GPU and that the attention computation respects the @@ -1011,7 +1046,7 @@ def parity_check_gqa_prompt_with_nonpad_kv_seqlen( # ORT path: use nonpad_kv_seqlen (int64 tensor) nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -1344,10 +1379,10 @@ def test_gqa_prompt_float_mask_4d(self): # ################################################################################################# -# Large Head Size Unfused GQA Tests (head_size=512, fixes #28195) +# Large Head Size Unfused Tests (head_size=512, fixes #28195) # # Flash Attention and Memory-Efficient Attention cap at head_size=256. For head_size=512 the -# op falls through to RunGqaUnfusedAttention which writes Q*K^T to an FP32 scratch buffer, +# op falls through to RunUnfusedAttention which writes Q*K^T to an FP32 scratch buffer, # eliminating fp16/bf16 overflow that caused NaNs (e.g. Gemma 4 global-attention layers). # # These tests deliberately disable both Flash and MEA to make the unfused fallback explicit @@ -1425,7 +1460,7 @@ class TestONNXAttentionGQALargeHeadUnfused(unittest.TestCase): Regression tests for GQA with head_size=512 via the unfused FP32-QK path (issue #28195). Flash Attention and MEA both cap at head_size=256. With both disabled the op routes - to RunGqaUnfusedAttention, which writes Q*K^T to an FP32 scratch buffer to avoid + to RunUnfusedAttention, which writes Q*K^T to an FP32 scratch buffer to avoid fp16/bf16 overflow that produced NaNs for Gemma 4 global-attention layers. Validates: no NaNs, numerical parity vs. PyTorch SDPA reference, for fp16 and bf16. @@ -1532,5 +1567,355 @@ def test_gqa_large_head_unfused_softcap_additive_mask_poison_fp16(self): self.assertLess(out.float().max().item(), 1.0) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQAFloatMaskDecode(unittest.TestCase): + """ + Test GQA with float additive attention mask during decode using MEA. + + This exercises the MEA decode path with float additive masks — a scenario + that was a HARD ERROR before MEA+decode support (MEA was ineligible + when past_key was present, so this fell through to no kernel). + """ + + def test_gqa_past_float_mask_4d(self): + """Test GQA decode with 4D float additive mask via MEA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment for MEA) + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 8, 128, device=device, dtype=torch_type) * std + + past_k = torch.randn(2, 2, 31, 128, device=device, dtype=torch_type) * std + past_v = torch.randn_like(past_k) * std + + new_k = torch.randn(2, 1, 2, 128, device=device, dtype=torch_type) * std + new_v = torch.randn_like(new_k) * std + + total_seq_len = 32 # past(31) + new(1), satisfies MEA bias alignment (32 % 4 == 0) + + # Create additive mask with padding pattern: batch 0 has 28 valid past, batch 1 full + past_seqlens = torch.tensor([28, 31], dtype=torch.int32, device=device) + total_seqlens = past_seqlens + config.kv_sequence_length + + attn_mask = create_additive_mask_from_seqlens( + seqlens=total_seqlens, + total_seq_len=total_seq_len, + mask_dims=4, + q_seq_len=1, + num_heads=8, + device=device, + dtype=torch_type, + ) + + # Zero padded past positions for batch 0 + past_k[0, :, 28:, :] = 0 + past_v[0, :, 28:, :] = 0 + + # Reference: concat past + new, then compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + # Expand 4D mask to reference attn_bias [batch, heads, q_seq, total_seq] + attn_bias_ref = attn_mask + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_ort = out_ort.reshape(2, 1, 8, 128) + + # --- Verify present_k/v match concatenated reference --- + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # --- Verify output --- + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMEAGQASoftcap(unittest.TestCase): + """ + Test softcap support for GQA via the Memory Efficient Attention path. + + Disables Flash Attention to force MEA. Verifies softcap with and without + attention mask for GQA (kv_num_heads != q_num_heads). + + MEA alignment requirement: total_seq % 4 == 0 when attn_mask is present. + """ + + def test_mea_gqa_softcap_with_mask_prompt_fp16(self): + """MEA GQA softcap + causal mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, # total_seq=8, divisible by 4 + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_no_mask_prompt_fp16(self): + """MEA GQA softcap without explicit mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_with_mask_decode_fp16(self): + """MEA GQA softcap + causal mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32, divisible by 4 + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + ) + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_mask_ordering_no_leakage_prompt_fp16(self): + """Guard test: verify MEA GQA softcap + mask ordering prevents attention leakage. + + Same poison-value technique as the MHA ordering test, but with GQA + (kv_num_heads != q_num_heads) forced to MEA path. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 # divisible by 4 for MEA alignment + q_num_heads = 4 + kv_num_heads = 2 + head_size = 64 + softcap_val = 2.0 + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(batch_size, q_seq, q_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, kv_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, kv_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + # 4D mask: [batch, q_num_heads, q_seq, kv_seq] + attn_mask = torch.zeros(batch_size, q_num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"MEA GQA attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means MEA applies softcap AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, q_num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.02, atol=0.02) + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping Flash GQA softcap tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) +class TestONNXAttentionFlashGQASoftcap(unittest.TestCase): + """Test softcap support for GQA via the Flash Attention path. + + Flash does NOT accept explicit attn_mask for GQA — uses nonpad_kv_seqlen + (padding mask) instead. Tests verify softcap works correctly through Flash + with and without padding mask. + + Requires SM80+ (Flash Attention hardware requirement). + """ + + def test_flash_gqa_softcap_with_padding_mask_prompt_fp16(self): + """Flash GQA softcap + padding mask (nonpad_kv_seqlen), prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_nonpad_kv_seqlen=True, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_gqa_softcap_no_mask_prompt_fp16(self): + """Flash GQA softcap without any mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_gqa_softcap_no_mask_decode_fp16(self): + """Flash GQA softcap, decode phase (past KV), fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index abe180ee35787..a488e11e39d20 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -99,9 +99,14 @@ def parity_check_mha_prompt( attn_mask = None attn_bias_ref = None if config.has_attn_mask: - # Create additive mask (0 for valid, -inf for masked) - # For prompt without padding, create a causal-style or zero mask - seqlens = torch.full((config.batch_size,), config.kv_sequence_length, dtype=torch.int32, device=device) + # When softcap is present, use partial seqlens so the mask has both valid and masked + # positions — otherwise the all-zero mask can't detect softcap→bias ordering bugs. + # For non-softcap tests, use full seqlens (existing behavior). + if config.softcap > 0: + mask_valid_len = max(1, config.kv_sequence_length * 3 // 4) + else: + mask_valid_len = config.kv_sequence_length + seqlens = torch.full((config.batch_size,), mask_valid_len, dtype=torch.int32, device=device) attn_mask = create_additive_mask_from_seqlens( seqlens=seqlens, total_seq_len=config.kv_sequence_length, @@ -127,6 +132,7 @@ def parity_check_mha_prompt( v=v, attn_bias=attn_bias_ref, causal=causal, + softcap=config.softcap, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -146,9 +152,15 @@ def parity_check_mha_prompt( if i == 0: first_out = out.clone() else: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + # FP16/BF16 GPU kernels may produce bit-level non-determinism across runs. + det_atol = 0 if torch_type == torch.float32 else 1e-3 + det_rtol = 0 if torch_type == torch.float32 else 1e-3 + torch.testing.assert_close( + out, first_out, rtol=det_rtol, atol=det_atol, msg="Output mismatch between two runs" + ) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + effective_v_head_size = config.v_head_size or config.head_size + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, effective_v_head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() # --- Comparison --- @@ -224,6 +236,65 @@ def parity_check_mha_past( ) new_v = torch.randn_like(new_k) * std + # Create attention mask if config requires one + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length + attn_mask = None + attn_bias_ref = None + if config.has_attn_mask: + # When softcap is present, use partial seqlens so the mask has both valid and masked + # positions — otherwise the all-zero mask can't detect softcap→bias ordering bugs. + # For non-softcap tests, use full seqlens (existing behavior). + if config.softcap > 0: + mask_valid_len = max(1, total_seq_len * 3 // 4) + else: + mask_valid_len = total_seq_len + seqlens = torch.full((config.batch_size,), mask_valid_len, dtype=torch.int32, device=device) + + if config.attn_mask_type == "bool": + # Create boolean mask for ORT (True=attend, False=mask) + arange = torch.arange(total_seq_len, device=device) + if config.attn_mask_dims == 2: + mask_1d = arange < seqlens[0] + attn_mask = mask_1d.unsqueeze(0).expand(config.q_sequence_length, -1).contiguous() + else: + attn_mask = create_boolean_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=config.attn_mask_dims, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + ) + # Create additive bias for PyTorch reference path + attn_bias_ref = create_additive_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=4, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + dtype=torch_type, + ) + else: + # Additive mask: same tensor for both ORT and reference + attn_mask = create_additive_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=config.attn_mask_dims, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + dtype=torch_type, + ) + if config.attn_mask_dims == 2: + attn_bias_ref = ( + attn_mask.unsqueeze(0).unsqueeze(0).expand(config.batch_size, config.q_num_heads, -1, -1) + ) + elif config.attn_mask_dims == 3: + attn_bias_ref = attn_mask.unsqueeze(0).expand(config.batch_size, -1, -1, -1) + else: + attn_bias_ref = attn_mask + # --- PyTorch Reference Path --- new_k_bnsh = new_k.transpose(1, 2) new_v_bnsh = new_v.transpose(1, 2) @@ -236,7 +307,9 @@ def parity_check_mha_past( q=q, k=full_k_bsnh, v=full_v_bsnh, + attn_bias=attn_bias_ref, causal=causal, + softcap=config.softcap, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -250,7 +323,7 @@ def parity_check_mha_past( new_k=new_k, new_v=new_v, config=config, - attn_mask=None, + attn_mask=attn_mask, ep=ep, device=device, ort_type=ort_type, @@ -258,9 +331,15 @@ def parity_check_mha_past( if i == 0: first_out = out.clone() else: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + # FP16/BF16 GPU kernels may produce bit-level non-determinism across runs. + det_atol = 0 if torch_type == torch.float32 else 1e-3 + det_rtol = 0 if torch_type == torch.float32 else 1e-3 + torch.testing.assert_close( + out, first_out, rtol=det_rtol, atol=det_atol, msg="Output mismatch between two runs" + ) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + effective_v_head_size = config.v_head_size or config.head_size + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, effective_v_head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() # --- Comparison --- @@ -367,10 +446,11 @@ def parity_check_mha_prompt_with_attn_bias( v=v, attn_bias=attn_bias_ref, causal=config.is_causal == 1, + softcap=config.softcap, ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -698,10 +778,11 @@ def parity_check_mha_prompt_with_bool_mask( v=v, key_padding_mask=key_padding_mask, causal=config.is_causal == 1, + softcap=config.softcap, ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -866,6 +947,110 @@ def test_mha_past_fp32(self, name, config): ) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEA(unittest.TestCase): + """Test ONNX Attention op MHA path — decoding with KV cache via Memory Efficient Attention. + + Explicitly forces MEA by disabling Flash Attention. This verifies that the + MEA decode path works correctly for MHA (kv_num_heads == q_num_heads). + """ + + @parameterized.expand(mha_past_test_cases()) + def test_mha_past_mea(self, name, config): + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEAFP32(unittest.TestCase): + """Test MHA decode via MEA with fp32 dtype.""" + + @parameterized.expand(mha_past_test_cases()) + def test_mha_past_mea_fp32(self, name, config): + config.kv_cache_type = "float32" + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEABoolMask(unittest.TestCase): + """Test MHA decode via MEA with boolean attention mask (converted to additive bias).""" + + def test_mha_past_bool_mask_mea(self): + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment) + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=2, + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEAFloatMask(unittest.TestCase): + """Test MHA decode via MEA with float additive attention mask.""" + + def test_mha_past_float_mask_4d_mea(self): + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment) + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping MHA tests.") class TestONNXAttentionMHAAttnBias(unittest.TestCase): """ @@ -998,7 +1183,7 @@ def parity_check_mha_prompt_with_nonpad_kv_seqlen( # ORT path: use nonpad_kv_seqlen (int64 tensor) nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -1249,116 +1434,693 @@ def test_mha_unfused_fp16(self, name, config): atol=atol["fp16"], ) - -# ################################################################################################# -# Broadcast Mask (1,1,q,kv) Tests -# ################################################################################################# + def test_mha_unfused_decode_fp32(self): + """Test unfused decode with fp32 (both Flash and MEA disabled).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + kv_cache_type="float32", + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) -@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") -class TestONNXAttentionMHABroadcastMask(unittest.TestCase): +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping unfused softcap tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION": "1"}) +class TestONNXAttentionMHAUnfusedSoftcap(unittest.TestCase): """ - Test attention with a (1,1,q_seq,kv_seq) mask that broadcasts across batch and heads. + Test softcap support in the unfused attention kernel. - This is a 4D mask with dim_0=1 (batch) and dim_1=1 (heads), verifying that - the broadcast_attn_bias_dim_0 and broadcast_attn_bias_dim_1 flags work correctly. + Disables Flash and MEA to force the unfused path. Verifies that + softcap * tanh(score / softcap) is correctly applied to attention logits + before softmax, matching the reference implementation. """ - def test_mha_broadcast_mask_additive(self): - """Test broadcast additive mask (1,1,q,kv) with MHA on CUDA.""" + def test_unfused_softcap_prompt_fp16(self): + """Test softcap on unfused path during prompt (fp16).""" config = AttentionConfig( batch_size=2, - q_sequence_length=16, - kv_sequence_length=16, - q_num_heads=8, - kv_num_heads=8, - head_size=128, - is_causal=0, - has_attn_mask=True, - attn_mask_dims=4, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, attn_mask_type="additive", - broadcast_mask_batch=True, - broadcast_mask_heads=True, ) - - torch.manual_seed(0) - device = "cuda" - torch_type = torch.float16 - - q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 - k = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 - v = torch.randn_like(k) * 0.2 - - # Create (1,1,q,kv) additive mask: lower-triangular causal pattern - mask_filter = float(torch.finfo(torch_type).min) - mask_2d = torch.zeros(16, 16, device=device, dtype=torch_type) - for i in range(16): - mask_2d[i, i + 1 :] = mask_filter - attn_mask = mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, 16, 16) - - # Reference: expand to full (B, H, Q, K) - attn_bias_ref = attn_mask.expand(2, 8, -1, -1).contiguous() - out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) - - # ORT path - out_ort, _, _ = attention_prompt_func( - q=q, - k=k, - v=v, + parity_check_mha_prompt( config=config, - attn_mask=attn_mask, ep="CUDAExecutionProvider", - device=device, + device="cuda", + torch_type=torch.float16, ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], ) - out_ort = out_ort.reshape(2, 16, 8, 128) - - out_np = out_ort.float().detach().cpu().numpy() - out_ref_np = out_ref.float().detach().cpu().numpy() - numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) - - -# ################################################################################################# -# 2D Mask Broadcast Regression Test -# ################################################################################################# + def test_unfused_softcap_decode_fp16(self): + """Test softcap on unfused path during decode (fp16).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) -@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping 2D mask broadcast tests.") -class TestONNXAttentionMHA2DMaskBroadcast(unittest.TestCase): - """ - Regression test for 2D mask [q_seq, total_seq] broadcast correctness. - - Per ONNX spec, a 2D attention mask has shape [q_seq, total_seq] and broadcasts - over batch and heads. This test uses batch_size > q_seq with a non-uniform - mask (different values per row) to verify correct broadcast behavior. - - The old bug indexed the 2D mask by batch index instead of query position, - causing OOB reads when batch_size > q_seq. - """ + def test_unfused_softcap_prompt_fp32(self): + """Test softcap on unfused path during prompt (fp32).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + kv_cache_type="float32", + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) - def test_2d_additive_mask_batch_gt_qseq(self): - """2D additive mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + def test_unfused_softcap_with_mask_prompt_fp16(self): + """Test softcap + float mask on unfused path — verifies spec-correct ordering (softcap→mask→softmax).""" config = AttentionConfig( - batch_size=4, - q_sequence_length=2, + batch_size=2, + q_sequence_length=8, kv_sequence_length=8, q_num_heads=4, kv_num_heads=4, head_size=64, is_causal=0, + softcap=2.0, has_attn_mask=True, - attn_mask_dims=2, + attn_mask_dims=4, attn_mask_type="additive", ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) - torch.manual_seed(42) - device = "cuda" - torch_type = torch.float16 - mask_filter_value = torch.finfo(torch_type).min - - q = ( - torch.randn( - config.batch_size, + def test_unfused_softcap_with_mask_decode_fp16(self): + """Test softcap + float mask on unfused decode — verifies spec-correct ordering.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- Partial masking: fp32 variants --- + + def test_unfused_softcap_with_mask_prompt_fp32(self): + """Test softcap + additive mask on unfused prompt (fp32). + + The helper auto-creates a partial mask (3/4 valid positions) when softcap > 0, + ensuring the mask has both 0.0 and -inf values to exercise the softcap→bias ordering. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + kv_cache_type="float32", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=False, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + def test_unfused_softcap_with_mask_decode_fp32(self): + """Test softcap + additive mask on unfused decode (fp32). + + Decode with past KV cache: total_seq=32, ~24 valid positions, 8 masked. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + kv_cache_type="float32", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + # --- Partial masking: different mask dimensionalities --- + + def test_unfused_softcap_with_mask_2d_prompt_fp16(self): + """Test softcap + 2D additive mask on unfused prompt. + + A 2D mask [q_seq, kv_seq] broadcasts across batch and heads. + This tests the 2D mask indexing path in the unfused kernel. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_unfused_softcap_with_mask_3d_prompt_fp16(self): + """Test softcap + 3D additive mask on unfused prompt. + + A 3D mask [heads, q_seq, kv_seq] broadcasts across batch dimension. + This tests the 3D mask broadcast path which has its own handling branch. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=3, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- Partial masking: larger sequence (different absolute mask boundary) --- + + def test_unfused_softcap_with_mask_longer_seq_prompt_fp16(self): + """Test softcap + mask with a longer sequence (kv_seq=16). + + With kv_seq=16, mask_valid_len=12 (3/4). This exercises a different absolute + mask boundary compared to the kv_seq=8 tests (valid_len=6) and provides + a wider range of softcapped logit values interacting with the mask. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_softcap_mask_ordering_no_leakage_prompt(self): + """Guard test: verify softcap + mask ordering prevents attention leakage. + + This test PROVES the ordering matters and would FAIL if someone reverts + to the wrong ordering (mask before softcap). + + Setup: Create a mask where some KV positions are -inf (masked). Place + a distinctive 'poison' value (1000.0) in V at masked positions. With + correct ordering (softcap → mask → softmax), masked positions get + -inf after bias addition → zero attention → output uncontaminated. + With wrong ordering (mask → softcap → softmax), softcap(-inf) = -softcap + (finite) → nonzero attention → output contaminated by poison values. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + # Only the first 4 KV positions are valid; last 4 are masked (-inf) + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float32 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + attn_mask = torch.zeros(batch_size, num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + # Run ONNX Runtime + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + + # If ordering is wrong, poison values leak into output producing extreme values. + # Valid output range with std=0.2 inputs and softcap=2.0 is roughly [-10, 10]. + # Any element > 50 indicates attention leakage to the poison=1000 positions. + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"Attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means softcap is applied AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.01, atol=0.01) + + def test_softcap_mask_ordering_no_leakage_decode(self): + """Guard test for decode (past KV) path: softcap + mask ordering prevents leakage. + + Same poison-value technique as the prompt test, but exercises the decode + code path with past KV cache. Masked positions in the past cache should + receive zero attention with correct ordering. + """ + batch_size = 1 + q_seq = 1 # decode: single token + kv_seq = 1 + past_kv_seq = 15 + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + total_kv_seq = past_kv_seq + kv_seq # 16 total + valid_kv_len = 8 # Only first 8 of 16 positions are valid + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + past_kv_sequence_length=past_kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float32 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Past KV with poison in masked positions + past_k = torch.randn(batch_size, num_heads, past_kv_seq, head_size, dtype=torch_type, device=device) * 0.2 + past_v = torch.randn(batch_size, num_heads, past_kv_seq, head_size, dtype=torch_type, device=device) * 0.2 + poison_value = 1000.0 + past_v[:, :, valid_kv_len:, :] = poison_value + + # Mask: 0.0 for first valid_kv_len positions, -inf for rest + attn_mask = torch.zeros(batch_size, num_heads, q_seq, total_kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + # Run ONNX Runtime via attention_past_func + out, _, _ = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=k, + new_v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"Attention leakage detected in decode path: max |output| = {max_abs:.1f}. " + f"Softcap must be applied BEFORE mask (per onnx/onnx#7865).", + ) + + +# ################################################################################################# +# Asymmetric Head Size Regression Test (MEA → unfused fallback) +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping asymmetric head size tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAAsymmetricHeadSize(unittest.TestCase): + """ + Regression test: MEA gracefully falls back to unfused when head_size != v_head_size + with past_key present (decode phase). + + Without the eligibility guard in ComputeInternal, this configuration would select + MEA which then crashes with ORT_ENFORCE because LaunchConcatNewToPastKV requires + head_size == v_head_size. The guard skips MEA and falls back to unfused attention. + + Uses MHA path (kv_num_heads == q_num_heads) because the GQA path has no unfused + fallback (returns NOT_IMPLEMENTED). + """ + + def test_mha_past_asymmetric_v_head_size(self): + """Verify decode with head_size=128, v_head_size=96 doesn't crash (falls to unfused).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=128, + v_head_size=96, + is_causal=1, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + + # Past KV in BNSH: K uses head_size=128, V uses v_head_size=96 + past_k = torch.randn(2, 4, 32, 128, device=device, dtype=torch_type) * std + past_v = torch.randn(2, 4, 32, 96, device=device, dtype=torch_type) * std + + new_k = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + new_v = torch.randn(2, 1, 4, 96, device=device, dtype=torch_type) * std + + # PyTorch reference: concat past + new, compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, causal=True) + + # ORT path — should fall back to unfused (not crash in MEA) + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=None, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + # Reshape output: [B, q_seq, q_num_heads * v_head_size] → [B, q_seq, q_num_heads, v_head_size] + out_ort = out_ort.reshape(2, 1, 4, 96) + + # Verify present_k and present_v + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # Verify output + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +# ################################################################################################# +# Broadcast Mask (1,1,q,kv) Tests +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") +class TestONNXAttentionMHABroadcastMask(unittest.TestCase): + """ + Test attention with a (1,1,q_seq,kv_seq) mask that broadcasts across batch and heads. + + This is a 4D mask with dim_0=1 (batch) and dim_1=1 (heads), verifying that + the broadcast_attn_bias_dim_0 and broadcast_attn_bias_dim_1 flags work correctly. + """ + + def test_mha_broadcast_mask_additive(self): + """Test broadcast additive mask (1,1,q,kv) with MHA on CUDA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=8, + head_size=128, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + broadcast_mask_batch=True, + broadcast_mask_heads=True, + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + k = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + v = torch.randn_like(k) * 0.2 + + # Create (1,1,q,kv) additive mask: lower-triangular causal pattern + mask_filter = float(torch.finfo(torch_type).min) + mask_2d = torch.zeros(16, 16, device=device, dtype=torch_type) + for i in range(16): + mask_2d[i, i + 1 :] = mask_filter + attn_mask = mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, 16, 16) + + # Reference: expand to full (B, H, Q, K) + attn_bias_ref = attn_mask.expand(2, 8, -1, -1).contiguous() + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + out_ort = out_ort.reshape(2, 16, 8, 128) + + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +# ################################################################################################# +# 2D Mask Broadcast Regression Test +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping 2D mask broadcast tests.") +class TestONNXAttentionMHA2DMaskBroadcast(unittest.TestCase): + """ + Regression test for 2D mask [q_seq, total_seq] broadcast correctness. + + Per ONNX spec, a 2D attention mask has shape [q_seq, total_seq] and broadcasts + over batch and heads. This test uses batch_size > q_seq with a non-uniform + mask (different values per row) to verify correct broadcast behavior. + + The old bug indexed the 2D mask by batch index instead of query position, + causing OOB reads when batch_size > q_seq. + """ + + def test_2d_additive_mask_batch_gt_qseq(self): + """2D additive mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + config = AttentionConfig( + batch_size=4, + q_sequence_length=2, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + mask_filter_value = torch.finfo(torch_type).min + + q = ( + torch.randn( + config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size, @@ -1490,6 +2252,285 @@ def test_2d_bool_mask_batch_gt_qseq(self): numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMEASoftcap(unittest.TestCase): + """ + Test softcap support in the Memory Efficient Attention (MEA) kernel. + + Disables Flash Attention to force the MEA path. Verifies that + softcap * tanh(score / softcap) is correctly applied to attention logits + in MEA, matching the reference implementation. + + MEA alignment requirement: total_seq % 4 == 0 when attn_mask is present. + """ + + # --- P0: MEA softcap+mask (MHA) --- + + def test_mea_softcap_with_mask_prompt_fp16(self): + """MEA softcap + additive mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, # total_seq=8, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_softcap_with_mask_decode_fp16(self): + """MEA softcap + additive mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq = 31+1 = 32, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- P0: MEA softcap-only (no mask) --- + + def test_mea_softcap_no_mask_prompt_fp16(self): + """MEA softcap without explicit mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_softcap_no_mask_decode_fp16(self): + """MEA softcap without explicit mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- P1: MEA softcap ordering poison test --- + + def test_mea_softcap_mask_ordering_no_leakage_prompt(self): + """Guard test: verify MEA softcap + mask ordering prevents attention leakage. + + Same poison-value technique as the unfused ordering test, but forces the + MEA path. Proves MEA correctly applies softcap before mask addition. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 # divisible by 4 for MEA alignment + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + attn_mask = torch.zeros(batch_size, num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"MEA attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means MEA applies softcap AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.02, atol=0.02) + + +@unittest.skipIf(not has_cuda_device(80), "Flash Attention requires Ampere or higher GPU, skipping tests.") +class TestONNXAttentionFlashSoftcap(unittest.TestCase): + """ + Test softcap support via Flash Attention path. + + Does NOT disable Flash or MEA — lets the dispatch cascade choose naturally. + On Ampere+ with fp16 and head_size<=256, this should route to Flash Attention. + """ + + def test_flash_softcap_prompt_fp16(self): + """Flash Attention softcap, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_softcap_decode_fp16(self): + """Flash Attention softcap, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_softcap_with_mask_prompt_fp16(self): + """Flash Attention softcap + mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # NOTE: GQA fully-masked batch fix (ZeroOutputForFullyMaskedBatches) is validated by # C++ test Attention_NonPadKVSeqLen_AllMasked_FP16_GQA. Python graph-level test omitted # because the fix is a CUDA kernel in the MEA path — a CPU-only test cannot validate it, diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py index a6a115bb12213..6b3f6d1c3ff34 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py @@ -460,16 +460,22 @@ def cpu_test_cases(): def cuda_fp16_test_cases(): - """CUDA fp16: both GQA and MHA cases. Flash attention handles external KV cache directly.""" + """CUDA fp16: both GQA and MHA cases. Flash attention handles external KV cache directly. + TensorScatter manages KV cache externally with nonpad_kv_seqlen bounding the active range. + Per ONNX spec, is_causal with S_q!=S_kv and no past_key gives upper-left alignment + (q[0] sees only kv[0]), which is not meaningful for decode. KV bounds are enforced by + nonpad_kv_seqlen instead, so is_causal=0 is the correct setting for TensorScatter decode.""" yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=0) - yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=1) def cuda_fp32_test_cases(): """CUDA fp32: MHA only. GQA requires fp16/bf16, and flash attention requires fp16/bf16. - fp32 MHA uses the unfused attention_bias fallback path.""" + fp32 MHA uses the unfused attention_bias fallback path. + TensorScatter manages KV cache externally with nonpad_kv_seqlen bounding the active range. + Per ONNX spec, is_causal with S_q!=S_kv and no past_key gives upper-left alignment + (q[0] sees only kv[0]), which is not meaningful for decode. KV bounds are enforced by + nonpad_kv_seqlen instead, so is_causal=0 is the correct setting for TensorScatter decode.""" yield from _make_test_params(_MHA_CASES, is_causal=0) - yield from _make_test_params(_MHA_CASES, is_causal=1) # ################################################################################################# @@ -975,5 +981,71 @@ def test_nonpad_with_bool_mask_cuda_fp16( numpy.testing.assert_allclose(present_v, ref_present_v, rtol=rtol["fp16"], atol=atol["fp16"]) +class TestCausalTensorScatterRejected(unittest.TestCase): + """Test that is_causal=1 + TensorScatter decode (S_q != S_kv, no past) is rejected. + + Per ONNX spec, is_causal without past_key means upper-left alignment: q[i] attends + only to kv[0..i]. For decode with external cache (S_q=1, S_kv=cache_size), this means + q[0] sees only kv[0] — not meaningful for autoregressive generation. + + The dispatch guard should return NOT_IMPLEMENTED for this combination. + Models should use is_causal=0 for TensorScatter decode. + """ + + @unittest.skipUnless("CUDAExecutionProvider" in get_available_providers(), "CUDA not available") + def test_is_causal_with_tensorscatter_no_past_rejected(self): + """Verify NOT_IMPLEMENTED is raised for is_causal=1 + TensorScatter + S_q != S_kv.""" + batch_size = 1 + q_seq_len = 1 + total_kv_seq_len = 8 + q_num_heads = 2 + kv_num_heads = 2 + head_size = 32 + + # Build model with is_causal=1 (the rejected combination) + model_bytes = build_tensorscatter_attention_graph( + batch_size=batch_size, + total_kv_seq_len=total_kv_seq_len, + q_seq_len=q_seq_len, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + ort_type=TensorProto.FLOAT16, + is_causal=1, + ) + + sess_opts = SessionOptions() + session = InferenceSession(model_bytes, sess_opts, providers=["CUDAExecutionProvider"]) + + kv_hidden = kv_num_heads * head_size + q_hidden = q_num_heads * head_size + key_cache = numpy.random.randn(batch_size, total_kv_seq_len, kv_hidden).astype(numpy.float16) + value_cache = numpy.random.randn(batch_size, total_kv_seq_len, kv_hidden).astype(numpy.float16) + new_k = numpy.random.randn(batch_size, q_seq_len, kv_hidden).astype(numpy.float16) + new_v = numpy.random.randn(batch_size, q_seq_len, kv_hidden).astype(numpy.float16) + write_indices = numpy.array([4], dtype=numpy.int64) + query = numpy.random.randn(batch_size, q_seq_len, q_hidden).astype(numpy.float16) + nonpad_kv_seqlen = numpy.array([5], dtype=numpy.int64) + + feeds = { + "key_cache": key_cache, + "value_cache": value_cache, + "new_k": new_k, + "new_v": new_v, + "write_indices": write_indices, + "query": query, + "nonpad_kv_seqlen": nonpad_kv_seqlen, + } + + with self.assertRaises(Exception) as ctx: + session.run(None, feeds) + + error_msg = str(ctx.exception) + self.assertTrue( + "NOT_IMPLEMENTED" in error_msg or "nonpad_kv_seqlen" in error_msg, + f"Expected NOT_IMPLEMENTED error for is_causal + TensorScatter decode, got: {error_msg}", + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index c5ec376f7d0f5..0237ce773eda2 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -235,7 +235,7 @@ TEST(ModelEditorAPITest, Basic_CApi) { &y_tensor)); Ort::ThrowOnError(model_editor_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); - y_tensor = nullptr; // graph now owns + api.ReleaseValue(y_tensor); if (use_constant_node) { // Test that a Constant node is converted to an initializer @@ -1083,3 +1083,326 @@ TEST(ModelEditorCompileAPITest, EmbedModeWithBufferOutputSatisfiesValidation) { allocator->Free(output_buffer); } } + +// +// Regression tests for double-free / ownership-transfer bugs in OrtModelEditorApi. +// These test that the API rejects attempts to transfer ownership of the same object twice. +// + +TEST(ModelEditorAPITest, AddInitializerToGraph_DuplicateName_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // Create two small ORT-allocated tensors (< 128 bytes, so data_is_external = false) + std::vector dims = {2, 2}; + OrtAllocator* allocator = nullptr; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + + OrtValue* tensor1 = nullptr; + OrtValue* tensor2 = nullptr; + Ort::ThrowOnError(api.CreateTensorAsOrtValue(allocator, dims.data(), dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor1)); + Ort::ThrowOnError(api.CreateTensorAsOrtValue(allocator, dims.data(), dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor2)); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddInitializerToGraph(graph, "W", tensor1, false)); + + // Second add with same name should fail + Ort::Status status{model_editor_api.AddInitializerToGraph(graph, "W", tensor2, false)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + // Clean up — caller retains ownership under copy semantics + api.ReleaseValue(tensor1); + api.ReleaseValue(tensor2); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddInitializerToGraph_SamePointerDifferentName_Succeeds) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + std::vector dims = {2, 2}; + OrtAllocator* allocator = nullptr; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + + OrtValue* tensor = nullptr; + Ort::ThrowOnError(api.CreateTensorAsOrtValue(allocator, dims.data(), dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor)); + + // Both adds succeed — each creates an independent copy sharing the same underlying data + ASSERT_ORTSTATUS_OK(model_editor_api.AddInitializerToGraph(graph, "W1", tensor, false)); + ASSERT_ORTSTATUS_OK(model_editor_api.AddInitializerToGraph(graph, "W2", tensor, false)); + + // Caller retains ownership and releases + api.ReleaseValue(tensor); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddNodeToGraph_DuplicateNode_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + OrtNode* node = CreateNode(model_editor_api, "Relu", "relu1", {"X"}, {"Y"}); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddNodeToGraph(graph, node)); + + // Second add of same node should fail (prevents double-free) + Ort::Status status{model_editor_api.AddNodeToGraph(graph, node)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddGraphToModel_DuplicateGraph_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph1 = nullptr; + OrtGraph* graph2 = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph1)); + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph2)); + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + OrtModel* model = nullptr; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model)); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddGraphToModel(model, graph1)); + + // Second add should fail (model already has a graph) + Ort::Status status{model_editor_api.AddGraphToModel(model, graph2)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already has a graph")); + + // Clean up graph2 since ownership was NOT transferred + api.ReleaseGraph(graph2); + api.ReleaseModel(model); +} + +TEST(ModelEditorAPITest, SetGraphInputs_DuplicatePointer_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // Create a single OrtValueInfo + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", type_info, &value_info)); + api.ReleaseTypeInfo(type_info); + + // Pass the same pointer twice in the inputs array — should fail + std::vector inputs = {value_info, value_info}; + Ort::Status status{model_editor_api.SetGraphInputs(graph, inputs.data(), inputs.size())}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("Duplicate")); + + // Clean up — ownership was NOT transferred + api.ReleaseValueInfo(value_info); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, SetGraphOutputs_DuplicatePointer_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // Create a single OrtValueInfo + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("Y", type_info, &value_info)); + api.ReleaseTypeInfo(type_info); + + // Pass the same pointer twice in the outputs array — should fail + std::vector outputs = {value_info, value_info}; + Ort::Status status{model_editor_api.SetGraphOutputs(graph, outputs.data(), outputs.size())}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("Duplicate")); + + // Clean up — ownership was NOT transferred + api.ReleaseValueInfo(value_info); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddNodeToGraph_NullGraph_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtNode* node = CreateNode(model_editor_api, "Relu", "relu1", {"X"}, {"Y"}); + + // Null graph should fail without crashing + Ort::Status status{model_editor_api.AddNodeToGraph(nullptr, node)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("null")); + + api.ReleaseNode(node); +} + +TEST(ModelEditorAPITest, AddGraphToModel_SameGraphTwoModels_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + OrtModel* model1 = nullptr; + OrtModel* model2 = nullptr; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model1)); + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model2)); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddGraphToModel(model1, graph)); + + // Second add to different model should fail (graph already owned) + Ort::Status status{model_editor_api.AddGraphToModel(model2, graph)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + // model2 doesn't own anything, model1 owns graph + api.ReleaseModel(model2); + api.ReleaseModel(model1); +} + +// Skipped in debug builds where the assert in Release functions would fire. +#ifdef NDEBUG +TEST(ModelEditorAPITest, ReleaseNode_AfterAddToGraph_IsNoOp) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + OrtNode* node = CreateNode(model_editor_api, "Relu", "relu1", {"X"}, {"Y"}); + + ASSERT_ORTSTATUS_OK(model_editor_api.AddNodeToGraph(graph, node)); + api.ReleaseNode(node); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, ReleaseGraph_AfterAddToModel_IsNoOp) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + OrtModel* model = nullptr; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model)); + + ASSERT_ORTSTATUS_OK(model_editor_api.AddGraphToModel(model, graph)); + api.ReleaseGraph(graph); + api.ReleaseModel(model); +} + +TEST(ModelEditorAPITest, ReleaseValueInfo_AfterSetGraphInputs_IsNoOp) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* x_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", type_info, &x_info)); + api.ReleaseTypeInfo(type_info); + + OrtValueInfo* saved_ptr = x_info; + std::vector inputs = {x_info}; + ASSERT_ORTSTATUS_OK(model_editor_api.SetGraphInputs(graph, inputs.data(), inputs.size())); + + api.ReleaseValueInfo(saved_ptr); + api.ReleaseGraph(graph); +} +#endif // NDEBUG + +TEST(ModelEditorAPITest, SetGraphInputs_AlreadyOwnedValueInfo_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph1 = nullptr; + OrtGraph* graph2 = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph1)); + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph2)); + + // Create OrtValueInfo + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* x_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", type_info, &x_info)); + api.ReleaseTypeInfo(type_info); + + // Save the raw pointer before SetGraphInputs nulls out the array entry + OrtValueInfo* saved_ptr = x_info; + std::vector inputs = {x_info}; + ASSERT_ORTSTATUS_OK(model_editor_api.SetGraphInputs(graph1, inputs.data(), inputs.size())); + + // Try to add the already-owned ValueInfo to a second graph — should fail + std::vector inputs2 = {saved_ptr}; + Ort::Status status{model_editor_api.SetGraphInputs(graph2, inputs2.data(), inputs2.size())}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + // graph1 owns x_info, graph2 is empty + api.ReleaseGraph(graph2); + api.ReleaseGraph(graph1); +} diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 5f8871d71c80a..5e8a6532e974d 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -42,14 +42,9 @@ "^test_attention_4d_attn_mask_3d_causal_expanded*", // webgpu "^test_attention_4d_diff_heads_mask4d_padded_kv*", // Need nonpad_kv_seqlen // TODO: support qk_matmul_output modes beyond kQK in Attention-cuda (see issue #27712) - // Tests combining qk_matmul with softcap need unfused-path softcap support (deferred). - "^test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // qk_matmul + softcap needs unfused softcap - "^test_attention_4d_with_qk_matmul_softcap_cuda", // qk_matmul + softcap needs unfused softcap - // softcap + diff head sizes (head_size != v_head_size) blocks Flash, falls to unfused which lacks softcap - "^test_attention_3d_diff_heads_sizes_softcap_cuda", // diff head sizes forces unfused, no softcap - "^test_attention_4d_diff_heads_sizes_softcap_cuda", // diff head sizes forces unfused, no softcap - "^test_attention_4d_attn_mask_bool_cuda", // bool mask not supported in Attention-cuda - "^test_attention_4d_attn_mask_bool_4d_cuda", // bool mask not supported in Attention-cuda + // Tests combining qk_matmul with softcap need unfused-path qk_matmul support (deferred). + "^test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // qk_matmul modes beyond kQK not supported + "^test_attention_4d_with_qk_matmul_softcap_cuda", // qk_matmul modes beyond kQK not supported "^test_attention_3d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_cuda", // QK matmul + bias not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_cuda", // QK matmul + bias not supported in Attention-cuda @@ -57,27 +52,6 @@ "^test_attention_4d_with_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda "^test_attention_3d_with_past_and_present_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda - // is_causal=Truen && q_seq_len != kv_seq_len not supported in Attention-cuda - "^test_attention_3d_causal_cuda", - "^test_attention_3d_diff_heads_sizes_causal_cuda", - "^test_attention_4d_attn_mask_3d_causal_cuda", - "^test_attention_4d_attn_mask_4d_causal_cuda", - "^test_attention_4d_causal_cuda", - "^test_attention_4d_diff_heads_sizes_causal_cuda", - // GQA Attention-cuda does not support fp16 and 4d QKV - "^test_attention_4d_gqa_with_past_and_present_fp16_cuda", // 4d QKV - "^test_attention_4d_gqa_with_past_and_present_cuda", // fp32 - "^test_attention_4d_gqa_softcap_cuda", // fp32 - "^test_attention_4d_gqa_scaled_cuda", // fp32 - "^test_attention_4d_gqa_cuda", // fp32 - "^test_attention_3d_gqa_attn_mask_cuda", // fp32 - "^test_attention_3d_gqa_causal_cuda", // fp32 - "^test_attention_3d_gqa_cuda", // fp32 - "^test_attention_3d_gqa_scaled_cuda", // fp32 - "^test_attention_3d_gqa_softcap_cuda", // fp32 - "^test_attention_3d_gqa_with_past_and_present_cuda", // fp32 - "^test_attention_4d_gqa_attn_mask_cuda", // fp32 - "^test_attention_4d_gqa_causal_cuda", // fp32 "^test_tensorscatter*", // TensorScatter(24) not implemented "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes diff --git a/plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION b/plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION new file mode 100644 index 0000000000000..bc584045a3db0 --- /dev/null +++ b/plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION @@ -0,0 +1 @@ +1.26.0 \ No newline at end of file diff --git a/plugin-ep-cuda/README.md b/plugin-ep-cuda/README.md new file mode 100644 index 0000000000000..0dc8c32904820 --- /dev/null +++ b/plugin-ep-cuda/README.md @@ -0,0 +1,30 @@ +# CUDA Plugin Execution Provider + +Packaging sources for the ONNX Runtime CUDA plugin Execution Provider (EP), distributed as a standalone artifact that +plugs into an existing ONNX Runtime installation rather than being built into the main `onnxruntime` binary. + +For more information about plugin EPs, see the documentation +[here](https://onnxruntime.ai/docs/execution-providers/plugin-ep-libraries/). + +## Contents + +- [`MIN_ONNXRUNTIME_VERSION`](MIN_ONNXRUNTIME_VERSION) - Minimum compatible ONNX Runtime version for the Python package. +- [`python/`](python/) - Sources and build script for the `onnxruntime-ep-cuda12`/`onnxruntime-ep-cuda13` Python wheels. + +## Usage + +Install the CUDA-family-specific Python distribution, then register the plugin EP at runtime. The package names are +`onnxruntime-ep-cuda12` for CUDA 12.x builds and `onnxruntime-ep-cuda13` for CUDA 13.x builds. Both distributions expose +the same Python import module, `onnxruntime_ep_cuda`. + +```python +import onnxruntime as ort +import onnxruntime_ep_cuda as cuda_ep + +ort.register_execution_provider_library(cuda_ep.get_ep_name(), cuda_ep.get_library_path()) + +devices = [d for d in ort.get_ep_devices() if d.ep_name == cuda_ep.get_ep_name()] +sess_options = ort.SessionOptions() +sess_options.add_provider_for_devices(devices, {}) +session = ort.InferenceSession("model.onnx", sess_options=sess_options) +``` diff --git a/plugin-ep-cuda/python/README.md b/plugin-ep-cuda/python/README.md new file mode 100644 index 0000000000000..5edf67540f5d0 --- /dev/null +++ b/plugin-ep-cuda/python/README.md @@ -0,0 +1,23 @@ +# CUDA Plugin EP Python Package + +This directory contains the packaging source for the CUDA plugin EP Python packages: + +- `onnxruntime-ep-cuda12` for CUDA 12.x builds +- `onnxruntime-ep-cuda13` for CUDA 13.x builds + +Both distributions install the same import module, `onnxruntime_ep_cuda`. + +## Building the wheel + +Wheels are built via `build_wheel.py`. Running `pip install` or `pip wheel` directly against this directory is not +supported because the source tree contains `pyproject.toml.in` instead of a concrete `pyproject.toml`. + +```bash +python build_wheel.py \ + --binary_dir \ + --version \ + --package_name \ + --output_dir +``` + +The script combines pre-built CUDA plugin EP binaries with the package source to produce a platform-specific wheel. diff --git a/plugin-ep-cuda/python/build_wheel.py b/plugin-ep-cuda/python/build_wheel.py new file mode 100644 index 0000000000000..a709fd06d3904 --- /dev/null +++ b/plugin-ep-cuda/python/build_wheel.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +"""Build a wheel for the onnxruntime-ep-cuda12 or onnxruntime-ep-cuda13 package.""" + +import argparse +import platform +import re +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent +MIN_ONNXRUNTIME_VERSION_FILE = SCRIPT_DIR.parent / "MIN_ONNXRUNTIME_VERSION" + +_TEMPLATE_VARIABLE_PATTERN = re.compile(r"@(\w+)@") +BINARY_PATTERNS = [ + "onnxruntime_providers_cuda_plugin.dll", + "libonnxruntime_providers_cuda_plugin.so", +] +AUDITWHEEL_EXCLUDE = [ + "libcuda.so.1", + "libcublas.so.12", + "libcublas.so.13", + "libcublasLt.so.12", + "libcublasLt.so.13", + "libcudart.so.12", + "libcudart.so.13", + "libcudnn.so.9", + "libcufft.so.11", + "libcufft.so.12", + "libnvJitLink.so.12", + "libnvJitLink.so.13", + "libnvrtc.so.12", + "libnvrtc.so.13", + "libnvrtc-builtins.so.12", + "libnvrtc-builtins.so.13", +] + + +def gen_file_from_template(template_file: Path, output_file: Path, variable_substitutions: dict[str, str]) -> None: + content = template_file.read_text(encoding="utf-8") + variables_in_file: set[str] = set() + + def replace(match: re.Match[str]) -> str: + name = match.group(1) + variables_in_file.add(name) + return variable_substitutions.get(name, match.group(0)) + + content = _TEMPLATE_VARIABLE_PATTERN.sub(replace, content) + if variables_in_file != variable_substitutions.keys(): + provided = set(variable_substitutions.keys()) + raise ValueError( + f"Template variables and substitution keys do not match for {template_file}. " + f"Only in template: {sorted(variables_in_file - provided)}. " + f"Only in substitutions: {sorted(provided - variables_in_file)}." + ) + + output_file.write_text(content, encoding="utf-8") + + +def prepare_staging_dir(staging_dir: Path, binary_dir: Path, version: str, package_name: str) -> None: + staging_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(SCRIPT_DIR / "setup.py", staging_dir / "setup.py") + shutil.copytree(SCRIPT_DIR / "onnxruntime_ep_cuda", staging_dir / "onnxruntime_ep_cuda") + + package_dir = staging_dir / "onnxruntime_ep_cuda" + copied = [] + for pattern in BINARY_PATTERNS: + for src in binary_dir.glob(pattern): + dst = package_dir / src.name + print(f"Copying {src} -> {dst}") + shutil.copy2(src, dst) + copied.append(dst) + if not copied: + raise FileNotFoundError(f"No plugin binaries found in {binary_dir}. Looked for: {BINARY_PATTERNS}") + + min_ort_version = MIN_ONNXRUNTIME_VERSION_FILE.read_text(encoding="utf-8").strip() + if not min_ort_version: + raise ValueError(f"{MIN_ONNXRUNTIME_VERSION_FILE} is empty") + + gen_file_from_template( + SCRIPT_DIR / "pyproject.toml.in", + staging_dir / "pyproject.toml", + {"package_name": package_name, "version": version, "min_onnxruntime_version": min_ort_version}, + ) + + +def build_wheel(source_dir: Path, wheel_dir: Path) -> None: + wheel_dir.mkdir(parents=True, exist_ok=True) + cmd = [ + sys.executable, + "-m", + "pip", + "wheel", + str(source_dir), + "--wheel-dir", + str(wheel_dir), + "--no-deps", + "--no-build-isolation", + ] + print(f"Running: {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def auditwheel_repair(wheel_dir: Path, wheel_name_prefix: str) -> None: + if platform.system() != "Linux": + return + + original_wheels = list(wheel_dir.glob(f"{wheel_name_prefix}-*.whl")) + if not original_wheels: + raise RuntimeError(f"No wheel found in {wheel_dir} to repair with auditwheel") + + with tempfile.TemporaryDirectory() as repaired_dir_name: + repaired_dir = Path(repaired_dir_name) + for wheel in original_wheels: + cmd = [sys.executable, "-m", "auditwheel", "repair", str(wheel), "--wheel-dir", str(repaired_dir)] + for lib in AUDITWHEEL_EXCLUDE: + cmd.extend(["--exclude", lib]) + print(f"Running: {' '.join(cmd)}") + subprocess.check_call(cmd) + wheel.unlink() + + repaired_wheels = list(repaired_dir.glob("*.whl")) + if not repaired_wheels: + raise RuntimeError(f"auditwheel repair produced no wheels in {repaired_dir}") + + for repaired_wheel in repaired_wheels: + repaired_wheel.replace(wheel_dir / repaired_wheel.name) + + +def collect_wheels(wheel_dir: Path, output_dir: Path, wheel_name_prefix: str) -> None: + wheels = list(wheel_dir.glob(f"{wheel_name_prefix}-*.whl")) + if not wheels: + raise RuntimeError("No wheel was produced") + output_dir.mkdir(parents=True, exist_ok=True) + for wheel in wheels: + dest = output_dir / wheel.name + shutil.copy2(wheel, dest) + print(f"Built wheel: {dest}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build onnxruntime-ep-cuda wheel") + parser.add_argument("--binary_dir", required=True, type=Path, help="Directory containing built plugin EP binaries") + parser.add_argument("--version", required=True, help="Package version string (PEP 440 format)") + parser.add_argument("--package_name", required=True, help="Python distribution name to write into pyproject.toml") + parser.add_argument("--output_dir", required=True, type=Path, help="Directory to place the built wheel") + args = parser.parse_args() + + if not args.binary_dir.is_dir(): + raise FileNotFoundError(f"Binary directory does not exist: {args.binary_dir}") + if not re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9._-]*", args.package_name): + raise ValueError(f"Invalid package name: {args.package_name}") + + wheel_name_prefix = args.package_name.replace("-", "_").replace(".", "_") + + with tempfile.TemporaryDirectory(prefix="ort_cuda_wheel_") as tmp: + staging_dir = Path(tmp) / "package" + wheel_dir = Path(tmp) / "wheels" + prepare_staging_dir(staging_dir, args.binary_dir, args.version, args.package_name) + build_wheel(staging_dir, wheel_dir) + auditwheel_repair(wheel_dir, wheel_name_prefix) + collect_wheels(wheel_dir, args.output_dir, wheel_name_prefix) + + +if __name__ == "__main__": + main() diff --git a/plugin-ep-cuda/python/onnxruntime_ep_cuda/README.md b/plugin-ep-cuda/python/onnxruntime_ep_cuda/README.md new file mode 100644 index 0000000000000..167ff50801d87 --- /dev/null +++ b/plugin-ep-cuda/python/onnxruntime_ep_cuda/README.md @@ -0,0 +1,17 @@ +# ONNX Runtime CUDA Plugin Execution Provider + +CUDA Execution Provider plugin for ONNX Runtime. Install alongside `onnxruntime` to enable the CUDA plugin EP. + +## Usage + +```python +import onnxruntime as ort +import onnxruntime_ep_cuda as cuda_ep + +ort.register_execution_provider_library(cuda_ep.get_ep_name(), cuda_ep.get_library_path()) + +devices = [d for d in ort.get_ep_devices() if d.ep_name == cuda_ep.get_ep_name()] +sess_options = ort.SessionOptions() +sess_options.add_provider_for_devices(devices, {}) +session = ort.InferenceSession("model.onnx", sess_options=sess_options) +``` \ No newline at end of file diff --git a/plugin-ep-cuda/python/onnxruntime_ep_cuda/__init__.py b/plugin-ep-cuda/python/onnxruntime_ep_cuda/__init__.py new file mode 100644 index 0000000000000..8e0e29c810433 --- /dev/null +++ b/plugin-ep-cuda/python/onnxruntime_ep_cuda/__init__.py @@ -0,0 +1,38 @@ +"""ONNX Runtime CUDA Plugin Execution Provider Python package.""" + +from __future__ import annotations + +import pathlib + +__all__ = [ + "get_ep_name", + "get_ep_names", + "get_library_path", +] + +_module_dir = pathlib.Path(__file__).parent + + +def get_library_path() -> str: + """Return the path to the CUDA plugin EP shared library.""" + candidate_paths = [ + _module_dir / "onnxruntime_providers_cuda_plugin.dll", + _module_dir / "libonnxruntime_providers_cuda_plugin.so", + ] + paths = [p for p in candidate_paths if p.is_file()] + if len(paths) != 1: + raise RuntimeError( + f"Expected exactly one CUDA plugin EP library in {_module_dir}, " + f"found {len(paths)}: {[p.name for p in paths]}" + ) + return str(paths[0]) + + +def get_ep_name() -> str: + """Return the CUDA plugin Execution Provider name.""" + return "CudaPluginExecutionProvider" + + +def get_ep_names() -> list[str]: + """Return a list of EP names provided by this plugin.""" + return [get_ep_name()] diff --git a/plugin-ep-cuda/python/pyproject.toml.in b/plugin-ep-cuda/python/pyproject.toml.in new file mode 100644 index 0000000000000..dfca37783d7ed --- /dev/null +++ b/plugin-ep-cuda/python/pyproject.toml.in @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "@package_name@" +version = "@version@" +description = "ONNX Runtime CUDA Plugin Execution Provider" +readme = "onnxruntime_ep_cuda/README.md" +license = {text = "MIT"} +requires-python = ">=3.11" +dependencies = [ + "onnxruntime>=@min_onnxruntime_version@", +] + +[tool.setuptools.packages.find] +include = ["onnxruntime_ep_cuda*"] + +[tool.setuptools.package-data] +onnxruntime_ep_cuda = ["*.dll", "*.so", "*.so.*"] diff --git a/plugin-ep-cuda/python/requirements-build-wheel.txt b/plugin-ep-cuda/python/requirements-build-wheel.txt new file mode 100644 index 0000000000000..eb72ee3b67d27 --- /dev/null +++ b/plugin-ep-cuda/python/requirements-build-wheel.txt @@ -0,0 +1,4 @@ +setuptools>=68.0 +wheel +auditwheel; sys_platform == "linux" +patchelf; sys_platform == "linux" \ No newline at end of file diff --git a/plugin-ep-cuda/python/setup.py b/plugin-ep-cuda/python/setup.py new file mode 100644 index 0000000000000..7b1968dbc847a --- /dev/null +++ b/plugin-ep-cuda/python/setup.py @@ -0,0 +1,21 @@ +"""Minimal setup.py to produce a platform-specific wheel.""" + +from setuptools import setup +from setuptools.dist import Distribution +from wheel.bdist_wheel import bdist_wheel + + +class PlatformBdistWheel(bdist_wheel): + """Override wheel tags to py3-none-{platform}.""" + + def get_tag(self): + _, _, plat = super().get_tag() + return "py3", "none", plat + + +class BinaryDistribution(Distribution): + def has_ext_modules(self): + return True + + +setup(distclass=BinaryDistribution, cmdclass={"bdist_wheel": PlatformBdistWheel}) diff --git a/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py b/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py new file mode 100644 index 0000000000000..885faeb56daf6 --- /dev/null +++ b/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""Smoke test for the onnxruntime-ep-cuda Python package. + +Tests: +1. Package import and library path resolution +2. EP registration with ONNX Runtime +3. Device discovery +4. Inference with a simple Mul model (requires CUDA-capable hardware) + +The inference test is skipped gracefully if no CUDA device is available +(e.g., on CPU-only build agents). +""" + +import os +import platform +import sys +import tempfile +import traceback +from pathlib import Path + +import numpy as np +import onnx + +import onnxruntime as ort + +VERBOSE = os.environ.get("ORT_TEST_VERBOSE", "").strip().lower() in ("1", "true", "yes") + + +def debug_print(*args, **kwargs): + """Print only when ORT_TEST_VERBOSE is set to a truthy value.""" + if VERBOSE: + print(*args, **kwargs) + + +def create_mul_model(output_dir: Path) -> Path: + """Create a simple Mul model in `output_dir` and return the path to the saved .onnx file.""" + x = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, [2, 3]) + y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [2, 3]) + z = onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [2, 3]) + + mul_node = onnx.helper.make_node("Mul", inputs=["x", "y"], outputs=["z"]) + + graph = onnx.helper.make_graph([mul_node], "mul_graph", [x, y], [z]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 13)]) + model.ir_version = 7 + + model_path = output_dir / "mul.onnx" + onnx.save(model, str(model_path)) + return model_path + + +def print_environment_info(): + """Print diagnostic information about the runtime environment.""" + print(f" Python: {sys.version}") + print(f" Platform: {platform.platform()}") + print(f" Architecture: {platform.machine()}") + print(f" ONNX Runtime version: {ort.__version__}") + print(f" ONNX Runtime location: {ort.__file__}") + print(f" Available providers (built-in): {ort.get_available_providers()}") + # Print relevant environment variables + for var in sorted(os.environ): + lower = var.lower() + if any(kw in lower for kw in ["onnx", "ort", "gpu", "cuda", "nv", "path", "ld_library"]): + print(f" ENV {var}={os.environ[var]}") + + +def test_import_and_library_path(): + """Test that the package imports and the library path is valid.""" + import onnxruntime_ep_cuda as cuda_ep # noqa: PLC0415 + + debug_print(f" Package location: {cuda_ep.__file__}") + pkg_dir = Path(cuda_ep.__file__).parent + debug_print(f" Package directory contents: {sorted(p.name for p in pkg_dir.iterdir())}") + + lib_path = cuda_ep.get_library_path() + assert Path(lib_path).is_file(), f"Library path does not exist: {lib_path}" + print(f"OK: Library path: {lib_path}") + + ep_name = cuda_ep.get_ep_name() + assert ep_name == "CudaPluginExecutionProvider", f"Unexpected EP name: {ep_name}" + print(f"OK: EP name: {ep_name}") + + ep_names = cuda_ep.get_ep_names() + assert ep_names == ["CudaPluginExecutionProvider"], f"Unexpected EP names: {ep_names}" + print(f"OK: EP names: {ep_names}") + + +def test_registration_and_inference(): + """Test EP registration, device discovery, and inference.""" + import onnxruntime_ep_cuda as cuda_ep # noqa: PLC0415 + + lib_path = cuda_ep.get_library_path() + ep_name = cuda_ep.get_ep_name() + registration_name = "cuda_plugin_test" + + # Register the plugin EP + debug_print(f" Registering library: {lib_path}") + debug_print(f" Library file size: {Path(lib_path).stat().st_size} bytes") + ort.register_execution_provider_library(registration_name, lib_path) + print(f"OK: Registered EP library as '{registration_name}'") + + try: + # Discover devices + all_devices = ort.get_ep_devices() + debug_print(f" All devices: {[(d.ep_name, getattr(d, 'device_id', 'N/A')) for d in all_devices]}") + cuda_devices = [d for d in all_devices if d.ep_name == ep_name] + print(f"Found {len(cuda_devices)} CUDA plugin device(s)") + + if not cuda_devices: + print("SKIP: No CUDA plugin devices available — skipping inference test") + return + + # Create session with CUDA plugin EP + sess_options = ort.SessionOptions() + sess_options.add_session_config_entry("session.disable_cpu_ep_fallback", "1") + sess_options.add_provider_for_devices(cuda_devices, {}) + assert sess_options.has_providers(), "SessionOptions should have providers after add_provider_for_devices" + print("OK: Session options configured with CUDA plugin EP") + + with tempfile.TemporaryDirectory() as model_dir: + model_path = create_mul_model(Path(model_dir)) + debug_print(f" Model path: {model_path}") + sess = ort.InferenceSession(str(model_path), sess_options=sess_options) + debug_print(f" Session providers: {sess.get_providers()}") + print("OK: InferenceSession created") + + # Run inference + x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + y = np.array([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]], dtype=np.float32) + expected = x * y + + outputs = sess.run(None, {"x": x, "y": y}) + result = outputs[0] + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + print("OK: Inference result matches expected output") + + del sess + print("OK: Session released") + + finally: + ort.unregister_execution_provider_library(registration_name) + print(f"OK: Unregistered EP library '{registration_name}'") + + +def main(): + print("=== CUDA Plugin EP Python Package Test ===") + + if VERBOSE: + # Set verbose ORT logging so ORT internals are visible in CI logs + ort.set_default_logger_severity(0) + + print("\n--- Environment ---") + print_environment_info() + + print("\n--- Test 1: Import and library path ---") + test_import_and_library_path() + + print("\n--- Test 2: Registration and inference ---") + test_registration_and_inference() + + print("\n=== All tests passed ===") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"\nFAILED: {e}", file=sys.stderr) + traceback.print_exc() + sys.exit(1) diff --git a/plugin-ep-webgpu/README.md b/plugin-ep-webgpu/README.md index dd874f8af1c3b..889fef10ae5e1 100644 --- a/plugin-ep-webgpu/README.md +++ b/plugin-ep-webgpu/README.md @@ -10,8 +10,12 @@ For more information about plugin EPs, see the documentation [here](https://onnx - [`VERSION_NUMBER`](VERSION_NUMBER) — Base plugin EP version consumed by the CI pipeline. The pipeline derives the final package version (release, dev) from this via [`tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml`](../tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml). +- [`MIN_ONNXRUNTIME_VERSION`](MIN_ONNXRUNTIME_VERSION) — Minimum compatible core `onnxruntime` version. Single source + of truth shared by all packages built from this directory. - [`python/`](python/) — Sources and build script for the `onnxruntime-ep-webgpu` Python wheel. See [`python/README.md`](python/README.md) for build and test instructions. +- [`csharp/`](csharp/) — Sources and packaging script for the `Microsoft.ML.OnnxRuntime.EP.WebGpu` NuGet package. See + [`csharp/README.md`](csharp/README.md) for build and test instructions. ## How it fits together @@ -19,6 +23,7 @@ The plugin EP is built as a shared library (`onnxruntime_providers_webgpu.{dll,s build (`--use_webgpu shared_lib`). The resulting binaries are then packaged into: - A Python wheel (`onnxruntime-ep-webgpu`), built from [`python/`](python/). +- A NuGet package (`Microsoft.ML.OnnxRuntime.EP.WebGpu`), built from [`csharp/`](csharp/). - A universal package published to the internal ORT-Nightly feed for Windows (x64 / arm64), Linux x64, and macOS arm64. @@ -29,7 +34,7 @@ and post-build smoke tests run in the companion `WebGPU Plugin EP Test Pipeline` ## Usage -Once installed, the plugin EP is registered at runtime: +Once installed, the plugin EP is registered at runtime. Example in Python: ```python import onnxruntime as ort @@ -43,5 +48,7 @@ sess_options.add_provider_for_devices(devices, {}) session = ort.InferenceSession("model.onnx", sess_options=sess_options) ``` -See [`python/onnxruntime_ep_webgpu/README.md`](python/onnxruntime_ep_webgpu/README.md) for the user-facing package -documentation (this README is bundled into the wheel). +See the user-facing package READMEs (bundled into the published packages) for full per-language usage: + +- Python: [`python/onnxruntime_ep_webgpu/README.md`](python/onnxruntime_ep_webgpu/README.md) +- C# / .NET: [`csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md`](csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md) diff --git a/plugin-ep-webgpu/RELEASE.md b/plugin-ep-webgpu/RELEASE.md new file mode 100644 index 0000000000000..8244e38eaee9a --- /dev/null +++ b/plugin-ep-webgpu/RELEASE.md @@ -0,0 +1,62 @@ +# Release Process + +This document describes the release conventions and process for the WebGPU plugin EP. + +## Versioning + +The plugin follows [Semantic Versioning](https://semver.org/): + +- **MAJOR** — incompatible API/ABI changes. +- **MINOR** — backwards-compatible feature additions. +- **PATCH** — backwards-compatible bug and security fixes. + +The current version is tracked in [VERSION_NUMBER](VERSION_NUMBER). + +## Branch and tag naming + +All release refs are namespaced under `plugin-ep-webgpu/` so they group together in `git branch` / `git tag` +listings and don't collide with the main ONNX Runtime release refs. + +- **Release branch:** `plugin-ep-webgpu/rel-X.Y` + - One branch per minor version line (e.g. `plugin-ep-webgpu/rel-1.0`). + - Holds all patch releases for that minor line (1.0.0, 1.0.1, 1.0.2, ...). + - Forked from `main` at the point of the first release on that line. +- **Release tag:** `plugin-ep-webgpu/vX.Y.Z` + - One tag per shipped release (e.g. `plugin-ep-webgpu/v1.0.0`). + - Tags are immutable and are the source of truth for "what shipped." +- **Pre-release tag:** `plugin-ep-webgpu/vX.Y.Z-rc.N` (semver-style) + - Used for release candidates and other pre-release artifacts. + - Note: this convention is forward-looking as we don't have release candidates in the release process yet. + +The `rel-` prefix on branches and the `v` prefix on tags ensure branches and tags are never ambiguous at the ref +level. + +### Difference from the main ONNX Runtime convention + +The main ORT repo uses **per-patch** release branches of the form `rel-X.Y.Z` (e.g. `rel-1.20.0`, `rel-1.20.1`). +This plugin deliberately uses **per-minor** branches (`rel-X.Y`) instead. + +The per-minor model is simpler: one long-lived branch per supported minor line, with each patch release marked by a +tag on that branch. Tags are the immutable record of what shipped; the branch is just where the next patch is staged. +For a component of this size and release cadence, that is sufficient and avoids the branch sprawl of the per-patch +model. + +The per-minor model is also the broader open-source convention (Linux, LLVM, Python, Node, Kubernetes), so +contributors coming from outside the ORT ecosystem will find it familiar. The namespaced ref prefix +(`plugin-ep-webgpu/`) keeps the plugin's release refs cleanly separated from the main ORT release refs. + +## Release workflow + +1. Prepare the release branch. + - New minor or major release: + - Create release branch `plugin-ep-webgpu/rel-X.Y` from `main`. + `main`'s `VERSION_NUMBER` should already be `X.Y.0`, reflecting the release that is about to be cut. + - Bump `VERSION_NUMBER` on `main` to the next development version (e.g. `X.(Y+1).0`). + - Patch release: + - Bump `VERSION_NUMBER` on the release branch to `X.Y.Z`. +2. Integrate any fixes into the release branch. These may be cherry-picked from `main` or made directly in the + release branch. The latter should be re-integrated into `main` unless the fix is specific to the release branch. +3. Run the full validation pipeline against the release branch tip. +4. Repeat steps 2 and 3 as needed. +5. Tag the release branch tip as `plugin-ep-webgpu/vX.Y.Z`. +6. Publish artifacts from the tag. diff --git a/plugin-ep-webgpu/VERSION_NUMBER b/plugin-ep-webgpu/VERSION_NUMBER index 6e8bf73aa550d..0ea3a944b399d 100644 --- a/plugin-ep-webgpu/VERSION_NUMBER +++ b/plugin-ep-webgpu/VERSION_NUMBER @@ -1 +1 @@ -0.1.0 +0.2.0 diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj new file mode 100644 index 0000000000000..94be6bec6ea46 --- /dev/null +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj @@ -0,0 +1,91 @@ + + + + netstandard2.0 + latest + enable + + + Microsoft.ML.OnnxRuntime.EP.WebGpu + + 0.0.0-dev + Microsoft + Microsoft + ONNX Runtime WebGPU Plugin Execution Provider. + README.md + ONNX;ONNX Runtime;Machine Learning;AI;Deep Learning;WebGPU + + + MIT + https://github.com/microsoft/onnxruntime + git + © Microsoft Corporation. All rights reserved. + + + true + snupkg + + + + + $(MSBuildThisFileDirectory)..\..\MIN_ONNXRUNTIME_VERSION + $([System.IO.File]::ReadAllText('$(OnnxRuntimeMinVersionFile)').Trim()) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md new file mode 100644 index 0000000000000..f4a717b8836d5 --- /dev/null +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md @@ -0,0 +1,42 @@ +## Microsoft.ML.OnnxRuntime.EP.WebGpu + +WebGPU plugin Execution Provider for [ONNX Runtime](https://github.com/microsoft/onnxruntime). + +### Usage + +```csharp +// Note: Error handling is omitted for brevity. + +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.EP.WebGpu; + +// Register the WebGPU EP plugin library +var env = OrtEnv.Instance(); +env.RegisterExecutionProviderLibrary("webgpu_ep", WebGpuEp.GetLibraryPath()); + +// Find the WebGPU EP device +OrtEpDevice? webGpuDevice = null; +foreach (var d in env.GetEpDevices()) +{ + if (d.EpName == WebGpuEp.GetEpName()) + { + webGpuDevice = d; + break; + } +} + +// Create a session with the WebGPU EP +using var sessionOptions = new SessionOptions(); +sessionOptions.AppendExecutionProvider(env, new[] { webGpuDevice }, new Dictionary()); + +using var session = new InferenceSession("model.onnx", sessionOptions); +``` + +### Supported Platforms + +| Runtime Identifier | Native Library | +|---|---| +| win-x64 | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| win-arm64 | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| linux-x64 | `libonnxruntime_providers_webgpu.so` | +| osx-arm64 | `libonnxruntime_providers_webgpu.dylib` | diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/WebGpuEp.cs b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/WebGpuEp.cs new file mode 100644 index 0000000000000..2a5ec106aad0d --- /dev/null +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/WebGpuEp.cs @@ -0,0 +1,112 @@ +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.OnnxRuntime.EP.WebGpu +{ + /// + /// Provides helper methods to locate the WebGPU plugin EP native library + /// and retrieve the EP name for registration with ONNX Runtime. + /// + public static class WebGpuEp + { + /// + /// Returns the path to the WebGPU plugin EP native library contained by this package. + /// Can be passed to OrtEnv.RegisterExecutionProviderLibrary(). + /// + /// Full path to the EP native library. + /// If the native library file does not exist at the expected path. + public static string GetLibraryPath() + { + string rootDir = GetNativeDirectory(); + string rid = GetRuntimeIdentifier(); + string libraryName = GetLibraryName(); + + // Probe the standard NuGet runtimes//native/ layout first, then fall back + // to the base directory for single-file/published layouts where native assets + // can land directly next to the managed assembly. + string[] candidates = + { + Path.Combine(rootDir, "runtimes", rid, "native", libraryName), + Path.Combine(rootDir, libraryName), + }; + + foreach (var candidate in candidates) + { + if (File.Exists(candidate)) + return Path.GetFullPath(candidate); + } + + throw new FileNotFoundException( + $"Did not find WebGPU EP library file. Probed: {string.Join(", ", candidates)}"); + } + + /// + /// Returns the names of the EPs created by the WebGPU plugin EP library. + /// Can be used to select an OrtEpDevice from those returned by OrtEnv.GetEpDevices(). + /// + /// Array of EP names. + public static string[] GetEpNames() + { + return new[] { GetEpName() }; + } + + /// + /// Returns the name of the one EP supported by this plugin EP library. + /// Convenience method for plugin EP packages that expose a single EP. + /// + /// The EP name string. + public static string GetEpName() + { + return "WebGpuExecutionProvider"; + } + + private static string GetNativeDirectory() + { + var assemblyDir = Path.GetDirectoryName(typeof(WebGpuEp).Assembly.Location); + + if (!string.IsNullOrEmpty(assemblyDir) && Directory.Exists(assemblyDir)) + return assemblyDir; + + return AppContext.BaseDirectory; + } + + private static string GetRuntimeIdentifier() + { + return GetOSTag() + "-" + GetArchTag(); + } + + private static string GetLibraryName() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return "onnxruntime_providers_webgpu.dll"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + return "libonnxruntime_providers_webgpu.so"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + return "libonnxruntime_providers_webgpu.dylib"; + + throw new PlatformNotSupportedException( + $"WebGPU plugin EP does not support OS platform: {RuntimeInformation.OSDescription}"); + } + + private static string GetOSTag() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) return "win"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) return "linux"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) return "osx"; + throw new PlatformNotSupportedException( + $"WebGPU plugin EP does not support OS platform: {RuntimeInformation.OSDescription}"); + } + + private static string GetArchTag() + { + return RuntimeInformation.ProcessArchitecture switch + { + Architecture.X64 => "x64", + Architecture.Arm64 => "arm64", + _ => throw new PlatformNotSupportedException( + $"WebGPU plugin EP does not support process architecture: {RuntimeInformation.ProcessArchitecture}"), + }; + } + } +} diff --git a/plugin-ep-webgpu/csharp/README.md b/plugin-ep-webgpu/csharp/README.md new file mode 100644 index 0000000000000..7a2b2041e364f --- /dev/null +++ b/plugin-ep-webgpu/csharp/README.md @@ -0,0 +1,140 @@ +# WebGPU Plugin EP — NuGet Packaging + +This directory contains the C# NuGet package project and test app for the WebGPU plugin Execution Provider. + +## Directory Structure + +``` +csharp/ +├── pack_nuget.py # Helper script to build the NuGet package +├── Microsoft.ML.OnnxRuntime.EP.WebGpu/ +│ ├── Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj # NuGet package project (netstandard2.0) +│ ├── WebGpuEp.cs # Helper class for native library resolution +│ └── README.md # Package readme (shipped inside .nupkg) +└── test/ + └── WebGpuEpNuGetTest/ + ├── WebGpuEpNuGetTest.csproj # Test console app (net8.0) + ├── Program.cs # Registers EP, runs inference, validates output + ├── mul.onnx # Test model (element-wise multiply) + └── generate_mul_model.py # Script to regenerate mul.onnx +``` + +## Prerequisites + +- .NET SDK 8.0 or later +- A built WebGPU plugin EP shared library + +## Building the NuGet Package + +Use `pack_nuget.py` to stage native binaries and run `dotnet pack`. The script copies everything into a staging +directory before building — the source tree is never modified. By default, an auto-cleaned temporary directory is used; +pass `--staging-dir` to use an explicit one (required when running with `--build-only` or `--pack-only`). + +At least one binary directory (or `--artifacts-dir` with matching subdirectories) must be provided. Platforms without +a binary directory are skipped. Run `python pack_nuget.py --help` for the full list of options and their defaults. + +### Pack with a local build (single platform) + +```powershell +cd plugin-ep-webgpu/csharp + +python pack_nuget.py --version 0.1.0-dev ` + --binary-dir-win-x64 +``` + +### Pack multiple platforms + +Each `--binary-dir-*` points at the directory containing that platform's already-built native binaries. In practice +the four binaries are produced on different machines and combined in CI; locally you'd typically only set the one(s) +you have available. + +```powershell +python pack_nuget.py --version 0.1.0-dev ` + --binary-dir-win-x64 ` + --binary-dir-win-arm64 ` + --binary-dir-linux-x64 ` + --binary-dir-macos-arm64 +``` + +## Versioning + +The package version is supplied to `pack_nuget.py` via `--version`. In the packaging pipeline, the release or +pre-release version is derived from [`plugin-ep-webgpu/VERSION_NUMBER`](../VERSION_NUMBER). + +## Inspecting the Package + +The `.nupkg` is a ZIP file. To verify its contents: + +```powershell +Expand-Archive nuget_output/Microsoft.ML.OnnxRuntime.EP.WebGpu.0.1.0-dev.nupkg ` + -DestinationPath nuget_output/inspect -Force + +Get-ChildItem nuget_output/inspect -Recurse | Select-Object FullName +``` + +Expected layout inside the package: + +``` +lib/netstandard2.0/Microsoft.ML.OnnxRuntime.EP.WebGpu.dll +runtimes/win-x64/native/onnxruntime_providers_webgpu.dll +runtimes/win-x64/native/dxil.dll +runtimes/win-x64/native/dxcompiler.dll +runtimes/win-arm64/native/... +runtimes/linux-x64/native/libonnxruntime_providers_webgpu.so +runtimes/osx-arm64/native/libonnxruntime_providers_webgpu.dylib +``` + +## Testing the Package + +The test app registers the WebGPU EP, creates a session, runs a simple Mul model, and validates the output. + +```powershell +# Point the test project's nuget.config at the pack output +$localFeed = (Resolve-Path nuget_output).Path +@" + + + + + + + + +"@ | Set-Content test/WebGpuEpNuGetTest/nuget.config + +# Build and run +dotnet run --project test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj --configuration Release +``` + +A successful run prints `PASSED: All outputs match expected values.` and exits with code 0. + +## Regenerating the Test Model + +```bash +python test/WebGpuEpNuGetTest/generate_mul_model.py +``` + +Requires the `onnx` Python package. + +## CI Pipeline + +The NuGet packaging is integrated into the WebGPU plugin pipeline: + +- **Pipeline:** `tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml` +- **Packaging stage:** `tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml` + +The CI stage downloads build artifacts from all enabled platform stages, invokes `pack_nuget.py`, ESRP-signs the +package, and runs the test app on a GPU agent. + +## Native Binaries Per Platform + +| RID | Required Files | +|---|---| +| `win-x64` | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| `win-arm64` | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| `linux-x64` | `libonnxruntime_providers_webgpu.so` | +| `osx-arm64` | `libonnxruntime_providers_webgpu.dylib` | + +On Windows, `dxil.dll` and `dxcompiler.dll` are the DirectX Shader Compiler binaries downloaded from the +[DXC GitHub releases](https://github.com/microsoft/DirectXShaderCompiler/releases). The CI pipeline handles this +automatically. diff --git a/plugin-ep-webgpu/csharp/pack_nuget.py b/plugin-ep-webgpu/csharp/pack_nuget.py new file mode 100644 index 0000000000000..9a29d067a4034 --- /dev/null +++ b/plugin-ep-webgpu/csharp/pack_nuget.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +"""Build the Microsoft.ML.OnnxRuntime.EP.WebGpu NuGet package. + +Stages native binaries from build artifacts into the runtimes/ layout expected +by the .csproj and runs `dotnet pack` to produce the .nupkg / .snupkg files. + +Can be invoked locally or from CI. In CI, pass --artifacts-dir to point at the +downloaded pipeline artifacts. Locally, pass individual --binary-dir-* options. + +Examples +-------- +Local: pack win-x64 only from a local build: + + python pack_nuget.py --version 0.1.0-dev \\ + --binary-dir-win-x64 ../../build/webgpu.plugin/Release/Release + +CI: pack all platforms from downloaded artifacts: + + python pack_nuget.py --version $(PluginPackageVersion) \\ + --artifacts-dir $(Build.BinariesDirectory)/artifacts \\ + --output-dir $(Build.ArtifactStagingDirectory)/nuget +""" + +from __future__ import annotations + +import argparse +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + +# Platform name -> (RID, list of native binary filenames expected in the source dir). +PLATFORMS: dict[str, tuple[str, tuple[str, ...]]] = { + "win_x64": ("win-x64", ("onnxruntime_providers_webgpu.dll", "dxil.dll", "dxcompiler.dll")), + "win_arm64": ("win-arm64", ("onnxruntime_providers_webgpu.dll", "dxil.dll", "dxcompiler.dll")), + "linux_x64": ("linux-x64", ("libonnxruntime_providers_webgpu.so",)), + "macos_arm64": ("osx-arm64", ("libonnxruntime_providers_webgpu.dylib",)), +} + +SCRIPT_DIR = Path(__file__).resolve().parent +PROJECT_DIR = SCRIPT_DIR / "Microsoft.ML.OnnxRuntime.EP.WebGpu" +CSPROJ = PROJECT_DIR / "Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj" +MIN_ORT_VERSION_FILE = SCRIPT_DIR.parent / "MIN_ONNXRUNTIME_VERSION" + + +class PackError(RuntimeError): + """Raised for any user-actionable failure during packaging.""" + + +def parse_args() -> argparse.Namespace: + def _absolute_path(value: str) -> Path: + """argparse `type` converter: parse a string as an absolute Path.""" + return Path(value).resolve() + + p = argparse.ArgumentParser( + description="Build the Microsoft.ML.OnnxRuntime.EP.WebGpu NuGet package.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--version", required=True, help="Package version (e.g. 0.1.0-dev).") + p.add_argument( + "--output-dir", + type=_absolute_path, + default=(SCRIPT_DIR / "nuget_output").resolve(), + help="Directory for the .nupkg / .snupkg output (default: ./nuget_output).", + ) + p.add_argument("--configuration", default="Release", help="Build configuration (default: Release).") + + # CI mode: a single root containing per-platform subdirectories. + p.add_argument( + "--artifacts-dir", + type=_absolute_path, + help="CI mode: root containing /bin/ subdirectories for each platform.", + ) + + # Local mode: explicit per-platform binary directories. Each takes precedence over + # --artifacts-dir for that platform. + for name in PLATFORMS: + flag = f"--binary-dir-{name.replace('_', '-')}" + p.add_argument(flag, type=_absolute_path, dest=f"binary_dir_{name}", help=f"Path to {name} native binaries.") + + p.add_argument( + "--nuget-config", type=_absolute_path, help="Optional NuGet.config passed to dotnet via --configfile." + ) + p.add_argument( + "--staging-dir", + type=_absolute_path, + help=( + "Explicit staging directory. Required with --build-only / --pack-only " + "(caller owns its lifecycle). When omitted, an auto-cleaned temporary " + "directory is used for the full build+pack flow." + ), + ) + + phase = p.add_mutually_exclusive_group() + phase.add_argument( + "--build-only", + action="store_true", + help="Stage and build the managed DLL only; skip dotnet pack. Preserves the staging dir.", + ) + phase.add_argument( + "--pack-only", + action="store_true", + help="Skip staging/build and run dotnet pack against an existing staging directory.", + ) + + p.add_argument( + "--required-platforms", + default="", + help=( + "Comma-separated list of platforms that MUST be staged successfully. " + "When omitted, the script just requires at least one platform to be staged." + ), + ) + + return p.parse_args() + + +def parse_required_platforms(value: str) -> list[str]: + names = [tok.strip() for tok in value.split(",") if tok.strip()] + invalid = [n for n in names if n not in PLATFORMS] + if invalid: + raise PackError( + f"unknown platform(s) in --required-platforms: {', '.join(invalid)}. valid: {', '.join(PLATFORMS)}." + ) + return names + + +def stage_sources(staging_dir: Path) -> None: + """Copy project sources into staging, excluding bin/obj.""" + print(f"Staging project files to {staging_dir}") + if staging_dir.exists(): + shutil.rmtree(staging_dir) + shutil.copytree( + PROJECT_DIR, + staging_dir, + ignore=shutil.ignore_patterns("bin", "obj"), + ) + + +def resolve_platform_source( + name: str, + binary_dir_override: Path | None, + artifacts_dir: Path | None, + is_required: bool, +) -> Path | None: + """Return the source dir for a platform, or None to skip.""" + if binary_dir_override is not None: + return binary_dir_override + if artifacts_dir is not None: + candidate = artifacts_dir / name / "bin" + if candidate.is_dir(): + return candidate + if is_required: + raise PackError(f"required platform '{name}' artifact directory not found: {candidate}") + if is_required: + raise PackError( + f"required platform '{name}' has no binary directory " + f"(pass --binary-dir-{name.replace('_', '-')} or --artifacts-dir)." + ) + return None + + +def stage_binaries( + staging_dir: Path, + args: argparse.Namespace, + required_platforms: list[str], +) -> None: + staged: set[str] = set() + + for name, (rid, files) in PLATFORMS.items(): + binary_dir_override: Path | None = getattr(args, f"binary_dir_{name}") + is_required = name in required_platforms + source_dir = resolve_platform_source(name, binary_dir_override, args.artifacts_dir, is_required) + if source_dir is None: + print(f"Skipping {name} (no binary directory provided)") + continue + if not source_dir.is_dir(): + raise PackError(f"binary directory does not exist: {source_dir}") + + target_dir = staging_dir / "runtimes" / rid / "native" + target_dir.mkdir(parents=True, exist_ok=True) + + print(f"Staging {name} -> runtimes/{rid}/native/") + for filename in files: + src = source_dir / filename + if not src.is_file(): + raise PackError(f"expected binary not found: {src}") + shutil.copy2(src, target_dir / filename) + print(f" {filename}") + staged.add(name) + + if required_platforms: + missing = [n for n in required_platforms if n not in staged] + if missing: + raise PackError(f"required platforms not staged: {', '.join(missing)}") + elif not staged: + raise PackError("no platform binaries were staged. Provide at least one --binary-dir-* or --artifacts-dir.") + + print() + print("Runtimes layout:") + for path in sorted((staging_dir / "runtimes").rglob("*")): + print(f" {path}") + + +def dotnet_common_args( + staged_csproj: Path, + args: argparse.Namespace, + min_ort_version_file: Path, +) -> list[str]: + common = [ + str(staged_csproj), + "--configuration", + args.configuration, + f"-p:Version={args.version}", + f"-p:OnnxRuntimeMinVersionFile={min_ort_version_file}", + ] + if args.nuget_config: + common.extend(["--configfile", str(args.nuget_config)]) + print(f"Using NuGet.config: {args.nuget_config}") + return common + + +def do_build(staged_csproj: Path, staging_dir: Path, args: argparse.Namespace, min_ort_version_file: Path) -> None: + print() + print(f"Running dotnet build (Version={args.version}, Configuration={args.configuration})...") + cmd = ["dotnet", "build", *dotnet_common_args(staged_csproj, args, min_ort_version_file)] + print("+ " + " ".join(cmd)) + subprocess.run(cmd, check=True) + + # Note: "netstandard2.0" must match in Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj. + managed_dll = staging_dir / "bin" / args.configuration / "netstandard2.0" / "Microsoft.ML.OnnxRuntime.EP.WebGpu.dll" + if not managed_dll.is_file(): + raise PackError(f"managed DLL not found after build: {managed_dll}") + print() + print(f"Built managed DLL: {managed_dll}") + print("Staging directory preserved for subsequent --pack-only invocation.") + + +def do_pack( + staged_csproj: Path, + output_dir: Path, + args: argparse.Namespace, + min_ort_version_file: Path, +) -> None: + print() + print(f"Running dotnet pack (Version={args.version}, Configuration={args.configuration})...") + pack_args = [ + "dotnet", + "pack", + *dotnet_common_args(staged_csproj, args, min_ort_version_file), + "--output", + str(output_dir), + ] + if args.pack_only: + pack_args.append("--no-build") + print("+ " + " ".join(pack_args)) + subprocess.run(pack_args, check=True) + + print() + nupkgs = sorted(output_dir.glob("*.nupkg")) + if not nupkgs: + raise PackError(f"no .nupkg files found in {output_dir}") + for pkg in nupkgs: + print(f"Produced: {pkg.name} ({pkg.stat().st_size / (1024 * 1024):.2f} MB)") + for pkg in sorted(output_dir.glob("*.snupkg")): + print(f"Produced: {pkg.name} ({pkg.stat().st_size / (1024 * 1024):.2f} MB)") + + +def run_in_staging(args: argparse.Namespace, staging_dir: Path, min_ort_version_file: Path) -> None: + staged_csproj = staging_dir / "Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj" + output_dir: Path = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + required_platforms = parse_required_platforms(args.required_platforms) + + if args.pack_only: + if not staged_csproj.is_file(): + raise PackError(f"staged project not found at {staged_csproj}. Run with --build-only first.") + print(f"Reusing existing staging directory: {staging_dir}") + else: + stage_sources(staging_dir) + stage_binaries(staging_dir, args, required_platforms) + + if args.build_only: + do_build(staged_csproj, staging_dir, args, min_ort_version_file) + return + + do_pack(staged_csproj, output_dir, args, min_ort_version_file) + + print() + print(f"Done. Output: {output_dir}") + + +def run(args: argparse.Namespace) -> None: + if not CSPROJ.is_file(): + raise PackError(f"project file not found: {CSPROJ}") + if not MIN_ORT_VERSION_FILE.is_file(): + raise PackError(f"MIN_ONNXRUNTIME_VERSION file not found: {MIN_ORT_VERSION_FILE}") + if args.nuget_config and not args.nuget_config.is_file(): + raise PackError(f"NuGet.config not found: {args.nuget_config}") + + if (args.build_only or args.pack_only) and not args.staging_dir: + raise PackError("--staging-dir is required when using --build-only or --pack-only.") + + min_ort_version_file = MIN_ORT_VERSION_FILE.resolve() + + if args.staging_dir: + staging_dir: Path = args.staging_dir + staging_dir.mkdir(parents=True, exist_ok=True) + run_in_staging(args, staging_dir, min_ort_version_file) + return + + # Full build+pack flow with no caller-managed staging dir: use a temp dir that + # is cleaned up automatically (including on exception). + with tempfile.TemporaryDirectory(prefix="webgpu_pack_") as tmp: + run_in_staging(args, Path(tmp), min_ort_version_file) + + +def main() -> int: + args = parse_args() + try: + run(args) + except PackError as e: + print(f"error: {e}", file=sys.stderr) + return 1 + except subprocess.CalledProcessError as e: + cmd_name = e.cmd[0] if e.cmd else "subprocess" + print(f"error: {cmd_name} failed with exit code {e.returncode}", file=sys.stderr) + return e.returncode or 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/Program.cs b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/Program.cs new file mode 100644 index 0000000000000..f5d1f0628c831 --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/Program.cs @@ -0,0 +1,82 @@ +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.EP.WebGpu; + +class Program +{ + static int Main() + { + string epLibPath = WebGpuEp.GetLibraryPath(); + string epRegistrationName = "webgpu_ep_registration"; + string epName = WebGpuEp.GetEpName(); + + Console.WriteLine($"WebGPU EP library path: {epLibPath}"); + + var env = OrtEnv.Instance(); + env.RegisterExecutionProviderLibrary(epRegistrationName, epLibPath); + Console.WriteLine($"Registered EP library: {epLibPath}"); + + try + { + // Find the OrtEpDevice for the WebGPU EP + OrtEpDevice? epDevice = null; + foreach (var d in env.GetEpDevices()) + { + if (string.Equals(epName, d.EpName, StringComparison.Ordinal)) + { + epDevice = d; + break; + } + } + + if (epDevice == null) + { + Console.Error.WriteLine($"ERROR: Unable to find OrtEpDevice with name '{epName}'"); + return 1; + } + Console.WriteLine($"Found OrtEpDevice for EP: {epName}"); + + // Create session with WebGPU EP + using var sessionOptions = new SessionOptions(); + sessionOptions.AppendExecutionProvider(env, new[] { epDevice }, new Dictionary()); + sessionOptions.AddSessionConfigEntry("session.disable_cpu_ep_fallback", "1"); + + string inputModelPath = Path.Combine(AppContext.BaseDirectory, "mul.onnx"); + Console.WriteLine($"Loading model: {inputModelPath}"); + + using var session = new InferenceSession(inputModelPath, sessionOptions); + + // Run model: mul(x, y) = x * y + float[] inputData = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }; + using var inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, new long[] { 2, 3 }); + var inputValues = new List { inputOrtValue, inputOrtValue }.AsReadOnly(); + var inputNames = new List { "x", "y" }.AsReadOnly(); + using var runOptions = new RunOptions(); + + using var outputs = session.Run(runOptions, inputNames, inputValues, session.OutputNames); + + float[] expected = { 1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f }; + var actual = outputs[0].GetTensorDataAsSpan().ToArray(); + + Console.WriteLine($"Input: {string.Join(", ", inputData)}"); + Console.WriteLine($"Output: {string.Join(", ", actual)}"); + Console.WriteLine($"Expected: {string.Join(", ", expected)}"); + + // Validate output + for (int i = 0; i < expected.Length; i++) + { + if (Math.Abs(actual[i] - expected[i]) > 1e-5f) + { + Console.Error.WriteLine($"ERROR: Output mismatch at index {i}: expected {expected[i]}, got {actual[i]}"); + return 1; + } + } + + Console.WriteLine("PASSED: All outputs match expected values."); + return 0; + } + finally + { + env.UnregisterExecutionProviderLibrary(epRegistrationName); + } + } +} diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj new file mode 100644 index 0000000000000..9554161b1e978 --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj @@ -0,0 +1,34 @@ + + + + Exe + net8.0 + latest + enable + enable + + *-* + + + + + + + + + + PreserveNewest + + + + + + + diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/generate_mul_model.py b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/generate_mul_model.py new file mode 100644 index 0000000000000..c64b4b7ec96bc --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/generate_mul_model.py @@ -0,0 +1,25 @@ +"""Generate a simple Mul ONNX model for testing. + +Produces mul.onnx in the same directory as this script. +The model computes z = x * y (element-wise) for float32 tensors of shape [2, 3]. +""" + +import os + +from onnx import TensorProto, checker, helper, save + +X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3]) +Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3]) +Z = helper.make_tensor_value_info("z", TensorProto.FLOAT, [2, 3]) + +mul_node = helper.make_node("Mul", inputs=["x", "y"], outputs=["z"]) + +graph = helper.make_graph([mul_node], "mul_graph", [X, Y], [Z]) +model = helper.make_model(graph, producer_name="onnxruntime-webgpu-ep-test") +model.opset_import[0].version = 13 + +checker.check_model(model) + +output_path = os.path.join(os.path.dirname(__file__), "mul.onnx") +save(model, output_path) +print(f"Saved {output_path}") diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/mul.onnx b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/mul.onnx new file mode 100644 index 0000000000000..6df01feb5cf58 --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/mul.onnx @@ -0,0 +1,16 @@ + onnxruntime-webgpu-ep-test:Z + +x +yz"Mul mul_graphZ +x +  + +Z +y +  + +b +z +  + +B \ No newline at end of file diff --git a/plugin-ep-webgpu/python/README.md b/plugin-ep-webgpu/python/README.md index ac14a84a70f48..849105a439396 100644 --- a/plugin-ep-webgpu/python/README.md +++ b/plugin-ep-webgpu/python/README.md @@ -19,19 +19,13 @@ Wheels are built via `build_wheel.py`. Running `pip install` or `pip wheel` dire supported — the source tree contains `pyproject.toml.in` (a template), not a real `pyproject.toml`. ```bash -python build_wheel.py \ - --binary_dir \ - --version \ - --output_dir +python build_wheel.py --binary_dir --version --output_dir ``` Example: ```bash -python build_wheel.py \ - --binary_dir ./build/Release \ - --version 0.1.0.dev20260429 \ - --output_dir ./dist +python build_wheel.py --binary_dir ./build/Release --version 0.1.0.devYYYYMMDD --output_dir ./dist ``` The script combines the pre-built plugin EP binaries with the package source to produce a platform-specific wheel. @@ -44,7 +38,7 @@ Install the wheel and dependencies in a clean environment, then run the smoke te python -m venv test_venv source test_venv/bin/activate # or test_venv\Scripts\Activate.ps1 on Windows pip install onnx numpy -pip install dist/onnxruntime_ep_webgpu-*.whl # pulls in onnxruntime>=1.24.4 +pip install dist/onnxruntime_ep_webgpu-*.whl # pulls in the minimum compatible onnxruntime python test/test_webgpu_plugin_ep.py ``` diff --git a/plugin-ep-webgpu/python/build_wheel.py b/plugin-ep-webgpu/python/build_wheel.py index 8f855a5d2179b..b4357bcdfbe0f 100644 --- a/plugin-ep-webgpu/python/build_wheel.py +++ b/plugin-ep-webgpu/python/build_wheel.py @@ -86,6 +86,7 @@ def prepare_staging_dir(staging_dir: Path, binary_dir: Path, version: str): shutil.copytree(SCRIPT_DIR / "onnxruntime_ep_webgpu", staging_dir / "onnxruntime_ep_webgpu") # Copy plugin binaries into the package directory + # Note: The binaries are assumed to be directly under `binary_dir`. package_dir = staging_dir / "onnxruntime_ep_webgpu" copied = [] for pattern in BINARY_PATTERNS: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index 5ddac928b32d3..ba57a4b2c85c9 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -69,6 +69,8 @@ stages: - stage: Android_Java_API_AAR_Testing_Full dependsOn: Setup + variables: + ReleaseVersionSuffix: $[ stageDependencies.Setup.Restore_And_Use_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix'] ] jobs: - template: templates/android-java-api-aar-test.yml parameters: @@ -77,6 +79,8 @@ stages: - stage: Final_AAR_Testing_Android_QNN dependsOn: Setup + variables: + ReleaseVersionSuffix: $[ stageDependencies.Setup.Restore_And_Use_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix'] ] jobs: - template: templates/android-java-api-aar-test.yml parameters: @@ -84,6 +88,7 @@ stages: packageName: 'onnxruntime-android-qnn' #TODO: get this information from the setup stage QnnSDKVersion: '2.42.0.251225' + ReleaseVersionSuffix: $(ReleaseVersionSuffix) - template: nuget/templates/test_win.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml index 3cf28655c36e7..888a9142088ee 100644 --- a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml @@ -31,6 +31,9 @@ parameters: type: number default: 0 +variables: +- template: templates/common-variables.yml + extends: # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. # For non-production pipelines, use "Unofficial" as defined below. @@ -76,7 +79,7 @@ extends: DoEsrp: ${{ parameters.DoEsrp }} NuPackScript: | python -m pip install setuptools - msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} /p:CurrentData=$(BuildDate) /p:CurrentTime=$(BuildTime) + msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} /p:CurrentData=$(ORT_CI_BUILD_DATE) /p:CurrentTime=$(ORT_CI_BUILD_TIME) if errorlevel 1 exit /b 1 copy $(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) if errorlevel 1 exit /b 1 diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 2548eebeb9d42..fa009c379a911 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -20,7 +20,7 @@ parameters: IsReleaseBuild: false stages: - stage: ${{ parameters.StageName }} - dependsOn: Setup + dependsOn: [] jobs: - job: ${{ parameters.StageName }} timeoutInMinutes: 200 @@ -39,8 +39,6 @@ stages: OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] ${{ if eq(parameters.EnableLto, true) }}: build_py_lto_flag: --enable_lto diff --git a/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml b/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml index 5e183e057aee9..4385446c6b741 100644 --- a/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml @@ -130,9 +130,11 @@ extends: version_file: ${{ variables.epVersionFile }} cmake_build_type: ${{ parameters.cmake_build_type }} ${{ if eq(parameters.cuda_version, '12.8') }}: + python_package_name: 'onnxruntime-ep-cuda12' docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' cmake_cuda_archs: '52-real;61-real;75-real;86-real;89-real;90-virtual' ${{ if eq(parameters.cuda_version, '13.0') }}: + python_package_name: 'onnxruntime-ep-cuda13' docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' docker_base_image_aarch64: 'onnxruntimebuildcache.azurecr.io/public/azureml/onnxruntime_build_cuda13_aarch64_almalinux9_gcc14:20260323.1' cmake_cuda_archs: '75-real;80-real;86-real;89-real;90-real;100-real;120-real;120-virtual' diff --git a/tools/ci_build/github/azure-pipelines/plugin-cuda-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/plugin-cuda-test-pipeline.yml new file mode 100644 index 0000000000000..83273c8870408 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/plugin-cuda-test-pipeline.yml @@ -0,0 +1,98 @@ +# This pipeline runs tests against artifacts produced by the CUDA +# plugin packaging pipeline. It is resource-triggered on successful +# packaging runs and can also be queued manually against any prior +# packaging run. +# +# Split from the packaging pipeline so the test side can be iterated +# on without rebuilding the CUDA plugin from source. + +trigger: none + +variables: +- name: DisableDockerDetector + value: true +- name: skipNugetSecurityAnalysis + value: true +- name: Codeql.SkipTaskAutoInjection + value: true + +resources: + pipelines: + - pipeline: build + source: 'CUDA Plugin EP Packaging Pipeline' + trigger: true + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + +parameters: +- name: test_windows_x64 + displayName: 'Test Windows x64' + type: boolean + default: true + +- name: test_linux_x64 + displayName: 'Test Linux x64' + type: boolean + default: true + +- name: cuda_version + displayName: 'CUDA Version' + type: string + default: '12.8' + values: + - '12.8' + - '13.0' + +extends: + # The pipeline extends the 1ES PT which will inject SDL and compliance + # tasks. Uses "Official" to stay consistent with the companion + # CUDA plugin packaging pipeline. + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + settings: + networkIsolationPolicy: Permissive + sdl: + # No top-level `pool:` is declared for this pipeline (each stage + # template pins its own pool), so source analysis needs an + # explicit pool. + sourceAnalysisPool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + componentgovernance: + ignoreDirectories: '$(Build.Repository.LocalPath)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/benchmark,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11/tests,$(Build.Repository.LocalPath)/cmake/external/onnxruntime-extensions,$(Build.Repository.LocalPath)/js/react_native/e2e/node_modules,$(Build.Repository.LocalPath)/js/node_modules,$(Build.Repository.LocalPath)/onnxruntime-inference-examples,$(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/benchmark,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11/tests,$(Build.SourcesDirectory)/cmake/external/onnxruntime-extensions,$(Build.SourcesDirectory)/js/react_native/e2e/node_modules,$(Build.SourcesDirectory)/js/node_modules,$(Build.SourcesDirectory)/onnxruntime-inference-examples,$(Build.BinariesDirectory)' + alertWarningLevel: High + failOnAlert: false + verbosity: Normal + timeout: 3600 + tsa: + enabled: true + # codeSignValidation is intentionally omitted: this pipeline does + # not produce or publish binaries. The wheels it consumes were + # already signed-and-validated by the packaging pipeline. + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + codeql: + compiled: + enabled: false + justificationForDisabling: 'CodeQL is taking nearly 6 hours resulting in timeouts in our production pipelines' + + stages: + # Windows x64 + - ${{ if eq(parameters.test_windows_x64, true) }}: + - template: stages/plugin-win-cuda-test-stage.yml + parameters: + cuda_version: ${{ parameters.cuda_version }} + + # Linux x64 + - ${{ if eq(parameters.test_linux_x64, true) }}: + - template: stages/plugin-linux-cuda-test-stage.yml + parameters: + cuda_version: ${{ parameters.cuda_version }} + ${{ if eq(parameters.cuda_version, '12.8') }}: + docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + ${{ if eq(parameters.cuda_version, '13.0') }}: + docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' diff --git a/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml b/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml index 7d9f7c24b3360..673452d8b110a 100644 --- a/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml @@ -46,7 +46,7 @@ parameters: type: string values: - release - - RC + # - RC # not implemented yet - dev default: dev diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml index b9f2cc0987816..1a0ebd783d552 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml @@ -17,7 +17,6 @@ parameters: stages: - stage: ${{ parameters.StageName }} dependsOn: - - Setup - ${{ if ne(parameters.DependsOnStageName, '') }}: - ${{ parameters.DependsOnStageName }} @@ -60,8 +59,6 @@ stages: runCodesignValidationInjection: ${{ parameters. DoEsrp}} #For the others, code sign is in a separated job DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] BuildCommandExtra: '' ${{ if eq(parameters.EnableLto, true) }}: build_py_lto_flag: --enable_lto @@ -179,4 +176,4 @@ stages: mkdir $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\llvm-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\clang-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} - displayName: 'Copy WebGPU build tools' \ No newline at end of file + displayName: 'Copy WebGPU build tools' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 79bbe39ce4af2..20105b467d001 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -34,8 +34,6 @@ stages: variables: breakCodesignValidationInjection: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[format('{0:yyyyMMdd}', pipeline.startTime)] - BuildTime: $[format('{0:HHmm}', pipeline.startTime)] steps: - checkout: self @@ -134,8 +132,14 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' configuration: RelWithDebInfo platform: 'Any CPU' - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentDate=$(BuildDate) -p:CurrentTime=$(BuildTime)' + msbuildArguments: >- + -t:CreatePackage + "-p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)" + -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu + "-p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}" + "-p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)" + "-p:CurrentDate=$(ORT_CI_BUILD_DATE)" + "-p:CurrentTime=$(ORT_CI_BUILD_TIME)" workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: BatchScript@1 diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml index d18ede02d8891..3ce33f87ae276 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml @@ -66,6 +66,11 @@ parameters: displayName: 'Docker Python executable path' default: '/opt/python/cp312-cp312/bin/python3.12' +- name: python_package_name + type: string + displayName: 'Python package distribution name' + default: '' + stages: # Windows x64 - ${{ if eq(parameters.build_windows_x64, true) }}: @@ -75,6 +80,7 @@ stages: cmake_cuda_archs: ${{ parameters.cmake_cuda_archs }} package_version: ${{ parameters.package_type }} version_file: ${{ parameters.version_file }} + python_package_name: ${{ parameters.python_package_name }} cmake_build_type: ${{ parameters.cmake_build_type }} # Linux x64 @@ -83,11 +89,12 @@ stages: parameters: stage_name: Linux_plugin_cuda_x64 arch: 'x64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' cuda_version: ${{ parameters.cuda_version }} cmake_cuda_archs: ${{ parameters.cmake_cuda_archs }} package_version: ${{ parameters.package_type }} version_file: ${{ parameters.version_file }} + python_package_name: ${{ parameters.python_package_name }} cmake_build_type: ${{ parameters.cmake_build_type }} docker_base_image: ${{ parameters.docker_base_image }} python_version: ${{ parameters.python_version }} @@ -105,6 +112,7 @@ stages: cmake_cuda_archs: ${{ parameters.cmake_cuda_archs }} package_version: ${{ parameters.package_type }} version_file: ${{ parameters.version_file }} + python_package_name: ${{ parameters.python_package_name }} cmake_build_type: ${{ parameters.cmake_build_type }} docker_base_image: ${{ parameters.docker_base_image_aarch64 }} python_version: ${{ parameters.python_version }} diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml index 4c6c60e176a50..8992df31bf848 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml @@ -12,7 +12,7 @@ parameters: - name: machine_pool type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: package_version type: string @@ -38,6 +38,9 @@ parameters: type: string default: '12.8' +- name: python_package_name + type: string + - name: cmake_cuda_archs type: string default: '52-real;61-real;75-real;86-real;89-real;90-virtual' @@ -81,20 +84,25 @@ stages: - template: ../templates/set-nightly-build-option-variable-step.yml + - template: ../templates/setup-feeds-and-python-steps.yml + parameters: + architecture: ${{ parameters.arch }} + - template: ../templates/set-plugin-build-variables-step.yml parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} - - template: ../templates/setup-feeds-and-python-steps.yml - parameters: - architecture: ${{ parameters.arch }} - - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg BUILD_UID=$( id -u ) --build-arg TRT_VERSION=" + DockerBuildArgs: >- + --network=host + --secret id=PIP_INDEX_URL + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + --build-arg TRT_VERSION= + --build-arg BUILD_UID=$( id -u ) Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}pluginbuild${{ parameters.arch }} - script: >- @@ -147,6 +155,65 @@ stages: command: publish publishDirectory: '$(Build.BinariesDirectory)/universal_package' vstsFeedPublish: 'PublicPackages/ORT-Nightly' - vstsFeedPackagePublish: 'onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-linux-${{ parameters.arch }}' + vstsFeedPackagePublish: "onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-linux-${{ parameters.arch }}" versionOption: custom versionPublish: '$(PluginUniversalPackageVersion)' + + - ${{ if eq(parameters.arch, 'x64') }}: + - job: ${{ parameters.stage_name }}_Python_Package + dependsOn: ${{ parameters.stage_name }} + timeoutInMinutes: 60 + workspace: + clean: all + pool: + name: ${{ parameters.machine_pool }} + os: linux + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/python + artifactName: cuda_plugin_python_linux_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + variables: + - template: ../templates/common-variables.yml + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/set-nightly-build-option-variable-step.yml + + - template: ../templates/setup-feeds-and-python-steps.yml + parameters: + architecture: ${{ parameters.arch }} + + - template: ../templates/set-plugin-build-variables-step.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --network=host + --secret id=PIP_INDEX_URL + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + --build-arg TRT_VERSION= + --build-arg BUILD_UID=$( id -u ) + Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}pluginbuild${{ parameters.arch }} + + - task: DownloadPipelineArtifact@2 + displayName: 'Download plugin build artifacts' + inputs: + artifactName: ${{ parameters.artifact_name }} + targetPath: '$(Build.BinariesDirectory)/plugin_artifacts' + + - script: | + set -e -x + $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh \ + -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}pluginbuild${{ parameters.arch }} \ + -p ${{ parameters.docker_python_exe_path }} \ + -v "$(PluginPythonPackageVersion)" \ + -n "${{ parameters.python_package_name }}" + displayName: 'Build Python wheel' diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-test-stage.yml new file mode 100644 index 0000000000000..391f002465d96 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-test-stage.yml @@ -0,0 +1,74 @@ +parameters: +- name: machine_pool + type: string + default: 'onnxruntime-Ubuntu2404-AMD-GPU-A10' + +- name: cuda_version + type: string + default: '12.8' + +- name: docker_base_image + type: string + default: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + +stages: +- stage: Linux_plugin_cuda_x64_Test + dependsOn: [] + jobs: + - job: Linux_plugin_cuda_x64_Python_Test + timeoutInMinutes: 60 + workspace: + clean: all + pool: + name: ${{ parameters.machine_pool }} + os: linux + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-feeds-and-python-steps.yml + + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION= --build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}plugintestx64 + + # Download the Python wheel produced by the packaging pipeline run that + # triggered this pipeline (or that was selected at queue time). + - download: build + artifact: cuda_plugin_python_linux_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + displayName: 'Download Python wheel' + + - script: | + set -e -x + mkdir -p "$(Build.BinariesDirectory)/python_wheel" + cp -R "$(Pipeline.Workspace)/build/cuda_plugin_python_linux_x64_cuda${{ replace(parameters.cuda_version, '.', '') }}/"* "$(Build.BinariesDirectory)/python_wheel/" + displayName: 'Stage Python wheel for test container' + + - script: | + set -e -x + docker run --rm --gpus all \ + --volume "$(Build.SourcesDirectory):/onnxruntime_src" \ + --volume "$(Build.BinariesDirectory):/build" \ + --env "PIP_INDEX_URL=${PIP_INDEX_URL}" \ + --env "NVIDIA_VISIBLE_DEVICES=all" \ + --env "ORT_TEST_VERBOSE=$(System.Debug)" \ + onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}plugintestx64 \ + /bin/bash -c " + set -e -x + python3 -m venv /build/test_venv + source /build/test_venv/bin/activate + python3 -m pip install onnxruntime onnx numpy + wheel=\$(find /build/python_wheel -name 'onnxruntime*ep*cuda*.whl' | head -1) + if [ -z \"\$wheel\" ]; then + echo 'ERROR: No matching wheel found in /build/python_wheel' + ls -la /build/python_wheel/ + exit 1 + fi + python3 -m pip install \"\$wheel\" + python3 -u /onnxruntime_src/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py + " + displayName: 'Install and test Python package' diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml index 9ce494d4b3a36..12ee9ca68bb4e 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml @@ -71,7 +71,7 @@ stages: set -e -x python3 -m venv /build/test_venv source /build/test_venv/bin/activate - python3 -m pip install onnxruntime onnx numpy + python3 -m pip install onnx numpy wheel=\$(find /build/python_wheel -name 'onnxruntime_ep_webgpu-*.whl' | head -1) python3 -m pip install \"\$wheel\" python3 -u /onnxruntime_src/plugin-ep-webgpu/python/test/test_webgpu_plugin_ep.py diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml index 5ad4e170b2855..6dca5dd450fd0 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml @@ -30,7 +30,7 @@ stages: set -e -x python3 -m venv "$(Build.BinariesDirectory)/test_venv" source "$(Build.BinariesDirectory)/test_venv/bin/activate" - python3 -m pip install onnxruntime onnx numpy + python3 -m pip install onnx numpy wheel=$(find "$(Pipeline.Workspace)/build/webgpu_plugin_python_macos_arm64" -name "onnxruntime_ep_webgpu-*.whl" | head -1) python3 -m pip install "$wheel" python3 -u "$(Build.SourcesDirectory)/plugin-ep-webgpu/python/test/test_webgpu_plugin_ep.py" diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml new file mode 100644 index 0000000000000..93210533d2dc0 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml @@ -0,0 +1,186 @@ +# NuGet packaging stage for WebGPU plugin EP. +# Downloads platform-specific build artifacts, packs them into a single multi-platform NuGet package, +# signs it, and runs a basic test. + +parameters: +- name: package_version + type: string + +- name: version_file + type: string + +- name: DoEsrp + type: boolean + default: true + +- name: platforms + type: object + default: + win_x64: false + win_arm64: false + linux_x64: false + macos_arm64: false + +stages: +- stage: NuGet_Packaging + displayName: 'NuGet Packaging' + dependsOn: + - ${{ if eq(parameters.platforms.win_x64, true) }}: + - Win_plugin_webgpu_x64_Build + - ${{ if eq(parameters.platforms.win_arm64, true) }}: + - Win_plugin_webgpu_arm64_Build + - ${{ if eq(parameters.platforms.linux_x64, true) }}: + - Linux_plugin_webgpu_x64_Build + - ${{ if eq(parameters.platforms.macos_arm64, true) }}: + - MacOS_plugin_webgpu_arm64_Build + jobs: + # ---------- Pack job ---------- + - job: NuGet_Pack + displayName: 'Pack NuGet' + timeoutInMinutes: 30 + workspace: + clean: all + pool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + templateContext: + outputs: + - output: pipelineArtifact + targetPath: '$(Build.ArtifactStagingDirectory)\nuget' + artifactName: webgpu_plugin_nuget + variables: + - template: ../templates/common-variables.yml + - name: WebGpuPackStagingDir + value: '$(Build.BinariesDirectory)\webgpu_pack_staging' + # Common arguments shared by the Build and Pack invocations of pack_nuget.py. + - name: WebGpuPackNuGetCommonArgs + value: >- + --version "$(PluginPackageVersion)" + --output-dir "$(Build.ArtifactStagingDirectory)\nuget" + --staging-dir "$(WebGpuPackStagingDir)" + --configuration Release + --nuget-config "$(Build.SourcesDirectory)\NuGet.config" + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + + - template: ../templates/set-nightly-build-option-variable-step.yml + + - template: ../templates/set-plugin-build-variables-step.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + + # Download platform artifacts + - ${{ if eq(parameters.platforms.win_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download win-x64 artifacts' + inputs: + artifactName: webgpu_plugin_win_x64 + targetPath: '$(Build.BinariesDirectory)\artifacts\win_x64' + + - ${{ if eq(parameters.platforms.win_arm64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download win-arm64 artifacts' + inputs: + artifactName: webgpu_plugin_win_arm64 + targetPath: '$(Build.BinariesDirectory)\artifacts\win_arm64' + + - ${{ if eq(parameters.platforms.linux_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download linux-x64 artifacts' + inputs: + artifactName: webgpu_plugin_linux_x64 + targetPath: '$(Build.BinariesDirectory)\artifacts\linux_x64' + + - ${{ if eq(parameters.platforms.macos_arm64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download macos-arm64 artifacts' + inputs: + artifactName: webgpu_plugin_macos_arm64 + targetPath: '$(Build.BinariesDirectory)\artifacts\macos_arm64' + + # Compute the set of required platforms from the pipeline parameters and verify the + # corresponding artifact directories actually downloaded. This catches renamed/moved + # upstream artifacts loudly before any pack work, and feeds pack_nuget.py the same + # list so it fails fast if any required platform's binaries are missing. + - task: PythonScript@0 + displayName: 'Compute required platforms' + inputs: + scriptSource: inline + script: | + import os + import sys + + # The string literals below are filled in by ADO template expansion at queue + # time and resolve to a boolean value 'True' or 'False'. Compare case-insensitively. + platforms_enabled = { + "win_x64": "${{ parameters.platforms.win_x64 }}".lower() == "true", + "win_arm64": "${{ parameters.platforms.win_arm64 }}".lower() == "true", + "linux_x64": "${{ parameters.platforms.linux_x64 }}".lower() == "true", + "macos_arm64": "${{ parameters.platforms.macos_arm64 }}".lower() == "true", + } + expected = [name for name, enabled in platforms_enabled.items() if enabled] + + if not expected: + print("##vso[task.logissue type=error]No platforms enabled in 'platforms' parameter — nothing to pack.") + sys.exit(1) + + artifacts_dir = r"$(Build.BinariesDirectory)\artifacts" + missing = [ + f"{p} ({d})" + for p in expected + for d in [os.path.join(artifacts_dir, p, "bin")] + if not os.path.isdir(d) + ] + if missing: + print("##vso[task.logissue type=error]Expected artifact directories not found:") + for m in missing: + print(f"##vso[task.logissue type=error] {m}") + sys.exit(1) + + required = ",".join(expected) + print(f"Required platforms: {required}") + print(f"##vso[task.setvariable variable=WebGpuRequiredPlatforms]{required}") + + # Stage binaries and build the managed assembly (so it can be ESRP-signed before packing). + - task: PythonScript@0 + displayName: 'Build managed DLL' + inputs: + scriptSource: filePath + scriptPath: '$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\pack_nuget.py' + arguments: >- + $(WebGpuPackNuGetCommonArgs) + --artifacts-dir "$(Build.BinariesDirectory)\artifacts" + --required-platforms $(WebGpuRequiredPlatforms) + --build-only + + # ESRP-sign the managed DLL before it gets embedded in the .nupkg. + - template: ../templates/win-esrp-dll.yml + parameters: + FolderPath: '$(WebGpuPackStagingDir)' + Pattern: 'Microsoft.ML.OnnxRuntime.EP.WebGpu.dll' + DisplayName: 'ESRP - Sign managed DLL' + DoEsrp: ${{ parameters.DoEsrp }} + + # Pack the (now-signed) managed DLL plus native binaries into the .nupkg. + - task: PythonScript@0 + displayName: 'Pack NuGet package' + inputs: + scriptSource: filePath + scriptPath: '$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\pack_nuget.py' + arguments: >- + $(WebGpuPackNuGetCommonArgs) + --pack-only + + # ESRP sign + - template: ../templates/esrp_nuget.yml + parameters: + FolderPath: '$(Build.ArtifactStagingDirectory)\nuget' + DisplayName: 'ESRP - Sign NuGet package' + DoEsrp: ${{ parameters.DoEsrp }} diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml index 9db25f5727cc2..996d6fa1af0a6 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml @@ -2,22 +2,18 @@ parameters: - name: build_windows_x64 displayName: 'Build Windows x64' type: boolean - default: true - name: build_windows_arm64 displayName: 'Build Windows ARM64' type: boolean - default: false - name: build_linux_x64 displayName: 'Build Linux x64' type: boolean - default: false - name: build_macos_arm64 displayName: 'Build macOS ARM64' type: boolean - default: false - name: package_version displayName: 'Package Version' @@ -26,7 +22,7 @@ parameters: values: - dev - release - - RC + # TODO: release candidate (RC) versioning is not yet implemented - name: version_file type: string @@ -52,9 +48,7 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} # Windows ARM64 - # ARM64 build requires the x64 tblgen.exe (used during the build), which is not correctly - # generated in a cross build. So we require x64 to be built first and download tblgen.exe from it. - - ${{ if and(eq(parameters.build_windows_arm64, true), eq(parameters.build_windows_x64, true)) }}: + - ${{ if eq(parameters.build_windows_arm64, true) }}: - template: plugin-win-webgpu-stage.yml parameters: arch: 'arm64' @@ -77,3 +71,167 @@ stages: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} cmake_build_type: ${{ parameters.cmake_build_type }} + + # NuGet packaging (runs after all platform builds) + - template: plugin-webgpu-nuget-packaging-stage.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + DoEsrp: true + platforms: + win_x64: ${{ parameters.build_windows_x64 }} + win_arm64: ${{ parameters.build_windows_arm64 }} + linux_x64: ${{ parameters.build_linux_x64 }} + macos_arm64: ${{ parameters.build_macos_arm64 }} + + # Create zip packages for Foundry Local consumption + - stage: Package_Foundry_Local_WebGPU_Zips + displayName: 'Package Foundry Local WebGPU Plugin-EP Zips' + dependsOn: + - ${{ if eq(parameters.build_windows_x64, true) }}: + - Win_plugin_webgpu_x64_Build + - ${{ if eq(parameters.build_windows_arm64, true) }}: + - Win_plugin_webgpu_arm64_Build + - ${{ if eq(parameters.build_linux_x64, true) }}: + - Linux_plugin_webgpu_x64_Build + - ${{ if eq(parameters.build_macos_arm64, true) }}: + - MacOS_plugin_webgpu_arm64_Build + jobs: + - job: CreateZipPackages + displayName: 'Create Foundry Local WebGPU Plugin-EP Zip Packages' + pool: + name: 'onnxruntime-Win-CPU-VS2022-Latest' + os: windows + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/webgpu-deps-package + artifactName: foundry-local-webgpu-plugin-ep-zips + steps: + # The 1ES TSA SDL task expects .config/tsaoptions.json in the source directory. + # Use a sparse checkout to pull only the .config directory (avoids full repo clone). + - checkout: self + fetchDepth: 1 + sparseCheckoutDirectories: .config + + - ${{ if eq(parameters.build_windows_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_win_x64' + inputs: + artifactName: webgpu_plugin_win_x64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-win-x64 + + - ${{ if eq(parameters.build_windows_arm64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_win_arm64' + inputs: + artifactName: webgpu_plugin_win_arm64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-win-arm64 + + - ${{ if eq(parameters.build_linux_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_linux_x64' + inputs: + artifactName: webgpu_plugin_linux_x64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-linux-x64 + + - ${{ if eq(parameters.build_macos_arm64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_macos_arm64' + inputs: + artifactName: webgpu_plugin_macos_arm64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-macos-arm64 + + - task: PowerShell@2 + displayName: 'Create version.json and zip packages for each platform' + inputs: + targetType: inline + script: | + $outputDir = '$(Build.ArtifactStagingDirectory)/webgpu-deps-package' + New-Item -ItemType Directory -Path $outputDir -Force + + $platforms = @( + @{ name = 'win-x64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-win-x64' }, + @{ name = 'win-arm64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-win-arm64' }, + @{ name = 'linux-x64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-linux-x64' }, + @{ name = 'macos-arm64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-macos-arm64' } + ) + + $resolvedVersion = $null + + foreach ($platform in $platforms) { + $depsDir = $platform.dir + $platformName = $platform.name + + if (-not (Test-Path $depsDir)) { + Write-Host "Skipping $platformName (not built)" + continue + } + + $binDir = Join-Path $depsDir "bin" + $versionDir = Join-Path $depsDir "version" + + if (-not (Test-Path $binDir)) { + throw "Bin directory not found for $platformName $binDir" + } + + Write-Host "--- Processing $platformName ---" + + $versionString = "Unknown" + if (Test-Path $versionDir) { + $versionFile = Get-ChildItem -Path $versionDir -File | Select-Object -First 1 + if ($versionFile) { + $versionString = $versionFile.Name.Trim() + } + } + + # Track the resolved version (all platforms must agree) + # Version formats (full -> filename): + # release: 0.1.0 -> 0.1.0 + # dev: 0.1.0-dev.20260401+2a1ffff2 -> 0.1.0.dev.20260401.2a1ffff2 + # Dev versions have - and + replaced with . for filename compatibility. + # Full version string is preserved in version.json. + # TODO: RC versioning (e.g. 0.1.0-rc1) is not yet implemented + $filenameVersion = $versionString -replace '[-+]', '.' + if ($null -eq $resolvedVersion) { + $resolvedVersion = $filenameVersion + } elseif ($resolvedVersion -ne $filenameVersion) { + throw "Version mismatch across platforms: expected '$resolvedVersion' but $platformName has '$filenameVersion'" + } + + $versionInfo = @{ + version = $versionString + } + + $json = $versionInfo | ConvertTo-Json + $versionPath = Join-Path $binDir "version.json" + Set-Content -Path $versionPath -Value $json -Encoding UTF8 + Write-Host "Created version.json:" + Write-Host $json + + # Collect the binaries (dll, so, dylib) and version.json + $filesToZip = Get-ChildItem -Path $binDir -File | Where-Object { + $_.Extension -in '.dll', '.so', '.dylib' -or $_.Name -eq 'version.json' + } + + $zipPath = Join-Path $outputDir "webgpu_ep_${filenameVersion}_${platformName}.zip" + if ($filesToZip) { + $filesToZip | Compress-Archive -DestinationPath $zipPath -Force + Write-Host "Created zip: $zipPath ($((Get-Item $zipPath).Length) bytes)" + } else { + throw "No files found to zip for $platformName in $binDir" + } + Write-Host "" + } + + if ($null -eq $resolvedVersion) { + throw "No platforms were processed — cannot determine version." + } + + # Create a version folder in the output artifact with a file whose name is the version string. + # This follows the same convention as the per-platform artifacts (e.g. webgpu_plugin_win_x64/version/) + # and allows downstream pipelines to read the version without parsing zip filenames. + $versionOutputDir = Join-Path $outputDir "version" + New-Item -ItemType Directory -Path $versionOutputDir -Force + New-Item -ItemType File -Path (Join-Path $versionOutputDir $resolvedVersion) -Force | Out-Null + Write-Host "Created version marker: $versionOutputDir/$resolvedVersion" diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml index 68968a0be86e3..7eac3842514a5 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml @@ -23,6 +23,9 @@ parameters: type: string default: '12.8' +- name: python_package_name + type: string + - name: cmake_cuda_archs type: string default: '52-real;61-real;75-real;86-real;89-real;90-virtual' @@ -74,6 +77,7 @@ stages: parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} + python_command: python - script: | python -m pip install -r "$(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt" @@ -81,28 +85,35 @@ stages: env: TMPDIR: "$(Agent.TempDirectory)" - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.cuda_version }}" "$(Agent.TempDirectory)" + - task: AzureCLI@2 displayName: 'Download CUDA SDK v${{ parameters.cuda_version }}' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.cuda_version }} "$(Agent.TempDirectory)" + # Since CUDA 13.0, CUDA DLLs are in bin\x64 folder instead of bin folder for Windows. - powershell: | Write-Host "Adding CUDA to PATH" - Write-Host "CUDA Path: $(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin" Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin\x64" Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\extras\CUPTI\lib64" displayName: 'Add CUDA to PATH' # Download cuDNN separately for CUDA 13.0 - ${{ if eq(parameters.cuda_version, '13.0') }}: - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cudnn_9/9.14.0.64_cuda13" "$(Agent.TempDirectory)" + - task: AzureCLI@2 displayName: 'Download cuDNN for CUDA 13.0' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cudnn_sdk/9.14.0.64_cuda13 "$(Agent.TempDirectory)" # CUDA 12.x build (no separate cuDNN) - ${{ if ne(parameters.cuda_version, '13.0') }}: @@ -130,9 +141,6 @@ stages: --cmake_extra_defines $(PluginEpVersionDefine) $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 # CUDA 13.0 build (separate cuDNN folder) - ${{ if eq(parameters.cuda_version, '13.0') }}: @@ -161,9 +169,6 @@ stages: --cmake_extra_defines $(PluginEpVersionDefine) $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 # Esrp signing - template: ../templates/win-esrp-dll.yml @@ -209,6 +214,58 @@ stages: command: publish publishDirectory: '$(Build.BinariesDirectory)\universal_package' vstsFeedPublish: 'PublicPackages/ORT-Nightly' - vstsFeedPackagePublish: 'onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-win-x64' + vstsFeedPackagePublish: "onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-win-x64" versionOption: custom versionPublish: '$(PluginUniversalPackageVersion)' + + - job: Win_plugin_cuda_x64_Python_Package + dependsOn: Win_plugin_cuda_x64_Build + timeoutInMinutes: 30 + workspace: + clean: all + pool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + templateContext: + outputs: + - output: pipelineArtifact + targetPath: '$(Build.ArtifactStagingDirectory)\python' + artifactName: cuda_plugin_python_win_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + variables: + - template: ../templates/common-variables.yml + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + python_version: ${{ parameters.python_version }} + + - template: ../templates/set-nightly-build-option-variable-step.yml + + - template: ../templates/set-plugin-build-variables-step.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + python_command: python + + - task: DownloadPipelineArtifact@2 + displayName: 'Download plugin build artifacts' + inputs: + artifactName: cuda_plugin_win_x64 + targetPath: '$(Build.BinariesDirectory)\plugin_artifacts' + + - task: PowerShell@2 + displayName: 'Build Python wheel' + inputs: + targetType: inline + pwsh: true + script: | + python -m pip install -r "$(Build.SourcesDirectory)\plugin-ep-cuda\python\requirements-build-wheel.txt" + python "$(Build.SourcesDirectory)\plugin-ep-cuda\python\build_wheel.py" ` + --binary_dir "$(Build.BinariesDirectory)\plugin_artifacts\bin" ` + --version "$(PluginPythonPackageVersion)" ` + --package_name "${{ parameters.python_package_name }}" ` + --output_dir "$(Build.ArtifactStagingDirectory)\python" diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-test-stage.yml new file mode 100644 index 0000000000000..813737ed8ecef --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-test-stage.yml @@ -0,0 +1,72 @@ +parameters: +- name: cuda_version + type: string + default: '12.8' + +stages: +- stage: Win_plugin_cuda_x64_Test + dependsOn: [] + jobs: + - job: Win_plugin_cuda_x64_Python_Test + timeoutInMinutes: 60 + workspace: + clean: all + pool: + name: onnxruntime-Win2022-GPU-A10 + os: windows + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-feeds-and-python-steps.yml + + # Download the Python wheel produced by the packaging pipeline run that + # triggered this pipeline (or that was selected at queue time). + - download: build + artifact: cuda_plugin_python_win_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + displayName: 'Download Python wheel' + + - task: AzureCLI@2 + displayName: 'Download CUDA SDK v${{ parameters.cuda_version }}' + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.cuda_version }} "$(Agent.TempDirectory)" + + - powershell: | + Write-Host "Adding CUDA to PATH" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin\x64" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\extras\CUPTI\lib64" + displayName: 'Add CUDA to PATH' + + - task: PowerShell@2 + displayName: 'Install and test Python package' + env: + ORT_TEST_VERBOSE: $(System.Debug) + inputs: + targetType: inline + pwsh: true + script: | + $ErrorActionPreference = 'Stop' + + echo "creating test_venv" + python -m venv "$(Build.BinariesDirectory)\test_venv" + + echo "activating test_venv" + & "$(Build.BinariesDirectory)\test_venv\Scripts\Activate.ps1" + + echo "installing onnxruntime onnx numpy" + python -m pip install onnxruntime onnx numpy + + $wheelDir = "$(Pipeline.Workspace)\build\cuda_plugin_python_win_x64_cuda${{ replace(parameters.cuda_version, '.', '') }}" + $wheel = (Get-ChildItem "$wheelDir\onnxruntime*ep*cuda*.whl")[0] + echo "installing ${wheel}" + python -m pip install $wheel.FullName + + echo "running test_cuda_plugin_ep.py" + python -u "$(Build.SourcesDirectory)\plugin-ep-cuda\python\test\test_cuda_plugin_ep.py" diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml index acad674143961..332d5b0224f37 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml @@ -28,6 +28,7 @@ parameters: stages: - stage: Win_plugin_webgpu_${{ parameters.arch }}_Build ${{ if eq(parameters.arch, 'arm64') }}: + # The ARM64 build consumes the x64 tblgen.exe artifact published by the Windows x64 stage. dependsOn: Win_plugin_webgpu_x64_Build ${{ else }}: dependsOn: [] @@ -86,6 +87,7 @@ stages: parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} + python_command: python - script: | python -m pip install -r "$(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt" @@ -267,6 +269,7 @@ stages: parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} + python_command: python - task: DownloadPipelineArtifact@2 displayName: 'Download plugin build artifacts' diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml index af29a62d69329..1494584ff98fd 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml @@ -31,29 +31,103 @@ stages: artifact: webgpu_plugin_python_win_${{ parameters.arch }} displayName: 'Download Python wheel' - - task: PowerShell@2 + - pwsh: | + $ErrorActionPreference = 'Stop' + + echo "creating test_venv" + python -m venv "$(Build.BinariesDirectory)\test_venv" + + echo "activating test_venv" + & "$(Build.BinariesDirectory)\test_venv\Scripts\Activate.ps1" + + echo "installing test dependencies" + python -m pip install onnx numpy + + $wheelDir = "$(Pipeline.Workspace)\build\webgpu_plugin_python_win_${{ parameters.arch }}" + $wheel = (Get-ChildItem "$wheelDir\onnxruntime_ep_webgpu-*.whl")[0] + echo "installing ${wheel}" + python -m pip install $wheel.FullName + + echo "running test_webgpu_plugin_ep.py" + python -u "$(Build.SourcesDirectory)\plugin-ep-webgpu\python\test\test_webgpu_plugin_ep.py" displayName: 'Install and test Python package' env: ORT_TEST_VERBOSE: $(System.Debug) - inputs: - targetType: inline - pwsh: true - script: | + + # NuGet package test (x64 only — the NuGet package is multi-platform but + # the test runs on a single Windows agent that exercises the WebGPU EP). + - ${{ if eq(parameters.arch, 'x64') }}: + - job: Win_plugin_webgpu_nuget_Test + timeoutInMinutes: 30 + workspace: + clean: all + pool: + name: onnxruntime-Win2022-VS2022-webgpu-A10 + os: windows + variables: + WebGpuTestProject: '$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\test\WebGpuEpNuGetTest\WebGpuEpNuGetTest.csproj' + steps: + - checkout: self + submodules: none + + - template: ../templates/setup-feeds-and-python-steps.yml + + # Download the NuGet package produced by the packaging pipeline run that + # triggered this pipeline (or that was selected at queue time). + - download: build + artifact: webgpu_plugin_nuget + displayName: 'Download NuGet package' + + # Set up local NuGet feed and extract the package version from the .nupkg filename + # so the test project can pin to it (instead of resolving via a floating version). + - pwsh: | $ErrorActionPreference = 'Stop' + $localFeedDir = "$(Build.BinariesDirectory)\local_feed" + New-Item -ItemType Directory -Path $localFeedDir -Force | Out-Null - echo "creating test_venv" - python -m venv "$(Build.BinariesDirectory)\test_venv" + # Locate the .nupkg. + $nupkg = Get-ChildItem "$(Pipeline.Workspace)\build\webgpu_plugin_nuget\Microsoft.ML.OnnxRuntime.EP.WebGpu.*.nupkg" | + Select-Object -First 1 + if (-not $nupkg) { + throw "No matching .nupkg found under $(Pipeline.Workspace)\build\webgpu_plugin_nuget" + } + Copy-Item $nupkg.FullName $localFeedDir -Force - echo "activating test_venv" - & "$(Build.BinariesDirectory)\test_venv\Scripts\Activate.ps1" + # Extract version from filename: Microsoft.ML.OnnxRuntime.EP.WebGpu..nupkg + # The version starts with a digit, which disambiguates from any future filename suffixes. + if ($nupkg.BaseName -notmatch '^Microsoft\.ML\.OnnxRuntime\.EP\.WebGpu\.(\d.*)$') { + throw "Could not extract version from .nupkg filename: $($nupkg.Name)" + } + $packageVersion = $Matches[1] + Write-Host "Detected package version: $packageVersion" + Write-Host "##vso[task.setvariable variable=OrtWebGpuPackageVersion]$packageVersion" - echo "installing onnxruntime onnx numpy" - python -m pip install onnxruntime onnx numpy + # Write a project-level nuget.config that adds ONLY the local feed. + # NuGet merges this with the repo-root NuGet.config. + $nugetConfig = "$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\test\WebGpuEpNuGetTest\nuget.config" + Set-Content -Path $nugetConfig -Encoding UTF8 -Value @" + + + + + + + "@ + Write-Host "Wrote project-level nuget.config with local feed: $localFeedDir" + Write-Host "Local feed contents:" + Get-ChildItem $localFeedDir | ForEach-Object { Write-Host " $($_.Name)" } + displayName: 'Set up local NuGet feed' - $wheelDir = "$(Pipeline.Workspace)\build\webgpu_plugin_python_win_${{ parameters.arch }}" - $wheel = (Get-ChildItem "$wheelDir\onnxruntime_ep_webgpu-*.whl")[0] - echo "installing ${wheel}" - python -m pip install $wheel.FullName + - pwsh: | + dotnet build ` + "$(WebGpuTestProject)" ` + --configuration Release ` + -p:OrtWebGpuPackageVersion=$(OrtWebGpuPackageVersion) + displayName: 'Build test project' - echo "running test_webgpu_plugin_ep.py" - python -u "$(Build.SourcesDirectory)\plugin-ep-webgpu\python\test\test_webgpu_plugin_ep.py" + - pwsh: | + dotnet run ` + --project "$(WebGpuTestProject)" ` + --configuration Release ` + --no-build + displayName: 'Run NuGet package test' diff --git a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml index 07dd1549acd2d..a11fd8b89b8b1 100644 --- a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml @@ -53,21 +53,6 @@ stages: echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]" fi name: Set_Release_Version_Suffix - - script: | - # Extracting hours and minutes - date=$(date +'%Y%m%d') - # Set the hhmm value as a pipeline variable - echo "##vso[task.setvariable variable=BuildDate;isOutput=true]$date" - displayName: 'Set Start Date as Variable' - name: Set_Build_Date - - - script: | - # Extracting hours and minutes - hhmm=$(date +'%H%M') - # Set the hhmm value as a pipeline variable - echo "##vso[task.setvariable variable=BuildTime;isOutput=true]$hhmm" - displayName: 'Set Start Time as Variable' - name: Set_Build_Time - bash: | echo "Recording pipeline parameters to a file..." diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 7b8b5758e79b5..04066a0c0b90c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -286,8 +286,6 @@ stages: variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - checkout: self @@ -356,7 +354,14 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentTime=$(BuildTime) -p:CurrentDate=$(BuildDate)' + msbuildArguments: >- + -t:CreatePackage + "-p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)" + "-p:OrtPackageId=$(OrtPackageId)" + "-p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}" + "-p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)" + "-p:CurrentTime=$(ORT_CI_BUILD_TIME)" + "-p:CurrentDate=$(ORT_CI_BUILD_DATE)" workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 @@ -416,12 +421,14 @@ stages: targetPath: $(Build.ArtifactStagingDirectory) artifactName: 'NPM_packages' variables: - ${{ if eq(parameters.IsReleaseBuild, true) }}: + ${{ if and(parameters.IsReleaseBuild, eq(parameters.PreReleaseVersionSuffixString, 'none')) }}: NpmPackagingMode: 'release' - ${{ if not(eq(parameters.IsReleaseBuild, true)) }}: + ${{ elseif and(parameters.IsReleaseBuild, eq(parameters.PreReleaseVersionSuffixString, 'rc')) }}: + NpmPackagingMode: 'rc' + ${{ elseif not(parameters.IsReleaseBuild) }}: NpmPackagingMode: 'dev' - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + ${{ else }}: # IsReleaseBuild + beta, alpha, etc. We don't support those and those suffixes are deprecated. + NpmPackagingMode: '' steps: - checkout: self @@ -635,7 +642,7 @@ stages: Write-Host "Latest version of ${packageName}: $latestVersion" # Generate current version - $currentVersion = "$(cat .\VERSION_NUMBER)-dev-$($env:BuildDate)-$($env:BuildTime)-$(git rev-parse --short HEAD)" + $currentVersion = "$(cat .\VERSION_NUMBER)-dev-$($env:ORT_CI_BUILD_DATE)-$($env:ORT_CI_BUILD_TIME)-$(git rev-parse --short HEAD)" Write-Host "Current version: $currentVersion" # Set the version as an environment variable diff --git a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml index 8c8dae9820810..250a023bcc158 100644 --- a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml +++ b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml @@ -7,5 +7,7 @@ variables: linux_trt_version_cuda12: ${{ variables.cuda12_trt_version }}-1.cuda12.9 # aarch64 TRT tar download (no RPMs available for aarch64) aarch64_trt_download_url_cuda13: https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.15.1/tars/TensorRT-${{ variables.aarch64_trt_version }}.Linux.aarch64-gnu.cuda-13.1.tar.gz + ORT_CI_BUILD_DATE: $[ format('{0:yyyyMMdd}', pipeline.startTime) ] + ORT_CI_BUILD_TIME: $[ format('{0:HHmm}', pipeline.startTime) ] win_trt_folder_cuda13: TensorRT-${{ variables.cuda13_trt_version }}.Windows.win10.cuda-13.0 win_trt_folder_cuda12: TensorRT-${{ variables.cuda12_trt_version }}.Windows.win10.cuda-12.9 diff --git a/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml index 2d1c182ec7512..44012af808a46 100644 --- a/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml @@ -30,8 +30,6 @@ stages: variables: DoEsrp: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - task: DownloadPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml b/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml index 0be3f4de65647..c14ac6cc7a3fd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml +++ b/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml @@ -31,8 +31,6 @@ stages: variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - template: set-version-number-variables-step.yml @@ -86,7 +84,15 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'AnyCPU' configuration: RelWithDebInfo - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentTime=$(BuildTime) -p:CurrentDate=$(BuildDate) -p:IncludeMobileTargets=false' + msbuildArguments: >- + -t:CreatePackage + "-p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)" + -p:OrtPackageId=Microsoft.ML.OnnxRuntime + "-p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}" + "-p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)" + "-p:CurrentTime=$(ORT_CI_BUILD_TIME)" + "-p:CurrentDate=$(ORT_CI_BUILD_DATE)" + -p:IncludeMobileTargets=false workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 diff --git a/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml b/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml index fcc388ef7e342..cc0e766f49ddb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml +++ b/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml @@ -17,99 +17,16 @@ parameters: - name: version_file type: string +# Python executable used to run the helper script. Default is python3 which works on +# Linux (including aarch64) and macOS. Windows callers must override with 'python'. +- name: python_command + type: string + default: 'python3' + steps: # Set package version string -- task: PythonScript@0 +# Use 'script' (not 'bash') so this works on both Linux and Windows agents. +# On Linux aarch64 agents UsePythonVersion@0 is unavailable, so we call the configured +# Python executable directly instead of using PythonScript@0. +- script: ${{ parameters.python_command }} "$(Build.SourcesDirectory)/tools/ci_build/set_plugin_build_variables.py" "${{ parameters.package_version }}" "${{ parameters.version_file }}" displayName: 'Set plugin package version string' - inputs: - scriptSource: inline - script: | - import os - import re - import subprocess - import sys - - package_version = "${{ parameters.package_version }}" - version_file_rel = "${{ parameters.version_file }}" - - if not version_file_rel: - print("##vso[task.logissue type=error]version_file parameter is empty.") - sys.exit(1) - - src_root = os.environ.get("BUILD_SOURCESDIRECTORY", "") - version_file = os.path.join(src_root, version_file_rel) - if not os.path.isfile(version_file): - print("##vso[task.logissue type=error]Cannot find version number file at: {}".format(version_file)) - sys.exit(1) - - with open(version_file, "r") as f: - original_ver = f.read().strip() - - if not original_ver: - print("##vso[task.logissue type=error]VERSION_NUMBER is empty.") - sys.exit(1) - - print("Original version: {}".format(original_ver)) - print("Package version type: {}".format(package_version)) - - if package_version == "release": - version_string = original_ver - universal_version = original_ver - python_version = original_ver - - elif package_version == "RC": - # RC versioning is not yet implemented. Fail the build to prevent publishing - # an ambiguous version without an RC number. - print("##vso[task.logissue type=error]RC versioning is not yet implemented. Use 'dev' or 'release' instead.") - sys.exit(1) - - elif package_version == "dev": - try: - commit_sha = subprocess.check_output( - ["git", "rev-parse", "--short=8", "HEAD"], - cwd=src_root - ).decode("utf-8").strip() - date_str = subprocess.check_output( - ["git", "show", "-s", "--format=%cd", "--date=format:%Y%m%d", "HEAD"], - cwd=src_root - ).decode("utf-8").strip() - except Exception as e: - print("##vso[task.logissue type=error]Failed to get git info: {}".format(e)) - sys.exit(1) - version_string = "{}-dev.{}+{}".format(original_ver, date_str, commit_sha) - # Prefix the SHA with "commit-" so the pre-release identifier always contains a - # non-digit. Otherwise, an all-numeric short SHA with a leading zero (e.g. "01234567") - # would violate SemVer 2.0.0's rule against leading zeros in numeric identifiers. - universal_version = "{}-dev.{}.commit-{}".format(original_ver, date_str, commit_sha) - python_version = "{}.dev{}".format(original_ver, date_str) - - else: - print("##vso[task.logissue type=error]Unknown package_version '{}'. Must be 'release', 'RC', or 'dev'.".format(package_version)) - sys.exit(1) - - print("Plugin package version string: {}".format(version_string)) - print("Plugin universal package version string: {}".format(universal_version)) - print("Plugin Python package version string: {}".format(python_version)) - - # Validate semver 2.0.0 format - semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" - if not re.match(semver_pattern, version_string): - print("##vso[task.logissue type=error]Version string '{}' is not valid semver 2.0.0.".format(version_string)) - sys.exit(1) - - # Validate universal version (SemVer 2.0.0, without build metadata) - universal_semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?$" - if not re.match(universal_semver_pattern, universal_version): - print("##vso[task.logissue type=error]Universal version string '{}' is not valid semver 2.0.0 (without build metadata).".format(universal_version)) - sys.exit(1) - - # Validate Python version (PEP 440) - pep440_pattern = r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$" - if not re.match(pep440_pattern, python_version): - print("##vso[task.logissue type=error]Python version string '{}' is not valid PEP 440.".format(python_version)) - sys.exit(1) - - print("##vso[task.setvariable variable=PluginPackageVersion]{}".format(version_string)) - print("##vso[task.setvariable variable=PluginUniversalPackageVersion]{}".format(universal_version)) - print("##vso[task.setvariable variable=PluginPythonPackageVersion]{}".format(python_version)) - print("##vso[task.setvariable variable=PluginEpVersionDefine]onnxruntime_PLUGIN_EP_VERSION={}".format(version_string)) diff --git a/tools/ci_build/github/linux/build_cuda_plugin_package.sh b/tools/ci_build/github/linux/build_cuda_plugin_package.sh index 7c89fc6b892df..1b4e897b05389 100755 --- a/tools/ci_build/github/linux/build_cuda_plugin_package.sh +++ b/tools/ci_build/github/linux/build_cuda_plugin_package.sh @@ -39,6 +39,7 @@ docker run --rm \ --volume "${BUILD_BINARIESDIRECTORY}:/build" \ --volume /data/models:/build/models:ro \ --volume "${HOME}/.onnx:/home/onnxruntimedev/.onnx" \ + -e PIP_INDEX_URL \ -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ -e SYSTEM_COLLECTIONURI \ diff --git a/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh b/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh new file mode 100755 index 0000000000000..171d5b2facea8 --- /dev/null +++ b/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh @@ -0,0 +1,50 @@ +#!/bin/bash +set -e -x + +DOCKER_IMAGE="onnxruntimecuda128pluginbuildx64" +PYTHON_EXE="/opt/python/cp312-cp312/bin/python3.12" +VERSION="" +PACKAGE_NAME="" + +while getopts "i:p:v:n:" parameter_Option +do case "${parameter_Option}" +in +i) DOCKER_IMAGE=${OPTARG};; +p) PYTHON_EXE=${OPTARG};; +v) VERSION=${OPTARG};; +n) PACKAGE_NAME=${OPTARG};; +*) echo "Usage: $0 -i -p -v -n " + exit 1;; +esac +done + +if [ -z "$VERSION" ]; then + echo "ERROR: Version is required. Use -v " + exit 1 +fi + +if [ -z "$PACKAGE_NAME" ]; then + echo "ERROR: Package name is required. Use -n " + exit 1 +fi + +PYTHON_BIN_DIR=$(dirname "${PYTHON_EXE}") + +docker run --rm \ + --volume "${BUILD_SOURCESDIRECTORY}:/onnxruntime_src" \ + --volume "${BUILD_BINARIESDIRECTORY}:/build" \ + --volume "${BUILD_ARTIFACTSTAGINGDIRECTORY}:/staging" \ + --env PIP_INDEX_URL \ + --env "ORT_CUDA_PLUGIN_EP_VERSION=${VERSION}" \ + --env "ORT_CUDA_PLUGIN_EP_PACKAGE_NAME=${PACKAGE_NAME}" \ + "$DOCKER_IMAGE" \ + /bin/bash -c ' + set -e -x + PATH="'"${PYTHON_BIN_DIR}"'":$PATH + "'"${PYTHON_EXE}"'" -m pip install -r /onnxruntime_src/plugin-ep-cuda/python/requirements-build-wheel.txt + "'"${PYTHON_EXE}"'" /onnxruntime_src/plugin-ep-cuda/python/build_wheel.py \ + --binary_dir /build/plugin_artifacts/bin \ + --version "$ORT_CUDA_PLUGIN_EP_VERSION" \ + --package_name "$ORT_CUDA_PLUGIN_EP_PACKAGE_NAME" \ + --output_dir /staging/python + ' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index ee7869f50bee5..3296fcc77f10f 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -35,7 +35,9 @@ fi ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts +RUN --mount=type=secret,id=PIP_INDEX_URL,required=false \ + if [ -f /run/secrets/PIP_INDEX_URL ]; then export PIP_INDEX_URL=$(cat /run/secrets/PIP_INDEX_URL); fi && \ + cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh index 093da075be13c..e4b05f8a0d1d7 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh @@ -6,7 +6,11 @@ os_major_version=$(tr -dc '0-9.' + +Where: + package_version: 'release', 'RC', or 'dev' + version_file_rel: path relative to BUILD_SOURCESDIRECTORY of the VERSION_NUMBER file +""" + +import os +import re +import subprocess +import sys + + +def main(): + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + + package_version = sys.argv[1] + version_file_rel = sys.argv[2] + + if not version_file_rel: + print("##vso[task.logissue type=error]version_file parameter is empty.") + sys.exit(1) + + src_root = os.environ.get("BUILD_SOURCESDIRECTORY", "") + version_file = os.path.join(src_root, version_file_rel) + if not os.path.isfile(version_file): + print(f"##vso[task.logissue type=error]Cannot find version number file at: {version_file}") + sys.exit(1) + + with open(version_file) as f: + original_ver = f.read().strip() + + if not original_ver: + print("##vso[task.logissue type=error]VERSION_NUMBER is empty.") + sys.exit(1) + + print(f"Original version: {original_ver}") + print(f"Package version type: {package_version}") + + if package_version == "release": + version_string = original_ver + universal_version = original_ver + python_version = original_ver + + elif package_version == "RC": + # RC versioning is not yet implemented. Fail the build to prevent publishing + # an ambiguous version without an RC number. + print("##vso[task.logissue type=error]RC versioning is not yet implemented. Use 'dev' or 'release' instead.") + sys.exit(1) + + elif package_version == "dev": + try: + commit_sha = ( + subprocess.check_output( + ["git", "rev-parse", "--short=8", "HEAD"], + cwd=src_root, + ) + .decode("utf-8") + .strip() + ) + date_str = ( + subprocess.check_output( + ["git", "show", "-s", "--format=%cd", "--date=format:%Y%m%d", "HEAD"], + cwd=src_root, + ) + .decode("utf-8") + .strip() + ) + except Exception as e: + print(f"##vso[task.logissue type=error]Failed to get git info: {e}") + sys.exit(1) + version_string = f"{original_ver}-dev.{date_str}+{commit_sha}" + # Prefix the SHA with "commit-" so the pre-release identifier always contains a + # non-digit. Otherwise, an all-numeric short SHA with a leading zero (e.g. "01234567") + # would violate SemVer 2.0.0's rule against leading zeros in numeric identifiers. + universal_version = f"{original_ver}-dev.{date_str}.commit-{commit_sha}" + python_version = f"{original_ver}.dev{date_str}" + + else: + print( + f"##vso[task.logissue type=error]Unknown package_version '{package_version}'. Must be 'release', 'RC', or 'dev'." + ) + sys.exit(1) + + print(f"Plugin package version string: {version_string}") + print(f"Plugin universal package version string: {universal_version}") + print(f"Plugin Python package version string: {python_version}") + + # Validate semver 2.0.0 format + semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" + if not re.match(semver_pattern, version_string): + print(f"##vso[task.logissue type=error]Version string '{version_string}' is not valid semver 2.0.0.") + sys.exit(1) + + # Validate universal version (SemVer 2.0.0, without build metadata) + universal_semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?$" + if not re.match(universal_semver_pattern, universal_version): + print( + f"##vso[task.logissue type=error]Universal version string '{universal_version}' is not valid semver 2.0.0 (without build metadata)." + ) + sys.exit(1) + + # Validate Python version (PEP 440) + pep440_pattern = r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$" + if not re.match(pep440_pattern, python_version): + print(f"##vso[task.logissue type=error]Python version string '{python_version}' is not valid PEP 440.") + sys.exit(1) + + print(f"##vso[task.setvariable variable=PluginPackageVersion]{version_string}") + print(f"##vso[task.setvariable variable=PluginUniversalPackageVersion]{universal_version}") + print(f"##vso[task.setvariable variable=PluginPythonPackageVersion]{python_version}") + print(f"##vso[task.setvariable variable=PluginEpVersionDefine]onnxruntime_PLUGIN_EP_VERSION={version_string}") + + +if __name__ == "__main__": + main()