From b81e0c65e686c19f6de91750612a55b0c2419ce7 Mon Sep 17 00:00:00 2001 From: eserscor Date: Mon, 4 May 2026 09:34:58 -0400 Subject: [PATCH 01/34] Bump version for 1.27.0 (#28324) ### Description Bump version to 1.27.0. --- VERSION_NUMBER | 2 +- docs/python/README.rst | 5 +++++ include/onnxruntime/core/session/onnxruntime_c_api.h | 2 +- js/common/lib/version.ts | 2 +- js/common/package-lock.json | 4 ++-- js/common/package.json | 2 +- js/node/lib/version.ts | 2 +- js/node/package-lock.json | 6 +++--- js/node/package.json | 2 +- js/node/script/install-metadata-versions.js | 2 +- js/react_native/lib/version.ts | 2 +- js/react_native/package-lock.json | 6 +++--- js/react_native/package.json | 2 +- js/web/lib/version.ts | 2 +- js/web/package-lock.json | 6 +++--- js/web/package.json | 2 +- onnxruntime/__init__.py | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- 18 files changed, 29 insertions(+), 24 deletions(-) diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 5ff8c4f5d2ad2..5db08bf2dc579 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.26.0 +1.27.0 diff --git a/docs/python/README.rst b/docs/python/README.rst index 0e03575236613..e8190c584fb62 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime Date: Mon, 4 May 2026 09:30:58 -0700 Subject: [PATCH 02/34] Bump plugin-ep-webgpu/VERSION_NUMBER to 0.2.0. (#28322) ### Description Bump plugin-ep-webgpu/VERSION_NUMBER to 0.2.0. ### Motivation and Context Version bump after creating the release branch. --- plugin-ep-webgpu/VERSION_NUMBER | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin-ep-webgpu/VERSION_NUMBER b/plugin-ep-webgpu/VERSION_NUMBER index 6e8bf73aa550d..0ea3a944b399d 100644 --- a/plugin-ep-webgpu/VERSION_NUMBER +++ b/plugin-ep-webgpu/VERSION_NUMBER @@ -1 +1 @@ -0.1.0 +0.2.0 From f1c96d56c86d92e56136383cdcab76071419f05c Mon Sep 17 00:00:00 2001 From: Sanaa Hamel Date: Mon, 4 May 2026 13:04:06 -0400 Subject: [PATCH 03/34] refactor(ci): simplify build date/time metadata propagation (#28294) ### Description Avoid having to depend on setup job/task for build date/time. Use pipeline var & runtime expression instead. ### Motivation and Context Faster, no need to wait for an ad-hoc job to set pipeline variables. Easier to read/reason about, reduces cross-stage deps. --- csharp/OnnxRuntime.CSharp.proj | 2 -- .../azure-pipelines/dml-nuget-packaging.yml | 5 ++++- .../nuget/templates/dml-vs-2022.yml | 4 +--- .../stages/nodejs-win-packaging-stage.yml | 5 +---- .../stages/nuget-cuda-packaging-stage.yml | 12 ++++++++---- .../stages/set_packaging_variables_stage.yml | 15 --------------- .../azure-pipelines/templates/c-api-cpu.yml | 15 +++++++++------ .../templates/common-variables.yml | 2 ++ .../templates/foundry-local-nuget-packaging.yml | 2 -- .../templates/managed-nuget-for-foundry-local.yml | 12 +++++++++--- 10 files changed, 34 insertions(+), 40 deletions(-) diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 6779fd60bcd0a..9e96c3ca16105 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -50,8 +50,6 @@ CMake creates a target to this project - $(BuildDate) - $(BuildTime) $([System.DateTime]::UtcNow.ToString(yyyyMMdd)) $([System.DateTime]::UtcNow.ToString(hhmm)) diff --git a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml index 3cf28655c36e7..888a9142088ee 100644 --- a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml @@ -31,6 +31,9 @@ parameters: type: number default: 0 +variables: +- template: templates/common-variables.yml + extends: # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. # For non-production pipelines, use "Unofficial" as defined below. @@ -76,7 +79,7 @@ extends: DoEsrp: ${{ parameters.DoEsrp }} NuPackScript: | python -m pip install setuptools - msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} /p:CurrentData=$(BuildDate) /p:CurrentTime=$(BuildTime) + msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} /p:CurrentData=$(ORT_CI_BUILD_DATE) /p:CurrentTime=$(ORT_CI_BUILD_TIME) if errorlevel 1 exit /b 1 copy $(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) if errorlevel 1 exit /b 1 diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 2548eebeb9d42..fa009c379a911 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -20,7 +20,7 @@ parameters: IsReleaseBuild: false stages: - stage: ${{ parameters.StageName }} - dependsOn: Setup + dependsOn: [] jobs: - job: ${{ parameters.StageName }} timeoutInMinutes: 200 @@ -39,8 +39,6 @@ stages: OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] ${{ if eq(parameters.EnableLto, true) }}: build_py_lto_flag: --enable_lto diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml index b9f2cc0987816..1a0ebd783d552 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml @@ -17,7 +17,6 @@ parameters: stages: - stage: ${{ parameters.StageName }} dependsOn: - - Setup - ${{ if ne(parameters.DependsOnStageName, '') }}: - ${{ parameters.DependsOnStageName }} @@ -60,8 +59,6 @@ stages: runCodesignValidationInjection: ${{ parameters. DoEsrp}} #For the others, code sign is in a separated job DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} - BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] BuildCommandExtra: '' ${{ if eq(parameters.EnableLto, true) }}: build_py_lto_flag: --enable_lto @@ -179,4 +176,4 @@ stages: mkdir $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\llvm-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\clang-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} - displayName: 'Copy WebGPU build tools' \ No newline at end of file + displayName: 'Copy WebGPU build tools' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 79bbe39ce4af2..20105b467d001 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -34,8 +34,6 @@ stages: variables: breakCodesignValidationInjection: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[format('{0:yyyyMMdd}', pipeline.startTime)] - BuildTime: $[format('{0:HHmm}', pipeline.startTime)] steps: - checkout: self @@ -134,8 +132,14 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' configuration: RelWithDebInfo platform: 'Any CPU' - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentDate=$(BuildDate) -p:CurrentTime=$(BuildTime)' + msbuildArguments: >- + -t:CreatePackage + "-p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)" + -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu + "-p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}" + "-p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)" + "-p:CurrentDate=$(ORT_CI_BUILD_DATE)" + "-p:CurrentTime=$(ORT_CI_BUILD_TIME)" workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: BatchScript@1 diff --git a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml index 07dd1549acd2d..a11fd8b89b8b1 100644 --- a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml @@ -53,21 +53,6 @@ stages: echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]" fi name: Set_Release_Version_Suffix - - script: | - # Extracting hours and minutes - date=$(date +'%Y%m%d') - # Set the hhmm value as a pipeline variable - echo "##vso[task.setvariable variable=BuildDate;isOutput=true]$date" - displayName: 'Set Start Date as Variable' - name: Set_Build_Date - - - script: | - # Extracting hours and minutes - hhmm=$(date +'%H%M') - # Set the hhmm value as a pipeline variable - echo "##vso[task.setvariable variable=BuildTime;isOutput=true]$hhmm" - displayName: 'Set Start Time as Variable' - name: Set_Build_Time - bash: | echo "Recording pipeline parameters to a file..." diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 7b8b5758e79b5..13d5578262102 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -286,8 +286,6 @@ stages: variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - checkout: self @@ -356,7 +354,14 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentTime=$(BuildTime) -p:CurrentDate=$(BuildDate)' + msbuildArguments: >- + -t:CreatePackage + "-p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)" + "-p:OrtPackageId=$(OrtPackageId)" + "-p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}" + "-p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)" + "-p:CurrentTime=$(ORT_CI_BUILD_TIME)" + "-p:CurrentDate=$(ORT_CI_BUILD_DATE)" workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 @@ -420,8 +425,6 @@ stages: NpmPackagingMode: 'release' ${{ if not(eq(parameters.IsReleaseBuild, true)) }}: NpmPackagingMode: 'dev' - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - checkout: self @@ -635,7 +638,7 @@ stages: Write-Host "Latest version of ${packageName}: $latestVersion" # Generate current version - $currentVersion = "$(cat .\VERSION_NUMBER)-dev-$($env:BuildDate)-$($env:BuildTime)-$(git rev-parse --short HEAD)" + $currentVersion = "$(cat .\VERSION_NUMBER)-dev-$($env:ORT_CI_BUILD_DATE)-$($env:ORT_CI_BUILD_TIME)-$(git rev-parse --short HEAD)" Write-Host "Current version: $currentVersion" # Set the version as an environment variable diff --git a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml index 8c8dae9820810..250a023bcc158 100644 --- a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml +++ b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml @@ -7,5 +7,7 @@ variables: linux_trt_version_cuda12: ${{ variables.cuda12_trt_version }}-1.cuda12.9 # aarch64 TRT tar download (no RPMs available for aarch64) aarch64_trt_download_url_cuda13: https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.15.1/tars/TensorRT-${{ variables.aarch64_trt_version }}.Linux.aarch64-gnu.cuda-13.1.tar.gz + ORT_CI_BUILD_DATE: $[ format('{0:yyyyMMdd}', pipeline.startTime) ] + ORT_CI_BUILD_TIME: $[ format('{0:HHmm}', pipeline.startTime) ] win_trt_folder_cuda13: TensorRT-${{ variables.cuda13_trt_version }}.Windows.win10.cuda-13.0 win_trt_folder_cuda12: TensorRT-${{ variables.cuda12_trt_version }}.Windows.win10.cuda-12.9 diff --git a/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml index 2d1c182ec7512..44012af808a46 100644 --- a/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml @@ -30,8 +30,6 @@ stages: variables: DoEsrp: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - task: DownloadPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml b/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml index 0be3f4de65647..c14ac6cc7a3fd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml +++ b/tools/ci_build/github/azure-pipelines/templates/managed-nuget-for-foundry-local.yml @@ -31,8 +31,6 @@ stages: variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] - BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - template: set-version-number-variables-step.yml @@ -86,7 +84,15 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'AnyCPU' configuration: RelWithDebInfo - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentTime=$(BuildTime) -p:CurrentDate=$(BuildDate) -p:IncludeMobileTargets=false' + msbuildArguments: >- + -t:CreatePackage + "-p:OnnxRuntimeBuildDirectory=$(Build.BinariesDirectory)" + -p:OrtPackageId=Microsoft.ML.OnnxRuntime + "-p:IsReleaseBuild=${{ parameters.IsReleaseBuild }}" + "-p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)" + "-p:CurrentTime=$(ORT_CI_BUILD_TIME)" + "-p:CurrentDate=$(ORT_CI_BUILD_DATE)" + -p:IncludeMobileTargets=false workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 From 40c9f85f698b37c33cc5cca4f381f7b54dcb1087 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Mon, 4 May 2026 11:53:57 -0700 Subject: [PATCH 04/34] Add plugin-ep-webgpu/RELEASE.md (#28321) ### Description Add release info doc for WebGPU plugin EP. ### Motivation and Context Document release info. --- plugin-ep-webgpu/RELEASE.md | 62 +++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 plugin-ep-webgpu/RELEASE.md diff --git a/plugin-ep-webgpu/RELEASE.md b/plugin-ep-webgpu/RELEASE.md new file mode 100644 index 0000000000000..8244e38eaee9a --- /dev/null +++ b/plugin-ep-webgpu/RELEASE.md @@ -0,0 +1,62 @@ +# Release Process + +This document describes the release conventions and process for the WebGPU plugin EP. + +## Versioning + +The plugin follows [Semantic Versioning](https://semver.org/): + +- **MAJOR** — incompatible API/ABI changes. +- **MINOR** — backwards-compatible feature additions. +- **PATCH** — backwards-compatible bug and security fixes. + +The current version is tracked in [VERSION_NUMBER](VERSION_NUMBER). + +## Branch and tag naming + +All release refs are namespaced under `plugin-ep-webgpu/` so they group together in `git branch` / `git tag` +listings and don't collide with the main ONNX Runtime release refs. + +- **Release branch:** `plugin-ep-webgpu/rel-X.Y` + - One branch per minor version line (e.g. `plugin-ep-webgpu/rel-1.0`). + - Holds all patch releases for that minor line (1.0.0, 1.0.1, 1.0.2, ...). + - Forked from `main` at the point of the first release on that line. +- **Release tag:** `plugin-ep-webgpu/vX.Y.Z` + - One tag per shipped release (e.g. `plugin-ep-webgpu/v1.0.0`). + - Tags are immutable and are the source of truth for "what shipped." +- **Pre-release tag:** `plugin-ep-webgpu/vX.Y.Z-rc.N` (semver-style) + - Used for release candidates and other pre-release artifacts. + - Note: this convention is forward-looking as we don't have release candidates in the release process yet. + +The `rel-` prefix on branches and the `v` prefix on tags ensure branches and tags are never ambiguous at the ref +level. + +### Difference from the main ONNX Runtime convention + +The main ORT repo uses **per-patch** release branches of the form `rel-X.Y.Z` (e.g. `rel-1.20.0`, `rel-1.20.1`). +This plugin deliberately uses **per-minor** branches (`rel-X.Y`) instead. + +The per-minor model is simpler: one long-lived branch per supported minor line, with each patch release marked by a +tag on that branch. Tags are the immutable record of what shipped; the branch is just where the next patch is staged. +For a component of this size and release cadence, that is sufficient and avoids the branch sprawl of the per-patch +model. + +The per-minor model is also the broader open-source convention (Linux, LLVM, Python, Node, Kubernetes), so +contributors coming from outside the ORT ecosystem will find it familiar. The namespaced ref prefix +(`plugin-ep-webgpu/`) keeps the plugin's release refs cleanly separated from the main ORT release refs. + +## Release workflow + +1. Prepare the release branch. + - New minor or major release: + - Create release branch `plugin-ep-webgpu/rel-X.Y` from `main`. + `main`'s `VERSION_NUMBER` should already be `X.Y.0`, reflecting the release that is about to be cut. + - Bump `VERSION_NUMBER` on `main` to the next development version (e.g. `X.(Y+1).0`). + - Patch release: + - Bump `VERSION_NUMBER` on the release branch to `X.Y.Z`. +2. Integrate any fixes into the release branch. These may be cherry-picked from `main` or made directly in the + release branch. The latter should be re-integrated into `main` unless the fix is specific to the release branch. +3. Run the full validation pipeline against the release branch tip. +4. Repeat steps 2 and 3 as needed. +5. Tag the release branch tip as `plugin-ep-webgpu/vX.Y.Z`. +6. Publish artifacts from the tag. From 4ca6b22880674707f359627c8651b4c81e8bdf8e Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 4 May 2026 13:41:48 -0700 Subject: [PATCH 05/34] Eliminate Legacy MHA Unfused path from ONNX Attention; unify on 3-tier dispatch with causal alignment fix (#27992) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Eliminate the legacy MHA Unfused path (`QkvToContext` in `attention_impl.cu`) from the ONNX standard Attention op, simplifying the CUDA dispatch to a clean 3-tier cascade. ## Design ``` Flash Attention → Memory-Efficient Attention (MEA) → Unified Unfused Attention ``` - **Flash**: Handles fp16/bf16 with head_size ≤ 256, no explicit attn_mask. Fastest path. - **MEA (CUTLASS)**: Handles cases Flash cannot (explicit masks, softcap+mask combos). Requires head_size % 8 == 0. - **Unified Unfused**: Fallback for everything else — fp32, small heads, H≠H_v, output_qk. Handles both MHA and GQA via FP32 QK accumulation. The legacy `RunUnfusedAttention` wrapper (which called contrib ops `QkvToContext`) is deleted. The contrib MHA op is unaffected. ## Key Behavior Changes - **Unified unfused kernel** replaces separate GQA-only and MHA-only unfused paths - **Causal alignment**: lower-right when past_key is present, upper-left otherwise (per ONNX spec) - **H≠H_v + past KV** now supported (separate K/V concat calls) - **output_qk (mode 0)** supported in unified kernel via `ScaledCopyQkKernel` - **29 ONNX backend test filters removed** — tests now pass natively ## Testing All existing tests pass (40 C++ attention tests, 215 Python parametrized cases) plus new coverage for causal alignment on CPU EP and softcap ordering verification. Closes #27880. Related: #27516, #28198. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cuda-attention-kernel-patterns/SKILL.md | 237 ++++ .../cpu/bert/attention_parameters.h | 1 + .../contrib_ops/cuda/bert/attention_data.h | 4 +- .../bert/cutlass_fmha/fmha_launch_template.h | 9 +- .../cutlass_fmha/memory_efficient_attention.h | 13 + .../cuda/bert/group_query_attention.cc | 6 +- .../cuda/bert/group_query_attention_impl.cu | 10 +- ...used_attention.cu => unfused_attention.cu} | 118 +- ...nfused_attention.h => unfused_attention.h} | 30 +- .../core/providers/cuda/llm/attention.cc | 684 +++++---- .../core/providers/cuda/llm/attention.h | 17 +- .../providers/cuda/llm/attention_mask_impl.cu | 101 -- .../providers/cuda/llm/attention_mask_impl.h | 28 - .../providers/cpu/llm/attention_op_test.cc | 366 ++++- .../test_onnx_attention/common.py | 67 +- .../test_onnx_attention/test_gqa.py | 435 +++++- .../test_onnx_attention/test_mha.py | 1229 +++++++++++++++-- .../test_tensorscatter_attention.py | 80 +- .../onnx_backend_test_series_filters.jsonc | 32 +- 19 files changed, 2695 insertions(+), 772 deletions(-) create mode 100644 .agents/skills/cuda-attention-kernel-patterns/SKILL.md rename onnxruntime/contrib_ops/cuda/bert/{gqa_unfused_attention.cu => unfused_attention.cu} (77%) rename onnxruntime/contrib_ops/cuda/bert/{gqa_unfused_attention.h => unfused_attention.h} (77%) diff --git a/.agents/skills/cuda-attention-kernel-patterns/SKILL.md b/.agents/skills/cuda-attention-kernel-patterns/SKILL.md new file mode 100644 index 0000000000000..5325a1bf22bdc --- /dev/null +++ b/.agents/skills/cuda-attention-kernel-patterns/SKILL.md @@ -0,0 +1,237 @@ +--- +name: cuda-attention-kernel-patterns +description: Patterns and pitfalls for the ONNX domain Attention operator (opset 23/24) CUDA implementation. Use when modifying the dispatch cascade in core/providers/cuda/llm/attention.cc, writing mask/bias CUDA kernels, debugging attention test routing, or adding features to the ONNX Attention op. NOT for contrib domain MultiHeadAttention/GroupQueryAttention. +--- + +# ONNX Domain Attention (Opset 23/24) CUDA Patterns + +Reusable knowledge from ONNX Attention CUDA development in ORT. + +> **Scope**: This skill covers the **ONNX domain** `Attention` operator (opset 23/24) +> implemented at `core/providers/cuda/llm/attention.cc`. This is **separate from** the +> contrib domain `MultiHeadAttention` / `GroupQueryAttention` at `contrib_ops/cuda/bert/`. +> They share some underlying kernels (CUTLASS FMHA, Flash Attention) and infrastructure +> (`attention_softmax.h`) but have **different dispatch logic, parameter structs, and eligibility checks**. +> +> - **Shared infrastructure**: CUTLASS FMHA kernel, Flash kernel, unified unfused kernel +> (`unfused_attention.cu`), `attention_softmax.h`, `attention_impl.cu` (contrib only) +> - **ONNX-specific**: Dispatch cascade in `attention.cc`, `ConvertAttnMaskToBias`, +> `mask_filter_value` cap, parameter bridge to contrib structs, `attention_mask_impl.cu` +> - **Contrib-specific**: Own dispatch in contrib MHA/GQA ops, uses `contrib::AttentionParameters` +> directly, has XQA kernel, past-present buffer sharing + +## 1. Runner Dispatch Cascade + +CUDA attention dispatches in priority order: **Flash → MEA (Memory Efficient) → Unified Unfused Attention**. + +``` +// onnxruntime/core/providers/cuda/llm/attention.cc — ComputeInternal() +Flash eligible? → RunFlashAttention() + ↓ no +MEA eligible? → RunMemoryEfficientAttention() + ↓ no +Unified Unfused → RunUnfusedAttention() + (handles both MHA and GQA via reshape-Q trick) +``` + +**Flash eligibility**: fp16/bf16 only, SM≥8.0 (Ampere+), `head_size == v_head_size`, `head_size <= 256`, no `output_qk`, `attn_mask == nullptr`. Uses `mha_fwd` / `mha_fwd_kvcache`. + +**MEA eligibility**: SM50+/53+/80+ by dtype, `head_size <= 1024` and divisible by 8, no `output_qk`. Decode requires `head_size == v_head_size` (for `LaunchConcatNewToPastKV`). Bias stride must satisfy `total_sequence_length % 4 == 0`. GQA with FP32 is excluded (LaunchUngroup only has fp16/bf16 instantiations). Supports `softcap + attn_mask` — CUTLASS applies softcap before bias in kernel tiles, matching ONNX spec ordering (onnx/onnx#7865). + +**Unified Unfused Attention**: Always available as the final fallback. Handles both MHA (`num_heads == kv_num_heads`, group=1) and GQA (`num_heads != kv_num_heads`, group>1) via a reshape-Q trick with stride-based cuBLAS batched GEMM (no K/V head replication). Uses FP32 QK scratch for precision. Supports all features: +- softcap + attn_mask (spec-correct ordering) +- output_qk (kQK mode: copies raw QK before softcap/mask mutations) +- past_key + past_value with `head_size != v_head_size` (separate K/V concat) +- causal masking, nonpad_kv_seqlen, all dtypes (fp16/bf16/fp32) + +## 2. CUTLASS kLog2e Overflow + +CUTLASS `iterative_softmax` multiplies all attention scores by `kLog2e ≈ 1.4427` internally (for `exp2f` instead of `expf`). For float/bf16: + +``` +mask_filter_value = std::numeric_limits::lowest() ≈ -3.40e+38 +-3.40e+38 × 1.4427 ≈ -4.91e+38 → overflows fp32 → -inf +``` + +When all values become `-inf`, CUTLASS's special-case path produces `s_prime=0` → `1/s_prime=inf` → `0 × inf = NaN`. + +**Fix**: Cap `mask_filter_value` to `-1.0e+30f` in `ConvertAttnMaskToBias`. This value is safe: `1e30 × 1.4427 ≈ 1.4e30 << FLT_MAX`, and `exp(-1e30) ≈ 0` (effectively masked). + +**fp16 is NOT affected**: `lowest() = -65504`, and `-65504 × 1.4427 ≈ -94500` stays within fp32 range. + +This cap is ONLY applied in MEA paths. The unfused path uses `lowest()` directly (its softmax subtracts max first, avoiding overflow). + +**Subtlety**: When bias is present (`kSupportsBias=true`), CUTLASS pre-applies `p.scale` to QK (line 858) and uses `scaling=1.0f` in the softmax loop (line 981). So the full `kLog2e` multiplier hits the bias-dominated values — the overflow is head_size-independent. Without bias, `scaling = p.scale * kLog2e = kLog2e/sqrt(head_size)`, which is much smaller. + +## 3. Bias Alignment + +CUTLASS FMHA requires the attention bias row stride to satisfy minimum alignment. The bias has shape `[B, H, S, T]` where `T = total_sequence_length` is the row stride. + +```cpp +constexpr int min_bias_align = 4; // elements, not bytes +if (parameters.total_sequence_length % min_bias_align != 0) { + mea_eligible = false; // fall through to unfused +} +``` + +**Impact on tests**: If a test uses `total_sequence_length` not divisible by 4 (e.g., past=5 + new=6 = 11), MEA is rejected and unfused handles it. To test MEA with bias, ensure `total_sequence_length % 4 == 0`. + +## 4. Softcap Ordering + +ONNX spec ordering (onnx/onnx#7865): `QK → scale → softcap → add mask/bias → softmax` + +- **MEA (CUTLASS)**: Fuses softcap before bias in kernel tile loop (`kernel_forward.h`). Matches spec ordering. +- **Flash**: Handles softcap natively in `mha_fwd`/`mha_fwd_kvcache` but rejects `attn_mask`, so ordering with mask is moot. +- **Unfused**: Handles spec-correct ordering in the fused softmax kernel: `QK → scale → softcap → add bias → softmax`. + +All three paths apply softcap BEFORE mask/bias. If softcap were applied after masking, `tanh(-inf/sc) = -sc` (finite), leaking probability to masked positions. + +The unfused path does: `QK → scale → softcap → add bias → softmax` (all fused in `UnfusedSoftmaxKernel`). + +## 5. Grid-Stride Loops for CUDA Kernels + +Always cap grid size to prevent exceeding `gridDim.x` limits, and use grid-stride loops for large workloads: + +```cpp +constexpr int64_t kMaxGridDimX = 65535; +int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); +int64_t blocks = (total + threads - 1) / threads; +unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); + +MyKernel<<>>(...); + +// Inside the kernel: +for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < total; + idx += static_cast(gridDim.x) * blockDim.x) { + // work +} +``` + +**Never** cast `int64_t` block count directly to `unsigned int` without capping — it silently truncates. + +Always call `CUDA_CALL(cudaGetLastError())` after kernel launches in standalone helper functions. This is the established pattern in the file (see `ConcatPastToPresent`, `PastPresentBufferShare`). + +## 6. Fully-Masked Batches + +All-false bool masks or `seqlens_k=0` produce NaN in CUTLASS MEA. + +**Additive-bias path** (bool mask converted to bias): Fixed by capping `mask_filter_value` to `-1e+30f` (see section 2). CUTLASS then naturally computes uniform softmax → mean(V). + +**Nonpad path** (`seqlens_k=0`): CUTLASS skips all K/V positions → `s_prime=0` → NaN. Fixed by `ZeroOutputForFullyMaskedBatches` kernel which zeros output for batches where `seqlens_k[b] == 0`. Note: this produces zeros, not mean(V) — a cross-EP consistency TODO exists. + +**CPU/Unfused behavior**: `mask_filter_value = lowest()` (not `-inf`). All masked values are equal → `softmax(equal) = 1/N` → output = mean(V). This is the spec reference. + +## 7. Test Runner Targeting + +Use `ScopedEnvironmentVariables` to force specific CUDA runners: + +```cpp +// Force MEA (disable Flash) +ScopedEnvironmentVariables scoped_env({ + {"ORT_DISABLE_FLASH_ATTENTION", "1"}, +}); + +// Force Unfused (disable both Flash and MEA) +ScopedEnvironmentVariables scoped_env({ + {"ORT_DISABLE_FLASH_ATTENTION", "1"}, + {"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", "1"}, +}); +``` + +**Always verify which runner a test actually hits.** A test designed for MEA may silently fall to unfused if: +- `total_sequence_length % 4 != 0` (bias alignment) +- `head_size != v_head_size` (decode path) +- fp32 dtype with GQA (LaunchUngroup fp16/bf16 only) +- fp32 dtype on SM < 80 + +Enable verbose logging to confirm: `LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using ..."`. + +## 8. Cross-EP Consistency + +CPU is the spec reference implementation. CUDA outputs should match CPU for all valid inputs. + +- CPU uses `mask_filter_value = std::numeric_limits::lowest()` (finite, not `-inf`) +- CPU softmax: subtract-max-first → works correctly with extreme finite values +- CPU handles fully-masked batches naturally (uniform softmax → mean(V)) + +Run tests with `disable_cpu=false` to always validate against CPU. The C++ test framework (`RunTest4D`) supports `disable_cpu`, `disable_cuda`, `disable_dml` flags. + +## 9. File Locations + +### ONNX Domain (this op's code) + +| File | Purpose | +|------|---------| +| `core/providers/cuda/llm/attention.cc` | ONNX Attention CUDA dispatch: Flash/MEA/Unfused cascade, `ConvertAttnMaskToBias`, parameter setup | +| `core/providers/cuda/llm/attention_mask_impl.cu` | ONNX-specific mask/bias CUDA kernels: bool→bias, nonpad→seqlens_k, ZeroOutput, bias composition | +| `core/providers/cuda/llm/attention_mask_impl.h` | Declarations for ONNX mask/bias kernels | +| `core/providers/cpu/llm/attention.cc` | CPU reference implementation (ONNX domain) | +| `core/providers/cpu/llm/attention_helper.h` | ONNX parameter validation and shape computation | +| `test/providers/cpu/llm/attention_op_test.cc` | C++ attention tests (all EPs) | +| `test/python/transformers/test_onnx_attention/test_mha.py` | Python parity tests | +| `test/python/transformers/test_onnx_attention/common.py` | Python test utilities and reference `attention_ref()` | + +### Shared Infrastructure (used by both ONNX and contrib ops) + +| File | Purpose | +|------|---------| +| `contrib_ops/cuda/bert/unfused_attention.cu` | Unified unfused attention: QK GEMM (FP32), fused softmax kernel (scale+softcap+bias+causal), V GEMM. Handles MHA and GQA. | +| `contrib_ops/cuda/bert/unfused_attention.h` | `UnfusedAttentionParams`, `LaunchUnfusedAttention`, workspace size | +| `contrib_ops/cuda/bert/attention_impl.cu` | Legacy unfused `QkvToContext` (contrib MHA only). Also `ApplySoftcap`, `ConcatPastToPresent` | +| `contrib_ops/cuda/bert/attention_softmax.h` | CUDA softmax kernels (`ComputeSoftmax`, `ComputeSoftmaxWithRawMask`) — used by legacy contrib path | +| `contrib_ops/cuda/bert/cutlass_fmha/` | CUTLASS FMHA (Memory Efficient Attention) kernels | +| `contrib_ops/cuda/bert/flash_attention/` | Flash Attention kernels | + +### Contrib Domain (separate ops, NOT covered by this skill) + +| File | Purpose | +|------|---------| +| `contrib_ops/cuda/bert/multihead_attention.cu` | Contrib `MultiHeadAttention` — own dispatch, uses `contrib::AttentionParameters` directly | +| `contrib_ops/cuda/bert/group_query_attention.cu` | Contrib `GroupQueryAttention` — has XQA kernel, past-present buffer sharing | + +## 10. Parameter Bridge (ONNX → Contrib) + +The ONNX Attention op uses `attention_helper::AttentionParameters` (in `core/providers/cpu/llm/attention_parameters.h`). The unified unfused kernel (`LaunchUnfusedAttention`) uses its own `UnfusedAttentionParams` struct populated directly from ONNX parameters in `RunUnfusedAttention`. + +The contrib `QkvToContext` function (used by contrib MHA, NOT by ONNX Attention) uses `contrib::AttentionParameters`. ONNX Attention does **not** bridge to `contrib::AttentionParameters` — it routes through the unified unfused kernel instead. + +## 11. Causal Alignment + +The ONNX spec defines two causal alignment modes based on where query positions sit in the full attention matrix: + +- **Upper-left**: `q_i` attends to `kv[0..i]`. Query positions start at 0 in the full matrix. +- **Lower-right**: `q_i` attends to `kv[kv_len - q_len + i..kv_len - 1]`. Query positions are at the end. + +**ONNX spec rule**: `is_causal=1` always means upper-left in the full matrix. When `past_key` provides context, `past_sequence_length` shifts the query start position forward — the resulting `[S_q × total_kv]` sub-matrix effectively has lower-right alignment. + +### Per-kernel behavior + +| Kernel | Alignment | Mechanism | +|--------|-----------|-----------| +| **Flash** | Lower-right only | `is_causal` flag → `seqlen_k - seqlen_q` offset in kernel. No top-left option. | +| **MEA (CUTLASS)** | Both | `causal_from_top_left` flag in `MemoryEfficientAttentionParams`. `true` → `CausalFromTopLeft` (offset=0). `false` → `CausalFromBottomRight` (offset = num_keys - num_queries). | +| **Unfused** | Both | `past_kv_length` param. `0` → upper-left. `total_kv - S_q` → lower-right. | + +### Dispatch logic in attention.cc + +```cpp +// Flash cannot do upper-left → guarded by causal_cross_no_past +bool causal_cross_no_past = parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; + +// Flash: skip when causal_cross_no_past (no top-left support) +// MEA: NOT skipped — handles it via causal_from_top_left = (past_sequence_length == 0) +// Unfused: always correct via past_kv_length = parameters.past_sequence_length +``` + +### When S_q == S_kv + +Upper-left and lower-right produce **identical** results when `S_q == S_kv` (the offset is 0 either way). The alignment distinction only matters for cross-attention shapes (`S_q != S_kv`). + +### TensorScatter decode (opset 24 external KV cache) + +TensorScatter manages KV cache externally — `past_key` is nullptr but K/V already contain the full sequence. Per the ONNX spec, `is_causal` with `S_q != S_kv` and no `past_key` means upper-left (q[0] sees only kv[0]), which is **not meaningful for decode**. + +**Correct pattern**: TensorScatter decode must use `is_causal=0` and rely on `nonpad_kv_seqlen` to bound the active KV range. Models using `is_causal=1` with TensorScatter decode have a spec-invalid combination. diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index f316a0dfdf91c..5b7624d11c6fd 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -33,6 +33,7 @@ struct AttentionParameters { bool broadcast_attn_bias_dim_1 = false; float mask_filter_value = 0.0f; float scale = 0.0f; + float softcap = 0.0f; bool use_tf32 = false; bool is_output_bnsh = false; // whether the output format is BNSH AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 98f92b79e6ec6..60f2d05446da1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -205,11 +205,11 @@ struct GroupQueryAttentionData { void* xqa_buffer = nullptr; size_t xqa_buffer_bytes = 0; - // Unfused fallback buffers (see LaunchGqaUnfusedAttention in gqa_unfused_attention.h): + // Unfused fallback buffers (see LaunchUnfusedAttention in unfused_attention.h): // unfused_q_bnsh : [B, N_q, S_q, H] (Q transposed from BSNH to BNSH) // unfused_y_bnsh : [B, N_q, S_q, H_v] (output BNSH, transposed to BSNH before leaving op) // unfused_workspace: FP32 QK scratch + T softmax scratch (sized by - // GetGqaUnfusedAttentionWorkspaceSize) + // GetUnfusedAttentionWorkspaceSize) T* unfused_q_bnsh = nullptr; T* unfused_y_bnsh = nullptr; void* unfused_workspace = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 29bb4fba6a09a..aedb370d38367 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -176,7 +176,14 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromBottomRight; + // ONNX spec: is_causal means upper-left alignment (q_i attends to kv[0..i]). + // When past_sequence_length > 0 (decode with KV cache), positions shift → lower-right. + // causal_from_top_left=true: past_seq==0, use CausalFromTopLeft (offset=0). + // causal_from_top_left=false: past_seq>0 or S_q==S_kv, use CausalFromBottomRight + // (offset = num_keys - num_queries, which is 0 when square). + p.custom_mask_type = params.causal_from_top_left + ? Attention::CausalFromTopLeft + : Attention::CausalFromBottomRight; } // We use max_sequence_length to calculate KV stride diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index ace598489a226..a961be051a16a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -13,6 +13,13 @@ namespace cuda { constexpr int kEfficientAttentionMaxHeadSize = 1024; +// CUTLASS online softmax multiplies attention scores by kLog2e (≈1.4427). +// For float/bf16, |lowest() × kLog2e| > FLT_MAX, overflowing to -inf and +// causing s_prime=0 → NaN for fully-masked batches. Cap to prevent this. +// -1e+30 is safe: 1e30 × 1.4427 ≈ 1.4e30 << FLT_MAX ≈ 3.4e38, and +// exp(-1e30) ≈ 0 (effectively masked). For fp16 lowest()=-65504 > -1e30, no-op. +constexpr float kCutlassSafeMaskFilterValue = -1.0e+30f; + struct MemoryEfficientAttentionParams { int32_t sm = 50; bool is_half = false; @@ -27,6 +34,12 @@ struct MemoryEfficientAttentionParams { int32_t v_head_size = 0; int32_t local_window_size = -1; bool causal = false; + // When true, causal masking uses upper-left alignment (q_i attends to kv[0..i]). + // When false (default), uses lower-right alignment (q_i attends to kv[kv_len-q_len+i..kv_len-1]). + // ONNX Attention spec requires upper-left for cross-attention without past (S_q != S_kv, past=0). + // Lower-right is correct for decode with KV cache (past > 0). + // For square matrices (S_q == S_kv), both alignments produce identical results. + bool causal_from_top_left = false; bool use_smooth_softmax = false; bool broadcast_attn_bias_dim_0 = false; bool broadcast_attn_bias_dim_1 = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 5f21f3cd34e8f..dfecc2b810a04 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -14,7 +14,7 @@ #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/xqa/xqa_loader.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cpu/utils/debug_macros.h" @@ -513,7 +513,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // GQA-capable unfused fallback (issue #28195). // Activates when Flash / MEA / XQA are all ineligible and KV is not quantized. // Supports any head_size (FP32 QK accumulation), GQA, sliding window, softcap. - // See LaunchGqaUnfusedAttention in contrib_ops/cuda/bert/gqa_unfused_attention.h. + // See LaunchUnfusedAttention in contrib_ops/cuda/bert/unfused_attention.h. // --------------------------------------------------------------------- IAllocatorUniquePtr unfused_scratch; if (!data.use_xqa && !data.use_flash_attention && !data.use_memory_efficient_attention && @@ -538,7 +538,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const SafeInt q_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H * sizeof(T)); const SafeInt y_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H_v * sizeof(T)); const SafeInt ws_bytes = SafeInt( - onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + onnxruntime::contrib::cuda::GetUnfusedAttentionWorkspaceSize( static_cast(B), static_cast(N_q), static_cast(S_q), static_cast(S_kv))); const SafeInt workspace_offset = q_bnsh_bytes + y_bnsh_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ebb6a0b0da215..70c58e6b8f764 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -38,7 +38,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh" @@ -1095,7 +1095,7 @@ Status UnfusedGqaAttention( } // Step 3: run unfused attention with FP32 QK accumulation. - GqaUnfusedAttentionParams p; + UnfusedAttentionParams p; p.batch_size = batch_size; p.num_heads = num_heads; p.kv_num_heads = kv_num_heads; @@ -1113,18 +1113,20 @@ Status UnfusedGqaAttention( p.broadcast_attn_bias_dim_1 = false; p.is_causal = parameters.is_unidirectional; p.local_window_size = parameters.local_window_size; // -1 disables + p.past_kv_length = parameters.total_sequence_length - parameters.sequence_length; p.scale = scale; p.softcap = parameters.softcap; p.seqlens_k = data.total_seq_lens; - ORT_RETURN_IF_ERROR((LaunchGqaUnfusedAttention( + ORT_RETURN_IF_ERROR((LaunchUnfusedAttention( device_prop, cublas, stream, p, data.unfused_q_bnsh, reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), /*attn_bias=*/nullptr, data.unfused_y_bnsh, - data.unfused_workspace))); + data.unfused_workspace, + /*output_qk=*/nullptr))); // Step 4: transpose output BNSH → BSNH into data.output. // Use p.v_head_size (== head_size per ORT_ENFORCE) for semantic correctness. diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu similarity index 77% rename from onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu rename to onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu index 8aac549aeba01..a0c9d4666cae3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.cu @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// GQA-capable unfused CUDA attention kernel. See header for contract. +// Unified unfused CUDA attention kernel. See header for contract. +#include #include #include "core/providers/cuda/cu_inc/cub.cuh" #include @@ -13,7 +14,7 @@ #include "core/providers/cuda/cuda_type_conversion.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" using onnxruntime::cuda::OrtToCudaType; @@ -38,10 +39,37 @@ __device__ __forceinline__ float ToFloat<__half>(__half v) { return __half2float template <> __device__ __forceinline__ float ToFloat<__nv_bfloat16>(__nv_bfloat16 v) { return __bfloat162float(v); } +// Device helper: convert float to T. +template +__device__ __forceinline__ T FromFloat(float v); +template <> +__device__ __forceinline__ float FromFloat(float v) { return v; } +template <> +__device__ __forceinline__ __half FromFloat<__half>(float v) { return __float2half(v); } +template <> +__device__ __forceinline__ __nv_bfloat16 FromFloat<__nv_bfloat16>(float v) { return __float2bfloat16(v); } + inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total_kv) { return SafeInt(batch_size) * num_heads * q_seq * total_kv; } +// --------------------------------------------------------------------------- +// CopyQK kernel: copies FP32 QK scratch to T output with scale applied. +// output_qk[i] = T(qk_fp32[i] * scale) for i in [0, total_elements). +// --------------------------------------------------------------------------- +template +__global__ void ScaledCopyQkKernel( + const float* __restrict__ qk_fp32, + T* __restrict__ output_qk, + const float scale, + const int64_t total_elements) { + for (int64_t idx = static_cast(blockIdx.x) * TPB + threadIdx.x; + idx < total_elements; + idx += static_cast(gridDim.x) * TPB) { + output_qk[idx] = FromFloat(qk_fp32[idx] * scale); + } +} + // --------------------------------------------------------------------------- // Softmax kernel: reads FP32 QK scores, writes T softmax output. // @@ -56,7 +84,7 @@ inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total // total_kv_length. Handles fully-masked rows by emitting zeros (no NaN). // --------------------------------------------------------------------------- template -__global__ void GqaUnfusedSoftmaxKernel( +__global__ void UnfusedSoftmaxKernel( const int q_sequence_length, const int total_kv_length, const int num_heads, // N_q @@ -68,6 +96,7 @@ __global__ void GqaUnfusedSoftmaxKernel( const int* __restrict__ seqlens_k, const bool is_causal, const int local_window_size, + const int past_kv_length, const float scale, const float softcap, T* __restrict__ softmax_out) { @@ -82,12 +111,13 @@ __global__ void GqaUnfusedSoftmaxKernel( if (v < kv_end) kv_end = v; if (v < 0) kv_end = 0; } - // past (number of KV positions before the current query tokens) must be - // per-batch when seqlens_k is provided, since different batches can have - // different amounts of valid past context. Using the global total_kv_length - // would over-estimate past for short batches and shift the sliding-window - // start past kv_end, producing an all-masked (zero) row. - const int past = kv_end - q_sequence_length; + // past_kv_length is the number of KV positions that precede the current query + // tokens. For upper-left causal alignment (ONNX Attention with no past), + // this is 0. For lower-right alignment (decode with past), this is + // total_kv_length - q_sequence_length. + // When seqlens_k varies per batch (GQA sliding window), derive per-batch + // so the window cutoff stays within the valid range for shorter batches. + const int past = (seqlens_k != nullptr) ? (kv_end - q_sequence_length) : past_kv_length; const int q_pos = past + q_in_head; int end = kv_end; @@ -191,16 +221,16 @@ __global__ void GqaUnfusedSoftmaxKernel( } template -void LaunchGqaUnfusedSoftmax( +void LaunchUnfusedSoftmax( cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const float* qk_in, const T* attn_bias, T* softmax_out) { const dim3 grid(params.num_heads * params.q_sequence_length, params.batch_size, 1); const bool has_bias = (attn_bias != nullptr); constexpr int TPB = 256; - GqaUnfusedSoftmaxKernel<<>>( + UnfusedSoftmaxKernel<<>>( params.q_sequence_length, params.total_kv_length, params.num_heads, @@ -212,6 +242,7 @@ void LaunchGqaUnfusedSoftmax( params.seqlens_k, params.is_causal, params.local_window_size, + params.past_kv_length, params.scale, params.softcap, softmax_out); @@ -250,7 +281,7 @@ template common::Status LaunchQkGemmFp32( const cudaDeviceProp& /*device_prop*/, cublasHandle_t cublas, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, float* qk_out) { @@ -292,7 +323,7 @@ common::Status LaunchQkGemmFp32( CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention QK GEMM failed: ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "UnfusedAttention QK GEMM failed: ", status); } return common::Status::OK(); } @@ -312,7 +343,7 @@ common::Status LaunchQkGemmFp32( template common::Status LaunchAttnVGemm( cublasHandle_t cublas, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* softmax_out, const T* value, T* output) { @@ -347,7 +378,7 @@ common::Status LaunchAttnVGemm( CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); if (status != CUBLAS_STATUS_SUCCESS) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention AV GEMM failed: ", status); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "UnfusedAttention AV GEMM failed: ", status); } return common::Status::OK(); } @@ -357,10 +388,10 @@ common::Status LaunchAttnVGemm( // --------------------------------------------------------------------------- // Public API // --------------------------------------------------------------------------- -size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, - int num_heads, - int q_sequence_length, - int total_kv_length) { +size_t GetUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length) { const size_t elems = QkElementCount(batch_size, num_heads, q_sequence_length, total_kv_length); // FP32 QK scratch + T softmax scratch. We always allocate sizeof(float) per // element for the T scratch too (upper bound); caller can cast appropriately. @@ -370,26 +401,27 @@ size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, } template -common::Status LaunchGqaUnfusedAttention( +common::Status LaunchUnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t cublas, cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, const T* value, const T* attn_bias, T* output, - void* workspace) { + void* workspace, + T* output_qk) { ORT_RETURN_IF_NOT(params.batch_size > 0 && params.num_heads > 0 && params.kv_num_heads > 0 && params.head_size > 0 && params.v_head_size > 0 && params.q_sequence_length > 0 && params.total_kv_length > 0 && params.max_kv_length >= params.total_kv_length, - "GqaUnfusedAttention: invalid params."); + "UnfusedAttention: invalid params."); ORT_RETURN_IF_NOT(params.num_heads % params.kv_num_heads == 0, - "GqaUnfusedAttention: num_heads (", params.num_heads, + "UnfusedAttention: num_heads (", params.num_heads, ") must be a multiple of kv_num_heads (", params.kv_num_heads, ")."); - ORT_RETURN_IF(workspace == nullptr, "GqaUnfusedAttention: workspace is null."); + ORT_RETURN_IF(workspace == nullptr, "UnfusedAttention: workspace is null."); const size_t elems = QkElementCount(params.batch_size, params.num_heads, params.q_sequence_length, params.total_kv_length); @@ -400,7 +432,21 @@ common::Status LaunchGqaUnfusedAttention( ORT_RETURN_IF_ERROR((LaunchQkGemmFp32(device_prop, cublas, params, query, key, qk_fp32))); - LaunchGqaUnfusedSoftmax(stream, params, qk_fp32, attn_bias, softmax_T); + // Copy scaled QK to output_qk BEFORE softcap/mask/softmax. + // output_qk[i] = T(qk_fp32[i] * scale) — this is "kQK" mode (scale * Q @ K^T). + // Note: When seqlens_k is provided, positions [seqlens_k[b], total_kv) in output_qk + // may contain stale KV cache data. Consumers of output_qk should only read positions + // [0, seqlens_k[b]) for batch b. + if (output_qk != nullptr) { + const int64_t total = static_cast(elems); + constexpr int kTPB = 256; + constexpr int kMaxBlocks = 65535; + const int blocks = static_cast(std::min(static_cast(kMaxBlocks), (total + kTPB - 1) / kTPB)); + ScaledCopyQkKernel<<>>(qk_fp32, output_qk, params.scale, total); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + } + + LaunchUnfusedSoftmax(stream, params, qk_fp32, attn_bias, softmax_T); CUDA_RETURN_IF_ERROR(cudaGetLastError()); ORT_RETURN_IF_ERROR((LaunchAttnVGemm(cublas, params, softmax_T, value, output))); @@ -409,18 +455,18 @@ common::Status LaunchGqaUnfusedAttention( } // Explicit template instantiations. -template common::Status LaunchGqaUnfusedAttention<__half>( +template common::Status LaunchUnfusedAttention<__half>( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const __half*, const __half*, const __half*, - const __half*, __half*, void*); -template common::Status LaunchGqaUnfusedAttention<__nv_bfloat16>( + const UnfusedAttentionParams&, const __half*, const __half*, const __half*, + const __half*, __half*, void*, __half*); +template common::Status LaunchUnfusedAttention<__nv_bfloat16>( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const __nv_bfloat16*, const __nv_bfloat16*, - const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, void*); -template common::Status LaunchGqaUnfusedAttention( + const UnfusedAttentionParams&, const __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, void*, __nv_bfloat16*); +template common::Status LaunchUnfusedAttention( const cudaDeviceProp&, cublasHandle_t, cudaStream_t, - const GqaUnfusedAttentionParams&, const float*, const float*, const float*, - const float*, float*, void*); + const UnfusedAttentionParams&, const float*, const float*, const float*, + const float*, float*, void*, float*); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.h similarity index 77% rename from onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h rename to onnxruntime/contrib_ops/cuda/bert/unfused_attention.h index 84d645cd2b349..8fb3a18ac7570 100644 --- a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/unfused_attention.h @@ -13,7 +13,7 @@ namespace contrib { namespace cuda { // ============================================================================ -// GQA Unfused Attention (CUDA fallback for large head_size / fp16 overflow) +// Unified Unfused Attention (CUDA fallback for large head_size / fp16 overflow) // ============================================================================ // // Purpose: @@ -38,18 +38,20 @@ namespace cuda { // - scale is applied to raw QK (before softcap / bias). // - softcap (> 0) is applied after scale: x = softcap * tanh(x / softcap). // - attn_bias (if non-null) is added after softcap (additive mask). -// - causal: k > (past + q) is -inf where past = total_kv - S_q. +// - causal: k > (past_kv_length + q) is -inf. +// When past_kv_length=0 (no past), gives upper-left alignment: q_i attends to kv[0..i]. +// When past_kv_length=total_kv-S_q (decode with past), gives lower-right alignment. // - local_window_size (>= 0): k < (past + q) - local_window_size is -inf. // local_window_size == -1 disables the sliding-window mask. // // The new kernel is suitable only as a fallback when Flash / MEA are ineligible -// (head_size > 256, past_key present with mask, GQA with MHA-only unfused, etc). +// (head_size > 256, past_key present with mask, etc). // The QK GEMM runs with CUBLAS_COMPUTE_32F and writes a FP32 scratch to avoid // fp16 overflow. // // ============================================================================ -struct GqaUnfusedAttentionParams { +struct UnfusedAttentionParams { int batch_size = 0; int num_heads = 0; // N_q int kv_num_heads = 0; // N_kv (num_heads % kv_num_heads == 0) @@ -68,6 +70,7 @@ struct GqaUnfusedAttentionParams { bool is_causal = false; int local_window_size = -1; // -1 disables sliding window + int past_kv_length = 0; // number of past KV positions (for causal alignment) float scale = 1.0f; float softcap = 0.0f; // 0 disables @@ -77,27 +80,30 @@ struct GqaUnfusedAttentionParams { }; // Returns required scratch size in bytes. Caller must allocate -// GetGqaUnfusedAttentionWorkspaceSize(...) bytes and pass as workspace. -size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, - int num_heads, - int q_sequence_length, - int total_kv_length); +// GetUnfusedAttentionWorkspaceSize(...) bytes and pass as workspace. +size_t GetUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length); // Compute: Y = softmax(scale * Q * K^T [softcap, causal, window, bias, seqlens_k]) * V. // All pointers are on device. Q/K/V/output are in type T (fp16/bf16/float). // attn_bias (if present) is in type T. +// output_qk (optional): when non-null, writes scale * Q @ K^T (FP32→T) before softcap/mask/softmax. +// Shape: [B, N_q, S_q, total_kv]. Caller allocates. template -common::Status LaunchGqaUnfusedAttention( +common::Status LaunchUnfusedAttention( const cudaDeviceProp& device_prop, cublasHandle_t cublas, cudaStream_t stream, - const GqaUnfusedAttentionParams& params, + const UnfusedAttentionParams& params, const T* query, const T* key, const T* value, const T* attn_bias, T* output, - void* workspace); + void* workspace, + T* output_qk = nullptr); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 228729745b65b..00ce18c65efd8 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1,17 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cpu/llm/attention.h" #include "core/providers/cpu/llm/attention_helper.h" #include "core/providers/cuda/llm/attention.h" #include "core/providers/cuda/llm/attention_mask_impl.h" -#include "contrib_ops/cuda/bert/attention_data.h" +// attention_impl.h provides Transpose_BNSH_to_BSNH / Transpose_BSNH_to_BNSH used +// by the transpose helpers. #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_kv_cache.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" -#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" +#include "contrib_ops/cuda/bert/unfused_attention.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "core/providers/cuda/cuda_type_conversion.h" @@ -155,7 +158,12 @@ Status Attention::ConvertAttnMaskToBias( int64_t num_elements = attn_mask->Shape().Size(); converted_mask_buffer = GetScratchBuffer( num_elements * sizeof(NativeCudaT), GetComputeStream(context)); - float mask_filter_value = static_cast(std::numeric_limits::lowest()); + // CUTLASS online softmax multiplies attention scores by kLog2e (≈1.4427). + // For float/bf16, |lowest() × kLog2e| > FLT_MAX, overflowing to -inf and + // causing s_prime=0 → NaN for fully-masked batches. Cap to prevent this. + // See kCutlassSafeMaskFilterValue in memory_efficient_attention.h for details. + float mask_filter_value = std::max(static_cast(std::numeric_limits::lowest()), + ::onnxruntime::contrib::cuda::kCutlassSafeMaskFilterValue); ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( attn_mask->Data(), reinterpret_cast(converted_mask_buffer.get()), @@ -189,7 +197,7 @@ Status Attention::ConvertAttnMaskToBias( // Path 1: nonpad_kv_seqlen (opset 24 external cache) -> mha_fwd_kvcache // Path 2: past_key + past_value (internal cache decode) -> mha_fwd_kvcache // - No mask support (attn_mask rejected at eligibility) -// - 4D BNSH: transposes Q/K/V to BSNH before kernel +// - 4D BNSH: transposes Q to BSNH; new K/V to BSNH for concat (cache stays BNSH) // Path 3: no past, no mask (prompt) -> mha_fwd // Eligibility: fp16/bf16, head_size==v_head_size, no output_qk, attn_mask==nullptr // Note: softcap is passed to the Flash kernel natively. softmax_precision is @@ -334,10 +342,10 @@ Status Attention::RunFlashAttention( ORT_ENFORCE(present_key != nullptr && present_value != nullptr, "present_key/value outputs are required when past_key is provided."); - // TODO(titaiwang): Consolidate preprocessing (RoPE, mask conversion, KV cache concat) into a + // TODO(titaiwang): Consolidate preprocessing (transpose, KV cache concat) into a // single fused kernel like GQA's LaunchUnpackRoPEAppend. Current decode path uses 4-6 kernel - // launches; a fused approach would reduce to ~2, saving ~21μs launch overhead and ~256KB - // intermediate buffer traffic per decode step. + // launches; a fused approach would reduce to ~2, saving launch overhead and intermediate + // buffer traffic per decode step. // Concat past + new KV directly into present buffers using a single fused kernel. // This replaces the old pattern of memset + strided cudaMemcpy2DAsync + Flash's @@ -476,7 +484,7 @@ Status Attention::RunFlashAttention( cuda_stream, device_prop.maxThreadsPerBlock)); } - // --- Populate present_key/value (BNSH) from K/V (BSNH) --- + // --- Populate present_key/value (BNSH) from K/V (BSNH or BNSH) --- // Skip for decode path where mha_fwd_kvcache already populated present buffers. if (!present_kv_already_populated) { if (present_key != nullptr && is_bsnh) { @@ -528,13 +536,15 @@ Status Attention::RunFlashAttention( // ============================================================================ // // Memory Efficient Attention (cutlass FMHA) dispatch paths: -// Path 1: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode -// Path 2: no past, with mask (prompt) -> standard MEA with additive bias -// Path 3: no past, no mask (prompt) -> standard MEA +// Path 1: Decode with past KV cache -> LaunchConcatNewToPastKV then standard MEA +// Path 2: nonpad_kv_seqlen (opset 24 external cache) -> has_custom_right_padding mode +// Path 3: Prompt with mask -> standard MEA with additive bias +// Path 4: Prompt without mask -> standard MEA // Eligibility: see has_memory_efficient_attention() (SM50+/53+/80+ by dtype, -// head_size <= 1024), plus: no output_qk, no past_key (decode excluded), -// bias stride alignment. -// Note: softcap is forwarded to the MEA kernel via p.softcap. softmax_precision +// head_size <= 1024, head_size divisible by 8), plus: no output_qk, bias stride alignment. +// Note: softcap is forwarded to the MEA kernel via p.softcap. CUTLASS applies +// softcap before bias (fused in kernel tiles), matching ONNX spec ordering +// (onnx/onnx#7865): QK → softcap → mask/bias → softmax. softmax_precision // is inherently satisfied (cutlass FMHA accumulates softmax in FP32). // template @@ -546,8 +556,6 @@ Status Attention::RunMemoryEfficientAttention( Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const { #if USE_MEMORY_EFFICIENT_ATTENTION - ORT_UNUSED_PARAMETER(past_key); - ORT_UNUSED_PARAMETER(past_value); auto& device_prop = GetDeviceProp(); auto cuda_stream = Stream(context); const bool is_bsnh = parameters.transpose_output; @@ -582,6 +590,120 @@ Status Attention::RunMemoryEfficientAttention( out_data = out_bsnh_buffer.get(); } + bool present_kv_already_populated = false; + // Track the effective layout of k_data/v_data. Initially matches input layout, + // but changes to BNSH (false) after decode concat into present buffers. + bool kv_is_bsnh = is_bsnh; + + // Scratch buffers for decode concat output when present_key/value are optional. + // Declared at function scope so they outlive the decode block (k_data/v_data may point here). + IAllocatorUniquePtr present_k_scratch; + IAllocatorUniquePtr present_v_scratch; + + // --- Decode path: concat past + new K/V → present buffers (BNSH) --- + // nonpad_kv_seqlen and past_key are mutually exclusive (enforced at validation), + // so the decode path only needs the internal-cache (past_key/present_key) flow. + if (past_key != nullptr) { + ORT_RETURN_IF_NOT(past_value != nullptr, "past_key requires past_value."); + ORT_RETURN_IF_NOT(nonpad_kv_seqlen == nullptr, + "nonpad_kv_seqlen and past_key are mutually exclusive (internal vs external cache)."); + // This mirrors the eligibility check in ComputeInternal — must stay in sync. + ORT_RETURN_IF_NOT(parameters.head_size == parameters.v_head_size, + "MEA decode (past_key) requires head_size == v_head_size for LaunchConcatNewToPastKV."); + + using NativeCudaT = typename OrtToCudaType::type; + + // Allocate scratch buffers for concat output when present_key/value are not requested. + // The concat kernel needs a destination buffer regardless of whether the caller wants present outputs. + T* present_k_data = nullptr; + T* present_v_data = nullptr; + + SafeInt present_k_bytes = SafeInt(parameters.batch_size) * parameters.kv_num_heads * + parameters.total_sequence_length * parameters.head_size * sizeof(T); + SafeInt present_v_bytes = SafeInt(parameters.batch_size) * parameters.kv_num_heads * + parameters.total_sequence_length * parameters.v_head_size * sizeof(T); + + if (present_key != nullptr) { + present_k_data = present_key->MutableData(); + } else { + present_k_scratch = GetScratchBuffer(present_k_bytes, GetComputeStream(context)); + present_k_data = static_cast(present_k_scratch.get()); + } + if (present_value != nullptr) { + present_v_data = present_value->MutableData(); + } else { + present_v_scratch = GetScratchBuffer(present_v_bytes, GetComputeStream(context)); + present_v_data = static_cast(present_v_scratch.get()); + } + + // Step 1: Uniform past sequence lengths for the concat kernel. + // ONNX past_key has shape [B, H, past_seq, head_size] — all batches share + // the same past_seq dimension. Bool masks do NOT change where tokens are stored; + // they change which tokens are attended to (via additive bias, handled below). + auto past_seqlens_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, + parameters.batch_size, cuda_stream, + device_prop.maxThreadsPerBlock)); + + // Step 2: Transpose K/V to BSNH if input is 4D BNSH (concat kernel reads new as BSNH). + const T* k_new_bsnh = K->Data(); + const T* v_new_bsnh = V->Data(); + IAllocatorUniquePtr k_bsnh_buffer; + IAllocatorUniquePtr v_bsnh_buffer; + if (!is_bsnh) { + size_t k_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.head_size; + size_t v_bytes = sizeof(T) * parameters.batch_size * parameters.kv_sequence_length * + parameters.kv_num_heads * parameters.v_head_size; + k_bsnh_buffer = GetScratchBuffer(k_bytes, GetComputeStream(context)); + v_bsnh_buffer = GetScratchBuffer(v_bytes, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), k_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), v_bsnh_buffer.get(), + cuda_stream, device_prop.maxThreadsPerBlock)); + k_new_bsnh = static_cast(k_bsnh_buffer.get()); + v_new_bsnh = static_cast(v_bsnh_buffer.get()); + } + + // Step 3: Fused concat: past_key + new_key → present_key (and same for values). + // One kernel copies past data from [0, past_seq) and new data from BSNH layout + // into present buffer at [past_seq, past_seq + kv_seq), all in BNSH. + // No memset needed: uniform past_seq_lens means every position in the present + // buffer is written by the concat kernel. Padding positions in past_key are copied + // as-is; the attention mask (additive bias) handles correctness at the attention level. + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + parameters.batch_size, + parameters.kv_num_heads, + parameters.head_size, + parameters.kv_sequence_length, + parameters.past_sequence_length, + parameters.total_sequence_length, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), + /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_k_data), + reinterpret_cast(present_v_data), + cuda_stream, + device_prop.maxThreadsPerBlock, + /*past_only=*/false)); + + // Point MEA's K/V inputs at the concatenated buffers (BNSH). + k_data = present_k_data; + v_data = present_v_data; + kv_is_bsnh = false; + present_kv_already_populated = true; + } + // GQA head expansion: MEA requires matching num_heads for Q/K/V. // When q_num_heads != kv_num_heads, expand K/V via LaunchUngroup. const bool is_gqa = parameters.q_num_heads != parameters.kv_num_heads; @@ -622,7 +744,7 @@ Status Attention::RunMemoryEfficientAttention( reinterpret_cast(v_data), parameters.total_sequence_length, parameters.total_sequence_length, - is_bsnh, + kv_is_bsnh, cuda_stream, device_prop.maxThreadsPerBlock)); @@ -631,8 +753,8 @@ Status Attention::RunMemoryEfficientAttention( } } - // Note: MEA with past_key/value is handled by the unfused fallback. - // The cascade in ComputeInternal ensures past_key == nullptr when we reach here. + // Note: When past_key is present (decode), k_data/v_data already point to present + // buffers (BNSH) after LaunchConcatNewToPastKV above, so MEA sees the full cache. // Handle attention mask → attention_bias conversion IAllocatorUniquePtr converted_mask_buffer; @@ -642,7 +764,8 @@ Status Attention::RunMemoryEfficientAttention( if (nonpad_kv_seqlen != nullptr) { // Convert nonpad_kv_seqlen to seqlens_k for custom right padding. - // MEA expects actual token count (not count-1), so use FlashSeqlensK variant. + // MEA expects seqlens_k as actual token count, so use FlashSeqlensK variant + // (which converts int64→int32 without subtracting 1). auto seqlens_k_buffer = GetScratchBuffer(parameters.batch_size, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( nonpad_kv_seqlen->Data(), @@ -665,7 +788,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -674,6 +797,15 @@ Status Attention::RunMemoryEfficientAttention( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_causal; + // ONNX spec: is_causal means upper-left alignment in the full attention matrix. + // When past_sequence_length == 0 and S_q != S_kv (cross-attention without KV cache), + // queries start at absolute position 0, so causal mask is upper-left. + // When past_sequence_length > 0 (decode with KV cache), queries start at position + // past_seq, so causal mask is effectively lower-right on the [S_q x total_kv] sub-matrix. + // NOTE: For external KV cache (TensorScatter), nonpad_kv_seqlen provides per-batch + // actual lengths and seqlens_k handles the masking — the causal_from_top_left flag + // is only consulted when params.causal is true, so it's correct here. + p.causal_from_top_left = (parameters.past_sequence_length == 0); p.scale = parameters.scale; p.softcap = parameters.softcap; p.seqlen_k_ptr = seqlens_k_buffer.get(); @@ -700,8 +832,12 @@ Status Attention::RunMemoryEfficientAttention( onnxruntime::contrib::cuda::run_memory_efficient_attention(p); // On the MEA (CUTLASS) path (used for both MHA and GQA when nonpad_kv_seqlen is provided), - // zero out output for fully-masked batches to produce zeros (matching Flash behavior). + // zero out output for fully-masked batches to prevent NaN. // CUTLASS epilogue computes 1/s_prime where s_prime=0 for seqlens_k=0, producing NaN. + // TODO(titaiwang): ZeroOutputForFullyMaskedBatches outputs zeros for fully-masked + // batches (seqlens_k=0), which diverges from CPU/Unfused behavior (uniform mean of V). + // For cross-EP consistency, replace with LaunchMeanOfVForFullyMaskedBatches that + // computes mean(V[b,n,:,h]) for each masked batch. See issue #27516. { using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; int64_t elements_per_batch = static_cast(parameters.q_sequence_length) * @@ -716,9 +852,10 @@ Status Attention::RunMemoryEfficientAttention( } } // Standard MEA path: float attention bias, bool mask (converted to bias), or no mask. - // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value) - // which correctly handles all-false masks (uniform softmax weights) unlike the - // custom_right_padding seqlens approach which would produce NaN. + // Bool masks are converted to additive attention bias (true→0, false→mask_filter_value). + // For fully-masked batches (all-false bool mask), ConvertAttnMaskToBias uses a capped + // mask_filter_value (-1e+30) that stays finite through CUTLASS's kLog2e multiplication, + // producing correct uniform softmax → mean(V) output. else { if (attn_mask != nullptr) { ORT_RETURN_IF_ERROR(ConvertAttnMaskToBias(context, attn_mask, cuda_stream, @@ -731,7 +868,7 @@ Status Attention::RunMemoryEfficientAttention( p.sm = sm; p.is_half = std::is_same::value; p.is_bf16 = std::is_same::value; - p.is_kv_bsnh = is_bsnh; + p.is_kv_bsnh = kv_is_bsnh; p.batch_size = parameters.batch_size; p.num_heads = parameters.q_num_heads; p.sequence_length = parameters.q_sequence_length; @@ -740,6 +877,8 @@ Status Attention::RunMemoryEfficientAttention( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_causal; + // Causal alignment: same logic as above — upper-left when no past. + p.causal_from_top_left = (parameters.past_sequence_length == 0); p.scale = parameters.scale; p.softcap = parameters.softcap; p.broadcast_attn_bias_dim_0 = broadcast_bias_dim_0; @@ -773,30 +912,33 @@ Status Attention::RunMemoryEfficientAttention( cuda_stream, device_prop.maxThreadsPerBlock)); } - // Populate present_key/present_value (BNSH) if requested - if (present_key != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.head_size, - K->Data(), present_key->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_key != nullptr && !is_bsnh) { - // 4D BNSH prompt: K is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_key->MutableData(), K->Data(), - K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); - } - if (present_value != nullptr && is_bsnh) { - ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( - parameters.batch_size, parameters.kv_sequence_length, - parameters.kv_num_heads, parameters.v_head_size, - V->Data(), present_value->MutableData(), - cuda_stream, device_prop.maxThreadsPerBlock)); - } else if (present_value != nullptr && !is_bsnh) { - // 4D BNSH prompt: V is already BNSH, just D2D copy to present - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( - present_value->MutableData(), V->Data(), - V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + // Populate present_key/present_value (BNSH) if requested. + // Skip for decode path where LaunchConcatNewToPastKV already populated present buffers. + if (!present_kv_already_populated) { + if (present_key != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.head_size, + K->Data(), present_key->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_key != nullptr && !is_bsnh) { + // 4D BNSH prompt: K is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_key->MutableData(), K->Data(), + K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + if (present_value != nullptr && is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH( + parameters.batch_size, parameters.kv_sequence_length, + parameters.kv_num_heads, parameters.v_head_size, + V->Data(), present_value->MutableData(), + cuda_stream, device_prop.maxThreadsPerBlock)); + } else if (present_value != nullptr && !is_bsnh) { + // 4D BNSH prompt: V is already BNSH, just D2D copy to present + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_value->MutableData(), V->Data(), + V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } } return Status::OK(); @@ -819,250 +961,30 @@ Status Attention::RunMemoryEfficientAttention( } // ============================================================================ -// RunUnfusedAttention: Delegates to MHA's QkvToContext (unfused GEMM+softmax+GEMM) -// ============================================================================ -// -// Unfused Attention dispatch paths: -// Universal fallback via MHA's QkvToContext. -// Path 1: nonpad_kv_seqlen only -> converts to attention_bias [B, q_seq, total_seq] -// Path 2: nonpad_kv_seqlen + attn_mask -> composes both into attention_bias [B, q_seq, total_seq] -// (nonpad bias + mask bias added element-wise with cyclic broadcasting) -// Path 3: all other cases -> passes mask/bias directly -// Supports: all dtypes (fp16/bf16/fp32), all mask types (bool/float/none), all head sizes -// Not supported: softcap (rejected at fallback), output_qk modes beyond kNone/kQK -// Limitation: MHA only (q_num_heads must equal kv_num_heads) -// -template -Status Attention::RunUnfusedAttention( - OpKernelContext* context, - const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, - const Tensor* nonpad_kv_seqlen, - Tensor* Y, Tensor* present_key, Tensor* present_value, - Tensor* output_qk, - const attention_helper::AttentionParameters& parameters) const { - using CudaT = typename ToCudaType::MappedType; - // OrtToCudaType maps BFloat16 → __nv_bfloat16 (native HW type), matching kernel instantiations. - using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; - auto& device_prop = GetDeviceProp(); - auto cuda_stream = Stream(context); - auto ort_stream = GetOrtStream(context); - - // Bridge to contrib::AttentionParameters for the MHA unfused path - onnxruntime::contrib::AttentionParameters contribop_parameters; - - if (!parameters.transpose_output) { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BNSH; - contribop_parameters.is_output_bnsh = true; - } else { - contribop_parameters.qkv_format = onnxruntime::contrib::AttentionQkvFormat::Q_K_V_BSNH; - contribop_parameters.is_output_bnsh = false; - } - - contribop_parameters.batch_size = parameters.batch_size; - contribop_parameters.sequence_length = parameters.q_sequence_length; - contribop_parameters.kv_sequence_length = parameters.kv_sequence_length; - contribop_parameters.past_sequence_length = parameters.past_sequence_length; - contribop_parameters.total_sequence_length = parameters.total_sequence_length; - contribop_parameters.max_sequence_length = parameters.total_sequence_length; - contribop_parameters.input_hidden_size = 0; - contribop_parameters.hidden_size = parameters.q_num_heads * parameters.head_size; - contribop_parameters.head_size = parameters.head_size; - contribop_parameters.v_head_size = parameters.v_head_size; - contribop_parameters.v_hidden_size = parameters.kv_num_heads * parameters.v_head_size; - contribop_parameters.num_heads = parameters.q_num_heads; - contribop_parameters.rotary_dim = 0; - contribop_parameters.num_splits = 1; - contribop_parameters.beam_width = 1; - contribop_parameters.is_unidirectional = parameters.is_causal; - contribop_parameters.past_present_share_buffer = false; - contribop_parameters.is_packed_qkv = false; - contribop_parameters.do_rotary = false; - contribop_parameters.mask_type = onnxruntime::contrib::AttentionMaskType::MASK_NONE; - contribop_parameters.mask_filter_value = static_cast(std::numeric_limits::lowest()); - contribop_parameters.scale = parameters.scale; - contribop_parameters.use_tf32 = UseTF32(); - - // Determine broadcast flags for attention_bias - if (attn_mask != nullptr) { - size_t attn_mask_dims_size = attn_mask->Shape().NumDimensions(); - auto attn_mask_dims = attn_mask->Shape().GetDims(); - if (attn_mask_dims_size == 2) { - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask_dims_size == 3) { - contribop_parameters.broadcast_attn_bias_dim_0 = true; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[0] == 1; - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = attn_mask_dims[0] == 1; - contribop_parameters.broadcast_attn_bias_dim_1 = attn_mask_dims[1] == 1; - } - } else { - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = false; - } - - // Construct AttentionData - onnxruntime::contrib::cuda::AttentionData data; - data.query = reinterpret_cast(Q->Data()); - data.key = reinterpret_cast(K->Data()); - data.value = reinterpret_cast(V->Data()); - data.mask_index = nullptr; - data.mask_index_dims = gsl::span(); - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - data.output = reinterpret_cast(Y->MutableData()); - data.present_key = (present_key == nullptr) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (present_value == nullptr) ? nullptr : reinterpret_cast(present_value->MutableData()); - if (output_qk != nullptr) { - data.output_qk = reinterpret_cast(output_qk->MutableData()); - } - data.bias = nullptr; - - // Handle attention mask / nonpad_kv_seqlen → attention_bias - IAllocatorUniquePtr converted_mask_buffer; - IAllocatorUniquePtr mask_bias_buffer; // temp buffer for mask→bias when composing - if (nonpad_kv_seqlen != nullptr) { - // Convert nonpad_kv_seqlen to additive attention bias: [B, q_seq, total_seq] - int64_t bias_elements = static_cast(parameters.batch_size) * - parameters.q_sequence_length * - parameters.total_sequence_length; - converted_mask_buffer = GetScratchBuffer(bias_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToAttentionBias( - nonpad_kv_seqlen->Data(), - reinterpret_cast(converted_mask_buffer.get()), - parameters.batch_size, - parameters.q_sequence_length, - parameters.total_sequence_length, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - - // When attn_mask is also present, compose it into the nonpad bias additively. - // The nonpad bias is [B, q, t]; the mask is added with cyclic broadcasting - // (e.g. a 2D [q, t] mask repeats over the batch dimension). - // Only 2D masks and 4D masks with head_dim=1 are supported — per-head masks - // (3D [H,q,t] or 4D [B,H>1,q,t]) cannot be composed into a [B,q,t] buffer. - if (attn_mask != nullptr) { - const auto& mask_shape = attn_mask->Shape(); - int mask_dims = static_cast(mask_shape.NumDimensions()); - ORT_ENFORCE(mask_dims == 2 || (mask_dims == 4 && mask_shape[1] == 1), - "nonpad_kv_seqlen + attn_mask composition in unfused path only supports " - "2D masks [q, t] and 4D masks with head_dim=1 [B, 1, q, t]. " - "Got mask shape: ", - mask_shape); - - int64_t mask_elements = mask_shape.Size(); - const NativeCudaT* mask_bias_ptr = nullptr; - - if (attn_mask->IsDataType()) { - // Convert bool mask to additive bias in a temp buffer, then add in-place. - mask_bias_buffer = GetScratchBuffer(mask_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( - attn_mask->Data(), - reinterpret_cast(mask_bias_buffer.get()), - mask_elements, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - mask_bias_ptr = reinterpret_cast(mask_bias_buffer.get()); - } else { - // Float mask is already in additive bias format. - mask_bias_ptr = reinterpret_cast(attn_mask->Data()); - } - - // Add mask bias into nonpad bias with cyclic broadcasting. - // 2D mask [q, t]: mask_elements = q*t, repeats for each batch → correct. - // 4D mask [B, 1, q, t]: mask_elements = B*q*t = bias_elements → direct add. - ORT_RETURN_IF_ERROR(LaunchAddBiasInPlace( - reinterpret_cast(converted_mask_buffer.get()), - mask_bias_ptr, - bias_elements, - mask_elements, - cuda_stream, - device_prop.maxThreadsPerBlock)); - } - - data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); - // Composed bias is [B, q_seq, total_seq] → broadcasts over heads but not batch. - contribop_parameters.broadcast_attn_bias_dim_0 = false; - contribop_parameters.broadcast_attn_bias_dim_1 = true; - } else if (attn_mask != nullptr) { - if (attn_mask->IsDataType()) { - int64_t num_elements = attn_mask->Shape().Size(); - converted_mask_buffer = GetScratchBuffer(num_elements * sizeof(NativeCudaT), GetComputeStream(context)); - ORT_RETURN_IF_ERROR(LaunchConvertBoolMaskToAttentionBias( - attn_mask->Data(), - reinterpret_cast(converted_mask_buffer.get()), - num_elements, - contribop_parameters.mask_filter_value, - cuda_stream, - device_prop.maxThreadsPerBlock)); - data.attention_bias = reinterpret_cast(converted_mask_buffer.get()); - } else { - data.attention_bias = reinterpret_cast(attn_mask->Data()); - } - } - - data.qkv_format = contribop_parameters.qkv_format; - data.use_flash_attention = false; - data.use_memory_efficient_attention = false; - data.fused_runner = nullptr; - data.fused_cross_attention_kernel = nullptr; - data.kernel_type = onnxruntime::contrib::AttentionKernelType::AttentionKernel_Unfused; - - // Allocate workspace - const bool no_qkv_workspace = onnxruntime::contrib::cuda::NoQkvWorkspace(contribop_parameters, data); - size_t workspace_bytes = onnxruntime::contrib::cuda::GetAttentionWorkspaceSize( - sizeof(T), - contribop_parameters.batch_size, - contribop_parameters.num_heads, - contribop_parameters.head_size, - contribop_parameters.v_head_size, - contribop_parameters.sequence_length, - contribop_parameters.kv_sequence_length, - contribop_parameters.total_sequence_length, - nullptr, false, false, false, false, false, - no_qkv_workspace); - auto work_space = GetScratchBuffer(workspace_bytes, GetComputeStream(context)); - - data.has_qkv_workspace = !no_qkv_workspace; - data.workspace = reinterpret_cast(work_space.get()); - data.workspace_bytes = workspace_bytes; - - cublasHandle_t cublas = GetCublasHandle(context); - cudnnHandle_t cudnn = GetCudnnHandle(context); - - // Note: unfused attention produces valid finite output (mean-of-V via uniform softmax) - // for fully-masked batches, so ZeroOutput is not needed here. Only MEA requires - // ZeroOutput to prevent NaN from the CUTLASS epilogue's 1/s_prime division. - return onnxruntime::contrib::cuda::QkvToContext( - device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data); -} - -// ============================================================================ -// RunGqaUnfusedAttention: GQA-capable unfused path + large-head fp16/bf16 fix +// RunUnfusedAttention: Unified unfused path for both MHA and GQA // ============================================================================ // -// Routes to LaunchGqaUnfusedAttention from contrib_ops/cuda/bert/gqa_unfused_attention.h. +// Routes to LaunchUnfusedAttention from contrib_ops/cuda/bert/unfused_attention.h. // // Handles: +// - MHA as a degenerate case (group_size=1, no head expansion needed). // - GQA natively (no K/V head replication; reshape-Q trick inside kernel). // - fp16/bf16 with large head_size via FP32 QK scratch (fixes issue #28195: // unfused attention producing NaN when head_dim > 256 at scale=1.0). // - Different Q/K sequence lengths, past_key+past_value, nonpad_kv_seqlen. // - attn_mask (bool/float, 2D/3D/4D), causal, softcap. // -// Not supported here (caller rejects upstream): -// - output_qk: only MHA unfused emits QK, so this path requires output_qk==nullptr. +// Not supported (returns NOT_IMPLEMENTED upstream): +// - qk_matmul_output_mode beyond kNone/kQK (kQKMask, kQKSoftCap, kQKSoftMax). // ============================================================================ template -Status Attention::RunGqaUnfusedAttention( +Status Attention::RunUnfusedAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const { using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; auto& device_prop = GetDeviceProp(); @@ -1108,9 +1030,6 @@ Status Attention::RunGqaUnfusedAttention( ORT_ENFORCE(past_value != nullptr, "past_key requires past_value."); ORT_ENFORCE(present_key != nullptr && present_value != nullptr, "present_key/value outputs are required when past_key is provided."); - // LaunchConcatNewToPastKV uses a single head_size for both K and V caches. - ORT_RETURN_IF(H != H_v, - "RunGqaUnfusedAttention: past_key with H != H_v not supported"); auto past_seqlens_buffer = GetScratchBuffer(B, GetComputeStream(context)); ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), parameters.past_sequence_length, B, @@ -1134,17 +1053,51 @@ Status Attention::RunGqaUnfusedAttention( v_new_bsnh = static_cast(v_bnsh_buffer.get()); } - ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( - B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, - /*is_bsnh=*/false, - past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, - reinterpret_cast(past_key->Data()), - reinterpret_cast(past_value->Data()), - reinterpret_cast(k_new_bsnh), - reinterpret_cast(v_new_bsnh), - reinterpret_cast(present_key->MutableData()), - reinterpret_cast(present_value->MutableData()), - cuda_stream, max_threads, /*past_only=*/false)); + if (H == H_v) { + // K and V have the same head_size -- single concat call handles both. + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_key->MutableData()), + reinterpret_cast(present_value->MutableData()), + cuda_stream, max_threads, /*past_only=*/false)); + } else { + // H != H_v: LaunchConcatNewToPastKV uses a single head_size for both K and V + // (grid Z=0 for K, Z=1 for V with the same block dims). We must call it + // twice with different head_size values -- once for K (head_size=H) and once + // for V (head_size=H_v). Each call duplicates K data into V params (or vice + // versa) so both Z indices write to the same buffer harmlessly. + // + // Trade-off: each call does 2× GPU work (both Z slices execute). This is + // acceptable because H!=H_v decode through MEA is rare, and modifying the + // shared kernel (contrib_ops/cuda/bert/attention_kv_cache.cu) to support + // nullptr outputs or K-only/V-only modes would risk breaking GQA callers. + auto* pk = reinterpret_cast(past_key->Data()); + auto* pv = reinterpret_cast(past_value->Data()); + auto* nk = reinterpret_cast(k_new_bsnh); + auto* nv = reinterpret_cast(v_new_bsnh); + auto* out_k = reinterpret_cast(present_key->MutableData()); + auto* out_v = reinterpret_cast(present_value->MutableData()); + // Concat K with head_size=H (V params duplicate K data -- harmless) + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + pk, pk, nk, nk, out_k, out_k, + cuda_stream, max_threads, /*past_only=*/false)); + // Concat V with head_size=H_v (K params duplicate V data -- harmless) + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H_v, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + pv, pv, nv, nv, out_v, out_v, + cuda_stream, max_threads, /*past_only=*/false)); + } k_cache = reinterpret_cast(present_key->MutableData()); v_cache = reinterpret_cast(present_value->MutableData()); present_already_populated = true; @@ -1214,12 +1167,12 @@ Status Attention::RunGqaUnfusedAttention( } // -------- Allocate kernel workspace ----------------------------------------- - const size_t ws_bytes = onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + const size_t ws_bytes = onnxruntime::contrib::cuda::GetUnfusedAttentionWorkspaceSize( B, N_q, S_q, total_kv); auto ws_buffer = GetScratchBuffer(ws_bytes, GetComputeStream(context)); // -------- Call the kernel --------------------------------------------------- - onnxruntime::contrib::cuda::GqaUnfusedAttentionParams p; + onnxruntime::contrib::cuda::UnfusedAttentionParams p; p.batch_size = B; p.num_heads = N_q; p.kv_num_heads = N_kv; @@ -1232,13 +1185,19 @@ Status Attention::RunGqaUnfusedAttention( p.broadcast_attn_bias_dim_1 = bcast1; p.is_causal = parameters.is_causal; p.local_window_size = -1; // ONNX Attention (opset 23/24) does not expose sliding window. + p.past_kv_length = parameters.past_sequence_length; p.scale = parameters.scale; p.softcap = parameters.softcap; p.seqlens_k = seqlens_k_ptr; - ORT_RETURN_IF_ERROR((onnxruntime::contrib::cuda::LaunchGqaUnfusedAttention( + NativeCudaT* output_qk_data = (output_qk != nullptr) + ? reinterpret_cast(output_qk->MutableData()) + : nullptr; + + ORT_RETURN_IF_ERROR((onnxruntime::contrib::cuda::LaunchUnfusedAttention( device_prop, GetCublasHandle(context), cuda_stream, - p, q_bnsh, k_cache, v_cache, attn_bias_data, out_bnsh, ws_buffer.get()))); + p, q_bnsh, k_cache, v_cache, attn_bias_data, out_bnsh, ws_buffer.get(), + output_qk_data))); // -------- Transpose output BNSH -> BSNH if input was 3D -------------------- if (is_bsnh && out_bnsh_buffer != nullptr) { @@ -1279,10 +1238,10 @@ Status Attention::RunGqaUnfusedAttention( // ============================================================================ // ComputeInternal: Dispatch to appropriate attention kernel // ============================================================================ -// MHA path (q_num_heads == kv_num_heads): uses direct kernel dispatch cascade -// flash → memory efficient → unfused -// GQA path (q_num_heads != kv_num_heads): uses flash (handles GQA natively), MEA -// (with head expansion via LaunchUngroup, fp16/bf16 only), or GQA unfused fallback. +// Dispatch cascade: Flash → MEA (Memory Efficient) → Unified Unfused Attention. +// The unified unfused kernel handles both MHA (num_heads == kv_num_heads) and +// GQA (num_heads != kv_num_heads) via a reshape-Q trick (no K/V head replication). +// MEA uses head expansion via LaunchUngroup (fp16/bf16 only) for GQA. // ============================================================================ template Status Attention::ComputeInternal(OpKernelContext* context) const { @@ -1331,12 +1290,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Flash: strictly requires BSNH — Q is transposed BNSH→BSNH before calling mha_fwd*. // K/V passed as BNSH to mha_fwd_kvcache (it handles both layouts). // MEA: accepts both BSNH and BNSH natively via is_kv_bsnh flag. Q transposed to BSNH. - // Unfused: accepts both via QkvToContext's qkv_format (Q_K_V_BSNH or Q_K_V_BNSH). + // Unfused: accepts both BSNH and BNSH (transposes if needed). // // nonpad_kv_seqlen + attn_mask routing: // Flash: cannot handle this combo (no bias param when seqlens_k is used) → excluded. // MEA: supports both (custom_right_padding for seqlens + additive attn_bias for mask). - // Unfused: nonpad → attention_bias; mask composed additively when both present. + // Unfused: nonpad → seqlens_k; mask → attention_bias; both handled independently in softmax kernel. #if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION const bool has_output_qk = (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone); #endif @@ -1347,6 +1306,39 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // softmax_precision=0 (default) is also fine since higher precision is always // acceptable per the ONNX spec. + // Flash Attention uses lower-right (bottom-right) causal alignment with no option for + // upper-left. The ONNX spec requires upper-left alignment when there is no past context: + // query[0] attends only to key[0]. The difference only manifests when S_q != S_kv + // (cross-attention shape) with no past. Skip Flash for this case; MEA handles it correctly + // via the causal_from_top_left flag, and Unified Unfused uses past_kv_length=0. + // Defined here for visibility — only Flash needs this guard (MEA/Unfused handle upper-left natively). + const bool causal_cross_no_past = parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; + + // Reject causal + TensorScatter decode (S_q < S_kv without past_key). + // Per ONNX spec, is_causal without past_key means upper-left alignment: q[i] attends + // only to kv[0..i]. For decode with external cache (S_q=1, S_kv=cache_size), this means + // q[0] sees only kv[0] — not meaningful for autoregressive generation. + // + // Why is_causal=0 is correct for external cache decode: + // - With S_q=1, there's only one query position at the end of the sequence + // - All KV positions are in the "past" relative to this query — nothing to mask + // - nonpad_kv_seqlen already bounds attention to valid cache positions + // + // For external cache prompt (S_q == S_kv), is_causal=1 works correctly (square matrix, + // upper-left == lower-right). For chunked prefill (S_q > 1 but S_q < S_kv), use an + // explicit attn_mask instead of is_causal. + if (causal_cross_no_past && nonpad_kv_seqlen != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Causal attention with TensorScatter (nonpad_kv_seqlen) and S_q != S_kv without " + "past_key is not supported. Per ONNX spec, is_causal without past_key produces " + "upper-left alignment where q[i] only attends to kv[0..i], which for decode (S_q=1) " + "means q[0] sees only kv[0]. Use is_causal=0 for TensorScatter decode; the KV bounds " + "are already enforced by nonpad_kv_seqlen without needing a causal mask. For chunked " + "prefill with external cache, use an explicit attn_mask instead."); + } + #if USE_FLASH_ATTENTION { auto& device_prop = GetDeviceProp(); @@ -1357,16 +1349,16 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.q_num_heads, parameters.kv_num_heads) && parameters.head_size == parameters.v_head_size && !has_output_qk && - // Flash does not support attention masks (no bias parameter in mha_fwd/mha_fwd_kvcache). - // Bool attn_mask + past_key is rejected because Flash uses paged KV cache semantics - // that produce spec-divergent present_kv layout for partial masks (e.g. [T,T,T,F]). - // Unfused handles bool+past_key spec-correctly via standard ConcatPastToPresent. - // TODO(titaiwang): GQA + bool attn_mask + past_key currently has no runner (Flash - // rejected here, unfused doesn't support GQA, MEA blocked by past_key != nullptr). - // Once PR #27851 merges (MEA supports past_key), this gap will be covered. + !causal_cross_no_past && + // Flash does not support attention masks — reject when attn_mask is present. attn_mask == nullptr; if (flash_eligible) { + LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using Flash Attention" + << " (batch=" << parameters.batch_size + << ", q_seq=" << parameters.q_sequence_length + << ", total_seq=" << parameters.total_sequence_length + << ", past=" << (past_key != nullptr ? "yes" : "no") << ")"; return RunFlashAttention(context, Q, K, V, past_key, past_value, nonpad_kv_seqlen, Y, present_key, present_value, parameters); } @@ -1383,7 +1375,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.v_head_size) && !has_output_qk && - past_key == nullptr && + // MEA decode requires head_size == v_head_size for LaunchConcatNewToPastKV + // (single head_size parameter). Fall back to unfused when they differ. + (past_key == nullptr || parameters.head_size == parameters.v_head_size) && // GQA+MEA requires LaunchUngroup which only has fp16/bf16 instantiations. // FP32 GQA must fall through to the unfused path. !(is_gqa && std::is_same::value); @@ -1408,65 +1402,43 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } if (mea_eligible) { + LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using Memory Efficient Attention" + << " (batch=" << parameters.batch_size + << ", q_seq=" << parameters.q_sequence_length + << ", total_seq=" << parameters.total_sequence_length + << ", past=" << (past_key != nullptr ? "yes" : "no") + << ", mask=" << (attn_mask != nullptr ? "yes" : "no") << ")"; return RunMemoryEfficientAttention(context, Q, K, V, attn_mask, past_key, past_value, nonpad_kv_seqlen, Y, present_key, present_value, parameters); } } #endif - // TODO(titaiwang): Support additional output_qk modes beyond kNone and kQK. - // Currently only unfused handles output_qk, and only kNone/kQK modes. + // Fallback: unified unfused attention + // Routes ALL cases to LaunchUnfusedAttention, which handles: + // - GQA natively (reshape-Q trick inside kernel, no K/V head replication) + // - MHA as a degenerate case (group_size=1) + // - fp16/bf16 with large head_size via FP32 QK scratch + // - softcap, attn_mask, causal, past_key+past_value, nonpad_kv_seqlen + // - output_qk (kQK mode: scale * Q @ K^T, before softcap/mask/softmax) + // - past_key with H != H_v (separate concat calls for K and V) + + // Guard: unified kernel only supports kNone and kQK output modes. + // Other modes (kQKMask, kQKSoftCap, kQKSoftMax) expect QK values captured at + // different pipeline stages that the unified kernel does not implement. if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kQK) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "qk_matmul_output_mode other than kNone and kQK is not supported yet " - "in Attention op (CUDA)."); - } - - // GQA-capable unfused fallback (issue #28195). - // Routes through LaunchGqaUnfusedAttention when: - // - GQA (q_num_heads != kv_num_heads) — the MHA unfused runner cannot handle this. - // - fp16/bf16 with head_size > 128 — raw Q*K^T can overflow fp16 storage even - // though cuBLAS accumulates in FP32; the new kernel writes QK to an FP32 scratch. - // The overflow threshold depends on the distribution of Q/K values and scale. - // head_size=256 at scale=1/sqrt(256)=0.0625 is borderline; head_size=512 at - // scale=1.0 (Gemma 4) definitely overflows. We use 128 as a conservative - // threshold since all fused kernels already handle head_size <= 128 anyway. - // This kernel supports softcap. It does not support output_qk, so we only enter it - // when qk_matmul_output_mode_ == kNone. - const bool is_half_or_bf16 = std::is_same::value || std::is_same::value; - const bool needs_fp32_qk_scratch = is_half_or_bf16 && parameters.head_size > 128; - if ((is_gqa || needs_fp32_qk_scratch) && - qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone) { - LOGS_DEFAULT(VERBOSE) << "Attention: using GQA unfused fallback (is_gqa=" << is_gqa - << ", needs_fp32_qk_scratch=" << needs_fp32_qk_scratch - << ", head_size=" << parameters.head_size - << ", softcap=" << parameters.softcap << ")"; - return RunGqaUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, - nonpad_kv_seqlen, Y, present_key, present_value, parameters); - } - - if (is_gqa) { - // qk_matmul_output_mode != kNone reaches here; the unfused MHA runner cannot handle GQA. - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "ONNX Attention with GQA (q_num_heads != kv_num_heads) and output_qk is not " - "supported by the unfused runner."); - } - - // Fallback: unfused MHA attention (legacy runner). - // Softcap is not implemented in the legacy unfused path — it requires Flash or MEA - // (or the new GQA unfused path above, which supports softcap for fp16/bf16/fp32). - // NOTE: keep this guard even if future PRs add softcap to more fused paths — this - // legacy unfused runner does NOT apply softcap and would silently produce wrong results. - if (parameters.softcap > 0.0f) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "softcap requires flash attention or memory efficient attention, " - "but neither is eligible for this configuration. Check dtype (fp16/bf16 required for Flash), " - "head_size constraints, and past_key compatibility."); + "Only kNone and kQK output modes are supported in unified unfused attention. Mode: ", + static_cast(qk_matmul_output_mode_)); } + LOGS_DEFAULT(VERBOSE) << "Attention: using unified unfused path (is_gqa=" << is_gqa + << ", head_size=" << parameters.head_size + << ", softcap=" << parameters.softcap << ")"; return RunUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, - nonpad_kv_seqlen, Y, present_key, present_value, output_qk, parameters); + nonpad_kv_seqlen, Y, present_key, present_value, + output_qk, parameters); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h index 2acbf3b2ed829..f11503f154a30 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.h +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -31,27 +31,18 @@ class Attention final : public CudaKernel { Tensor* Y, Tensor* present_key, Tensor* present_value, const attention_helper::AttentionParameters& parameters) const; - Status RunUnfusedAttention( - OpKernelContext* context, - const Tensor* Q, const Tensor* K, const Tensor* V, - const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, - const Tensor* nonpad_kv_seqlen, - Tensor* Y, Tensor* present_key, Tensor* present_value, - Tensor* output_qk, - const attention_helper::AttentionParameters& parameters) const; - - // GQA-capable unfused fallback. Handles: + // Unified unfused fallback. Handles: // - GQA (q_num_heads != kv_num_heads) without K/V head replication. // - fp16/bf16 with large head_size (FP32 QK accumulation, fixes #28195). // - past_key+past_value, attn_mask (bool/float), nonpad_kv_seqlen. - // Does not support: output_qk - // (output_qk modes other than kNone are rejected upstream). - Status RunGqaUnfusedAttention( + // - output_qk (kQK mode: scale * Q @ K^T, before softcap/mask/softmax). + Status RunUnfusedAttention( OpKernelContext* context, const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, const Tensor* nonpad_kv_seqlen, Tensor* Y, Tensor* present_key, Tensor* present_value, + Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const; Status ConvertAttnMaskToBias( diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu index 4ab3990b2f85d..2ba7f2e1a9836 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu @@ -89,107 +89,6 @@ Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( return CUDA_CALL(cudaGetLastError()); } -// CUDA kernel to convert nonpad_kv_seqlen to an additive attention bias. -// Generates (batch_size, q_seq_len, total_seq_len) output where: -// position t < nonpad_kv_seqlen[b] → 0.0 (attend) -// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) -template -__global__ void ConvertNonpadKvSeqlenToAttentionBiasKernel( - const int64_t* __restrict__ nonpad_kv_seqlen, - T* __restrict__ attention_bias, - const int batch_size, - const int q_seq_len, - const int total_seq_len, - const float mask_filter_value) { - int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; - for (; idx < total; idx += static_cast(gridDim.x) * blockDim.x) { - int b = static_cast(idx / (static_cast(q_seq_len) * total_seq_len)); - int t = static_cast(idx % total_seq_len); - int64_t valid_len = nonpad_kv_seqlen[b]; - CUDA_KERNEL_ASSERT(valid_len >= 0 && valid_len <= static_cast(total_seq_len)); - valid_len = max(static_cast(0), min(valid_len, static_cast(total_seq_len))); - attention_bias[idx] = (t < static_cast(valid_len)) ? T(0.0f) : T(mask_filter_value); - } -} - -template -Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t* nonpad_kv_seqlen, - T* attention_bias, - int batch_size, - int q_seq_len, - int total_seq_len, - float mask_filter_value, - cudaStream_t stream, - int max_threads_per_block) { - int64_t total = static_cast(batch_size) * q_seq_len * total_seq_len; - if (total == 0) { - return Status::OK(); - } - - int threads = static_cast(std::min(static_cast(max_threads_per_block), total)); - int64_t blocks = (total + threads - 1) / threads; - constexpr int64_t kMaxGridDimX = 65535; - unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); - - ConvertNonpadKvSeqlenToAttentionBiasKernel<<>>( - nonpad_kv_seqlen, attention_bias, batch_size, q_seq_len, total_seq_len, mask_filter_value); - - return CUDA_CALL(cudaGetLastError()); -} - -template Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t*, float*, int, int, int, float, cudaStream_t, int); -template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__half>( - const int64_t*, __half*, int, int, int, float, cudaStream_t, int); -template Status LaunchConvertNonpadKvSeqlenToAttentionBias<__nv_bfloat16>( - const int64_t*, __nv_bfloat16*, int, int, int, float, cudaStream_t, int); - -// Add an addend bias into an existing bias buffer using cyclic broadcasting. -// Used to compose nonpad_kv_seqlen bias [B, q, t] with an attn_mask bias that -// is smaller or equal (e.g. 2D [q, t] cyclic-broadcasts over batch dimension). -template -__global__ void AddBiasInPlaceKernel( - T* __restrict__ bias, - const T* __restrict__ addend, - int64_t total_elements, - int64_t addend_elements) { - for (int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - idx < total_elements; - idx += static_cast(gridDim.x) * blockDim.x) { - float sum = static_cast(bias[idx]) + static_cast(addend[idx % addend_elements]); - bias[idx] = T(sum); - } -} - -template -Status LaunchAddBiasInPlace( - T* bias, - const T* addend, - int64_t total_elements, - int64_t addend_elements, - cudaStream_t stream, - int max_threads_per_block) { - if (total_elements == 0 || addend_elements == 0) { - return Status::OK(); - } - - int threads = static_cast(std::min(static_cast(max_threads_per_block), total_elements)); - int64_t blocks = (total_elements + threads - 1) / threads; - constexpr int64_t kMaxGridDimX = 65535; - unsigned int grid_size = static_cast(std::min(blocks, kMaxGridDimX)); - - AddBiasInPlaceKernel<<>>( - bias, addend, total_elements, addend_elements); - - return CUDA_CALL(cudaGetLastError()); -} - -template Status LaunchAddBiasInPlace(float*, const float*, int64_t, int64_t, cudaStream_t, int); -template Status LaunchAddBiasInPlace<__half>(__half*, const __half*, int64_t, int64_t, cudaStream_t, int); -template Status LaunchAddBiasInPlace<__nv_bfloat16>(__nv_bfloat16*, const __nv_bfloat16*, int64_t, int64_t, cudaStream_t, int); - // Zero output elements for batches where seqlens_k == 0 (fully masked). // CUTLASS MEA epilogue computes 1/s_prime where s_prime=0 → NaN for fully-masked // batches. The unfused path produces uniform softmax weights (finite mask_filter_value, diff --git a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h index 1ada783e9d64d..d2cb4dbbd25ae 100644 --- a/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h +++ b/onnxruntime/core/providers/cuda/llm/attention_mask_impl.h @@ -31,34 +31,6 @@ Status LaunchConvertNonpadKvSeqlenToFlashSeqlensK( cudaStream_t stream, int max_threads_per_block); -// Convert nonpad_kv_seqlen to an additive attention bias for the MHA unfused path. -// Generates a (batch_size, q_seq_len, total_seq_len) tensor where: -// position t < nonpad_kv_seqlen[b] → 0.0 (attend) -// position t >= nonpad_kv_seqlen[b] → mask_filter_value (mask out) -template -Status LaunchConvertNonpadKvSeqlenToAttentionBias( - const int64_t* nonpad_kv_seqlen, - T* attention_bias, - int batch_size, - int q_seq_len, - int total_seq_len, - float mask_filter_value, - cudaStream_t stream, - int max_threads_per_block); - -// Additively compose an addend bias into an existing bias buffer in-place. -// Supports cyclic broadcasting: addend of size [q, t] is repeated over batch -// to compose with a bias of size [B, q, t]. When both have the same number -// of elements (e.g. 4D mask [B, 1, q, t]), it performs a direct element-wise add. -template -Status LaunchAddBiasInPlace( - T* bias, - const T* addend, - int64_t total_elements, - int64_t addend_elements, - cudaStream_t stream, - int max_threads_per_block); - // Zero output elements for batches where seqlens_k == 0 (fully masked). // Used in the MEA path only: CUTLASS epilogue computes 1/s_prime where s_prime=0, // producing NaN for fully-masked batches. This kernel overwrites those NaN outputs diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index 0cf95141b7a6c..40c45db2dfd66 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -8,6 +8,8 @@ #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/scoped_env_vars.h" +#include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { namespace test { @@ -91,8 +93,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, y, false, 0, 3e-5f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, present_key); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, present_value); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, qk_matmul_output); } else if (tensor_type == TensorType::kFloat16) { @@ -120,8 +126,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, ToFloat16(y), false, 0, 3e-3f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, ToFloat16(present_key)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, ToFloat16(present_value)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, ToFloat16(qk_matmul_output)); } else { @@ -149,8 +159,12 @@ static void AddInputs(OpTester& test, test.AddOutput("Y", y_shape, FloatsToBFloat16s(y), false, 0, 3e-3f); if (!present_key.empty()) test.AddOutput("present_key", present_key_shape, FloatsToBFloat16s(present_key)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_key placeholder if (!present_value.empty()) test.AddOutput("present_value", present_value_shape, FloatsToBFloat16s(present_value)); + else if (!qk_matmul_output.empty()) + test.AddOptionalOutputEdge(); // present_value placeholder if (!qk_matmul_output.empty()) test.AddOutput("qk_matmul_output", qk_matmul_output_shape, FloatsToBFloat16s(qk_matmul_output)); } @@ -516,11 +530,10 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalse) { // Regression guard: all-false bool mask in decode mode (past_sequence_length > 0). // Guards against a bug where fully-masked batches produce NaN or incorrect output. -// Expected behavior: uniform softmax over past KV values produces Y = mean-of-V. -// With past_v = [10,20,30,40] and [20,40,60,80] per head, and all positions masked out, -// softmax(all -inf + constant mask_filter_value) → uniform weights → Y = {25, 50}. -// This test originally came from upstream/main and validates that both CPU and CUDA -// (unfused path) handle the all-false mask case identically. +// Expected behavior: uniform softmax over all KV values produces Y = mean-of-V. +// On CUDA, MEA decode handles this config (total_seq=4, 4-aligned). The capped +// mask_filter_value (-1e+30) in ConvertAttnMaskToBias prevents CUTLASS overflow, +// producing correct uniform softmax → mean(V). TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { int batch_size = 1; int q_num_heads = 2; @@ -609,8 +622,9 @@ TEST(AttentionTest, Attention4DAttnMaskBoolAllFalseDecodeWithPast) { ); } -// Unfused decode path with fp16 and all-true bool attention mask. -// Flash rejects attn_mask (requires attn_mask==nullptr), so CUDA routes to unfused. +// Decode path with fp16 and all-true bool attention mask. +// Flash rejects attn_mask (requires attn_mask==nullptr). MEA handles decode with +// bool mask via additive bias (past_key concat + ConvertAttnMaskToBias). // head_size=64. Uniform keys make output analytically verifiable: // all attention scores are equal, so softmax is uniform over all positions. TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { @@ -695,8 +709,8 @@ TEST(AttentionTest, Attention4DAttnMaskBoolDecodeWithPastFloat16) { // Decode with partial bool mask [T,T,T,F]: the new token is masked out. // With mask [T,T,T,F] past_seq=3 total=4: only positions 0,1,2 are attended (past only). -// Flash is ineligible (bool+past_key rejected), so CUDA uses unfused which handles this -// spec-correctly via standard ConcatPastToPresent + element-wise mask application. +// Flash is ineligible (bool+past_key rejected). MEA handles decode with bool mask +// via additive bias (past_key concat + ConvertAttnMaskToBias). // Y = uniform mean over the 3 attended past values (Q=K=constant → uniform softmax). // CPU always runs; CUDA runs when SM 5.3+ is available. TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { @@ -781,7 +795,8 @@ TEST(AttentionTest, Attention4DAttnMaskBoolPartialMaskDecodeFloat16) { // Multi-batch decode with per-batch partial bool masks. // batch_size=2: batch 0 [T,T,T,F,F,F] (3 leading trues), batch 1 [T,T,T,T,T,T] (all true). -// Flash is ineligible (bool+past_key rejected), CUDA uses unfused. +// Flash is ineligible (bool+past_key rejected). MEA rejected by CUTLASS bias alignment +// (total_seq=6, 6%4≠0), so CUDA falls through to unfused. // Unfused applies standard ConcatPastToPresent (new token at position past_sequence_length=5 // for all batches) and element-wise mask in softmax. // Runs on both CPU and CUDA to verify cross-EP consistency. @@ -988,9 +1003,8 @@ TEST(AttentionTest, Attention4DSoftCap) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - // disable_cuda: head_size(8) != v_head_size(10) blocks Flash, past_key blocks MEA, - // unfused path doesn't support softcap. Needs test with head_size == v_head_size and no past. - false, true, true // disable_cpu, disable_cuda, disable_dml + // head_size(8) != v_head_size(10) blocks Flash and MEA decode; falls to unfused which now supports softcap. + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1018,9 +1032,8 @@ TEST(AttentionTest, Attention4DSoftCapFloat16) { q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type ys, std::vector(), std::vector(), std::vector(), - // disable_cuda: head_size(8) != v_head_size(10) blocks Flash, past_key blocks MEA, - // unfused path doesn't support softcap. Needs test with head_size == v_head_size and no past. - false, true, true // disable_cpu, disable_cuda, disable_dml + // head_size(8) != v_head_size(10) blocks Flash and MEA decode; falls to unfused which now supports softcap. + false, false, true // disable_cpu, disable_cuda, disable_dml ); } @@ -1160,7 +1173,6 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { ); } -// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausal) { int batch_size = 2; // Q.shape[0] int q_num_heads = 3; // Q.shape[1] @@ -1250,7 +1262,6 @@ TEST(AttentionTest, Attention4DAttnIsCausalBasicFloat16) { ); } -// TODO(titaiwang, xadupre): Do we really need cross attention + causal mask test case? TEST(AttentionTest, Attention4DAttnIsCausalBasicDifferentSequenceLength) { int batch_size = 2; // Q.shape[0] int q_num_heads = 1; // Q.shape[1] @@ -2308,10 +2319,10 @@ TEST(AttentionTest, Attention_NonPadKVSeqLen_WithFloatAttnMask_MultiBatch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused attention with FP32 QK accumulation for large head_size (> 128). -// This exercises the RunGqaUnfusedAttention path in attention.cc which uses +// Unfused attention with FP32 QK accumulation for large head_size (> 128). +// This exercises the RunUnfusedAttention path in attention.cc which uses // an FP32 scratch buffer for QK matmul to prevent overflow in fp16. -TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_FP16) { +TEST(AttentionTest, Attention_Unfused_LargeHeadSize_FP16) { if (!HasCudaEnvironment(530)) { return; // fp16 requires SM 5.3+ } @@ -2371,9 +2382,9 @@ TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused attention with causal mask and large head_size. -// Verifies that is_causal works correctly in the unfused GQA path. -TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_Causal_FP16) { +// Unfused attention with causal mask and large head_size. +// Verifies that is_causal works correctly in the unfused path. +TEST(AttentionTest, Attention_Unfused_LargeHeadSize_Causal_FP16) { if (!HasCudaEnvironment(530)) { return; // fp16 requires SM 5.3+ } @@ -2440,8 +2451,8 @@ TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_Causal_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with past_key + attn_mask: exercises concat + bias path together. -TEST(AttentionTest, Attention_GqaUnfused_PastKey_AttnMask_FP16) { +// Unfused with past_key + attn_mask: exercises concat + bias path together. +TEST(AttentionTest, Attention_Unfused_PastKey_AttnMask_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2519,8 +2530,8 @@ TEST(AttentionTest, Attention_GqaUnfused_PastKey_AttnMask_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with softcap + attn_mask: verifies the softcap + bias interaction. -TEST(AttentionTest, Attention_GqaUnfused_Softcap_AttnMask_FP16) { +// Unfused with softcap + attn_mask: verifies the softcap + bias interaction. +TEST(AttentionTest, Attention_Unfused_Softcap_AttnMask_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2572,8 +2583,8 @@ TEST(AttentionTest, Attention_GqaUnfused_Softcap_AttnMask_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with BSNH (3D) input: previous tests all use 4D BNSH input. -TEST(AttentionTest, Attention_GqaUnfused_BSNH_FP16) { +// Unfused with BSNH (3D) input: previous tests all use 4D BNSH input. +TEST(AttentionTest, Attention_Unfused_BSNH_FP16) { if (!HasCudaEnvironment(530)) { return; } @@ -2622,8 +2633,8 @@ TEST(AttentionTest, Attention_GqaUnfused_BSNH_FP16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -// GQA unfused with fp32: exercises the float template instantiation. -TEST(AttentionTest, Attention_GqaUnfused_FP32) { +// Unfused with fp32: exercises the float template instantiation. +TEST(AttentionTest, Attention_Unfused_FP32) { if (!HasCudaEnvironment(0)) { return; } @@ -2673,5 +2684,296 @@ TEST(AttentionTest, Attention_GqaUnfused_FP32) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// Test MEA decode path by disabling Flash Attention. +// Uses the same Attention4DDefaultBasic data (head_size == v_head_size, fp16 with past_key) +// but forces MEA runner via environment variable. +TEST(AttentionTest, Attention4DMEADecodeFloat16) { + int batch_size = 2; + int q_num_heads = 3; + int q_sequence_length = 4; + int head_size = 8; + int kv_sequence_length = 6; + int kv_num_heads = 3; + int v_head_size = 8; + int past_sequence_length = 5; + + // Simple test data: one-hot Q/K/V to make expected output predictable + size_t q_size = batch_size * q_num_heads * q_sequence_length * head_size; + size_t k_size = batch_size * kv_num_heads * kv_sequence_length * head_size; + size_t v_size = batch_size * kv_num_heads * kv_sequence_length * v_head_size; + + std::vector q(q_size, 0.0f); + q[0] = 1.0f; // first element of first query is 1 + std::vector k(k_size, 0.0f); + k[0] = 1.0f; // first element of first key is 1 + std::vector v(v_size, 0.0f); + v[0] = 1.0f; // first element of first value is 1 + + // Expected output matches Attention4DDefaultBasic (same data, same math regardless of runner) + std::vector y = {0.221683f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + + // Force MEA by disabling Flash Attention + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDisableFlashAttention, "1"}}}; + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, + y, std::vector(), std::vector(), std::vector(), + true, false, true // disable_cpu, disable_cuda=false (test CUDA MEA), disable_dml + ); +} + +// Regression test for output_qk + softcap: verifies that qk_matmul_output_mode=0 (kQK) +// returns RAW Q*K logits (before softcap), not softcapped values. +// This test would FAIL if CopyQK were moved after ApplySoftcap: +// - Correct (CopyQK before softcap): output_qk = 2.0 (raw dot product) +// - Wrong (CopyQK after softcap): output_qk = tanh(2.0) ≈ 0.964 (clamped by softcap=1.0) +// Uses constant Q=1, K=1 with head_size=4 so QK = scale * dot(Q,K) = 0.5 * 4 = 2.0. +// v_head_size(6) != head_size(4) blocks Flash Attention and MEA decode, forcing unfused path. +TEST(AttentionTest, Attention4DSoftCapOutputQkRawLogits) { + int batch_size = 1; + int q_num_heads = 2; + int q_sequence_length = 2; + int head_size = 4; + int kv_sequence_length = 3; + int kv_num_heads = 2; + int v_head_size = 6; + int past_sequence_length = 0; + int total_sequence_length = past_sequence_length + kv_sequence_length; + + // Constant Q and K: all 1.0 + // QK = scale * dot(Q[i], K[j]) = (1/sqrt(4)) * 4 = 2.0 for all (i,j) pairs + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 1.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 1.0f); + + // V: position j gets value (j+1)*0.1 across all v_head_size dims + std::vector v(batch_size * kv_num_heads * kv_sequence_length * v_head_size); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < v_head_size; h++) { + v[(n * kv_sequence_length + s) * v_head_size + h] = val; + } + } + } + + // Expected output_qk: raw QK logits = 2.0 for all entries + // Shape: [batch, q_num_heads, q_seq, total_seq] = [1, 2, 2, 3] = 12 values + std::vector expected_qk(batch_size * q_num_heads * q_sequence_length * total_sequence_length, 2.0f); + + // Expected Y: softcap(2.0) ≈ 0.964 for all QK → uniform softmax → Y = mean(V) = 0.2 + // Shape: [batch, q_num_heads, q_seq, v_head_size] = [1, 2, 2, 6] = 24 values + std::vector ys(batch_size * q_num_heads * q_sequence_length * v_head_size, 0.2f); + + // present_key = K (no past), present_value = V (no past) + // These must be provided so the OpTester has all 4 outputs for correct index mapping. + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, 0, std::numeric_limits::quiet_NaN(), 1.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode=kQK, scale=default, softcap=1.0 + ys, k, v, expected_qk, + false, false, true // disable_cpu, disable_cuda, disable_dml — runs on both CPU and CUDA unfused (v_head_size != head_size blocks Flash/MEA) + ); +} + +// ============================================================================ +// Causal alignment tests: verify upper-left (no past) vs lower-right (with past) +// These are CUDA-only tests that validate the causal masking fix. +// ============================================================================ + +// Test: Causal + cross-attention (S_q=3, S_kv=5, no past) +// ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. +// V is identity-like so output directly reveals which KV positions were attended. +// Exercises MEA (fp32, head_size divisible by 4) or Unfused kernel on CUDA. +TEST(AttentionTest, Attention4DCausalCrossAttentionUpperLeft) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 3; + int head_size = 4; + int kv_sequence_length = 5; + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, + 0.4f, 0.8f, 0.1f, 0.6f, + 0.7f, 0.3f, 0.9f, 0.5f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f, + 0.3f, 0.2f, 0.1f, 0.4f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 0.5f, 0.5f, 0.5f, 0.5f}; + // Upper-left causal (scale=0.5): q0→v[0]=[1,0,0,0], q1→softmax([0.47,0.375])@v[0:2], q2→softmax([0.6,0.48,0.495])@v[0:3] + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.523732f, 0.476268f, 0.000000f, 0.000000f, + 0.358777f, 0.318207f, 0.323016f, 0.000000f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + +// Test: Causal + cross-attention (S_q=3, S_kv=5, no past) with head_size=8. +// ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. +// head_size=8 targets the MEA path (below Flash minimum of 32) but validates +// correctness regardless of which kernel handles it. head_size=8 satisfies +// MEA's head_size%8==0 requirement, so this exercises MEA's CausalFromTopLeft +// path (via causal_from_top_left=true when past_seq==0). +// V is identity-like so output directly reveals which KV positions were attended. +TEST(AttentionTest, Attention4DCausalCrossAttentionUpperLeftSmallHead) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 3; + int head_size = 8; + int kv_sequence_length = 5; + int kv_num_heads = 1; + int v_head_size = 8; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, 0.8f, 0.4f, 0.6f, 0.1f, + 0.4f, 0.8f, 0.1f, 0.6f, 0.3f, 0.7f, 0.2f, 0.9f, + 0.7f, 0.3f, 0.9f, 0.5f, 0.1f, 0.6f, 0.4f, 0.8f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, 0.1f, 0.3f, 0.5f, 0.7f, + 0.1f, 0.3f, 0.5f, 0.7f, 0.9f, 0.2f, 0.4f, 0.6f, + 0.9f, 0.1f, 0.2f, 0.3f, 0.4f, 0.8f, 0.7f, 0.5f, + 0.5f, 0.6f, 0.7f, 0.8f, 0.2f, 0.4f, 0.3f, 0.1f, + 0.3f, 0.2f, 0.1f, 0.4f, 0.6f, 0.5f, 0.8f, 0.9f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}; + // Upper-left causal (scale=1/sqrt(8)): q0→v[0], q1→softmax(scaled_scores[0:2])@v[0:2], q2→softmax(scaled_scores[0:3])@v[0:3] + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.511488f, 0.488512f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.344711f, 0.305668f, 0.349621f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} +// Lower-right alignment: q0 at absolute position 4 attends to all 5 KV positions. +// Exercises Unfused or MEA decode path on CUDA. +TEST(AttentionTest, Attention4DCausalDecodeWithPastLowerRight) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 1; + int head_size = 4; + int kv_sequence_length = 1; // new KV tokens + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 4; // total = 4 + 1 = 5 + + // clang-format off + std::vector q = {0.7f, 0.3f, 0.9f, 0.5f}; + std::vector k = {0.3f, 0.2f, 0.1f, 0.4f}; // new key + std::vector v = {0.5f, 0.5f, 0.5f, 0.5f}; // new value + std::vector past_key = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f}; + std::vector past_value = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + // Lower-right: q0 at pos 4 sees all 5 positions. scores=[0.6,0.48,0.495,0.78,0.28]*scale=0.5 already applied + std::vector y = {0.289363f, 0.265357f, 0.268203f, 0.331229f}; + // present = concat(past, new) in BNSH layout + std::vector present_key = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f, + 0.3f, 0.2f, 0.1f, 0.4f}; + std::vector present_value = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f, + 0.5f, 0.5f, 0.5f, 0.5f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), past_key, past_value, + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + +// Test: Causal + square (S_q=S_kv=4, no past) +// Upper-left == lower-right for square matrices. Verifies correctness on both paths. +// Exercises MEA or Unfused kernel depending on GPU capability. +TEST(AttentionTest, Attention4DCausalSquareNoPast) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 4; + int head_size = 4; + int kv_sequence_length = 4; + int kv_num_heads = 1; + int v_head_size = 4; + int past_sequence_length = 0; + + // clang-format off + std::vector q = {1.0f, 0.5f, 0.3f, 0.2f, + 0.4f, 0.8f, 0.1f, 0.6f, + 0.7f, 0.3f, 0.9f, 0.5f, + 0.2f, 0.6f, 0.4f, 0.8f}; + std::vector k = {0.2f, 0.4f, 0.6f, 0.8f, + 0.1f, 0.3f, 0.5f, 0.7f, + 0.9f, 0.1f, 0.2f, 0.3f, + 0.5f, 0.6f, 0.7f, 0.8f}; + std::vector v = {1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + // Both alignments give identical result for square (no past). + std::vector y = {1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.523732f, 0.476268f, 0.000000f, 0.000000f, + 0.358777f, 0.318207f, 0.323016f, 0.000000f, + 0.265821f, 0.240525f, 0.196925f, 0.296730f}; + // clang-format on + + ASSERT_EQ(q.size(), static_cast(batch_size * q_num_heads * q_sequence_length * head_size)); + ASSERT_EQ(k.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * head_size)); + ASSERT_EQ(v.size(), static_cast(batch_size * kv_num_heads * kv_sequence_length * v_head_size)); + ASSERT_EQ(y.size(), static_cast(batch_size * q_num_heads * q_sequence_length * v_head_size)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 48640fa38aca2..1ab38fb1ea0f9 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -11,7 +11,7 @@ # ------------------------------------------------------------------------- """ -Shared utilities for ONNX Attention op (opset 23) tests. +Shared utilities for ONNX Attention op (opset 23/24) tests. Contains configuration, ONNX graph builders, reference implementation, and parity check helpers used by both GQA and MHA test modules. @@ -38,9 +38,6 @@ # Reduces number of tests to run for faster pipeline checks pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" -# Number of values per parameter (compared to pipeline mode) -param_count = int(os.getenv("PARAM_COUNT", "3")) if not pipeline_mode else 2 - # When quick build is used, flash attention only supports head_size=128 quick_build = ", quick-build=" in get_build_info() @@ -71,14 +68,6 @@ torch.int8: TensorProto.INT8, } -TORCH_DTYPE_MAP = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "int8": torch.int8, - "int4": torch.uint8, -} - @dataclass class AttentionConfig: @@ -88,6 +77,7 @@ class AttentionConfig: q_num_heads: int kv_num_heads: int head_size: int + v_head_size: int = 0 # 0 means same as head_size; set explicitly for asymmetric Q/V head sizes is_causal: int = 0 past_kv_sequence_length: int = 0 softcap: float = 0.0 @@ -115,7 +105,7 @@ def create_attention_node_and_io( """ Create ONNX Attention op node and I/O definitions for testing. - ONNX Attention op (opset 23) inputs: + ONNX Attention op (opset 23/24) inputs: - 0: Q (query) - required - 1: K (key) - required - 2: V (value) - required @@ -135,6 +125,9 @@ def create_attention_node_and_io( else: # Prompt (no past KV cache) present_kv_seqlen = config.kv_sequence_length + # Effective v_head_size: defaults to head_size when not explicitly set + effective_v_head_size = config.v_head_size or config.head_size + if not config.kv_cache_type: config.kv_cache_type = { TensorProto.FLOAT16: "float16", @@ -168,7 +161,7 @@ def create_attention_node_and_io( while inputs and inputs[-1] == "": inputs.pop() - # ONNX Attention op attributes (opset 23) + # ONNX Attention op attributes (opset 23/24) node = helper.make_node( op_type="Attention", inputs=inputs, @@ -199,13 +192,14 @@ def create_attention_node_and_io( helper.make_tensor_value_info( "value", ort_type, - [config.batch_size, config.kv_num_heads, config.kv_sequence_length, config.head_size], + [config.batch_size, config.kv_num_heads, config.kv_sequence_length, effective_v_head_size], ), ] else: # 3D inputs: [batch, seq_len, hidden_size] q_hidden_size = config.q_num_heads * config.head_size kv_hidden_size = config.kv_num_heads * config.head_size + v_hidden_size = config.kv_num_heads * effective_v_head_size graph_input = [ helper.make_tensor_value_info( "query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size] @@ -214,7 +208,7 @@ def create_attention_node_and_io( "key", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] ), helper.make_tensor_value_info( - "value", ort_type, [config.batch_size, config.kv_sequence_length, kv_hidden_size] + "value", ort_type, [config.batch_size, config.kv_sequence_length, v_hidden_size] ), ] @@ -263,10 +257,11 @@ def create_attention_node_and_io( # Shape: [batch, num_heads, past_seq_len, head_size] (4D BNSH format) if is_past: past_k_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, config.head_size] + past_v_shape = [config.batch_size, config.kv_num_heads, config.past_kv_sequence_length, effective_v_head_size] graph_input.extend( [ helper.make_tensor_value_info("past_key", cache_ort_type, past_k_shape), - helper.make_tensor_value_info("past_value", cache_ort_type, past_k_shape), + helper.make_tensor_value_info("past_value", cache_ort_type, past_v_shape), ] ) @@ -276,16 +271,17 @@ def create_attention_node_and_io( # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + output_v_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, effective_v_head_size] if config.use_4d_bnsh: - output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size] + output_shape = [config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size] else: - output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * config.head_size] + output_shape = [config.batch_size, config.q_sequence_length, config.q_num_heads * effective_v_head_size] graph_output = [ helper.make_tensor_value_info("output", ort_type, output_shape), helper.make_tensor_value_info("present_key", cache_ort_type, output_k_shape), - helper.make_tensor_value_info("present_value", cache_ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", cache_ort_type, output_v_shape), ] if output_qk > 0: @@ -447,24 +443,26 @@ def attention_prompt_func( bind_tensor(io_binding, "nonpad_kv_seqlen", nonpad_kv_seqlen, device, TensorProto.INT64) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape for prompt (no past) present_seqlen = config.kv_sequence_length - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] # Determine dtype for cache tensors cache_dtype = out_dtype @@ -473,8 +471,8 @@ def attention_prompt_func( else: cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) @@ -565,28 +563,30 @@ def attention_past_func( bind_tensor(io_binding, "past_value", past_v_sliced, device, cache_ort_type) # Bind Outputs - hidden_size = config.q_num_heads * config.head_size + effective_v_head_size = config.v_head_size or config.head_size + output_hidden_size = config.q_num_heads * effective_v_head_size out_dtype = _get_out_dtype(ort_type) if config.use_4d_bnsh: out_torch = torch.zeros( - (config.batch_size, config.q_num_heads, config.q_sequence_length, config.head_size), + (config.batch_size, config.q_num_heads, config.q_sequence_length, effective_v_head_size), dtype=out_dtype, device=device, ) else: out_torch = torch.zeros( - (config.batch_size, config.q_sequence_length, hidden_size), dtype=out_dtype, device=device + (config.batch_size, config.q_sequence_length, output_hidden_size), dtype=out_dtype, device=device ) bind_output_tensor(io_binding, "output", out_torch, device, ort_type) # present KV shape (past + new) present_seqlen = total_seq_len - present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_k_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + present_v_dims = [config.batch_size, config.kv_num_heads, present_seqlen, effective_v_head_size] cache_dtype = out_dtype - present_k = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) - present_v = torch.zeros(tuple(present_dims), dtype=cache_dtype, device=device) + present_k = torch.zeros(tuple(present_k_dims), dtype=cache_dtype, device=device) + present_v = torch.zeros(tuple(present_v_dims), dtype=cache_dtype, device=device) bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) @@ -645,6 +645,9 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k) / math.sqrt(q.shape[-1]) + # Corrected ordering per onnx/onnx#7865: QK → softcap → add bias/mask → softmax + # Softcap must be applied before mask so that -inf mask values are not + # squashed to finite -softcap, which would leak probability to masked positions. if softcap > 0: scores = (scores / softcap).tanh() * softcap diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index c4e3c1b19e85e..55f07666e8c6f 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -98,16 +98,19 @@ def parity_check_gqa_prompt( ) v = torch.randn_like(k) * std - # --- Create attn_mask as boolean padding mask (simulating seqlens_k) --- + # --- Create attn_mask matching the ONNX model's expected shape --- attn_mask = None key_padding_mask = None if config.has_attn_mask: + total_seq = config.past_kv_sequence_length + config.kv_sequence_length + # 2D mask shape: [q_seq, total_seq] per ONNX spec (matches create_attention_graph_prompt) attn_mask = torch.ones( - config.batch_size, - config.kv_sequence_length, + config.q_sequence_length, + total_seq, device=device, dtype=torch.bool, ) + # key_padding_mask for PyTorch reference: [batch, kv_seq] key_padding_mask = torch.ones( config.batch_size, config.kv_sequence_length, @@ -115,6 +118,17 @@ def parity_check_gqa_prompt( dtype=torch.bool, ) + # --- Create nonpad_kv_seqlen tensor if needed (opset 24+) --- + nonpad_kv_seqlen = None + if config.has_nonpad_kv_seqlen: + # Each batch element has the full kv_sequence_length as valid (no padding) + nonpad_kv_seqlen = torch.full( + (config.batch_size,), + config.kv_sequence_length, + device=device, + dtype=torch.int64, + ) + # --- PyTorch Reference Path --- out_ref, _ = attention_ref( q=q, @@ -138,6 +152,7 @@ def parity_check_gqa_prompt( ep=ep, device=device, ort_type=ort_type, + nonpad_kv_seqlen=nonpad_kv_seqlen, ) if i == 0: first_out = out.clone() @@ -271,7 +286,7 @@ def parity_check_gqa_past( key_padding_mask = None if config.has_attn_mask: attn_mask = torch.ones( - config.batch_size, + config.q_sequence_length, total_seq_len, device=device, dtype=torch.bool, @@ -441,7 +456,7 @@ def parity_check_gqa_prompt_with_padding( ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -568,7 +583,7 @@ def parity_check_gqa_past_with_padding( ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_past_func( + out, _present_k, _present_v = attention_past_func( q=q, past_k=past_k, past_v=past_v, @@ -708,6 +723,9 @@ def gqa_prompt_padding_test_cases(): # Guard case: batch_size=4 != q_seq_len=1 (decode). This catches the original bug # where 2D mask was [batch, total_seq] instead of [q_seq, total_seq]. + # NOTE: is_causal=0 because per ONNX spec, is_causal with S_q!=S_kv and no past_key + # gives upper-left alignment (q[0] sees only kv[0]), which is not meaningful for decode. + # KV bounds are enforced by the attention mask instead. for mask_dims in mask_dims_options: config = AttentionConfig( batch_size=4, @@ -717,7 +735,7 @@ def gqa_prompt_padding_test_cases(): q_num_heads=8, kv_num_heads=2, head_size=128, - is_causal=1, + is_causal=0, has_attn_mask=True, attn_mask_dims=mask_dims, ) @@ -730,7 +748,9 @@ def gqa_past_padding_test_cases(): Generate test cases for ONNX Attention op GQA path with boolean padding masks in decoding phase. """ batches = [2] - seqs = [(1, 32)] + # past=31 + new=1 = total_seq=32, which satisfies MEA's bias alignment + # requirement (total_seq % 4 == 0) when attn_mask is present. + seqs = [(1, 31)] heads = [(8, 2)] h_sizes = [128] mask_dims_options = [2, 3, 4] @@ -863,22 +883,37 @@ def test_gqa_prompt_memory_efficient(self, name, config): # flash attention. -# TODO(titaiwang): Re-enable once PR #27851 merges (MEA supports past_key for GQA). -# Flash now rejects attn_mask (requires attn_mask==nullptr). GQA + bool mask + past_key -# has no runner until MEA supports past_key. See issue #27885. -@unittest.skip( - "Flash now rejects attn_mask. GQA + bool mask + past_key has no runner " - "until PR #27851 (MEA with past_key). See issue #27885." -) -@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") -@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) -class TestONNXAttentionPaddingMaskGQA(unittest.TestCase): +@unittest.skipIf(not has_cuda_device(80), "BF16 requires Ampere or higher GPU, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQABF16(unittest.TestCase): + """Test ONNX Attention op (opset 23) GQA path with Memory Efficient Attention using BFloat16.""" + + @parameterized.expand(gqa_past_test_cases()) + def test_gqa_past_memory_efficient_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + config.kv_cache_type = "bfloat16" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionPaddingMaskMEAGQA(unittest.TestCase): """ Test ONNX Attention op (opset 23) GQA path with boolean padding masks. - SKIPPED: Flash now requires attn_mask == nullptr. GQA + bool attn_mask + - past_key currently has no runner (Flash rejected, unfused doesn't support GQA, - MEA blocked by past_key != nullptr). Will be re-enabled when PR #27851 lands. + GQA + bool attn_mask + past_key uses the MEA decode path (Flash requires + attn_mask == nullptr). MEA handles bool masks via additive bias conversion. These tests verify that the boolean attn_mask is correctly converted to sequence lengths on GPU and that the attention computation respects the @@ -1011,7 +1046,7 @@ def parity_check_gqa_prompt_with_nonpad_kv_seqlen( # ORT path: use nonpad_kv_seqlen (int64 tensor) nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -1344,10 +1379,10 @@ def test_gqa_prompt_float_mask_4d(self): # ################################################################################################# -# Large Head Size Unfused GQA Tests (head_size=512, fixes #28195) +# Large Head Size Unfused Tests (head_size=512, fixes #28195) # # Flash Attention and Memory-Efficient Attention cap at head_size=256. For head_size=512 the -# op falls through to RunGqaUnfusedAttention which writes Q*K^T to an FP32 scratch buffer, +# op falls through to RunUnfusedAttention which writes Q*K^T to an FP32 scratch buffer, # eliminating fp16/bf16 overflow that caused NaNs (e.g. Gemma 4 global-attention layers). # # These tests deliberately disable both Flash and MEA to make the unfused fallback explicit @@ -1425,7 +1460,7 @@ class TestONNXAttentionGQALargeHeadUnfused(unittest.TestCase): Regression tests for GQA with head_size=512 via the unfused FP32-QK path (issue #28195). Flash Attention and MEA both cap at head_size=256. With both disabled the op routes - to RunGqaUnfusedAttention, which writes Q*K^T to an FP32 scratch buffer to avoid + to RunUnfusedAttention, which writes Q*K^T to an FP32 scratch buffer to avoid fp16/bf16 overflow that produced NaNs for Gemma 4 global-attention layers. Validates: no NaNs, numerical parity vs. PyTorch SDPA reference, for fp16 and bf16. @@ -1532,5 +1567,355 @@ def test_gqa_large_head_unfused_softcap_additive_mask_poison_fp16(self): self.assertLess(out.float().max().item(), 1.0) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMemoryEfficientGQAFloatMaskDecode(unittest.TestCase): + """ + Test GQA with float additive attention mask during decode using MEA. + + This exercises the MEA decode path with float additive masks — a scenario + that was a HARD ERROR before MEA+decode support (MEA was ineligible + when past_key was present, so this fell through to no kernel). + """ + + def test_gqa_past_float_mask_4d(self): + """Test GQA decode with 4D float additive mask via MEA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment for MEA) + q_num_heads=8, + kv_num_heads=2, + head_size=128, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 8, 128, device=device, dtype=torch_type) * std + + past_k = torch.randn(2, 2, 31, 128, device=device, dtype=torch_type) * std + past_v = torch.randn_like(past_k) * std + + new_k = torch.randn(2, 1, 2, 128, device=device, dtype=torch_type) * std + new_v = torch.randn_like(new_k) * std + + total_seq_len = 32 # past(31) + new(1), satisfies MEA bias alignment (32 % 4 == 0) + + # Create additive mask with padding pattern: batch 0 has 28 valid past, batch 1 full + past_seqlens = torch.tensor([28, 31], dtype=torch.int32, device=device) + total_seqlens = past_seqlens + config.kv_sequence_length + + attn_mask = create_additive_mask_from_seqlens( + seqlens=total_seqlens, + total_seq_len=total_seq_len, + mask_dims=4, + q_seq_len=1, + num_heads=8, + device=device, + dtype=torch_type, + ) + + # Zero padded past positions for batch 0 + past_k[0, :, 28:, :] = 0 + past_v[0, :, 28:, :] = 0 + + # Reference: concat past + new, then compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + # Expand 4D mask to reference attn_bias [batch, heads, q_seq, total_seq] + attn_bias_ref = attn_mask + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_ort = out_ort.reshape(2, 1, 8, 128) + + # --- Verify present_k/v match concatenated reference --- + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # --- Verify output --- + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMEAGQASoftcap(unittest.TestCase): + """ + Test softcap support for GQA via the Memory Efficient Attention path. + + Disables Flash Attention to force MEA. Verifies softcap with and without + attention mask for GQA (kv_num_heads != q_num_heads). + + MEA alignment requirement: total_seq % 4 == 0 when attn_mask is present. + """ + + def test_mea_gqa_softcap_with_mask_prompt_fp16(self): + """MEA GQA softcap + causal mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, # total_seq=8, divisible by 4 + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_no_mask_prompt_fp16(self): + """MEA GQA softcap without explicit mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_with_mask_decode_fp16(self): + """MEA GQA softcap + causal mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32, divisible by 4 + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + ) + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_gqa_softcap_mask_ordering_no_leakage_prompt_fp16(self): + """Guard test: verify MEA GQA softcap + mask ordering prevents attention leakage. + + Same poison-value technique as the MHA ordering test, but with GQA + (kv_num_heads != q_num_heads) forced to MEA path. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 # divisible by 4 for MEA alignment + q_num_heads = 4 + kv_num_heads = 2 + head_size = 64 + softcap_val = 2.0 + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(batch_size, q_seq, q_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, kv_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, kv_num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + # 4D mask: [batch, q_num_heads, q_seq, kv_seq] + attn_mask = torch.zeros(batch_size, q_num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"MEA GQA attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means MEA applies softcap AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, q_num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.02, atol=0.02) + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping Flash GQA softcap tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "0"}) +class TestONNXAttentionFlashGQASoftcap(unittest.TestCase): + """Test softcap support for GQA via the Flash Attention path. + + Flash does NOT accept explicit attn_mask for GQA — uses nonpad_kv_seqlen + (padding mask) instead. Tests verify softcap works correctly through Flash + with and without padding mask. + + Requires SM80+ (Flash Attention hardware requirement). + """ + + def test_flash_gqa_softcap_with_padding_mask_prompt_fp16(self): + """Flash GQA softcap + padding mask (nonpad_kv_seqlen), prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_nonpad_kv_seqlen=True, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_gqa_softcap_no_mask_prompt_fp16(self): + """Flash GQA softcap without any mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_gqa_softcap_no_mask_decode_fp16(self): + """Flash GQA softcap, decode phase (past KV), fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, + q_num_heads=8, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + ) + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index abe180ee35787..a488e11e39d20 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -99,9 +99,14 @@ def parity_check_mha_prompt( attn_mask = None attn_bias_ref = None if config.has_attn_mask: - # Create additive mask (0 for valid, -inf for masked) - # For prompt without padding, create a causal-style or zero mask - seqlens = torch.full((config.batch_size,), config.kv_sequence_length, dtype=torch.int32, device=device) + # When softcap is present, use partial seqlens so the mask has both valid and masked + # positions — otherwise the all-zero mask can't detect softcap→bias ordering bugs. + # For non-softcap tests, use full seqlens (existing behavior). + if config.softcap > 0: + mask_valid_len = max(1, config.kv_sequence_length * 3 // 4) + else: + mask_valid_len = config.kv_sequence_length + seqlens = torch.full((config.batch_size,), mask_valid_len, dtype=torch.int32, device=device) attn_mask = create_additive_mask_from_seqlens( seqlens=seqlens, total_seq_len=config.kv_sequence_length, @@ -127,6 +132,7 @@ def parity_check_mha_prompt( v=v, attn_bias=attn_bias_ref, causal=causal, + softcap=config.softcap, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -146,9 +152,15 @@ def parity_check_mha_prompt( if i == 0: first_out = out.clone() else: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + # FP16/BF16 GPU kernels may produce bit-level non-determinism across runs. + det_atol = 0 if torch_type == torch.float32 else 1e-3 + det_rtol = 0 if torch_type == torch.float32 else 1e-3 + torch.testing.assert_close( + out, first_out, rtol=det_rtol, atol=det_atol, msg="Output mismatch between two runs" + ) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + effective_v_head_size = config.v_head_size or config.head_size + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, effective_v_head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() # --- Comparison --- @@ -224,6 +236,65 @@ def parity_check_mha_past( ) new_v = torch.randn_like(new_k) * std + # Create attention mask if config requires one + total_seq_len = config.past_kv_sequence_length + config.kv_sequence_length + attn_mask = None + attn_bias_ref = None + if config.has_attn_mask: + # When softcap is present, use partial seqlens so the mask has both valid and masked + # positions — otherwise the all-zero mask can't detect softcap→bias ordering bugs. + # For non-softcap tests, use full seqlens (existing behavior). + if config.softcap > 0: + mask_valid_len = max(1, total_seq_len * 3 // 4) + else: + mask_valid_len = total_seq_len + seqlens = torch.full((config.batch_size,), mask_valid_len, dtype=torch.int32, device=device) + + if config.attn_mask_type == "bool": + # Create boolean mask for ORT (True=attend, False=mask) + arange = torch.arange(total_seq_len, device=device) + if config.attn_mask_dims == 2: + mask_1d = arange < seqlens[0] + attn_mask = mask_1d.unsqueeze(0).expand(config.q_sequence_length, -1).contiguous() + else: + attn_mask = create_boolean_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=config.attn_mask_dims, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + ) + # Create additive bias for PyTorch reference path + attn_bias_ref = create_additive_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=4, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + dtype=torch_type, + ) + else: + # Additive mask: same tensor for both ORT and reference + attn_mask = create_additive_mask_from_seqlens( + seqlens=seqlens, + total_seq_len=total_seq_len, + mask_dims=config.attn_mask_dims, + q_seq_len=config.q_sequence_length, + num_heads=config.q_num_heads, + device=device, + dtype=torch_type, + ) + if config.attn_mask_dims == 2: + attn_bias_ref = ( + attn_mask.unsqueeze(0).unsqueeze(0).expand(config.batch_size, config.q_num_heads, -1, -1) + ) + elif config.attn_mask_dims == 3: + attn_bias_ref = attn_mask.unsqueeze(0).expand(config.batch_size, -1, -1, -1) + else: + attn_bias_ref = attn_mask + # --- PyTorch Reference Path --- new_k_bnsh = new_k.transpose(1, 2) new_v_bnsh = new_v.transpose(1, 2) @@ -236,7 +307,9 @@ def parity_check_mha_past( q=q, k=full_k_bsnh, v=full_v_bsnh, + attn_bias=attn_bias_ref, causal=causal, + softcap=config.softcap, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -250,7 +323,7 @@ def parity_check_mha_past( new_k=new_k, new_v=new_v, config=config, - attn_mask=None, + attn_mask=attn_mask, ep=ep, device=device, ort_type=ort_type, @@ -258,9 +331,15 @@ def parity_check_mha_past( if i == 0: first_out = out.clone() else: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + # FP16/BF16 GPU kernels may produce bit-level non-determinism across runs. + det_atol = 0 if torch_type == torch.float32 else 1e-3 + det_rtol = 0 if torch_type == torch.float32 else 1e-3 + torch.testing.assert_close( + out, first_out, rtol=det_rtol, atol=det_atol, msg="Output mismatch between two runs" + ) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size)) + effective_v_head_size = config.v_head_size or config.head_size + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.q_num_heads, effective_v_head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() # --- Comparison --- @@ -367,10 +446,11 @@ def parity_check_mha_prompt_with_attn_bias( v=v, attn_bias=attn_bias_ref, causal=config.is_causal == 1, + softcap=config.softcap, ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -698,10 +778,11 @@ def parity_check_mha_prompt_with_bool_mask( v=v, key_padding_mask=key_padding_mask, causal=config.is_causal == 1, + softcap=config.softcap, ) # --- ONNX Runtime Path --- - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -866,6 +947,110 @@ def test_mha_past_fp32(self, name, config): ) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEA(unittest.TestCase): + """Test ONNX Attention op MHA path — decoding with KV cache via Memory Efficient Attention. + + Explicitly forces MEA by disabling Flash Attention. This verifies that the + MEA decode path works correctly for MHA (kv_num_heads == q_num_heads). + """ + + @parameterized.expand(mha_past_test_cases()) + def test_mha_past_mea(self, name, config): + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEAFP32(unittest.TestCase): + """Test MHA decode via MEA with fp32 dtype.""" + + @parameterized.expand(mha_past_test_cases()) + def test_mha_past_mea_fp32(self, name, config): + config.kv_cache_type = "float32" + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEABoolMask(unittest.TestCase): + """Test MHA decode via MEA with boolean attention mask (converted to additive bias).""" + + def test_mha_past_bool_mask_mea(self): + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment) + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=2, + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAPastMEAFloatMask(unittest.TestCase): + """Test MHA decode via MEA with float additive attention mask.""" + + def test_mha_past_float_mask_4d_mea(self): + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 (CUTLASS bias alignment) + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + @unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping MHA tests.") class TestONNXAttentionMHAAttnBias(unittest.TestCase): """ @@ -998,7 +1183,7 @@ def parity_check_mha_prompt_with_nonpad_kv_seqlen( # ORT path: use nonpad_kv_seqlen (int64 tensor) nonpad_kv_seqlen_tensor = nonpad_seqlens.to(torch.int64).to(device) - out, present_k, present_v = attention_prompt_func( + out, _present_k, _present_v = attention_prompt_func( q=q, k=k, v=v, @@ -1249,116 +1434,693 @@ def test_mha_unfused_fp16(self, name, config): atol=atol["fp16"], ) - -# ################################################################################################# -# Broadcast Mask (1,1,q,kv) Tests -# ################################################################################################# + def test_mha_unfused_decode_fp32(self): + """Test unfused decode with fp32 (both Flash and MEA disabled).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + kv_cache_type="float32", + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) -@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") -class TestONNXAttentionMHABroadcastMask(unittest.TestCase): +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping unfused softcap tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION": "1"}) +class TestONNXAttentionMHAUnfusedSoftcap(unittest.TestCase): """ - Test attention with a (1,1,q_seq,kv_seq) mask that broadcasts across batch and heads. + Test softcap support in the unfused attention kernel. - This is a 4D mask with dim_0=1 (batch) and dim_1=1 (heads), verifying that - the broadcast_attn_bias_dim_0 and broadcast_attn_bias_dim_1 flags work correctly. + Disables Flash and MEA to force the unfused path. Verifies that + softcap * tanh(score / softcap) is correctly applied to attention logits + before softmax, matching the reference implementation. """ - def test_mha_broadcast_mask_additive(self): - """Test broadcast additive mask (1,1,q,kv) with MHA on CUDA.""" + def test_unfused_softcap_prompt_fp16(self): + """Test softcap on unfused path during prompt (fp16).""" config = AttentionConfig( batch_size=2, - q_sequence_length=16, - kv_sequence_length=16, - q_num_heads=8, - kv_num_heads=8, - head_size=128, - is_causal=0, - has_attn_mask=True, - attn_mask_dims=4, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, attn_mask_type="additive", - broadcast_mask_batch=True, - broadcast_mask_heads=True, ) - - torch.manual_seed(0) - device = "cuda" - torch_type = torch.float16 - - q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 - k = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 - v = torch.randn_like(k) * 0.2 - - # Create (1,1,q,kv) additive mask: lower-triangular causal pattern - mask_filter = float(torch.finfo(torch_type).min) - mask_2d = torch.zeros(16, 16, device=device, dtype=torch_type) - for i in range(16): - mask_2d[i, i + 1 :] = mask_filter - attn_mask = mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, 16, 16) - - # Reference: expand to full (B, H, Q, K) - attn_bias_ref = attn_mask.expand(2, 8, -1, -1).contiguous() - out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) - - # ORT path - out_ort, _, _ = attention_prompt_func( - q=q, - k=k, - v=v, + parity_check_mha_prompt( config=config, - attn_mask=attn_mask, ep="CUDAExecutionProvider", - device=device, + device="cuda", + torch_type=torch.float16, ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], ) - out_ort = out_ort.reshape(2, 16, 8, 128) - - out_np = out_ort.float().detach().cpu().numpy() - out_ref_np = out_ref.float().detach().cpu().numpy() - numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) - - -# ################################################################################################# -# 2D Mask Broadcast Regression Test -# ################################################################################################# + def test_unfused_softcap_decode_fp16(self): + """Test softcap on unfused path during decode (fp16).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) -@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping 2D mask broadcast tests.") -class TestONNXAttentionMHA2DMaskBroadcast(unittest.TestCase): - """ - Regression test for 2D mask [q_seq, total_seq] broadcast correctness. - - Per ONNX spec, a 2D attention mask has shape [q_seq, total_seq] and broadcasts - over batch and heads. This test uses batch_size > q_seq with a non-uniform - mask (different values per row) to verify correct broadcast behavior. - - The old bug indexed the 2D mask by batch index instead of query position, - causing OOB reads when batch_size > q_seq. - """ + def test_unfused_softcap_prompt_fp32(self): + """Test softcap on unfused path during prompt (fp32).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + kv_cache_type="float32", + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) - def test_2d_additive_mask_batch_gt_qseq(self): - """2D additive mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + def test_unfused_softcap_with_mask_prompt_fp16(self): + """Test softcap + float mask on unfused path — verifies spec-correct ordering (softcap→mask→softmax).""" config = AttentionConfig( - batch_size=4, - q_sequence_length=2, + batch_size=2, + q_sequence_length=8, kv_sequence_length=8, q_num_heads=4, kv_num_heads=4, head_size=64, is_causal=0, + softcap=2.0, has_attn_mask=True, - attn_mask_dims=2, + attn_mask_dims=4, attn_mask_type="additive", ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) - torch.manual_seed(42) - device = "cuda" - torch_type = torch.float16 - mask_filter_value = torch.finfo(torch_type).min - - q = ( - torch.randn( - config.batch_size, + def test_unfused_softcap_with_mask_decode_fp16(self): + """Test softcap + float mask on unfused decode — verifies spec-correct ordering.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # 31+1=32, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- Partial masking: fp32 variants --- + + def test_unfused_softcap_with_mask_prompt_fp32(self): + """Test softcap + additive mask on unfused prompt (fp32). + + The helper auto-creates a partial mask (3/4 valid positions) when softcap > 0, + ensuring the mask has both 0.0 and -inf values to exercise the softcap→bias ordering. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + kv_cache_type="float32", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=False, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + def test_unfused_softcap_with_mask_decode_fp32(self): + """Test softcap + additive mask on unfused decode (fp32). + + Decode with past KV cache: total_seq=32, ~24 valid positions, 8 masked. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + kv_cache_type="float32", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float32, + ort_type=TensorProto.FLOAT, + causal=True, + rtol=rtol["fp32"], + atol=atol["fp32"], + ) + + # --- Partial masking: different mask dimensionalities --- + + def test_unfused_softcap_with_mask_2d_prompt_fp16(self): + """Test softcap + 2D additive mask on unfused prompt. + + A 2D mask [q_seq, kv_seq] broadcasts across batch and heads. + This tests the 2D mask indexing path in the unfused kernel. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_unfused_softcap_with_mask_3d_prompt_fp16(self): + """Test softcap + 3D additive mask on unfused prompt. + + A 3D mask [heads, q_seq, kv_seq] broadcasts across batch dimension. + This tests the 3D mask broadcast path which has its own handling branch. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=3, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- Partial masking: larger sequence (different absolute mask boundary) --- + + def test_unfused_softcap_with_mask_longer_seq_prompt_fp16(self): + """Test softcap + mask with a longer sequence (kv_seq=16). + + With kv_seq=16, mask_valid_len=12 (3/4). This exercises a different absolute + mask boundary compared to the kv_seq=8 tests (valid_len=6) and provides + a wider range of softcapped logit values interacting with the mask. + """ + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + softcap=2.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=False, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_softcap_mask_ordering_no_leakage_prompt(self): + """Guard test: verify softcap + mask ordering prevents attention leakage. + + This test PROVES the ordering matters and would FAIL if someone reverts + to the wrong ordering (mask before softcap). + + Setup: Create a mask where some KV positions are -inf (masked). Place + a distinctive 'poison' value (1000.0) in V at masked positions. With + correct ordering (softcap → mask → softmax), masked positions get + -inf after bias addition → zero attention → output uncontaminated. + With wrong ordering (mask → softcap → softmax), softcap(-inf) = -softcap + (finite) → nonzero attention → output contaminated by poison values. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + # Only the first 4 KV positions are valid; last 4 are masked (-inf) + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float32 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + attn_mask = torch.zeros(batch_size, num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + # Run ONNX Runtime + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + + # If ordering is wrong, poison values leak into output producing extreme values. + # Valid output range with std=0.2 inputs and softcap=2.0 is roughly [-10, 10]. + # Any element > 50 indicates attention leakage to the poison=1000 positions. + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"Attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means softcap is applied AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.01, atol=0.01) + + def test_softcap_mask_ordering_no_leakage_decode(self): + """Guard test for decode (past KV) path: softcap + mask ordering prevents leakage. + + Same poison-value technique as the prompt test, but exercises the decode + code path with past KV cache. Masked positions in the past cache should + receive zero attention with correct ordering. + """ + batch_size = 1 + q_seq = 1 # decode: single token + kv_seq = 1 + past_kv_seq = 15 + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + total_kv_seq = past_kv_seq + kv_seq # 16 total + valid_kv_len = 8 # Only first 8 of 16 positions are valid + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + past_kv_sequence_length=past_kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float32 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Past KV with poison in masked positions + past_k = torch.randn(batch_size, num_heads, past_kv_seq, head_size, dtype=torch_type, device=device) * 0.2 + past_v = torch.randn(batch_size, num_heads, past_kv_seq, head_size, dtype=torch_type, device=device) * 0.2 + poison_value = 1000.0 + past_v[:, :, valid_kv_len:, :] = poison_value + + # Mask: 0.0 for first valid_kv_len positions, -inf for rest + attn_mask = torch.zeros(batch_size, num_heads, q_seq, total_kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + # Run ONNX Runtime via attention_past_func + out, _, _ = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=k, + new_v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"Attention leakage detected in decode path: max |output| = {max_abs:.1f}. " + f"Softcap must be applied BEFORE mask (per onnx/onnx#7865).", + ) + + +# ################################################################################################# +# Asymmetric Head Size Regression Test (MEA → unfused fallback) +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping asymmetric head size tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMHAAsymmetricHeadSize(unittest.TestCase): + """ + Regression test: MEA gracefully falls back to unfused when head_size != v_head_size + with past_key present (decode phase). + + Without the eligibility guard in ComputeInternal, this configuration would select + MEA which then crashes with ORT_ENFORCE because LaunchConcatNewToPastKV requires + head_size == v_head_size. The guard skips MEA and falls back to unfused attention. + + Uses MHA path (kv_num_heads == q_num_heads) because the GQA path has no unfused + fallback (returns NOT_IMPLEMENTED). + """ + + def test_mha_past_asymmetric_v_head_size(self): + """Verify decode with head_size=128, v_head_size=96 doesn't crash (falls to unfused).""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=32, + q_num_heads=4, + kv_num_heads=4, + head_size=128, + v_head_size=96, + is_causal=1, + attn_mask_type="additive", + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + # std=0.2 keeps values in a numerically stable range for fp16 attention + std = 0.2 + + q = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + + # Past KV in BNSH: K uses head_size=128, V uses v_head_size=96 + past_k = torch.randn(2, 4, 32, 128, device=device, dtype=torch_type) * std + past_v = torch.randn(2, 4, 32, 96, device=device, dtype=torch_type) * std + + new_k = torch.randn(2, 1, 4, 128, device=device, dtype=torch_type) * std + new_v = torch.randn(2, 1, 4, 96, device=device, dtype=torch_type) * std + + # PyTorch reference: concat past + new, compute attention + new_k_bnsh = new_k.transpose(1, 2) + new_v_bnsh = new_v.transpose(1, 2) + full_k_bnsh = torch.cat([past_k, new_k_bnsh], dim=2) + full_v_bnsh = torch.cat([past_v, new_v_bnsh], dim=2) + full_k_bsnh = full_k_bnsh.transpose(1, 2) + full_v_bsnh = full_v_bnsh.transpose(1, 2) + + out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, causal=True) + + # ORT path — should fall back to unfused (not crash in MEA) + out_ort, present_k, present_v = attention_past_func( + q=q, + past_k=past_k, + past_v=past_v, + new_k=new_k, + new_v=new_v, + config=config, + attn_mask=None, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + # Reshape output: [B, q_seq, q_num_heads * v_head_size] → [B, q_seq, q_num_heads, v_head_size] + out_ort = out_ort.reshape(2, 1, 4, 96) + + # Verify present_k and present_v + full_k_ref_np = full_k_bnsh.float().detach().cpu().numpy() + full_v_ref_np = full_v_bnsh.float().detach().cpu().numpy() + present_k_np = present_k.float().detach().cpu().numpy() + present_v_np = present_v.float().detach().cpu().numpy() + + print_diff_statistics(torch.tensor(present_k_np - full_k_ref_np), "present_k") + numpy.testing.assert_allclose(present_k_np, full_k_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + print_diff_statistics(torch.tensor(present_v_np - full_v_ref_np), "present_v") + numpy.testing.assert_allclose(present_v_np, full_v_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + # Verify output + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +# ################################################################################################# +# Broadcast Mask (1,1,q,kv) Tests +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping broadcast mask tests.") +class TestONNXAttentionMHABroadcastMask(unittest.TestCase): + """ + Test attention with a (1,1,q_seq,kv_seq) mask that broadcasts across batch and heads. + + This is a 4D mask with dim_0=1 (batch) and dim_1=1 (heads), verifying that + the broadcast_attn_bias_dim_0 and broadcast_attn_bias_dim_1 flags work correctly. + """ + + def test_mha_broadcast_mask_additive(self): + """Test broadcast additive mask (1,1,q,kv) with MHA on CUDA.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=8, + kv_num_heads=8, + head_size=128, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + broadcast_mask_batch=True, + broadcast_mask_heads=True, + ) + + torch.manual_seed(0) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + k = torch.randn(2, 16, 8, 128, device=device, dtype=torch_type) * 0.2 + v = torch.randn_like(k) * 0.2 + + # Create (1,1,q,kv) additive mask: lower-triangular causal pattern + mask_filter = float(torch.finfo(torch_type).min) + mask_2d = torch.zeros(16, 16, device=device, dtype=torch_type) + for i in range(16): + mask_2d[i, i + 1 :] = mask_filter + attn_mask = mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, 16, 16) + + # Reference: expand to full (B, H, Q, K) + attn_bias_ref = attn_mask.expand(2, 8, -1, -1).contiguous() + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_bias_ref, causal=False) + + # ORT path + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + out_ort = out_ort.reshape(2, 16, 8, 128) + + out_np = out_ort.float().detach().cpu().numpy() + out_ref_np = out_ref.float().detach().cpu().numpy() + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) + + +# ################################################################################################# +# 2D Mask Broadcast Regression Test +# ################################################################################################# + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping 2D mask broadcast tests.") +class TestONNXAttentionMHA2DMaskBroadcast(unittest.TestCase): + """ + Regression test for 2D mask [q_seq, total_seq] broadcast correctness. + + Per ONNX spec, a 2D attention mask has shape [q_seq, total_seq] and broadcasts + over batch and heads. This test uses batch_size > q_seq with a non-uniform + mask (different values per row) to verify correct broadcast behavior. + + The old bug indexed the 2D mask by batch index instead of query position, + causing OOB reads when batch_size > q_seq. + """ + + def test_2d_additive_mask_batch_gt_qseq(self): + """2D additive mask [q_seq=2, total_seq=8] with batch=4 — would OOB on old code.""" + config = AttentionConfig( + batch_size=4, + q_sequence_length=2, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=0, + has_attn_mask=True, + attn_mask_dims=2, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + mask_filter_value = torch.finfo(torch_type).min + + q = ( + torch.randn( + config.batch_size, config.q_sequence_length, config.q_num_heads, config.head_size, @@ -1490,6 +2252,285 @@ def test_2d_bool_mask_batch_gt_qseq(self): numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) +@unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) +class TestONNXAttentionMEASoftcap(unittest.TestCase): + """ + Test softcap support in the Memory Efficient Attention (MEA) kernel. + + Disables Flash Attention to force the MEA path. Verifies that + softcap * tanh(score / softcap) is correctly applied to attention logits + in MEA, matching the reference implementation. + + MEA alignment requirement: total_seq % 4 == 0 when attn_mask is present. + """ + + # --- P0: MEA softcap+mask (MHA) --- + + def test_mea_softcap_with_mask_prompt_fp16(self): + """MEA softcap + additive mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, # total_seq=8, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_softcap_with_mask_decode_fp16(self): + """MEA softcap + additive mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq = 31+1 = 32, divisible by 4 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- P0: MEA softcap-only (no mask) --- + + def test_mea_softcap_no_mask_prompt_fp16(self): + """MEA softcap without explicit mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_mea_softcap_no_mask_decode_fp16(self): + """MEA softcap without explicit mask, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # --- P1: MEA softcap ordering poison test --- + + def test_mea_softcap_mask_ordering_no_leakage_prompt(self): + """Guard test: verify MEA softcap + mask ordering prevents attention leakage. + + Same poison-value technique as the unfused ordering test, but forces the + MEA path. Proves MEA correctly applies softcap before mask addition. + """ + batch_size = 1 + q_seq = 4 + kv_seq = 8 # divisible by 4 for MEA alignment + num_heads = 2 + head_size = 64 + softcap_val = 2.0 + valid_kv_len = 4 + + config = AttentionConfig( + batch_size=batch_size, + q_sequence_length=q_seq, + kv_sequence_length=kv_seq, + q_num_heads=num_heads, + kv_num_heads=num_heads, + head_size=head_size, + is_causal=0, + softcap=softcap_val, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + torch.manual_seed(42) + device = "cuda" + torch_type = torch.float16 + + q = torch.randn(batch_size, q_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + k = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + v = torch.randn(batch_size, kv_seq, num_heads, head_size, dtype=torch_type, device=device) * 0.2 + + # Place poison values in V at masked positions + poison_value = 1000.0 + v[:, valid_kv_len:, :, :] = poison_value + + # Create additive mask: 0.0 for valid, -inf for masked + attn_mask = torch.zeros(batch_size, num_heads, q_seq, kv_seq, dtype=torch_type, device=device) + attn_mask[:, :, :, valid_kv_len:] = float("-inf") + + out, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out_np = out.to(torch.float32).detach().cpu().numpy().flatten() + max_abs = numpy.max(numpy.abs(out_np)) + self.assertLess( + max_abs, + 50.0, + f"MEA attention leakage detected: max |output| = {max_abs:.1f}. " + f"This likely means MEA applies softcap AFTER mask (wrong ordering). " + f"Correct ordering: QK → softcap → mask → softmax (per onnx/onnx#7865).", + ) + + # Also verify against reference + out_ref, _ = attention_ref(q=q, k=k, v=v, attn_bias=attn_mask, softcap=softcap_val) + out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() + out_reshaped = torch.reshape(out, (batch_size, q_seq, num_heads, head_size)) + out_reshaped_np = out_reshaped.to(torch.float32).detach().cpu().numpy() + numpy.testing.assert_allclose(out_reshaped_np, out_ref_np, rtol=0.02, atol=0.02) + + +@unittest.skipIf(not has_cuda_device(80), "Flash Attention requires Ampere or higher GPU, skipping tests.") +class TestONNXAttentionFlashSoftcap(unittest.TestCase): + """ + Test softcap support via Flash Attention path. + + Does NOT disable Flash or MEA — lets the dispatch cascade choose naturally. + On Ampere+ with fp16 and head_size<=256, this should route to Flash Attention. + """ + + def test_flash_softcap_prompt_fp16(self): + """Flash Attention softcap, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=16, + kv_sequence_length=16, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_softcap_decode_fp16(self): + """Flash Attention softcap, decode phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=31, # total_seq=32 + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + attn_mask_type="additive", + ) + parity_check_mha_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + def test_flash_softcap_with_mask_prompt_fp16(self): + """Flash Attention softcap + mask, prompt phase, fp16.""" + config = AttentionConfig( + batch_size=2, + q_sequence_length=8, + kv_sequence_length=8, + q_num_heads=4, + kv_num_heads=4, + head_size=64, + is_causal=1, + softcap=50.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + parity_check_mha_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + + # NOTE: GQA fully-masked batch fix (ZeroOutputForFullyMaskedBatches) is validated by # C++ test Attention_NonPadKVSeqLen_AllMasked_FP16_GQA. Python graph-level test omitted # because the fix is a CUDA kernel in the MEA path — a CPU-only test cannot validate it, diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py index a6a115bb12213..6b3f6d1c3ff34 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py @@ -460,16 +460,22 @@ def cpu_test_cases(): def cuda_fp16_test_cases(): - """CUDA fp16: both GQA and MHA cases. Flash attention handles external KV cache directly.""" + """CUDA fp16: both GQA and MHA cases. Flash attention handles external KV cache directly. + TensorScatter manages KV cache externally with nonpad_kv_seqlen bounding the active range. + Per ONNX spec, is_causal with S_q!=S_kv and no past_key gives upper-left alignment + (q[0] sees only kv[0]), which is not meaningful for decode. KV bounds are enforced by + nonpad_kv_seqlen instead, so is_causal=0 is the correct setting for TensorScatter decode.""" yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=0) - yield from _make_test_params(_GQA_CASES + _MHA_CASES, is_causal=1) def cuda_fp32_test_cases(): """CUDA fp32: MHA only. GQA requires fp16/bf16, and flash attention requires fp16/bf16. - fp32 MHA uses the unfused attention_bias fallback path.""" + fp32 MHA uses the unfused attention_bias fallback path. + TensorScatter manages KV cache externally with nonpad_kv_seqlen bounding the active range. + Per ONNX spec, is_causal with S_q!=S_kv and no past_key gives upper-left alignment + (q[0] sees only kv[0]), which is not meaningful for decode. KV bounds are enforced by + nonpad_kv_seqlen instead, so is_causal=0 is the correct setting for TensorScatter decode.""" yield from _make_test_params(_MHA_CASES, is_causal=0) - yield from _make_test_params(_MHA_CASES, is_causal=1) # ################################################################################################# @@ -975,5 +981,71 @@ def test_nonpad_with_bool_mask_cuda_fp16( numpy.testing.assert_allclose(present_v, ref_present_v, rtol=rtol["fp16"], atol=atol["fp16"]) +class TestCausalTensorScatterRejected(unittest.TestCase): + """Test that is_causal=1 + TensorScatter decode (S_q != S_kv, no past) is rejected. + + Per ONNX spec, is_causal without past_key means upper-left alignment: q[i] attends + only to kv[0..i]. For decode with external cache (S_q=1, S_kv=cache_size), this means + q[0] sees only kv[0] — not meaningful for autoregressive generation. + + The dispatch guard should return NOT_IMPLEMENTED for this combination. + Models should use is_causal=0 for TensorScatter decode. + """ + + @unittest.skipUnless("CUDAExecutionProvider" in get_available_providers(), "CUDA not available") + def test_is_causal_with_tensorscatter_no_past_rejected(self): + """Verify NOT_IMPLEMENTED is raised for is_causal=1 + TensorScatter + S_q != S_kv.""" + batch_size = 1 + q_seq_len = 1 + total_kv_seq_len = 8 + q_num_heads = 2 + kv_num_heads = 2 + head_size = 32 + + # Build model with is_causal=1 (the rejected combination) + model_bytes = build_tensorscatter_attention_graph( + batch_size=batch_size, + total_kv_seq_len=total_kv_seq_len, + q_seq_len=q_seq_len, + q_num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + ort_type=TensorProto.FLOAT16, + is_causal=1, + ) + + sess_opts = SessionOptions() + session = InferenceSession(model_bytes, sess_opts, providers=["CUDAExecutionProvider"]) + + kv_hidden = kv_num_heads * head_size + q_hidden = q_num_heads * head_size + key_cache = numpy.random.randn(batch_size, total_kv_seq_len, kv_hidden).astype(numpy.float16) + value_cache = numpy.random.randn(batch_size, total_kv_seq_len, kv_hidden).astype(numpy.float16) + new_k = numpy.random.randn(batch_size, q_seq_len, kv_hidden).astype(numpy.float16) + new_v = numpy.random.randn(batch_size, q_seq_len, kv_hidden).astype(numpy.float16) + write_indices = numpy.array([4], dtype=numpy.int64) + query = numpy.random.randn(batch_size, q_seq_len, q_hidden).astype(numpy.float16) + nonpad_kv_seqlen = numpy.array([5], dtype=numpy.int64) + + feeds = { + "key_cache": key_cache, + "value_cache": value_cache, + "new_k": new_k, + "new_v": new_v, + "write_indices": write_indices, + "query": query, + "nonpad_kv_seqlen": nonpad_kv_seqlen, + } + + with self.assertRaises(Exception) as ctx: + session.run(None, feeds) + + error_msg = str(ctx.exception) + self.assertTrue( + "NOT_IMPLEMENTED" in error_msg or "nonpad_kv_seqlen" in error_msg, + f"Expected NOT_IMPLEMENTED error for is_causal + TensorScatter decode, got: {error_msg}", + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 5f8871d71c80a..5e8a6532e974d 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -42,14 +42,9 @@ "^test_attention_4d_attn_mask_3d_causal_expanded*", // webgpu "^test_attention_4d_diff_heads_mask4d_padded_kv*", // Need nonpad_kv_seqlen // TODO: support qk_matmul_output modes beyond kQK in Attention-cuda (see issue #27712) - // Tests combining qk_matmul with softcap need unfused-path softcap support (deferred). - "^test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // qk_matmul + softcap needs unfused softcap - "^test_attention_4d_with_qk_matmul_softcap_cuda", // qk_matmul + softcap needs unfused softcap - // softcap + diff head sizes (head_size != v_head_size) blocks Flash, falls to unfused which lacks softcap - "^test_attention_3d_diff_heads_sizes_softcap_cuda", // diff head sizes forces unfused, no softcap - "^test_attention_4d_diff_heads_sizes_softcap_cuda", // diff head sizes forces unfused, no softcap - "^test_attention_4d_attn_mask_bool_cuda", // bool mask not supported in Attention-cuda - "^test_attention_4d_attn_mask_bool_4d_cuda", // bool mask not supported in Attention-cuda + // Tests combining qk_matmul with softcap need unfused-path qk_matmul support (deferred). + "^test_attention_3d_with_past_and_present_qk_matmul_softcap_cuda", // qk_matmul modes beyond kQK not supported + "^test_attention_4d_with_qk_matmul_softcap_cuda", // qk_matmul modes beyond kQK not supported "^test_attention_3d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_cuda", // QK matmul + bias not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_cuda", // QK matmul + bias not supported in Attention-cuda @@ -57,27 +52,6 @@ "^test_attention_4d_with_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda "^test_attention_3d_with_past_and_present_qk_matmul_softmax_cuda", // QK matmul + softmax not supported in Attention-cuda "^test_attention_4d_with_past_and_present_qk_matmul_bias_cuda", // QK matmul + bias not supported in Attention-cuda - // is_causal=Truen && q_seq_len != kv_seq_len not supported in Attention-cuda - "^test_attention_3d_causal_cuda", - "^test_attention_3d_diff_heads_sizes_causal_cuda", - "^test_attention_4d_attn_mask_3d_causal_cuda", - "^test_attention_4d_attn_mask_4d_causal_cuda", - "^test_attention_4d_causal_cuda", - "^test_attention_4d_diff_heads_sizes_causal_cuda", - // GQA Attention-cuda does not support fp16 and 4d QKV - "^test_attention_4d_gqa_with_past_and_present_fp16_cuda", // 4d QKV - "^test_attention_4d_gqa_with_past_and_present_cuda", // fp32 - "^test_attention_4d_gqa_softcap_cuda", // fp32 - "^test_attention_4d_gqa_scaled_cuda", // fp32 - "^test_attention_4d_gqa_cuda", // fp32 - "^test_attention_3d_gqa_attn_mask_cuda", // fp32 - "^test_attention_3d_gqa_causal_cuda", // fp32 - "^test_attention_3d_gqa_cuda", // fp32 - "^test_attention_3d_gqa_scaled_cuda", // fp32 - "^test_attention_3d_gqa_softcap_cuda", // fp32 - "^test_attention_3d_gqa_with_past_and_present_cuda", // fp32 - "^test_attention_4d_gqa_attn_mask_cuda", // fp32 - "^test_attention_4d_gqa_causal_cuda", // fp32 "^test_tensorscatter*", // TensorScatter(24) not implemented "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes From 3454f86eef5d96217bf6998ef0ec42fc68b65ed9 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 4 May 2026 15:16:10 -0700 Subject: [PATCH 06/34] Fix BitShift UB when shift amount >= bit width (#28272) ### Description Shifting by >= the bit width of an unsigned type is undefined behavior in C++. On x86-64, the hardware masks 64-bit shift amounts to 6 bits, so `x >> 64` silently becomes `x >> 0`, returning the original value instead of 0. Added `SafeShiftLeft`/`SafeShiftRight` helpers that return 0 when `shift >= sizeof(T) * 8`, applied across all three broadcast code paths (scalar-X, scalar-Y, element-wise). ```cpp template inline T SafeShiftRight(T value, T shift) { return shift >= sizeof(T) * 8 ? T{0} : value >> shift; } ``` Added tests covering: - Shift by exact bit width (32, 64) for `uint32_t` and `uint64_t` - Shift by more than bit width (65, 128) - All three broadcast paths (scalar-X, scalar-Y, element-wise) - New tests are excluded for DirectML EP, which has the same hardware-level shift masking behavior ### Motivation and Context `BitShift` with `direction="RIGHT"` on `uint64` inputs with shift amount 64 returns the original values instead of zeros. Reproduces with `CPUExecutionProvider` and `ORT_DISABLE_ALL` (constant folding masks the bug under `ORT_ENABLE_ALL`). --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> --- .../providers/cpu/math/element_wise_ops.cc | 25 +++++++-- .../cpu/math/element_wise_ops_test.cc | 56 +++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 4ddb5c7e78037..935fb3172cc14 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -1334,6 +1334,19 @@ BitShift::BitShift(const OpKernelInfo& info) : OpKernel(info) { ORT_THROW("Invalid direction value of '", direction, "'. Valid values are 'LEFT' or 'RIGHT'."); } +// Shifting by >= the bit width of an unsigned type is undefined behavior in C++. +// On x86, 64-bit shifts mask the shift amount to 6 bits, so shift by 64 acts like shift by 0. +// Guard against this by returning 0 when the shift amount >= the bit width. +template +inline T SafeShiftLeft(T value, T shift) { + return shift >= sizeof(T) * 8 ? T{0} : value << shift; +} + +template +inline T SafeShiftRight(T value, T shift) { + return shift >= sizeof(T) * 8 ? T{0} : value >> shift; +} + template Status BitShift::Compute(OpKernelContext* context) const { ProcessBroadcastSpanFuncs funcs{ @@ -1345,11 +1358,11 @@ Status BitShift::Compute(OpKernelContext* context) const { ptrdiff_t i = 0; if (shift_left) { for (const auto& input : input1.array()) { - output[i++] = input0 << input; + output[i++] = SafeShiftLeft(input0, input); } } else { for (const auto& input : input1.array()) { - output[i++] = input0 >> input; + output[i++] = SafeShiftRight(input0, input); } } }, @@ -1361,11 +1374,11 @@ Status BitShift::Compute(OpKernelContext* context) const { ptrdiff_t i = 0; if (shift_left) { for (const auto& input : input0.array()) { - output[i++] = input << input1; + output[i++] = SafeShiftLeft(input, input1); } } else { for (const auto& input : input0.array()) { - output[i++] = input >> input1; + output[i++] = SafeShiftRight(input, input1); } } }, @@ -1380,11 +1393,11 @@ Status BitShift::Compute(OpKernelContext* context) const { auto cur_out = output.begin(), end_out = output.end(); if (shift_left) { for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) { - *cur_out = *cur0 << *cur1; + *cur_out = SafeShiftLeft(*cur0, *cur1); } } else { for (; cur0 != end0; ++cur0, ++cur1, ++cur_out) { - *cur_out = *cur0 >> *cur1; + *cur_out = SafeShiftRight(*cur0, *cur1); } } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 48a18210face7..283f20a4be9b0 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -4420,6 +4420,62 @@ TEST(BitShiftOpTest, BroadcastXRight_Uint8) { test.Run(); } +// Test that shift amounts >= bit width produce 0 (not undefined behavior). +// DirectML EP has the same hardware-level shift masking behavior, so skip these tests for DML. +TEST(BitShiftOpTest, RightShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {4}, {1000, 255, 1, 42}); + test.AddInput("Y", {4}, {64, 64, 64, 64}); + test.AddOutput("Z", {4}, {0, 0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, LeftShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {4}, {1000, 255, 1, 42}); + test.AddInput("Y", {4}, {64, 64, 64, 64}); + test.AddOutput("Z", {4}, {0, 0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, RightShiftByBitWidth_Uint32) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {3}, {16, 4, 1}); + test.AddInput("Y", {3}, {32, 32, 32}); + test.AddOutput("Z", {3}, {0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, RightShiftByMoreThanBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {2}, {1000, 42}); + test.AddInput("Y", {2}, {65, 128}); + test.AddOutput("Z", {2}, {0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, ScalarRightShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "RIGHT"); + test.AddInput("X", {1}, {1000}); + test.AddInput("Y", {3}, {64, 65, 128}); + test.AddOutput("Z", {3}, {0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + +TEST(BitShiftOpTest, ScalarLeftShiftByBitWidth_Uint64) { + OpTester test("BitShift", 11); + test.AddAttribute("direction", "LEFT"); + test.AddInput("X", {3}, {1000, 255, 42}); + test.AddInput("Y", {1}, {64}); + test.AddOutput("Z", {3}, {0, 0, 0}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDmlExecutionProvider}); +} + TEST(MathOpTest, BitwiseAnd) { OpTester test("BitwiseAnd", 18); std::vector dims{3}; From a1aa3bbfb531ce589b1743c641da41beb5970ca1 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 4 May 2026 16:20:49 -0700 Subject: [PATCH 07/34] adds foundry local packaging to webgpu plugin ep packaging pipeline (#28300) test run: https://dev.azure.com/aiinfra/Lotus/_build/results?buildId=1201168&view=results --------- Co-authored-by: Prathik Rao --- .../stages/plugin-webgpu-packaging-stage.yml | 161 +++++++++++++++++- 1 file changed, 156 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml index 9db25f5727cc2..6777f207d67b9 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml @@ -2,22 +2,18 @@ parameters: - name: build_windows_x64 displayName: 'Build Windows x64' type: boolean - default: true - name: build_windows_arm64 displayName: 'Build Windows ARM64' type: boolean - default: false - name: build_linux_x64 displayName: 'Build Linux x64' type: boolean - default: false - name: build_macos_arm64 displayName: 'Build macOS ARM64' type: boolean - default: false - name: package_version displayName: 'Package Version' @@ -26,7 +22,7 @@ parameters: values: - dev - release - - RC + # TODO: release candidate (RC) versioning is not yet implemented - name: version_file type: string @@ -77,3 +73,158 @@ stages: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} cmake_build_type: ${{ parameters.cmake_build_type }} + + # Create zip packages for Foundry Local consumption + - stage: Package_Foundry_Local_WebGPU_Zips + displayName: 'Package Foundry Local WebGPU Plugin-EP Zips' + dependsOn: + - ${{ if eq(parameters.build_windows_x64, true) }}: + - Win_plugin_webgpu_x64_Build + - ${{ if and(eq(parameters.build_windows_arm64, true), eq(parameters.build_windows_x64, true)) }}: + - Win_plugin_webgpu_arm64_Build + - ${{ if eq(parameters.build_linux_x64, true) }}: + - Linux_plugin_webgpu_x64_Build + - ${{ if eq(parameters.build_macos_arm64, true) }}: + - MacOS_plugin_webgpu_arm64_Build + jobs: + - job: CreateZipPackages + displayName: 'Create Foundry Local WebGPU Plugin-EP Zip Packages' + pool: + name: 'onnxruntime-Win-CPU-VS2022-Latest' + os: windows + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/webgpu-deps-package + artifactName: foundry-local-webgpu-plugin-ep-zips + steps: + # The 1ES TSA SDL task expects .config/tsaoptions.json in the source directory. + # Use a sparse checkout to pull only the .config directory (avoids full repo clone). + - checkout: self + fetchDepth: 1 + sparseCheckoutDirectories: .config + + - ${{ if eq(parameters.build_windows_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_win_x64' + inputs: + artifactName: webgpu_plugin_win_x64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-win-x64 + + # Windows ARM64 + # ARM64 build requires the x64 tblgen.exe (used during the build), which is not correctly + # generated in a cross build. So we require x64 to be built first and download tblgen.exe from it. + - ${{ if and(eq(parameters.build_windows_arm64, true), eq(parameters.build_windows_x64, true)) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_win_arm64' + inputs: + artifactName: webgpu_plugin_win_arm64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-win-arm64 + + - ${{ if eq(parameters.build_linux_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_linux_x64' + inputs: + artifactName: webgpu_plugin_linux_x64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-linux-x64 + + - ${{ if eq(parameters.build_macos_arm64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download webgpu_plugin_macos_arm64' + inputs: + artifactName: webgpu_plugin_macos_arm64 + targetPath: $(Build.SourcesDirectory)/webgpu-plugin-macos-arm64 + + - task: PowerShell@2 + displayName: 'Create version.json and zip packages for each platform' + inputs: + targetType: inline + script: | + $outputDir = '$(Build.ArtifactStagingDirectory)/webgpu-deps-package' + New-Item -ItemType Directory -Path $outputDir -Force + + $platforms = @( + @{ name = 'win-x64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-win-x64' }, + @{ name = 'win-arm64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-win-arm64' }, + @{ name = 'linux-x64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-linux-x64' }, + @{ name = 'macos-arm64'; dir = '$(Build.SourcesDirectory)/webgpu-plugin-macos-arm64' } + ) + + $resolvedVersion = $null + + foreach ($platform in $platforms) { + $depsDir = $platform.dir + $platformName = $platform.name + + if (-not (Test-Path $depsDir)) { + Write-Host "Skipping $platformName (not built)" + continue + } + + $binDir = Join-Path $depsDir "bin" + $versionDir = Join-Path $depsDir "version" + + if (-not (Test-Path $binDir)) { + throw "Bin directory not found for $platformName $binDir" + } + + Write-Host "--- Processing $platformName ---" + + $versionString = "Unknown" + if (Test-Path $versionDir) { + $versionFile = Get-ChildItem -Path $versionDir -File | Select-Object -First 1 + if ($versionFile) { + $versionString = $versionFile.Name.Trim() + } + } + + # Track the resolved version (all platforms must agree) + # Version formats (full -> filename): + # release: 0.1.0 -> 0.1.0 + # dev: 0.1.0-dev.20260401+2a1ffff2 -> 0.1.0.dev.20260401.2a1ffff2 + # Dev versions have - and + replaced with . for filename compatibility. + # Full version string is preserved in version.json. + # TODO: RC versioning (e.g. 0.1.0-rc1) is not yet implemented + $filenameVersion = $versionString -replace '[-+]', '.' + if ($null -eq $resolvedVersion) { + $resolvedVersion = $filenameVersion + } elseif ($resolvedVersion -ne $filenameVersion) { + throw "Version mismatch across platforms: expected '$resolvedVersion' but $platformName has '$filenameVersion'" + } + + $versionInfo = @{ + version = $versionString + } + + $json = $versionInfo | ConvertTo-Json + $versionPath = Join-Path $binDir "version.json" + Set-Content -Path $versionPath -Value $json -Encoding UTF8 + Write-Host "Created version.json:" + Write-Host $json + + # Collect the binaries (dll, so, dylib) and version.json + $filesToZip = Get-ChildItem -Path $binDir -File | Where-Object { + $_.Extension -in '.dll', '.so', '.dylib' -or $_.Name -eq 'version.json' + } + + $zipPath = Join-Path $outputDir "webgpu_ep_${filenameVersion}_${platformName}.zip" + if ($filesToZip) { + $filesToZip | Compress-Archive -DestinationPath $zipPath -Force + Write-Host "Created zip: $zipPath ($((Get-Item $zipPath).Length) bytes)" + } else { + throw "No files found to zip for $platformName in $binDir" + } + Write-Host "" + } + + if ($null -eq $resolvedVersion) { + throw "No platforms were processed — cannot determine version." + } + + # Create a version folder in the output artifact with a file whose name is the version string. + # This follows the same convention as the per-platform artifacts (e.g. webgpu_plugin_win_x64/version/) + # and allows downstream pipelines to read the version without parsing zip filenames. + $versionOutputDir = Join-Path $outputDir "version" + New-Item -ItemType Directory -Path $versionOutputDir -Force + New-Item -ItemType File -Path (Join-Path $versionOutputDir $resolvedVersion) -Force | Out-Null + Write-Host "Created version marker: $versionOutputDir/$resolvedVersion" From 8a0950166b8030041e0f5bd0c94d1a1459d795b3 Mon Sep 17 00:00:00 2001 From: Rishi Dave <62260675+Rishi-Dave@users.noreply.github.com> Date: Mon, 4 May 2026 18:40:43 -0700 Subject: [PATCH 08/34] Suppress -Wmaybe-uninitialized for onnxruntime_pybind11_state under pybind11 3.0 (#28251) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Probes `-Wmaybe-uninitialized` via `check_cxx_compiler_flag` and applies `-Wno-maybe-uninitialized` only to the `onnxruntime_pybind11_state` target when the compiler accepts it. - Fixes the GCC build break introduced when ORT is compiled against pybind11 3.0, currently blocking Fedora's pybind11 3.0 package update. ## Motivation pybind11 3.0 rewrote `def_readwrite` to use a `property_cpp_function_classic` template that generates a lambda capturing a member pointer by value. GCC's `-Wmaybe-uninitialized` flow analysis flags that lambda inside pybind11's own headers, so any consumer compiling ORT's Python bindings against system pybind11 3.0 fails the build. This is a header-side false positive — there is no real uninitialized read in ORT code or in pybind11. Fixes #25681 ## Changes - `cmake/CMakeLists.txt`: add `check_cxx_compiler_flag(-Wno-maybe-uninitialized HAS_NO_MAYBE_UNINITIALIZED)` next to the existing `HAS_CAST_FUNCTION_TYPE` probe. - `cmake/onnxruntime_python.cmake`: when `HAS_NO_MAYBE_UNINITIALIZED` is set, append `-Wno-maybe-uninitialized` to the `onnxruntime_pybind11_state` target's private compile options. Mirrors the established `HAS_CAST_FUNCTION_TYPE` pattern in the same file. The suppression is target-scoped (only the Python binding shared library), compiler-scoped (only when the flag is accepted — effectively GCC), and warning-scoped (only the flow-sensitive `-Wmaybe-uninitialized`, not the strict `-Wuninitialized`). ## Test Plan - [x] `lintrunner -a` clean on the diff. - [ ] CI: confirm Linux GCC builds remain green. - [ ] Downstream verification: Fedora packagers can rebuild ORT against system pybind11 3.0 without `-Wmaybe-uninitialized` errors (per issue reporter). --- cmake/onnxruntime_python.cmake | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 39985b23da3cc..494d5588c2d03 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -93,6 +93,12 @@ endif() if(HAS_CAST_FUNCTION_TYPE) target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type") endif() +# pybind11 3.0 headers trigger -Wmaybe-uninitialized in GCC's flow analysis +# of property accessor lambdas. Suppress it for this target only. +# See https://github.com/microsoft/onnxruntime/issues/25681 +if(HAS_MAYBE_UNINITIALIZED) + target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-maybe-uninitialized") +endif() # We export symbols using linker and the compiler does not know anything about it # There is a problem with classes that have pybind types as members. From 7529033fef1e422823f37278c60483851731a3fb Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 4 May 2026 18:43:28 -0700 Subject: [PATCH 09/34] Fix ReshapeFusion dropping allowzero on inferred 0-sized intermediate dims (#28349) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description `ReshapeFusion::FuseContiguousReshapes` collapses a chain of `Reshape` / `Squeeze` / `Unsqueeze` nodes into a single `Reshape` whose shape data is taken verbatim from the fully-inferred output shape of the last node in the chain. The new node is created without an `allowzero` attribute, so it defaults to `allowzero = 0`. When that inferred shape contains a literal `0` dim (legitimate when the original chain used `allowzero=1`, or when intermediate tensors had zero-sized dimensions), the fused `Reshape` misinterprets the `0` as "copy the corresponding dim from the input tensor" — but the input here is the original input of the *first* reshape in the chain, with unrelated dims. The result is a silently wrong output shape (and a benign-looking `MergeShapeInfo` warning at graph load). ### Repro (before the fix) ```python import numpy as np, onnx, onnxruntime as ort, onnx.reference from onnx import helper, TensorProto X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [0, 6, 2]) Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None]) s1 = helper.make_tensor("s1", TensorProto.INT64, [3], [3, 2, -1]) s2 = helper.make_tensor("s2", TensorProto.INT64, [3], [0, 0, 3]) n1 = helper.make_node("Reshape", ["X", "s1"], ["mid"]) n2 = helper.make_node("Reshape", ["mid", "s2"], ["Y"], allowzero=1) m = helper.make_model(helper.make_graph([n1, n2], "g", [X], [Y], initializer=[s1, s2]), opset_imports=[helper.make_opsetid("", 18)]) inp = np.random.default_rng(7).random((0, 6, 2), dtype=np.float32) print("REF:", onnx.reference.ReferenceEvaluator(m).run(None, {"X": inp})[0].shape) print("ORT:", ort.InferenceSession(m.SerializeToString(), providers=["CPUExecutionProvider"]).run(None, {"X": inp})[0].shape) ``` Output on `main` (`40c9f85f69`): ``` REF: (0, 0, 3) [W ... graph.cc:122 MergeShapeInfo] Error merging shape info for output. 'Y' source:{0,6,3} target:{0,0,3}. Falling back to lenient merge. ORT: (0, 6, 3) ❌ ``` ### Fix Setting `allowzero=1` on the fused node would also work but requires opset >= 14, which this transformer cannot assume (it accepts `Reshape` opset 5+). Bail out of fusion conservatively when `shape_value` contains any literal `0` dim. ### Test Adds `ReshapeFusionContiguousReshapesWithZeroDim` that builds the bug repro programmatically and asserts: - the two reshapes are NOT collapsed - the inferred output shape stays `(0, 0, 3)` The existing happy-path test `ReshapeFusion_Contiguous_Reshape` (added in #22494) is unaffected — its inferred output shape `(2, 1, 64, 32)` contains no zero dims, so the new guard does not trigger. ### Provenance `FuseContiguousReshapes` was introduced in #22494 (Feb 2025). The bug has been latent in `main` since then. ### Motivation and Context Found while reviewing https://github.com/microsoft/onnxscript/pull/2907 — the rewriter rule under test there is semantically correct, but its numerical-equivalence check using ORT as the oracle fails because of this fusion bug. Fixes #28348. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxruntime/core/optimizer/reshape_fusion.cc | 12 +++ .../test/optimizer/graph_transform_test.cc | 75 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 167952356ff58..f88ce56fe36fa 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/reshape_fusion.h" @@ -486,6 +488,16 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { return false; } + // The fused shape is taken verbatim from the inferred output shape of the last reshape + // (we ensured tensor_shape.Size() != -1 above, so dims are concrete). If any dim is + // literally 0, fusing into a single Reshape is unsafe: ONNX Reshape with the default + // allowzero=0 would reinterpret the 0 as "copy from input", producing the wrong shape. + // Setting allowzero=1 would fix it but requires opset >= 14, which we cannot assume + // here (this transformer accepts Reshape opset 5+). Bail out conservatively. + if (std::any_of(shape_value.begin(), shape_value.end(), [](int64_t d) { return d == 0; })) { + return false; + } + const std::string& name = contiguous_reshapes[0].get().Name(); ONNX_NAMESPACE::TensorProto shape_initializer_proto; shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape")); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 950355742193c..0779bd4d4ec09 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4713,6 +4713,81 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraph) { } } +// Regression test: FuseContiguousReshapes must not collapse a chain of Reshapes +// when the inferred output shape contains a literal 0 dim. Doing so would create +// a single Reshape whose shape data contains 0 and (because allowzero defaults +// to 0) be misinterpreted as "copy from input dim", silently producing wrong shape. +// See https://github.com/microsoft/onnxruntime/issues/28348. +TEST_F(GraphTransformationTests, ReshapeFusionContiguousReshapesWithZeroDim) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 21; + Model model("ReshapeFusionContiguousReshapesWithZeroDim", false, ModelMetaData(), + PathString(), IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), *logger_); + auto& graph = model.MainGraph(); + + // X: float[0, 6, 2] (zero-sized first dim, fully concrete) + TypeProto x_type; + x_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(0); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(6); + x_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); + + TypeProto y_type; + y_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + + auto& X = graph.GetOrCreateNodeArg("X", &x_type); + auto& mid = graph.GetOrCreateNodeArg("mid", &y_type); + auto& Y = graph.GetOrCreateNodeArg("Y", &y_type); + + // shape1 = [3, 2, -1] -> mid shape (3, 2, 0) + ONNX_NAMESPACE::TensorProto shape1_proto; + shape1_proto.set_name("shape1"); + shape1_proto.set_data_type(TensorProto_DataType_INT64); + shape1_proto.add_dims(3); + for (int64_t v : {3, 2, -1}) shape1_proto.add_int64_data(v); + graph.AddInitializedTensor(shape1_proto); + + // shape2 = [0, 0, 3] with allowzero=1 -> Y shape (0, 0, 3) + ONNX_NAMESPACE::TensorProto shape2_proto; + shape2_proto.set_name("shape2"); + shape2_proto.set_data_type(TensorProto_DataType_INT64); + shape2_proto.add_dims(3); + for (int64_t v : {0, 0, 3}) shape2_proto.add_int64_data(v); + graph.AddInitializedTensor(shape2_proto); + + auto& shape1 = graph.GetOrCreateNodeArg("shape1", nullptr); + auto& shape2 = graph.GetOrCreateNodeArg("shape2", nullptr); + + graph.AddNode("reshape1", "Reshape", "first reshape", {&X, &shape1}, {&mid}); + auto& reshape2 = graph.AddNode("reshape2", "Reshape", "second reshape (allowzero=1)", + {&mid, &shape2}, {&Y}); + reshape2.AddAttribute("allowzero", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::map op_to_count_before = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count_before["Reshape"], 2); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // Fusion must NOT collapse the two reshapes, otherwise the resulting single + // Reshape would (mis)compute output shape (0, 6, 3) instead of (0, 0, 3). + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Reshape"], 2); + + // Y's inferred shape must remain (0, 0, 3). + const auto* y_shape = graph.GetNodeArg("Y")->Shape(); + ASSERT_NE(y_shape, nullptr); + ASSERT_EQ(y_shape->dim_size(), 3); + EXPECT_EQ(y_shape->dim(0).dim_value(), 0); + EXPECT_EQ(y_shape->dim(1).dim_value(), 0); + EXPECT_EQ(y_shape->dim(2).dim_value(), 3); +} + TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/reshape_fusion_with_slice1.onnx"; std::shared_ptr p_model; From b81f3f85580a3740261c568b0f9b03bc93215267 Mon Sep 17 00:00:00 2001 From: Rishi Dave <62260675+Rishi-Dave@users.noreply.github.com> Date: Mon, 4 May 2026 22:47:37 -0700 Subject: [PATCH 10/34] fix: make sympy an optional runtime dependency (#28141) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Defer `sympy` import so `import onnxruntime.quantization` succeeds without sympy installed - Move `SymbolicShapeInference` import in `quant_pre_process` behind `skip_symbolic_shape` gate - Defer sympy-dependent imports in `transformers.onnx_model` and `transformers.shape_infer_helper` - Raise a clear, actionable `ImportError` instructing users to install sympy when needed ## Motivation Fixes #24872. `sympy` (~29 MB plus `mpmath` ~2 MB) was a hard runtime dependency even though it is only needed for symbolic shape inference. Pure-inference users — the common case — pay the install/import cost for functionality they do not use. `setup.py` already declares sympy as an optional extra (`"symbolic": ["sympy"]`), but top-level imports forced it to load unconditionally. ## Changes - `onnxruntime/python/tools/quantization/shape_inference.py`: move `from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference` from module top-level into `quant_pre_process`, guarded by `if not skip_symbolic_shape`. Wrap in `try/except ImportError` that re-raises with install instructions. - `onnxruntime/python/tools/transformers/onnx_model.py`: move the `from shape_infer_helper import SymbolicShapeInferenceHelper` from module top-level into the two methods that instantiate it. Add `TYPE_CHECKING`-guarded import for type annotations. - `onnxruntime/python/tools/transformers/shape_infer_helper.py`: wrap the import of `symbolic_shape_infer` in `try/except ImportError`. The `SymbolicShapeInferenceHelper.__init__` now raises a clear `ImportError` when sympy is unavailable, instead of failing at module load time. - `onnxruntime/test/python/quantization/test_quant_preprocess.py`: add `test_skip_symbolic_shape_does_not_require_sympy` which removes sympy from `sys.modules` and verifies `quant_pre_process(..., skip_symbolic_shape=True)` completes successfully. No public API signatures change. Users who want symbolic shape inference install sympy as before (`pip install sympy` or `pip install onnxruntime[symbolic]`). ## Test Plan - `python -m pytest onnxruntime/test/python/quantization/test_quant_preprocess.py -v` — all tests pass including the new coverage. - Smoke-tested locally: `import onnxruntime.quantization` no longer pulls `sympy` into `sys.modules`. - `lintrunner -a` clean on all changed files. Fixes #24872 --- .../tools/quantization/shape_inference.py | 8 +- .../python/tools/transformers/onnx_model.py | 10 ++- .../tools/transformers/shape_infer_helper.py | 20 ++++- .../quantization/test_quant_preprocess.py | 79 +++++++++++++++++++ 4 files changed, 114 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index cc3bc2ef28c4f..0a1ba0462f9bf 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -13,7 +13,6 @@ import onnx import onnxruntime -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data from .fusions import ReplaceUpsampleWithResize @@ -88,6 +87,13 @@ def quant_pre_process( model = save_and_reload_model_with_shape_infer(model) if not skip_symbolic_shape: + try: + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference # noqa: PLC0415 + except ImportError as e: + raise ImportError( + "sympy is required for symbolic shape inference in quantization preprocessing. " + "Install with: 'pip install sympy' or pass skip_symbolic_shape=True to quant_pre_process()." + ) from e logger.info("Performing symbolic shape inference...") model = SymbolicShapeInference.infer_shapes( model, diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index a00cddf18870e..25bd35e479bd2 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations import itertools import logging @@ -9,6 +10,10 @@ import sys from collections import deque from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from shape_infer_helper import SymbolicShapeInferenceHelper from float16 import convert_float_to_float16 from onnx import ( @@ -23,7 +28,6 @@ save_model, ) from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data -from shape_infer_helper import SymbolicShapeInferenceHelper logger = logging.getLogger(__name__) @@ -51,6 +55,8 @@ def disable_shape_inference(self): def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False): # noqa: B006 if self.enable_shape_infer: if self.shape_infer_helper is None or update: + from shape_infer_helper import SymbolicShapeInferenceHelper # noqa: PLC0415 + self.shape_infer_helper = SymbolicShapeInferenceHelper(self.model) try: @@ -764,6 +770,8 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): if use_symbolic_shape_infer: # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) # are not recognized by onnx shape inference. + from shape_infer_helper import SymbolicShapeInferenceHelper # noqa: PLC0415 + shape_infer_helper = SymbolicShapeInferenceHelper(model) try: model_with_shape = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False) diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py index f4d65d05ad0c8..5651c3cddba72 100644 --- a/onnxruntime/python/tools/transformers/shape_infer_helper.py +++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py @@ -14,13 +14,31 @@ else: sys.path.append(os.path.join(file_path, "..")) -from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy # noqa: E402 +try: + from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy + + _symbolic_shape_infer_available = True + _symbolic_shape_infer_import_error: ImportError | None = None +except ImportError as exc: + SymbolicShapeInference = object # type: ignore[assignment,misc] + get_shape_from_type_proto = None # type: ignore[assignment] + sympy = None # type: ignore[assignment] + _symbolic_shape_infer_available = False + _symbolic_shape_infer_import_error = exc logger = logging.getLogger(__name__) class SymbolicShapeInferenceHelper(SymbolicShapeInference): def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False): + if not _symbolic_shape_infer_available: + err = _symbolic_shape_infer_import_error + cause = ( + "missing 'sympy' (install with: pip install sympy)" + if err is not None and "sympy" in str(err) + else f"failed to import symbolic_shape_infer: {err!r}" + ) + raise ImportError(f"SymbolicShapeInferenceHelper is unavailable — {cause}") from err super().__init__(int_max, auto_merge, guess_output_rank, verbose) self.model_ = model self.all_shapes_inferred_: bool = False diff --git a/onnxruntime/test/python/quantization/test_quant_preprocess.py b/onnxruntime/test/python/quantization/test_quant_preprocess.py index c93f081072f35..f00fb4a05b6d8 100644 --- a/onnxruntime/test/python/quantization/test_quant_preprocess.py +++ b/onnxruntime/test/python/quantization/test_quant_preprocess.py @@ -5,6 +5,7 @@ # license information. # -------------------------------------------------------------------------- +import sys import tempfile import unittest from pathlib import Path @@ -158,5 +159,83 @@ def test_clip_version_conversion(self): assert preprocessed_model.opset_import[0].version >= 11 +class TestSkipSymbolicShape(unittest.TestCase): + """Verify that skip_symbolic_shape=True avoids importing sympy.""" + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory(prefix="ort.quant_preprocess_skip_sympy_") + self.temp_path = Path(self.temp_dir.name) + + def tearDown(self): + self.temp_dir.cleanup() + + def build_simple_model(self): + """Build a minimal identity model for testing.""" + input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 4]) + output_tensor = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 4]) + identity_node = onnx.helper.make_node("Identity", ["input"], ["output"]) + graph = onnx.helper.make_graph([identity_node], "simple_graph", [input_tensor], [output_tensor]) + opset_imports = [onnx.helper.make_opsetid("", 13)] + return onnx.helper.make_model(graph, opset_imports=opset_imports) + + def test_skip_symbolic_shape_does_not_require_sympy(self): + """ + When skip_symbolic_shape=True, quant_pre_process must not attempt to + import onnxruntime.tools.symbolic_shape_infer (which requires sympy). + We verify this by installing a meta_path finder that raises + ModuleNotFoundError for those modules — guaranteeing any fresh import + attempt fails — and asserting the call succeeds without ever loading + them. + """ + + class _BlockSympyAndSymbolicFinder: + blocked_prefixes = ("sympy",) + blocked_substrings = ("symbolic_shape_infer",) + + def find_spec(self, fullname, path=None, target=None): + if fullname == "sympy" or fullname.startswith("sympy."): + raise ModuleNotFoundError(f"blocked by test: {fullname}") + if "symbolic_shape_infer" in fullname: + raise ModuleNotFoundError(f"blocked by test: {fullname}") + return None + + model = self.build_simple_model() + input_path = self.temp_path / "simple_model.onnx" + output_path = self.temp_path / "out_model.onnx" + onnx.save_model(model, str(input_path)) + + saved = {} + for key in list(sys.modules.keys()): + if key == "sympy" or key.startswith("sympy.") or "symbolic_shape_infer" in key: + saved[key] = sys.modules.pop(key) + + blocker = _BlockSympyAndSymbolicFinder() + sys.meta_path.insert(0, blocker) + try: + quant_pre_process( + input_model=str(input_path), + output_model_path=str(output_path), + skip_optimization=True, + skip_onnx_shape=True, + skip_symbolic_shape=True, + ) + + for mod_name in list(sys.modules): + self.assertFalse( + mod_name == "sympy" or mod_name.startswith("sympy."), + f"sympy was imported despite skip_symbolic_shape=True: {mod_name}", + ) + self.assertNotIn( + "symbolic_shape_infer", + mod_name, + f"symbolic_shape_infer was imported despite skip_symbolic_shape=True: {mod_name}", + ) + finally: + sys.meta_path.remove(blocker) + sys.modules.update(saved) + + self.assertTrue(output_path.exists(), "Output model should be created even without sympy") + + if __name__ == "__main__": unittest.main() From ebee6069d9f21f2c60194417684cf8fbbf571e2c Mon Sep 17 00:00:00 2001 From: Sanaa Hamel Date: Tue, 5 May 2026 08:53:20 -0400 Subject: [PATCH 11/34] fix(ci): test pipeline didn't correctly specify `ReleaseVersionSuffix` (#28346) ### Description Fix `ReleaseVersionSuffix` passing in 'Nuget Test' pipeline. --- .../github/azure-pipelines/c-api-noopenmp-test-pipelines.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index 5ddac928b32d3..ba57a4b2c85c9 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -69,6 +69,8 @@ stages: - stage: Android_Java_API_AAR_Testing_Full dependsOn: Setup + variables: + ReleaseVersionSuffix: $[ stageDependencies.Setup.Restore_And_Use_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix'] ] jobs: - template: templates/android-java-api-aar-test.yml parameters: @@ -77,6 +79,8 @@ stages: - stage: Final_AAR_Testing_Android_QNN dependsOn: Setup + variables: + ReleaseVersionSuffix: $[ stageDependencies.Setup.Restore_And_Use_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix'] ] jobs: - template: templates/android-java-api-aar-test.yml parameters: @@ -84,6 +88,7 @@ stages: packageName: 'onnxruntime-android-qnn' #TODO: get this information from the setup stage QnnSDKVersion: '2.42.0.251225' + ReleaseVersionSuffix: $(ReleaseVersionSuffix) - template: nuget/templates/test_win.yml parameters: From ef44604558d3cdac3704115ddbd1828a1b9c7c40 Mon Sep 17 00:00:00 2001 From: Sanaa Hamel Date: Tue, 5 May 2026 08:54:19 -0400 Subject: [PATCH 12/34] chore: rename `ort_api_1_to_26` to `ort_api_1_to_27` (#28341) ### Description Rename `ort_api_1_to_26` -> `ort_api_1_to_27`. ### Motivation and Context This should have been done in #28324, but we wanted to merge ASAP. --- onnxruntime/core/session/onnxruntime_c_api.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 3bbb8f170dbc7..5ee5f1486b137 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4387,7 +4387,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_26 = { +static constexpr OrtApi ort_api_1_to_27 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -4894,6 +4894,7 @@ static constexpr OrtApi ort_api_1_to_26 = { &OrtApis::KernelInfoGetAttributeArray_string, &OrtApis::SetPerSessionThreadPoolCallbacks, // End of Version 25 - DO NOT MODIFY ABOVE (see above text for more information) + // End of Version 26 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -4932,13 +4933,14 @@ static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of versio static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_string) / sizeof(void*) == 417, "Size of version 25 API cannot change"); +// no additions in version 26 // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.27.0", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // -// 2. If there were any APIs added to ort_api_1_to_26 above: +// 2. If there were any APIs added to ort_api_1_to_X above: // a. Add the 'End of version #' markers (pattern above should be obvious) // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version // @@ -4950,7 +4952,7 @@ static_assert(std::string_view(ORT_VERSION) == "1.27.0", ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { if (version >= 1 && version <= ORT_API_VERSION) - return &ort_api_1_to_26; + return &ort_api_1_to_27; fprintf(stderr, "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." From 1f25783745ba452554257304945d8aacd02c1210 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 5 May 2026 07:13:44 -0700 Subject: [PATCH 13/34] Fix CUDA Attention dispatch: skip MEA when head_size != v_head_size in GQA (#28358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary ## Problem The Memory-Efficient Attention (MEA) path crashes with `cudaErrorMisalignedAddress` when: - GQA mode (`q_num_heads != kv_num_heads`) - `head_size != v_head_size` (e.g., Q.head_dim=256, K.head_dim=512) - `seq_len >= 4` (Flash Attention not eligible due to attention mask) This is because MEA's `LaunchUngroup` requires equal head sizes, but the dispatch logic only checked this constraint for the past_key case (line 1380), not the general GQA case. ## Fix Skip MEA for GQA when head sizes differ. The Unfused Attention fallback handles this correctly. ## Affected Models Gemma 4 was not affected. This was a previously incorrect graph. But the fix is still good to have that improves robustness anyways. ~~**Gemma4** (google/gemma-4-e2b-it) with KV sharing:~~ - Layers 15-34 borrow K,V from source layers - Q projection: 1536 → 2048 (8 heads × 256) - K/V from source: [batch, 1, seq, 512] - `head_size = 256`, `v_head_size = 512` ## Testing Minimal repro (from #28357): ```python # Attention(Q=[1,S,2048], K=[1,S,512], V=[1,S,512], q_num_heads=8, kv_num_heads=1) # Before fix: seq=4+ crashes with misaligned address # After fix: all seq lengths work ``` Full Gemma4 decoder (35 layers, 15 GQA + 20 standard Attention): - Prefill seq=32: ✅ - Decode seq=1: ✅ Fixes #28357 Signed-off-by: Justin Chu Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxruntime/core/providers/cuda/llm/attention.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 00ce18c65efd8..15f9dcbf8e7f2 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1375,8 +1375,11 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.v_head_size) && !has_output_qk && - // MEA decode requires head_size == v_head_size for LaunchConcatNewToPastKV - // (single head_size parameter). Fall back to unfused when they differ. + // MEA requires head_size == v_head_size in two internal paths: + // - LaunchConcatNewToPastKV (decode with past_key) + // - LaunchUngroup (GQA head expansion) + // Fall back to unfused attention when they differ. + (!is_gqa || parameters.head_size == parameters.v_head_size) && (past_key == nullptr || parameters.head_size == parameters.v_head_size) && // GQA+MEA requires LaunchUngroup which only has fp16/bf16 instantiations. // FP32 GQA must fall through to the unfused path. From 5e38dfe95499bb7901c2c6fe9342735bd9eede94 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> Date: Tue, 5 May 2026 18:51:41 +0200 Subject: [PATCH 14/34] Fix CApi tests on S390x (#28074) ### Description When loading data into tensors from memory buffers from external files, byteswap it if necessary. Also add a fix for deleter when byteswapping: keep copy of AllocatorPtr instead of reference. ### Motivation and Context While trying to setup local s390x CI, I've found 4 more tests that fail on s390x: CApiTest.TestLoadModelFromArrayWithExternalInitializerFromFileArray CApiTest.TestLoadModelFromArrayWithExternalInitializersFromFileArray CApiTest.TestLoadModelFromArrayWithExternalInitializersFromFileArrayPathRobust CApiTest.TestLoadModelFromArrayWithExternalInitializersFromFileMmap --- .../core/framework/tensorprotoutils.cc | 2 +- onnxruntime/core/graph/graph.cc | 41 ++++++++++++++++--- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 3e928afcf6c80..360726d780a17 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1569,7 +1569,7 @@ Status GetExtDataFromTensorProto(const Env& env, if constexpr (endian::native != endian::little) { auto allocator = CPUAllocator::DefaultInstance(); - auto deleter = [&allocator](uint8_t* ptr) { allocator->Free(ptr); }; + auto deleter = [allocator](uint8_t* ptr) { allocator->Free(ptr); }; std::unique_ptr native_data{reinterpret_cast(allocator->Alloc(static_cast(raw_data_safe_len))), deleter}; size_t element_size = onnxruntime::utils::GetElementSizeOfTensor(static_cast(tensor_proto.data_type())); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 1346b976461ce..7da1c6936ff31 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4177,12 +4177,43 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(old_initializer.data_type())->GetElementType(); TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer); - auto tensor = Tensor(type, tensor_shape, user_provided_tensor_buffer, - OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); - constexpr const bool use_tensor_buffer_false = false; - auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false); - **existing_entry = std::move(new_tensor_proto); + // Convert data from little endian before assigning it to tensor. + // It would have been better to byteswap it right after loading from file, + // but at that moment information about tensor element size was not available. + if constexpr (endian::native != endian::little) { + size_t element_size = onnxruntime::utils::GetElementSizeOfTensor( + static_cast(old_initializer.data_type())); + + // If element size is unknown, set it to 1 to disable byteswapping + if (element_size < 1) element_size = 1; + + auto allocator = CPUAllocator::DefaultInstance(); + + auto deleter = [allocator](uint8_t* ptr) { allocator->Free(ptr); }; + std::unique_ptr native_data{ + reinterpret_cast(allocator->Alloc(tensor_byte_size)), deleter}; + + auto src_span = gsl::make_span( + reinterpret_cast(user_provided_tensor_buffer), tensor_byte_size); + auto dst_span = gsl::make_span( + reinterpret_cast(native_data.get()), tensor_byte_size); + + ORT_RETURN_IF_ERROR(onnxruntime::utils::ReadLittleEndian(element_size, src_span, dst_span)); + + auto tensor = Tensor{type, tensor_shape, native_data.release(), allocator}; + + constexpr const bool use_tensor_buffer_false = false; + auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false); + **existing_entry = std::move(new_tensor_proto); + } else { + auto tensor = Tensor(type, tensor_shape, user_provided_tensor_buffer, + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); + + constexpr const bool use_tensor_buffer_false = false; + auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_false); + **existing_entry = std::move(new_tensor_proto); + } } } From 07b8f3952b453ee0e57f2ef29c2b10d81e1ccaae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 May 2026 10:48:45 -0700 Subject: [PATCH 15/34] Bump postcss from 8.5.3 to 8.5.13 in /js/web/test/e2e/exports/testcases/vite-default (#28304) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [postcss](https://github.com/postcss/postcss) from 8.5.3 to 8.5.13.
Release notes

Sourced from postcss's releases.

8.5.13

  • Fixed postcss-scss commend regression.

8.5.12

  • Fixed reading any file via user-generated CSS.
  • Added opts.unsafeMap to disable checks.

8.5.11

  • Fixed nested brackets parsing performance (by @​offset).

8.5.10

  • Fixed XSS via unescaped </style> in non-bundler cases (by @​TharVid).

8.5.9

  • Speed up source map encoding paring in case of the error.

8.5.8

  • Fixed Processor#version.

8.5.7

  • Improved source map annotation cleaning performance (by CodeAnt AI).

8.5.6

  • Fixed ContainerWithChildren type discriminating (by @​Goodwine).

8.5.5

  • Fixed package.jsonexports compatibility with some tools (by @​JounQin).

8.5.4

Changelog

Sourced from postcss's changelog.

8.5.13

  • Fixed postcss-scss commend regression.

8.5.12

  • Fixed reading any file via user-generated CSS.
  • Added opts.unsafeMap to disable checks.

8.5.11

  • Fixed nested brackets parsing performance (by @​offset).

8.5.10

  • Fixed XSS via unescaped </style> in non-bundler cases (by @​TharVid).

8.5.9

  • Speed up source map encoding paring in case of the error.

8.5.8

  • Fixed Processor#version.

8.5.7

  • Improved source map annotation cleaning performance (by CodeAnt AI).

8.5.6

  • Fixed ContainerWithChildren type discriminating (by @​Goodwine).

8.5.5

  • Fixed package.jsonexports compatibility with some tools (by @​JounQin).

8.5.4

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=postcss&package-manager=npm_and_yarn&previous-version=8.5.3&new-version=8.5.13)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../testcases/vite-default/package-lock.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index d6d1e383641c2..2cfbd6ed4c92a 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -1068,9 +1068,9 @@ } }, "node_modules/nanoid": { - "version": "3.3.8", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz", - "integrity": "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==", + "version": "3.3.12", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.12.tgz", + "integrity": "sha512-ZB9RH/39qpq5Vu6Y+NmUaFhQR6pp+M2Xt76XBnEwDaGcVAqhlvxrl3B2bKS5D3NH3QR76v3aSrKaF/Kiy7lEtQ==", "funding": [ { "type": "github", @@ -1105,9 +1105,9 @@ } }, "node_modules/postcss": { - "version": "8.5.3", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.3.tgz", - "integrity": "sha512-dle9A3yYxlBSrt8Fu+IpjGT8SY8hN0mlaA6GY8t0P5PjIOZemULz/E2Bnm/2dcUOena75OTNkHI76uZBNUUq3A==", + "version": "8.5.13", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.13.tgz", + "integrity": "sha512-qif0+jGGZoLWdHey3UFHHWP0H7Gbmsk8T5VEqyYFbWqPr1XqvLGBbk/sl8V5exGmcYJklJOhOQq1pV9IcsiFag==", "funding": [ { "type": "opencollective", @@ -1124,7 +1124,7 @@ ], "license": "MIT", "dependencies": { - "nanoid": "^3.3.8", + "nanoid": "^3.3.11", "picocolors": "^1.1.1", "source-map-js": "^1.2.1" }, From c85ec499b943a54cf232e6edc6a65ba5b6ec2d0b Mon Sep 17 00:00:00 2001 From: Sanaa Hamel Date: Tue, 5 May 2026 15:04:28 -0400 Subject: [PATCH 16/34] fix(ci): 'rc' qualifier ignored when packaging `onnxruntime-node` (#28350) ### Description Fix `onnxruntime-node` and `onnxruntime-common` NPM packages lacking an RC suffix when built in Release + RC mode. This isn't great, the suffix looks like `-QUAL.DATE-COMMIT`. This'll break the publishing pipeline if the packaging pipelines (zip-nuget and NPM) span more than a single day due to same-version checks/enforcement. ### Motivation and Context Missing the RC qualifier/suffix fails the NPM publish pipeline. It correctly assets that the (onnxruntime-node, onnxruntime-common, and onnxruntime-web) do not share a common version specifier. --- .../github/azure-pipelines/templates/c-api-cpu.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 13d5578262102..04066a0c0b90c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -421,10 +421,14 @@ stages: targetPath: $(Build.ArtifactStagingDirectory) artifactName: 'NPM_packages' variables: - ${{ if eq(parameters.IsReleaseBuild, true) }}: + ${{ if and(parameters.IsReleaseBuild, eq(parameters.PreReleaseVersionSuffixString, 'none')) }}: NpmPackagingMode: 'release' - ${{ if not(eq(parameters.IsReleaseBuild, true)) }}: + ${{ elseif and(parameters.IsReleaseBuild, eq(parameters.PreReleaseVersionSuffixString, 'rc')) }}: + NpmPackagingMode: 'rc' + ${{ elseif not(parameters.IsReleaseBuild) }}: NpmPackagingMode: 'dev' + ${{ else }}: # IsReleaseBuild + beta, alpha, etc. We don't support those and those suffixes are deprecated. + NpmPackagingMode: '' steps: - checkout: self From 80a23527f2183ee75776717c5de6fcf356ee43a6 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 5 May 2026 13:03:49 -0700 Subject: [PATCH 17/34] Fix round_prefer_ceil nearest mode for negative halfway values in Resize op (#28345) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description `ROUND_PREFER_CEIL` in the Resize operator used bare `std::round`/`roundf`, which rounds away from zero. This is correct for positive halfway values (e.g., `round(0.5) = 1 = ceil(0.5)`) but wrong for negative halfway values (e.g., `round(-0.5) = -1`, but `ceil(-0.5) = 0`). Negative coordinates occur naturally with the `half_pixel` coordinate transformation mode for the first output pixels when upsampling. Added an explicit negative-halfway check, mirroring the existing positive-halfway check in `ROUND_PREFER_FLOOR`: ```cpp // CPU (upsamplebase.h) case ROUND_PREFER_CEIL: return [](float x_original, bool) { if (x_original == static_cast(x_original) - 0.5f) { return static_cast(std::ceil(x_original)); } return static_cast(std::round(x_original)); }; ``` Same fix applied to the CUDA implementation (`resize_impl.cu`). Added two test cases in `resize_op_test.cc`: 1. `ResizeOpNearestUpSample_RoundPreferCeil_HalfPixel` — exercises non-integer scale (26→64) from the original issue report, verifying correct source pixel selection at fractional boundaries. 2. `ResizeOpNearestUpSample_RoundPreferCeil_HalfPixel_2x2to7x8` — exercises a positive 0.5 boundary where `round_prefer_ceil` selects ceiling. ### Motivation and Context The `round_prefer_floor` path already had an explicit halfway-case override (for positive values where `std::round` disagrees with floor). The `round_prefer_ceil` path was missing the symmetric fix for negative values, violating the ONNX spec semantics of "at ties, prefer ceiling." --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../core/providers/cpu/tensor/upsamplebase.h | 6 ++ .../core/providers/cuda/tensor/resize_impl.cu | 6 ++ .../providers/cpu/tensor/resize_op_test.cc | 73 +++++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index ded4813276b1d..e10e896a62d18 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -462,6 +462,12 @@ class UpsampleBase { }; case ROUND_PREFER_CEIL: return [](float x_original, bool) { + // for half way cases prefer ceil + // std::round rounds away from zero which is correct for positive .5 values + // but for negative .5 values (e.g., -0.5) it rounds to -1 instead of 0 (ceil) + if (x_original == static_cast(x_original) - 0.5f) { + return static_cast(std::ceil(x_original)); + } return static_cast(std::round(x_original)); }; case FLOOR: diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 6e0586e772334..d4559363ce68b 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -31,6 +31,12 @@ struct NearestPixel_ROUND_PREFER_FLOOR { struct NearestPixel_ROUND_PREFER_CEIL { __device__ __forceinline__ int operator()(float x_original, bool) const { + // for half way cases prefer ceil + // roundf rounds away from zero which is correct for positive .5 values + // but for negative .5 values (e.g., -0.5) it rounds to -1 instead of 0 (ceil) + if (x_original == static_cast(x_original) - 0.5f) { + return static_cast(_Ceil(x_original)); + } return static_cast(roundf(x_original)); } }; diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 3129476b1b505..7de16b00dafe3 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -1431,6 +1431,79 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); } +// Test round_prefer_ceil with half_pixel coordinate transformation. +// Exercises non-integer scale (26->64) where round_prefer_ceil selects +// source pixels at fractional boundaries. +TEST(ResizeOpTest, ResizeOpNearestUpSample_RoundPreferCeil_HalfPixel) { + OpTester test("Resize", 13); + + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 1.0f, 64.0f / 26.0f}; + + test.AddAttribute("mode", "nearest"); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + test.AddAttribute("nearest_mode", "round_prefer_ceil"); + + constexpr int64_t N = 1, C = 1, H = 1, W = 26; + std::vector X(26); + for (int i = 0; i < 26; i++) X[i] = static_cast(i); + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales); + + std::vector Y = { + 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, + 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f, 5.0f, 6.0f, + 6.0f, 7.0f, 7.0f, 7.0f, 8.0f, 8.0f, 9.0f, 9.0f, + 9.0f, 10.0f, 10.0f, 11.0f, 11.0f, 11.0f, 12.0f, 12.0f, + 13.0f, 13.0f, 14.0f, 14.0f, 14.0f, 15.0f, 15.0f, 16.0f, + 16.0f, 16.0f, 17.0f, 17.0f, 18.0f, 18.0f, 18.0f, 19.0f, + 19.0f, 20.0f, 20.0f, 20.0f, 21.0f, 21.0f, 22.0f, 22.0f, + 22.0f, 23.0f, 23.0f, 24.0f, 24.0f, 24.0f, 25.0f, 25.0f}; + + test.AddOutput("Y", {N, C, H, 64}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); +} + +// Test round_prefer_ceil with half_pixel for a small upsample (2x2 -> 7x8). +// Verifies that at positive .5 boundaries, ceiling is preferred. +TEST(ResizeOpTest, ResizeOpNearestUpSample_RoundPreferCeil_HalfPixel_2x2to7x8) { + OpTester test("Resize", 13); + + std::vector roi{}; + std::vector scales{}; + std::vector sizes{1, 1, 7, 8}; + + test.AddAttribute("mode", "nearest"); + test.AddAttribute("coordinate_transformation_mode", "half_pixel"); + test.AddAttribute("nearest_mode", "round_prefer_ceil"); + + constexpr int64_t N = 1, C = 1, H = 2, W = 2; + std::vector X = {1.0f, 2.0f, 3.0f, 4.0f}; + + test.AddInput("X", {N, C, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("", {0}, scales); + test.AddInput("sizes", {4}, sizes); + + // half_pixel: x_orig = (x_resized + 0.5) / scale - 0.5 + // H scale = 7/2 = 3.5, W scale = 8/2 = 4.0 + // H coords: i=0: -0.357, i=1: -0.071, i=2: 0.214, i=3: 0.5, i=4: 0.786, i=5: 1.071, i=6: 1.357 + // round_prefer_ceil at 0.5 -> ceil(0.5) = 1 + // W coords: i=0: -0.375, i=1: -0.125, i=2: 0.125, i=3: 0.375, i=4: 0.625, i=5: 0.875, i=6: 1.125, i=7: 1.375 + std::vector Y = {1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f, + 3.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 4.0f}; + + test.AddOutput("Y", {N, C, sizes[2], sizes[3]}, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", ExcludeTrtOnA100()); +} + TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { From a2373233602a66a6748f11d27c0502f5e5000487 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 5 May 2026 13:54:33 -0700 Subject: [PATCH 18/34] Add CUDA plugin EP Python package pipeline (#28299) ### Description Add Python wheel packaging support for the CUDA plugin EP, following the WebGPU plugin EP packaging pattern from #28226. Changes include: - Add `plugin-ep-cuda/python` packaging sources for the `onnxruntime-ep-cuda` wheel. - Add helper APIs to locate/register the CUDA plugin EP shared library. - Add Linux and Windows x64 Python package jobs that consume the CUDA plugin binary artifacts. - Extend plugin package version setup to emit a PEP 440-compatible `PluginPythonPackageVersion`. - Add a Linux Docker helper script to build the CUDA plugin Python wheel inside the manylinux CUDA image. ### Validation - Parsed touched Azure pipeline YAML files with PyYAML. - Ran Python syntax checks for the new package helper and wheel builder. ### Notes The Linux Python package job is limited to x64 for now, matching the existing x64 plugin artifact packaging flow. --------- Signed-off-by: Jonathan Clohessy Signed-off-by: bfilipek Signed-off-by: dependabot[bot] Signed-off-by: Christian Bourjau Co-authored-by: Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Xiaoxi Han Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Dmitri Smirnov Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Ankit Maheshkar Co-authored-by: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Co-authored-by: Jaswanth Gannamaneni Co-authored-by: Klimenko, Mikhail Co-authored-by: Vishnudas Thaniel S Co-authored-by: n1harika Co-authored-by: TejalKhade28 Co-authored-by: Preetha Veeramalai Co-authored-by: liang Co-authored-by: Javier Martinez Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: sfatimar Co-authored-by: Garth Long Co-authored-by: Eric Crawford Co-authored-by: derdeljan-msft Co-authored-by: Jonathan Clohessy Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Co-authored-by: Christopher Warrington Co-authored-by: Ishwar Raut Co-authored-by: Gaurav Garg Co-authored-by: Xinpeng Dou <15529241576@163.com> Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: adrastogi Co-authored-by: Aditya Rastogi Co-authored-by: qti-hungjuiw Co-authored-by: qti-yuduo Co-authored-by: Pradeep Sakhamoori Co-authored-by: Adam Pocock Co-authored-by: Changming Sun Co-authored-by: mingyue <131847423+mingyueliuh@users.noreply.github.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Susanta Bhattacharjee Co-authored-by: jatinwadhwa921 Co-authored-by: Jozef Wludzik Co-authored-by: Bartlomiej Filipek Co-authored-by: Kotomi-Du Co-authored-by: Rajeev Sekar Co-authored-by: Mayuresh M Varerkar Co-authored-by: Mikhail Dvoretckii Co-authored-by: bopeng1234 Co-authored-by: fs-eire <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Wenqin Yang Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: xieofxie Co-authored-by: hualxie Co-authored-by: Jiajia Qin Co-authored-by: Joshua Lochner Co-authored-by: Christian Bourjau Co-authored-by: Xiaofei Han Co-authored-by: chunghow-qti Co-authored-by: Guenther Schmuelling Co-authored-by: Jiawei Shao Co-authored-by: czekun Co-authored-by: Ryan Metcalfe Co-authored-by: Jaskaran Singh Nagi Co-authored-by: ai-fw-intg Co-authored-by: Rajeev Sekar Co-authored-by: RajeevSekar <117911837+RajeevSekar@users.noreply.github.com> Co-authored-by: Nazanin Beheshti Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION | 1 + plugin-ep-cuda/README.md | 30 +++ plugin-ep-cuda/python/README.md | 23 +++ plugin-ep-cuda/python/build_wheel.py | 168 +++++++++++++++++ .../python/onnxruntime_ep_cuda/README.md | 17 ++ .../python/onnxruntime_ep_cuda/__init__.py | 38 ++++ plugin-ep-cuda/python/pyproject.toml.in | 20 ++ .../python/requirements-build-wheel.txt | 4 + plugin-ep-cuda/python/setup.py | 21 +++ .../python/test/test_cuda_plugin_ep.py | 171 ++++++++++++++++++ .../azure-pipelines/plugin-cuda-pipeline.yml | 2 + .../plugin-cuda-test-pipeline.yml | 98 ++++++++++ .../stages/plugin-cuda-packaging-stage.yml | 10 +- .../stages/plugin-linux-cuda-stage.yml | 81 ++++++++- .../stages/plugin-linux-cuda-test-stage.yml | 74 ++++++++ .../stages/plugin-win-cuda-stage.yml | 93 ++++++++-- .../stages/plugin-win-cuda-test-stage.yml | 72 ++++++++ .../stages/plugin-win-webgpu-stage.yml | 2 + .../set-plugin-build-variables-step.yml | 103 +---------- .../github/linux/build_cuda_plugin_package.sh | 1 + .../linux/build_cuda_plugin_python_package.sh | 50 +++++ .../docker/Dockerfile.manylinux2_28_cuda | 4 +- .../scripts/manylinux/install_centos.sh | 6 +- tools/ci_build/set_plugin_build_variables.py | 125 +++++++++++++ 24 files changed, 1093 insertions(+), 121 deletions(-) create mode 100644 plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION create mode 100644 plugin-ep-cuda/README.md create mode 100644 plugin-ep-cuda/python/README.md create mode 100644 plugin-ep-cuda/python/build_wheel.py create mode 100644 plugin-ep-cuda/python/onnxruntime_ep_cuda/README.md create mode 100644 plugin-ep-cuda/python/onnxruntime_ep_cuda/__init__.py create mode 100644 plugin-ep-cuda/python/pyproject.toml.in create mode 100644 plugin-ep-cuda/python/requirements-build-wheel.txt create mode 100644 plugin-ep-cuda/python/setup.py create mode 100644 plugin-ep-cuda/python/test/test_cuda_plugin_ep.py create mode 100644 tools/ci_build/github/azure-pipelines/plugin-cuda-test-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-test-stage.yml create mode 100644 tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-test-stage.yml create mode 100755 tools/ci_build/github/linux/build_cuda_plugin_python_package.sh create mode 100644 tools/ci_build/set_plugin_build_variables.py diff --git a/plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION b/plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION new file mode 100644 index 0000000000000..bc584045a3db0 --- /dev/null +++ b/plugin-ep-cuda/MIN_ONNXRUNTIME_VERSION @@ -0,0 +1 @@ +1.26.0 \ No newline at end of file diff --git a/plugin-ep-cuda/README.md b/plugin-ep-cuda/README.md new file mode 100644 index 0000000000000..0dc8c32904820 --- /dev/null +++ b/plugin-ep-cuda/README.md @@ -0,0 +1,30 @@ +# CUDA Plugin Execution Provider + +Packaging sources for the ONNX Runtime CUDA plugin Execution Provider (EP), distributed as a standalone artifact that +plugs into an existing ONNX Runtime installation rather than being built into the main `onnxruntime` binary. + +For more information about plugin EPs, see the documentation +[here](https://onnxruntime.ai/docs/execution-providers/plugin-ep-libraries/). + +## Contents + +- [`MIN_ONNXRUNTIME_VERSION`](MIN_ONNXRUNTIME_VERSION) - Minimum compatible ONNX Runtime version for the Python package. +- [`python/`](python/) - Sources and build script for the `onnxruntime-ep-cuda12`/`onnxruntime-ep-cuda13` Python wheels. + +## Usage + +Install the CUDA-family-specific Python distribution, then register the plugin EP at runtime. The package names are +`onnxruntime-ep-cuda12` for CUDA 12.x builds and `onnxruntime-ep-cuda13` for CUDA 13.x builds. Both distributions expose +the same Python import module, `onnxruntime_ep_cuda`. + +```python +import onnxruntime as ort +import onnxruntime_ep_cuda as cuda_ep + +ort.register_execution_provider_library(cuda_ep.get_ep_name(), cuda_ep.get_library_path()) + +devices = [d for d in ort.get_ep_devices() if d.ep_name == cuda_ep.get_ep_name()] +sess_options = ort.SessionOptions() +sess_options.add_provider_for_devices(devices, {}) +session = ort.InferenceSession("model.onnx", sess_options=sess_options) +``` diff --git a/plugin-ep-cuda/python/README.md b/plugin-ep-cuda/python/README.md new file mode 100644 index 0000000000000..5edf67540f5d0 --- /dev/null +++ b/plugin-ep-cuda/python/README.md @@ -0,0 +1,23 @@ +# CUDA Plugin EP Python Package + +This directory contains the packaging source for the CUDA plugin EP Python packages: + +- `onnxruntime-ep-cuda12` for CUDA 12.x builds +- `onnxruntime-ep-cuda13` for CUDA 13.x builds + +Both distributions install the same import module, `onnxruntime_ep_cuda`. + +## Building the wheel + +Wheels are built via `build_wheel.py`. Running `pip install` or `pip wheel` directly against this directory is not +supported because the source tree contains `pyproject.toml.in` instead of a concrete `pyproject.toml`. + +```bash +python build_wheel.py \ + --binary_dir \ + --version \ + --package_name \ + --output_dir +``` + +The script combines pre-built CUDA plugin EP binaries with the package source to produce a platform-specific wheel. diff --git a/plugin-ep-cuda/python/build_wheel.py b/plugin-ep-cuda/python/build_wheel.py new file mode 100644 index 0000000000000..a709fd06d3904 --- /dev/null +++ b/plugin-ep-cuda/python/build_wheel.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +"""Build a wheel for the onnxruntime-ep-cuda12 or onnxruntime-ep-cuda13 package.""" + +import argparse +import platform +import re +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent +MIN_ONNXRUNTIME_VERSION_FILE = SCRIPT_DIR.parent / "MIN_ONNXRUNTIME_VERSION" + +_TEMPLATE_VARIABLE_PATTERN = re.compile(r"@(\w+)@") +BINARY_PATTERNS = [ + "onnxruntime_providers_cuda_plugin.dll", + "libonnxruntime_providers_cuda_plugin.so", +] +AUDITWHEEL_EXCLUDE = [ + "libcuda.so.1", + "libcublas.so.12", + "libcublas.so.13", + "libcublasLt.so.12", + "libcublasLt.so.13", + "libcudart.so.12", + "libcudart.so.13", + "libcudnn.so.9", + "libcufft.so.11", + "libcufft.so.12", + "libnvJitLink.so.12", + "libnvJitLink.so.13", + "libnvrtc.so.12", + "libnvrtc.so.13", + "libnvrtc-builtins.so.12", + "libnvrtc-builtins.so.13", +] + + +def gen_file_from_template(template_file: Path, output_file: Path, variable_substitutions: dict[str, str]) -> None: + content = template_file.read_text(encoding="utf-8") + variables_in_file: set[str] = set() + + def replace(match: re.Match[str]) -> str: + name = match.group(1) + variables_in_file.add(name) + return variable_substitutions.get(name, match.group(0)) + + content = _TEMPLATE_VARIABLE_PATTERN.sub(replace, content) + if variables_in_file != variable_substitutions.keys(): + provided = set(variable_substitutions.keys()) + raise ValueError( + f"Template variables and substitution keys do not match for {template_file}. " + f"Only in template: {sorted(variables_in_file - provided)}. " + f"Only in substitutions: {sorted(provided - variables_in_file)}." + ) + + output_file.write_text(content, encoding="utf-8") + + +def prepare_staging_dir(staging_dir: Path, binary_dir: Path, version: str, package_name: str) -> None: + staging_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(SCRIPT_DIR / "setup.py", staging_dir / "setup.py") + shutil.copytree(SCRIPT_DIR / "onnxruntime_ep_cuda", staging_dir / "onnxruntime_ep_cuda") + + package_dir = staging_dir / "onnxruntime_ep_cuda" + copied = [] + for pattern in BINARY_PATTERNS: + for src in binary_dir.glob(pattern): + dst = package_dir / src.name + print(f"Copying {src} -> {dst}") + shutil.copy2(src, dst) + copied.append(dst) + if not copied: + raise FileNotFoundError(f"No plugin binaries found in {binary_dir}. Looked for: {BINARY_PATTERNS}") + + min_ort_version = MIN_ONNXRUNTIME_VERSION_FILE.read_text(encoding="utf-8").strip() + if not min_ort_version: + raise ValueError(f"{MIN_ONNXRUNTIME_VERSION_FILE} is empty") + + gen_file_from_template( + SCRIPT_DIR / "pyproject.toml.in", + staging_dir / "pyproject.toml", + {"package_name": package_name, "version": version, "min_onnxruntime_version": min_ort_version}, + ) + + +def build_wheel(source_dir: Path, wheel_dir: Path) -> None: + wheel_dir.mkdir(parents=True, exist_ok=True) + cmd = [ + sys.executable, + "-m", + "pip", + "wheel", + str(source_dir), + "--wheel-dir", + str(wheel_dir), + "--no-deps", + "--no-build-isolation", + ] + print(f"Running: {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def auditwheel_repair(wheel_dir: Path, wheel_name_prefix: str) -> None: + if platform.system() != "Linux": + return + + original_wheels = list(wheel_dir.glob(f"{wheel_name_prefix}-*.whl")) + if not original_wheels: + raise RuntimeError(f"No wheel found in {wheel_dir} to repair with auditwheel") + + with tempfile.TemporaryDirectory() as repaired_dir_name: + repaired_dir = Path(repaired_dir_name) + for wheel in original_wheels: + cmd = [sys.executable, "-m", "auditwheel", "repair", str(wheel), "--wheel-dir", str(repaired_dir)] + for lib in AUDITWHEEL_EXCLUDE: + cmd.extend(["--exclude", lib]) + print(f"Running: {' '.join(cmd)}") + subprocess.check_call(cmd) + wheel.unlink() + + repaired_wheels = list(repaired_dir.glob("*.whl")) + if not repaired_wheels: + raise RuntimeError(f"auditwheel repair produced no wheels in {repaired_dir}") + + for repaired_wheel in repaired_wheels: + repaired_wheel.replace(wheel_dir / repaired_wheel.name) + + +def collect_wheels(wheel_dir: Path, output_dir: Path, wheel_name_prefix: str) -> None: + wheels = list(wheel_dir.glob(f"{wheel_name_prefix}-*.whl")) + if not wheels: + raise RuntimeError("No wheel was produced") + output_dir.mkdir(parents=True, exist_ok=True) + for wheel in wheels: + dest = output_dir / wheel.name + shutil.copy2(wheel, dest) + print(f"Built wheel: {dest}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build onnxruntime-ep-cuda wheel") + parser.add_argument("--binary_dir", required=True, type=Path, help="Directory containing built plugin EP binaries") + parser.add_argument("--version", required=True, help="Package version string (PEP 440 format)") + parser.add_argument("--package_name", required=True, help="Python distribution name to write into pyproject.toml") + parser.add_argument("--output_dir", required=True, type=Path, help="Directory to place the built wheel") + args = parser.parse_args() + + if not args.binary_dir.is_dir(): + raise FileNotFoundError(f"Binary directory does not exist: {args.binary_dir}") + if not re.fullmatch(r"[A-Za-z0-9][A-Za-z0-9._-]*", args.package_name): + raise ValueError(f"Invalid package name: {args.package_name}") + + wheel_name_prefix = args.package_name.replace("-", "_").replace(".", "_") + + with tempfile.TemporaryDirectory(prefix="ort_cuda_wheel_") as tmp: + staging_dir = Path(tmp) / "package" + wheel_dir = Path(tmp) / "wheels" + prepare_staging_dir(staging_dir, args.binary_dir, args.version, args.package_name) + build_wheel(staging_dir, wheel_dir) + auditwheel_repair(wheel_dir, wheel_name_prefix) + collect_wheels(wheel_dir, args.output_dir, wheel_name_prefix) + + +if __name__ == "__main__": + main() diff --git a/plugin-ep-cuda/python/onnxruntime_ep_cuda/README.md b/plugin-ep-cuda/python/onnxruntime_ep_cuda/README.md new file mode 100644 index 0000000000000..167ff50801d87 --- /dev/null +++ b/plugin-ep-cuda/python/onnxruntime_ep_cuda/README.md @@ -0,0 +1,17 @@ +# ONNX Runtime CUDA Plugin Execution Provider + +CUDA Execution Provider plugin for ONNX Runtime. Install alongside `onnxruntime` to enable the CUDA plugin EP. + +## Usage + +```python +import onnxruntime as ort +import onnxruntime_ep_cuda as cuda_ep + +ort.register_execution_provider_library(cuda_ep.get_ep_name(), cuda_ep.get_library_path()) + +devices = [d for d in ort.get_ep_devices() if d.ep_name == cuda_ep.get_ep_name()] +sess_options = ort.SessionOptions() +sess_options.add_provider_for_devices(devices, {}) +session = ort.InferenceSession("model.onnx", sess_options=sess_options) +``` \ No newline at end of file diff --git a/plugin-ep-cuda/python/onnxruntime_ep_cuda/__init__.py b/plugin-ep-cuda/python/onnxruntime_ep_cuda/__init__.py new file mode 100644 index 0000000000000..8e0e29c810433 --- /dev/null +++ b/plugin-ep-cuda/python/onnxruntime_ep_cuda/__init__.py @@ -0,0 +1,38 @@ +"""ONNX Runtime CUDA Plugin Execution Provider Python package.""" + +from __future__ import annotations + +import pathlib + +__all__ = [ + "get_ep_name", + "get_ep_names", + "get_library_path", +] + +_module_dir = pathlib.Path(__file__).parent + + +def get_library_path() -> str: + """Return the path to the CUDA plugin EP shared library.""" + candidate_paths = [ + _module_dir / "onnxruntime_providers_cuda_plugin.dll", + _module_dir / "libonnxruntime_providers_cuda_plugin.so", + ] + paths = [p for p in candidate_paths if p.is_file()] + if len(paths) != 1: + raise RuntimeError( + f"Expected exactly one CUDA plugin EP library in {_module_dir}, " + f"found {len(paths)}: {[p.name for p in paths]}" + ) + return str(paths[0]) + + +def get_ep_name() -> str: + """Return the CUDA plugin Execution Provider name.""" + return "CudaPluginExecutionProvider" + + +def get_ep_names() -> list[str]: + """Return a list of EP names provided by this plugin.""" + return [get_ep_name()] diff --git a/plugin-ep-cuda/python/pyproject.toml.in b/plugin-ep-cuda/python/pyproject.toml.in new file mode 100644 index 0000000000000..dfca37783d7ed --- /dev/null +++ b/plugin-ep-cuda/python/pyproject.toml.in @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "@package_name@" +version = "@version@" +description = "ONNX Runtime CUDA Plugin Execution Provider" +readme = "onnxruntime_ep_cuda/README.md" +license = {text = "MIT"} +requires-python = ">=3.11" +dependencies = [ + "onnxruntime>=@min_onnxruntime_version@", +] + +[tool.setuptools.packages.find] +include = ["onnxruntime_ep_cuda*"] + +[tool.setuptools.package-data] +onnxruntime_ep_cuda = ["*.dll", "*.so", "*.so.*"] diff --git a/plugin-ep-cuda/python/requirements-build-wheel.txt b/plugin-ep-cuda/python/requirements-build-wheel.txt new file mode 100644 index 0000000000000..eb72ee3b67d27 --- /dev/null +++ b/plugin-ep-cuda/python/requirements-build-wheel.txt @@ -0,0 +1,4 @@ +setuptools>=68.0 +wheel +auditwheel; sys_platform == "linux" +patchelf; sys_platform == "linux" \ No newline at end of file diff --git a/plugin-ep-cuda/python/setup.py b/plugin-ep-cuda/python/setup.py new file mode 100644 index 0000000000000..7b1968dbc847a --- /dev/null +++ b/plugin-ep-cuda/python/setup.py @@ -0,0 +1,21 @@ +"""Minimal setup.py to produce a platform-specific wheel.""" + +from setuptools import setup +from setuptools.dist import Distribution +from wheel.bdist_wheel import bdist_wheel + + +class PlatformBdistWheel(bdist_wheel): + """Override wheel tags to py3-none-{platform}.""" + + def get_tag(self): + _, _, plat = super().get_tag() + return "py3", "none", plat + + +class BinaryDistribution(Distribution): + def has_ext_modules(self): + return True + + +setup(distclass=BinaryDistribution, cmdclass={"bdist_wheel": PlatformBdistWheel}) diff --git a/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py b/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py new file mode 100644 index 0000000000000..885faeb56daf6 --- /dev/null +++ b/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""Smoke test for the onnxruntime-ep-cuda Python package. + +Tests: +1. Package import and library path resolution +2. EP registration with ONNX Runtime +3. Device discovery +4. Inference with a simple Mul model (requires CUDA-capable hardware) + +The inference test is skipped gracefully if no CUDA device is available +(e.g., on CPU-only build agents). +""" + +import os +import platform +import sys +import tempfile +import traceback +from pathlib import Path + +import numpy as np +import onnx + +import onnxruntime as ort + +VERBOSE = os.environ.get("ORT_TEST_VERBOSE", "").strip().lower() in ("1", "true", "yes") + + +def debug_print(*args, **kwargs): + """Print only when ORT_TEST_VERBOSE is set to a truthy value.""" + if VERBOSE: + print(*args, **kwargs) + + +def create_mul_model(output_dir: Path) -> Path: + """Create a simple Mul model in `output_dir` and return the path to the saved .onnx file.""" + x = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, [2, 3]) + y = onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, [2, 3]) + z = onnx.helper.make_tensor_value_info("z", onnx.TensorProto.FLOAT, [2, 3]) + + mul_node = onnx.helper.make_node("Mul", inputs=["x", "y"], outputs=["z"]) + + graph = onnx.helper.make_graph([mul_node], "mul_graph", [x, y], [z]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 13)]) + model.ir_version = 7 + + model_path = output_dir / "mul.onnx" + onnx.save(model, str(model_path)) + return model_path + + +def print_environment_info(): + """Print diagnostic information about the runtime environment.""" + print(f" Python: {sys.version}") + print(f" Platform: {platform.platform()}") + print(f" Architecture: {platform.machine()}") + print(f" ONNX Runtime version: {ort.__version__}") + print(f" ONNX Runtime location: {ort.__file__}") + print(f" Available providers (built-in): {ort.get_available_providers()}") + # Print relevant environment variables + for var in sorted(os.environ): + lower = var.lower() + if any(kw in lower for kw in ["onnx", "ort", "gpu", "cuda", "nv", "path", "ld_library"]): + print(f" ENV {var}={os.environ[var]}") + + +def test_import_and_library_path(): + """Test that the package imports and the library path is valid.""" + import onnxruntime_ep_cuda as cuda_ep # noqa: PLC0415 + + debug_print(f" Package location: {cuda_ep.__file__}") + pkg_dir = Path(cuda_ep.__file__).parent + debug_print(f" Package directory contents: {sorted(p.name for p in pkg_dir.iterdir())}") + + lib_path = cuda_ep.get_library_path() + assert Path(lib_path).is_file(), f"Library path does not exist: {lib_path}" + print(f"OK: Library path: {lib_path}") + + ep_name = cuda_ep.get_ep_name() + assert ep_name == "CudaPluginExecutionProvider", f"Unexpected EP name: {ep_name}" + print(f"OK: EP name: {ep_name}") + + ep_names = cuda_ep.get_ep_names() + assert ep_names == ["CudaPluginExecutionProvider"], f"Unexpected EP names: {ep_names}" + print(f"OK: EP names: {ep_names}") + + +def test_registration_and_inference(): + """Test EP registration, device discovery, and inference.""" + import onnxruntime_ep_cuda as cuda_ep # noqa: PLC0415 + + lib_path = cuda_ep.get_library_path() + ep_name = cuda_ep.get_ep_name() + registration_name = "cuda_plugin_test" + + # Register the plugin EP + debug_print(f" Registering library: {lib_path}") + debug_print(f" Library file size: {Path(lib_path).stat().st_size} bytes") + ort.register_execution_provider_library(registration_name, lib_path) + print(f"OK: Registered EP library as '{registration_name}'") + + try: + # Discover devices + all_devices = ort.get_ep_devices() + debug_print(f" All devices: {[(d.ep_name, getattr(d, 'device_id', 'N/A')) for d in all_devices]}") + cuda_devices = [d for d in all_devices if d.ep_name == ep_name] + print(f"Found {len(cuda_devices)} CUDA plugin device(s)") + + if not cuda_devices: + print("SKIP: No CUDA plugin devices available — skipping inference test") + return + + # Create session with CUDA plugin EP + sess_options = ort.SessionOptions() + sess_options.add_session_config_entry("session.disable_cpu_ep_fallback", "1") + sess_options.add_provider_for_devices(cuda_devices, {}) + assert sess_options.has_providers(), "SessionOptions should have providers after add_provider_for_devices" + print("OK: Session options configured with CUDA plugin EP") + + with tempfile.TemporaryDirectory() as model_dir: + model_path = create_mul_model(Path(model_dir)) + debug_print(f" Model path: {model_path}") + sess = ort.InferenceSession(str(model_path), sess_options=sess_options) + debug_print(f" Session providers: {sess.get_providers()}") + print("OK: InferenceSession created") + + # Run inference + x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + y = np.array([[2.0, 3.0, 4.0], [5.0, 6.0, 7.0]], dtype=np.float32) + expected = x * y + + outputs = sess.run(None, {"x": x, "y": y}) + result = outputs[0] + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) + print("OK: Inference result matches expected output") + + del sess + print("OK: Session released") + + finally: + ort.unregister_execution_provider_library(registration_name) + print(f"OK: Unregistered EP library '{registration_name}'") + + +def main(): + print("=== CUDA Plugin EP Python Package Test ===") + + if VERBOSE: + # Set verbose ORT logging so ORT internals are visible in CI logs + ort.set_default_logger_severity(0) + + print("\n--- Environment ---") + print_environment_info() + + print("\n--- Test 1: Import and library path ---") + test_import_and_library_path() + + print("\n--- Test 2: Registration and inference ---") + test_registration_and_inference() + + print("\n=== All tests passed ===") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"\nFAILED: {e}", file=sys.stderr) + traceback.print_exc() + sys.exit(1) diff --git a/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml b/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml index 5e183e057aee9..4385446c6b741 100644 --- a/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/plugin-cuda-pipeline.yml @@ -130,9 +130,11 @@ extends: version_file: ${{ variables.epVersionFile }} cmake_build_type: ${{ parameters.cmake_build_type }} ${{ if eq(parameters.cuda_version, '12.8') }}: + python_package_name: 'onnxruntime-ep-cuda12' docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' cmake_cuda_archs: '52-real;61-real;75-real;86-real;89-real;90-virtual' ${{ if eq(parameters.cuda_version, '13.0') }}: + python_package_name: 'onnxruntime-ep-cuda13' docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' docker_base_image_aarch64: 'onnxruntimebuildcache.azurecr.io/public/azureml/onnxruntime_build_cuda13_aarch64_almalinux9_gcc14:20260323.1' cmake_cuda_archs: '75-real;80-real;86-real;89-real;90-real;100-real;120-real;120-virtual' diff --git a/tools/ci_build/github/azure-pipelines/plugin-cuda-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/plugin-cuda-test-pipeline.yml new file mode 100644 index 0000000000000..83273c8870408 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/plugin-cuda-test-pipeline.yml @@ -0,0 +1,98 @@ +# This pipeline runs tests against artifacts produced by the CUDA +# plugin packaging pipeline. It is resource-triggered on successful +# packaging runs and can also be queued manually against any prior +# packaging run. +# +# Split from the packaging pipeline so the test side can be iterated +# on without rebuilding the CUDA plugin from source. + +trigger: none + +variables: +- name: DisableDockerDetector + value: true +- name: skipNugetSecurityAnalysis + value: true +- name: Codeql.SkipTaskAutoInjection + value: true + +resources: + pipelines: + - pipeline: build + source: 'CUDA Plugin EP Packaging Pipeline' + trigger: true + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + +parameters: +- name: test_windows_x64 + displayName: 'Test Windows x64' + type: boolean + default: true + +- name: test_linux_x64 + displayName: 'Test Linux x64' + type: boolean + default: true + +- name: cuda_version + displayName: 'CUDA Version' + type: string + default: '12.8' + values: + - '12.8' + - '13.0' + +extends: + # The pipeline extends the 1ES PT which will inject SDL and compliance + # tasks. Uses "Official" to stay consistent with the companion + # CUDA plugin packaging pipeline. + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + settings: + networkIsolationPolicy: Permissive + sdl: + # No top-level `pool:` is declared for this pipeline (each stage + # template pins its own pool), so source analysis needs an + # explicit pool. + sourceAnalysisPool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + componentgovernance: + ignoreDirectories: '$(Build.Repository.LocalPath)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/benchmark,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11/tests,$(Build.Repository.LocalPath)/cmake/external/onnxruntime-extensions,$(Build.Repository.LocalPath)/js/react_native/e2e/node_modules,$(Build.Repository.LocalPath)/js/node_modules,$(Build.Repository.LocalPath)/onnxruntime-inference-examples,$(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/benchmark,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11/tests,$(Build.SourcesDirectory)/cmake/external/onnxruntime-extensions,$(Build.SourcesDirectory)/js/react_native/e2e/node_modules,$(Build.SourcesDirectory)/js/node_modules,$(Build.SourcesDirectory)/onnxruntime-inference-examples,$(Build.BinariesDirectory)' + alertWarningLevel: High + failOnAlert: false + verbosity: Normal + timeout: 3600 + tsa: + enabled: true + # codeSignValidation is intentionally omitted: this pipeline does + # not produce or publish binaries. The wheels it consumes were + # already signed-and-validated by the packaging pipeline. + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + codeql: + compiled: + enabled: false + justificationForDisabling: 'CodeQL is taking nearly 6 hours resulting in timeouts in our production pipelines' + + stages: + # Windows x64 + - ${{ if eq(parameters.test_windows_x64, true) }}: + - template: stages/plugin-win-cuda-test-stage.yml + parameters: + cuda_version: ${{ parameters.cuda_version }} + + # Linux x64 + - ${{ if eq(parameters.test_linux_x64, true) }}: + - template: stages/plugin-linux-cuda-test-stage.yml + parameters: + cuda_version: ${{ parameters.cuda_version }} + ${{ if eq(parameters.cuda_version, '12.8') }}: + docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + ${{ if eq(parameters.cuda_version, '13.0') }}: + docker_base_image: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251107.1' diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml index d18ede02d8891..3ce33f87ae276 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-cuda-packaging-stage.yml @@ -66,6 +66,11 @@ parameters: displayName: 'Docker Python executable path' default: '/opt/python/cp312-cp312/bin/python3.12' +- name: python_package_name + type: string + displayName: 'Python package distribution name' + default: '' + stages: # Windows x64 - ${{ if eq(parameters.build_windows_x64, true) }}: @@ -75,6 +80,7 @@ stages: cmake_cuda_archs: ${{ parameters.cmake_cuda_archs }} package_version: ${{ parameters.package_type }} version_file: ${{ parameters.version_file }} + python_package_name: ${{ parameters.python_package_name }} cmake_build_type: ${{ parameters.cmake_build_type }} # Linux x64 @@ -83,11 +89,12 @@ stages: parameters: stage_name: Linux_plugin_cuda_x64 arch: 'x64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' cuda_version: ${{ parameters.cuda_version }} cmake_cuda_archs: ${{ parameters.cmake_cuda_archs }} package_version: ${{ parameters.package_type }} version_file: ${{ parameters.version_file }} + python_package_name: ${{ parameters.python_package_name }} cmake_build_type: ${{ parameters.cmake_build_type }} docker_base_image: ${{ parameters.docker_base_image }} python_version: ${{ parameters.python_version }} @@ -105,6 +112,7 @@ stages: cmake_cuda_archs: ${{ parameters.cmake_cuda_archs }} package_version: ${{ parameters.package_type }} version_file: ${{ parameters.version_file }} + python_package_name: ${{ parameters.python_package_name }} cmake_build_type: ${{ parameters.cmake_build_type }} docker_base_image: ${{ parameters.docker_base_image_aarch64 }} python_version: ${{ parameters.python_version }} diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml index 4c6c60e176a50..8992df31bf848 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-stage.yml @@ -12,7 +12,7 @@ parameters: - name: machine_pool type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: package_version type: string @@ -38,6 +38,9 @@ parameters: type: string default: '12.8' +- name: python_package_name + type: string + - name: cmake_cuda_archs type: string default: '52-real;61-real;75-real;86-real;89-real;90-virtual' @@ -81,20 +84,25 @@ stages: - template: ../templates/set-nightly-build-option-variable-step.yml + - template: ../templates/setup-feeds-and-python-steps.yml + parameters: + architecture: ${{ parameters.arch }} + - template: ../templates/set-plugin-build-variables-step.yml parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} - - template: ../templates/setup-feeds-and-python-steps.yml - parameters: - architecture: ${{ parameters.arch }} - - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg BUILD_UID=$( id -u ) --build-arg TRT_VERSION=" + DockerBuildArgs: >- + --network=host + --secret id=PIP_INDEX_URL + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + --build-arg TRT_VERSION= + --build-arg BUILD_UID=$( id -u ) Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}pluginbuild${{ parameters.arch }} - script: >- @@ -147,6 +155,65 @@ stages: command: publish publishDirectory: '$(Build.BinariesDirectory)/universal_package' vstsFeedPublish: 'PublicPackages/ORT-Nightly' - vstsFeedPackagePublish: 'onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-linux-${{ parameters.arch }}' + vstsFeedPackagePublish: "onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-linux-${{ parameters.arch }}" versionOption: custom versionPublish: '$(PluginUniversalPackageVersion)' + + - ${{ if eq(parameters.arch, 'x64') }}: + - job: ${{ parameters.stage_name }}_Python_Package + dependsOn: ${{ parameters.stage_name }} + timeoutInMinutes: 60 + workspace: + clean: all + pool: + name: ${{ parameters.machine_pool }} + os: linux + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory)/python + artifactName: cuda_plugin_python_linux_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + variables: + - template: ../templates/common-variables.yml + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/set-nightly-build-option-variable-step.yml + + - template: ../templates/setup-feeds-and-python-steps.yml + parameters: + architecture: ${{ parameters.arch }} + + - template: ../templates/set-plugin-build-variables-step.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --network=host + --secret id=PIP_INDEX_URL + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + --build-arg TRT_VERSION= + --build-arg BUILD_UID=$( id -u ) + Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}pluginbuild${{ parameters.arch }} + + - task: DownloadPipelineArtifact@2 + displayName: 'Download plugin build artifacts' + inputs: + artifactName: ${{ parameters.artifact_name }} + targetPath: '$(Build.BinariesDirectory)/plugin_artifacts' + + - script: | + set -e -x + $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh \ + -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}pluginbuild${{ parameters.arch }} \ + -p ${{ parameters.docker_python_exe_path }} \ + -v "$(PluginPythonPackageVersion)" \ + -n "${{ parameters.python_package_name }}" + displayName: 'Build Python wheel' diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-test-stage.yml new file mode 100644 index 0000000000000..391f002465d96 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-cuda-test-stage.yml @@ -0,0 +1,74 @@ +parameters: +- name: machine_pool + type: string + default: 'onnxruntime-Ubuntu2404-AMD-GPU-A10' + +- name: cuda_version + type: string + default: '12.8' + +- name: docker_base_image + type: string + default: 'onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' + +stages: +- stage: Linux_plugin_cuda_x64_Test + dependsOn: [] + jobs: + - job: Linux_plugin_cuda_x64_Python_Test + timeoutInMinutes: 60 + workspace: + clean: all + pool: + name: ${{ parameters.machine_pool }} + os: linux + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-feeds-and-python-steps.yml + + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION= --build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}plugintestx64 + + # Download the Python wheel produced by the packaging pipeline run that + # triggered this pipeline (or that was selected at queue time). + - download: build + artifact: cuda_plugin_python_linux_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + displayName: 'Download Python wheel' + + - script: | + set -e -x + mkdir -p "$(Build.BinariesDirectory)/python_wheel" + cp -R "$(Pipeline.Workspace)/build/cuda_plugin_python_linux_x64_cuda${{ replace(parameters.cuda_version, '.', '') }}/"* "$(Build.BinariesDirectory)/python_wheel/" + displayName: 'Stage Python wheel for test container' + + - script: | + set -e -x + docker run --rm --gpus all \ + --volume "$(Build.SourcesDirectory):/onnxruntime_src" \ + --volume "$(Build.BinariesDirectory):/build" \ + --env "PIP_INDEX_URL=${PIP_INDEX_URL}" \ + --env "NVIDIA_VISIBLE_DEVICES=all" \ + --env "ORT_TEST_VERBOSE=$(System.Debug)" \ + onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}plugintestx64 \ + /bin/bash -c " + set -e -x + python3 -m venv /build/test_venv + source /build/test_venv/bin/activate + python3 -m pip install onnxruntime onnx numpy + wheel=\$(find /build/python_wheel -name 'onnxruntime*ep*cuda*.whl' | head -1) + if [ -z \"\$wheel\" ]; then + echo 'ERROR: No matching wheel found in /build/python_wheel' + ls -la /build/python_wheel/ + exit 1 + fi + python3 -m pip install \"\$wheel\" + python3 -u /onnxruntime_src/plugin-ep-cuda/python/test/test_cuda_plugin_ep.py + " + displayName: 'Install and test Python package' diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml index 68968a0be86e3..7eac3842514a5 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-stage.yml @@ -23,6 +23,9 @@ parameters: type: string default: '12.8' +- name: python_package_name + type: string + - name: cmake_cuda_archs type: string default: '52-real;61-real;75-real;86-real;89-real;90-virtual' @@ -74,6 +77,7 @@ stages: parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} + python_command: python - script: | python -m pip install -r "$(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt" @@ -81,28 +85,35 @@ stages: env: TMPDIR: "$(Agent.TempDirectory)" - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.cuda_version }}" "$(Agent.TempDirectory)" + - task: AzureCLI@2 displayName: 'Download CUDA SDK v${{ parameters.cuda_version }}' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.cuda_version }} "$(Agent.TempDirectory)" + # Since CUDA 13.0, CUDA DLLs are in bin\x64 folder instead of bin folder for Windows. - powershell: | Write-Host "Adding CUDA to PATH" - Write-Host "CUDA Path: $(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin" Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin\x64" Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\extras\CUPTI\lib64" displayName: 'Add CUDA to PATH' # Download cuDNN separately for CUDA 13.0 - ${{ if eq(parameters.cuda_version, '13.0') }}: - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cudnn_9/9.14.0.64_cuda13" "$(Agent.TempDirectory)" + - task: AzureCLI@2 displayName: 'Download cuDNN for CUDA 13.0' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cudnn_sdk/9.14.0.64_cuda13 "$(Agent.TempDirectory)" # CUDA 12.x build (no separate cuDNN) - ${{ if ne(parameters.cuda_version, '13.0') }}: @@ -130,9 +141,6 @@ stages: --cmake_extra_defines $(PluginEpVersionDefine) $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 # CUDA 13.0 build (separate cuDNN folder) - ${{ if eq(parameters.cuda_version, '13.0') }}: @@ -161,9 +169,6 @@ stages: --cmake_extra_defines $(PluginEpVersionDefine) $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 # Esrp signing - template: ../templates/win-esrp-dll.yml @@ -209,6 +214,58 @@ stages: command: publish publishDirectory: '$(Build.BinariesDirectory)\universal_package' vstsFeedPublish: 'PublicPackages/ORT-Nightly' - vstsFeedPackagePublish: 'onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-win-x64' + vstsFeedPackagePublish: "onnxruntime-plugin-ep-cuda${{ replace(parameters.cuda_version, '.', '') }}-win-x64" versionOption: custom versionPublish: '$(PluginUniversalPackageVersion)' + + - job: Win_plugin_cuda_x64_Python_Package + dependsOn: Win_plugin_cuda_x64_Build + timeoutInMinutes: 30 + workspace: + clean: all + pool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + templateContext: + outputs: + - output: pipelineArtifact + targetPath: '$(Build.ArtifactStagingDirectory)\python' + artifactName: cuda_plugin_python_win_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + variables: + - template: ../templates/common-variables.yml + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + python_version: ${{ parameters.python_version }} + + - template: ../templates/set-nightly-build-option-variable-step.yml + + - template: ../templates/set-plugin-build-variables-step.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + python_command: python + + - task: DownloadPipelineArtifact@2 + displayName: 'Download plugin build artifacts' + inputs: + artifactName: cuda_plugin_win_x64 + targetPath: '$(Build.BinariesDirectory)\plugin_artifacts' + + - task: PowerShell@2 + displayName: 'Build Python wheel' + inputs: + targetType: inline + pwsh: true + script: | + python -m pip install -r "$(Build.SourcesDirectory)\plugin-ep-cuda\python\requirements-build-wheel.txt" + python "$(Build.SourcesDirectory)\plugin-ep-cuda\python\build_wheel.py" ` + --binary_dir "$(Build.BinariesDirectory)\plugin_artifacts\bin" ` + --version "$(PluginPythonPackageVersion)" ` + --package_name "${{ parameters.python_package_name }}" ` + --output_dir "$(Build.ArtifactStagingDirectory)\python" diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-test-stage.yml new file mode 100644 index 0000000000000..813737ed8ecef --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-cuda-test-stage.yml @@ -0,0 +1,72 @@ +parameters: +- name: cuda_version + type: string + default: '12.8' + +stages: +- stage: Win_plugin_cuda_x64_Test + dependsOn: [] + jobs: + - job: Win_plugin_cuda_x64_Python_Test + timeoutInMinutes: 60 + workspace: + clean: all + pool: + name: onnxruntime-Win2022-GPU-A10 + os: windows + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-feeds-and-python-steps.yml + + # Download the Python wheel produced by the packaging pipeline run that + # triggered this pipeline (or that was selected at queue time). + - download: build + artifact: cuda_plugin_python_win_x64_cuda${{ replace(parameters.cuda_version, '.', '') }} + displayName: 'Download Python wheel' + + - task: AzureCLI@2 + displayName: 'Download CUDA SDK v${{ parameters.cuda_version }}' + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.cuda_version }} "$(Agent.TempDirectory)" + + - powershell: | + Write-Host "Adding CUDA to PATH" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\bin\x64" + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.cuda_version }}\extras\CUPTI\lib64" + displayName: 'Add CUDA to PATH' + + - task: PowerShell@2 + displayName: 'Install and test Python package' + env: + ORT_TEST_VERBOSE: $(System.Debug) + inputs: + targetType: inline + pwsh: true + script: | + $ErrorActionPreference = 'Stop' + + echo "creating test_venv" + python -m venv "$(Build.BinariesDirectory)\test_venv" + + echo "activating test_venv" + & "$(Build.BinariesDirectory)\test_venv\Scripts\Activate.ps1" + + echo "installing onnxruntime onnx numpy" + python -m pip install onnxruntime onnx numpy + + $wheelDir = "$(Pipeline.Workspace)\build\cuda_plugin_python_win_x64_cuda${{ replace(parameters.cuda_version, '.', '') }}" + $wheel = (Get-ChildItem "$wheelDir\onnxruntime*ep*cuda*.whl")[0] + echo "installing ${wheel}" + python -m pip install $wheel.FullName + + echo "running test_cuda_plugin_ep.py" + python -u "$(Build.SourcesDirectory)\plugin-ep-cuda\python\test\test_cuda_plugin_ep.py" diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml index acad674143961..c774d42776afb 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml @@ -86,6 +86,7 @@ stages: parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} + python_command: python - script: | python -m pip install -r "$(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt" @@ -267,6 +268,7 @@ stages: parameters: package_version: ${{ parameters.package_version }} version_file: ${{ parameters.version_file }} + python_command: python - task: DownloadPipelineArtifact@2 displayName: 'Download plugin build artifacts' diff --git a/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml b/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml index fcc388ef7e342..cc0e766f49ddb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml +++ b/tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml @@ -17,99 +17,16 @@ parameters: - name: version_file type: string +# Python executable used to run the helper script. Default is python3 which works on +# Linux (including aarch64) and macOS. Windows callers must override with 'python'. +- name: python_command + type: string + default: 'python3' + steps: # Set package version string -- task: PythonScript@0 +# Use 'script' (not 'bash') so this works on both Linux and Windows agents. +# On Linux aarch64 agents UsePythonVersion@0 is unavailable, so we call the configured +# Python executable directly instead of using PythonScript@0. +- script: ${{ parameters.python_command }} "$(Build.SourcesDirectory)/tools/ci_build/set_plugin_build_variables.py" "${{ parameters.package_version }}" "${{ parameters.version_file }}" displayName: 'Set plugin package version string' - inputs: - scriptSource: inline - script: | - import os - import re - import subprocess - import sys - - package_version = "${{ parameters.package_version }}" - version_file_rel = "${{ parameters.version_file }}" - - if not version_file_rel: - print("##vso[task.logissue type=error]version_file parameter is empty.") - sys.exit(1) - - src_root = os.environ.get("BUILD_SOURCESDIRECTORY", "") - version_file = os.path.join(src_root, version_file_rel) - if not os.path.isfile(version_file): - print("##vso[task.logissue type=error]Cannot find version number file at: {}".format(version_file)) - sys.exit(1) - - with open(version_file, "r") as f: - original_ver = f.read().strip() - - if not original_ver: - print("##vso[task.logissue type=error]VERSION_NUMBER is empty.") - sys.exit(1) - - print("Original version: {}".format(original_ver)) - print("Package version type: {}".format(package_version)) - - if package_version == "release": - version_string = original_ver - universal_version = original_ver - python_version = original_ver - - elif package_version == "RC": - # RC versioning is not yet implemented. Fail the build to prevent publishing - # an ambiguous version without an RC number. - print("##vso[task.logissue type=error]RC versioning is not yet implemented. Use 'dev' or 'release' instead.") - sys.exit(1) - - elif package_version == "dev": - try: - commit_sha = subprocess.check_output( - ["git", "rev-parse", "--short=8", "HEAD"], - cwd=src_root - ).decode("utf-8").strip() - date_str = subprocess.check_output( - ["git", "show", "-s", "--format=%cd", "--date=format:%Y%m%d", "HEAD"], - cwd=src_root - ).decode("utf-8").strip() - except Exception as e: - print("##vso[task.logissue type=error]Failed to get git info: {}".format(e)) - sys.exit(1) - version_string = "{}-dev.{}+{}".format(original_ver, date_str, commit_sha) - # Prefix the SHA with "commit-" so the pre-release identifier always contains a - # non-digit. Otherwise, an all-numeric short SHA with a leading zero (e.g. "01234567") - # would violate SemVer 2.0.0's rule against leading zeros in numeric identifiers. - universal_version = "{}-dev.{}.commit-{}".format(original_ver, date_str, commit_sha) - python_version = "{}.dev{}".format(original_ver, date_str) - - else: - print("##vso[task.logissue type=error]Unknown package_version '{}'. Must be 'release', 'RC', or 'dev'.".format(package_version)) - sys.exit(1) - - print("Plugin package version string: {}".format(version_string)) - print("Plugin universal package version string: {}".format(universal_version)) - print("Plugin Python package version string: {}".format(python_version)) - - # Validate semver 2.0.0 format - semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" - if not re.match(semver_pattern, version_string): - print("##vso[task.logissue type=error]Version string '{}' is not valid semver 2.0.0.".format(version_string)) - sys.exit(1) - - # Validate universal version (SemVer 2.0.0, without build metadata) - universal_semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?$" - if not re.match(universal_semver_pattern, universal_version): - print("##vso[task.logissue type=error]Universal version string '{}' is not valid semver 2.0.0 (without build metadata).".format(universal_version)) - sys.exit(1) - - # Validate Python version (PEP 440) - pep440_pattern = r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$" - if not re.match(pep440_pattern, python_version): - print("##vso[task.logissue type=error]Python version string '{}' is not valid PEP 440.".format(python_version)) - sys.exit(1) - - print("##vso[task.setvariable variable=PluginPackageVersion]{}".format(version_string)) - print("##vso[task.setvariable variable=PluginUniversalPackageVersion]{}".format(universal_version)) - print("##vso[task.setvariable variable=PluginPythonPackageVersion]{}".format(python_version)) - print("##vso[task.setvariable variable=PluginEpVersionDefine]onnxruntime_PLUGIN_EP_VERSION={}".format(version_string)) diff --git a/tools/ci_build/github/linux/build_cuda_plugin_package.sh b/tools/ci_build/github/linux/build_cuda_plugin_package.sh index 7c89fc6b892df..1b4e897b05389 100755 --- a/tools/ci_build/github/linux/build_cuda_plugin_package.sh +++ b/tools/ci_build/github/linux/build_cuda_plugin_package.sh @@ -39,6 +39,7 @@ docker run --rm \ --volume "${BUILD_BINARIESDIRECTORY}:/build" \ --volume /data/models:/build/models:ro \ --volume "${HOME}/.onnx:/home/onnxruntimedev/.onnx" \ + -e PIP_INDEX_URL \ -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ -e SYSTEM_COLLECTIONURI \ diff --git a/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh b/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh new file mode 100755 index 0000000000000..171d5b2facea8 --- /dev/null +++ b/tools/ci_build/github/linux/build_cuda_plugin_python_package.sh @@ -0,0 +1,50 @@ +#!/bin/bash +set -e -x + +DOCKER_IMAGE="onnxruntimecuda128pluginbuildx64" +PYTHON_EXE="/opt/python/cp312-cp312/bin/python3.12" +VERSION="" +PACKAGE_NAME="" + +while getopts "i:p:v:n:" parameter_Option +do case "${parameter_Option}" +in +i) DOCKER_IMAGE=${OPTARG};; +p) PYTHON_EXE=${OPTARG};; +v) VERSION=${OPTARG};; +n) PACKAGE_NAME=${OPTARG};; +*) echo "Usage: $0 -i -p -v -n " + exit 1;; +esac +done + +if [ -z "$VERSION" ]; then + echo "ERROR: Version is required. Use -v " + exit 1 +fi + +if [ -z "$PACKAGE_NAME" ]; then + echo "ERROR: Package name is required. Use -n " + exit 1 +fi + +PYTHON_BIN_DIR=$(dirname "${PYTHON_EXE}") + +docker run --rm \ + --volume "${BUILD_SOURCESDIRECTORY}:/onnxruntime_src" \ + --volume "${BUILD_BINARIESDIRECTORY}:/build" \ + --volume "${BUILD_ARTIFACTSTAGINGDIRECTORY}:/staging" \ + --env PIP_INDEX_URL \ + --env "ORT_CUDA_PLUGIN_EP_VERSION=${VERSION}" \ + --env "ORT_CUDA_PLUGIN_EP_PACKAGE_NAME=${PACKAGE_NAME}" \ + "$DOCKER_IMAGE" \ + /bin/bash -c ' + set -e -x + PATH="'"${PYTHON_BIN_DIR}"'":$PATH + "'"${PYTHON_EXE}"'" -m pip install -r /onnxruntime_src/plugin-ep-cuda/python/requirements-build-wheel.txt + "'"${PYTHON_EXE}"'" /onnxruntime_src/plugin-ep-cuda/python/build_wheel.py \ + --binary_dir /build/plugin_artifacts/bin \ + --version "$ORT_CUDA_PLUGIN_EP_VERSION" \ + --package_name "$ORT_CUDA_PLUGIN_EP_PACKAGE_NAME" \ + --output_dir /staging/python + ' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index ee7869f50bee5..3296fcc77f10f 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -35,7 +35,9 @@ fi ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts +RUN --mount=type=secret,id=PIP_INDEX_URL,required=false \ + if [ -f /run/secrets/PIP_INDEX_URL ]; then export PIP_INDEX_URL=$(cat /run/secrets/PIP_INDEX_URL); fi && \ + cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh index 093da075be13c..e4b05f8a0d1d7 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh @@ -6,7 +6,11 @@ os_major_version=$(tr -dc '0-9.' + +Where: + package_version: 'release', 'RC', or 'dev' + version_file_rel: path relative to BUILD_SOURCESDIRECTORY of the VERSION_NUMBER file +""" + +import os +import re +import subprocess +import sys + + +def main(): + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + + package_version = sys.argv[1] + version_file_rel = sys.argv[2] + + if not version_file_rel: + print("##vso[task.logissue type=error]version_file parameter is empty.") + sys.exit(1) + + src_root = os.environ.get("BUILD_SOURCESDIRECTORY", "") + version_file = os.path.join(src_root, version_file_rel) + if not os.path.isfile(version_file): + print(f"##vso[task.logissue type=error]Cannot find version number file at: {version_file}") + sys.exit(1) + + with open(version_file) as f: + original_ver = f.read().strip() + + if not original_ver: + print("##vso[task.logissue type=error]VERSION_NUMBER is empty.") + sys.exit(1) + + print(f"Original version: {original_ver}") + print(f"Package version type: {package_version}") + + if package_version == "release": + version_string = original_ver + universal_version = original_ver + python_version = original_ver + + elif package_version == "RC": + # RC versioning is not yet implemented. Fail the build to prevent publishing + # an ambiguous version without an RC number. + print("##vso[task.logissue type=error]RC versioning is not yet implemented. Use 'dev' or 'release' instead.") + sys.exit(1) + + elif package_version == "dev": + try: + commit_sha = ( + subprocess.check_output( + ["git", "rev-parse", "--short=8", "HEAD"], + cwd=src_root, + ) + .decode("utf-8") + .strip() + ) + date_str = ( + subprocess.check_output( + ["git", "show", "-s", "--format=%cd", "--date=format:%Y%m%d", "HEAD"], + cwd=src_root, + ) + .decode("utf-8") + .strip() + ) + except Exception as e: + print(f"##vso[task.logissue type=error]Failed to get git info: {e}") + sys.exit(1) + version_string = f"{original_ver}-dev.{date_str}+{commit_sha}" + # Prefix the SHA with "commit-" so the pre-release identifier always contains a + # non-digit. Otherwise, an all-numeric short SHA with a leading zero (e.g. "01234567") + # would violate SemVer 2.0.0's rule against leading zeros in numeric identifiers. + universal_version = f"{original_ver}-dev.{date_str}.commit-{commit_sha}" + python_version = f"{original_ver}.dev{date_str}" + + else: + print( + f"##vso[task.logissue type=error]Unknown package_version '{package_version}'. Must be 'release', 'RC', or 'dev'." + ) + sys.exit(1) + + print(f"Plugin package version string: {version_string}") + print(f"Plugin universal package version string: {universal_version}") + print(f"Plugin Python package version string: {python_version}") + + # Validate semver 2.0.0 format + semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" + if not re.match(semver_pattern, version_string): + print(f"##vso[task.logissue type=error]Version string '{version_string}' is not valid semver 2.0.0.") + sys.exit(1) + + # Validate universal version (SemVer 2.0.0, without build metadata) + universal_semver_pattern = r"^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?$" + if not re.match(universal_semver_pattern, universal_version): + print( + f"##vso[task.logissue type=error]Universal version string '{universal_version}' is not valid semver 2.0.0 (without build metadata)." + ) + sys.exit(1) + + # Validate Python version (PEP 440) + pep440_pattern = r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$" + if not re.match(pep440_pattern, python_version): + print(f"##vso[task.logissue type=error]Python version string '{python_version}' is not valid PEP 440.") + sys.exit(1) + + print(f"##vso[task.setvariable variable=PluginPackageVersion]{version_string}") + print(f"##vso[task.setvariable variable=PluginUniversalPackageVersion]{universal_version}") + print(f"##vso[task.setvariable variable=PluginPythonPackageVersion]{python_version}") + print(f"##vso[task.setvariable variable=PluginEpVersionDefine]onnxruntime_PLUGIN_EP_VERSION={version_string}") + + +if __name__ == "__main__": + main() From 513b9bf9b8ac2ac8514d74b23ac38a67ba78f4a4 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 5 May 2026 14:03:14 -0700 Subject: [PATCH 19/34] [WebGPU plugin EP] NuGet packaging (#28313) ### Description This pull request adds C#/.NET (NuGet) packaging support for the WebGPU plugin Execution Provider, including all necessary project files, documentation, and helper code. It introduces a new NuGet package (`Microsoft.ML.OnnxRuntime.EP.WebGpu`), updates the main plugin documentation to reflect C# support, and provides detailed instructions and code samples for building, packaging, and using the provider in .NET applications. It also has some minor changes for the existing Python packaging setup. The most important changes are: **C#/.NET Packaging Infrastructure:** - Added the `Microsoft.ML.OnnxRuntime.EP.WebGpu` project (`.csproj`) for NuGet packaging, including metadata, dependency management, and logic to read the minimum ONNX Runtime version from a shared file. Native binaries are included per platform, and the README is bundled in the package. - Introduced the `WebGpuEp.cs` helper class to resolve the native library path and EP name at runtime, simplifying registration and usage in .NET. **Documentation:** - Added a detailed `README.md` for the C# package, including usage instructions, supported platforms, and example code for registering and using the WebGPU EP in .NET. - Added a top-level `csharp/README.md` with instructions for building, packaging, and testing the NuGet package, as well as information on CI integration and native binary requirements. ### Motivation and Context Create WebGPU plugin EP NuGet package. --- plugin-ep-webgpu/README.md | 13 +- .../Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj | 91 +++++ .../README.md | 42 +++ .../WebGpuEp.cs | 112 ++++++ plugin-ep-webgpu/csharp/README.md | 140 ++++++++ plugin-ep-webgpu/csharp/pack_nuget.py | 336 ++++++++++++++++++ .../csharp/test/WebGpuEpNuGetTest/Program.cs | 82 +++++ .../WebGpuEpNuGetTest.csproj | 34 ++ .../WebGpuEpNuGetTest/generate_mul_model.py | 25 ++ .../csharp/test/WebGpuEpNuGetTest/mul.onnx | 16 + plugin-ep-webgpu/python/README.md | 12 +- plugin-ep-webgpu/python/build_wheel.py | 1 + .../plugin-webgpu-pipeline.yml | 2 +- .../stages/plugin-linux-webgpu-test-stage.yml | 2 +- .../stages/plugin-mac-webgpu-test-stage.yml | 2 +- .../plugin-webgpu-nuget-packaging-stage.yml | 186 ++++++++++ .../stages/plugin-webgpu-packaging-stage.yml | 23 +- .../stages/plugin-win-webgpu-stage.yml | 1 + .../stages/plugin-win-webgpu-test-stage.yml | 108 +++++- 19 files changed, 1188 insertions(+), 40 deletions(-) create mode 100644 plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj create mode 100644 plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md create mode 100644 plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/WebGpuEp.cs create mode 100644 plugin-ep-webgpu/csharp/README.md create mode 100644 plugin-ep-webgpu/csharp/pack_nuget.py create mode 100644 plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/Program.cs create mode 100644 plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj create mode 100644 plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/generate_mul_model.py create mode 100644 plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/mul.onnx create mode 100644 tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml diff --git a/plugin-ep-webgpu/README.md b/plugin-ep-webgpu/README.md index dd874f8af1c3b..889fef10ae5e1 100644 --- a/plugin-ep-webgpu/README.md +++ b/plugin-ep-webgpu/README.md @@ -10,8 +10,12 @@ For more information about plugin EPs, see the documentation [here](https://onnx - [`VERSION_NUMBER`](VERSION_NUMBER) — Base plugin EP version consumed by the CI pipeline. The pipeline derives the final package version (release, dev) from this via [`tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml`](../tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml). +- [`MIN_ONNXRUNTIME_VERSION`](MIN_ONNXRUNTIME_VERSION) — Minimum compatible core `onnxruntime` version. Single source + of truth shared by all packages built from this directory. - [`python/`](python/) — Sources and build script for the `onnxruntime-ep-webgpu` Python wheel. See [`python/README.md`](python/README.md) for build and test instructions. +- [`csharp/`](csharp/) — Sources and packaging script for the `Microsoft.ML.OnnxRuntime.EP.WebGpu` NuGet package. See + [`csharp/README.md`](csharp/README.md) for build and test instructions. ## How it fits together @@ -19,6 +23,7 @@ The plugin EP is built as a shared library (`onnxruntime_providers_webgpu.{dll,s build (`--use_webgpu shared_lib`). The resulting binaries are then packaged into: - A Python wheel (`onnxruntime-ep-webgpu`), built from [`python/`](python/). +- A NuGet package (`Microsoft.ML.OnnxRuntime.EP.WebGpu`), built from [`csharp/`](csharp/). - A universal package published to the internal ORT-Nightly feed for Windows (x64 / arm64), Linux x64, and macOS arm64. @@ -29,7 +34,7 @@ and post-build smoke tests run in the companion `WebGPU Plugin EP Test Pipeline` ## Usage -Once installed, the plugin EP is registered at runtime: +Once installed, the plugin EP is registered at runtime. Example in Python: ```python import onnxruntime as ort @@ -43,5 +48,7 @@ sess_options.add_provider_for_devices(devices, {}) session = ort.InferenceSession("model.onnx", sess_options=sess_options) ``` -See [`python/onnxruntime_ep_webgpu/README.md`](python/onnxruntime_ep_webgpu/README.md) for the user-facing package -documentation (this README is bundled into the wheel). +See the user-facing package READMEs (bundled into the published packages) for full per-language usage: + +- Python: [`python/onnxruntime_ep_webgpu/README.md`](python/onnxruntime_ep_webgpu/README.md) +- C# / .NET: [`csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md`](csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md) diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj new file mode 100644 index 0000000000000..94be6bec6ea46 --- /dev/null +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj @@ -0,0 +1,91 @@ + + + + netstandard2.0 + latest + enable + + + Microsoft.ML.OnnxRuntime.EP.WebGpu + + 0.0.0-dev + Microsoft + Microsoft + ONNX Runtime WebGPU Plugin Execution Provider. + README.md + ONNX;ONNX Runtime;Machine Learning;AI;Deep Learning;WebGPU + + + MIT + https://github.com/microsoft/onnxruntime + git + © Microsoft Corporation. All rights reserved. + + + true + snupkg + + + + + $(MSBuildThisFileDirectory)..\..\MIN_ONNXRUNTIME_VERSION + $([System.IO.File]::ReadAllText('$(OnnxRuntimeMinVersionFile)').Trim()) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md new file mode 100644 index 0000000000000..f4a717b8836d5 --- /dev/null +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md @@ -0,0 +1,42 @@ +## Microsoft.ML.OnnxRuntime.EP.WebGpu + +WebGPU plugin Execution Provider for [ONNX Runtime](https://github.com/microsoft/onnxruntime). + +### Usage + +```csharp +// Note: Error handling is omitted for brevity. + +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.EP.WebGpu; + +// Register the WebGPU EP plugin library +var env = OrtEnv.Instance(); +env.RegisterExecutionProviderLibrary("webgpu_ep", WebGpuEp.GetLibraryPath()); + +// Find the WebGPU EP device +OrtEpDevice? webGpuDevice = null; +foreach (var d in env.GetEpDevices()) +{ + if (d.EpName == WebGpuEp.GetEpName()) + { + webGpuDevice = d; + break; + } +} + +// Create a session with the WebGPU EP +using var sessionOptions = new SessionOptions(); +sessionOptions.AppendExecutionProvider(env, new[] { webGpuDevice }, new Dictionary()); + +using var session = new InferenceSession("model.onnx", sessionOptions); +``` + +### Supported Platforms + +| Runtime Identifier | Native Library | +|---|---| +| win-x64 | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| win-arm64 | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| linux-x64 | `libonnxruntime_providers_webgpu.so` | +| osx-arm64 | `libonnxruntime_providers_webgpu.dylib` | diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/WebGpuEp.cs b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/WebGpuEp.cs new file mode 100644 index 0000000000000..2a5ec106aad0d --- /dev/null +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/WebGpuEp.cs @@ -0,0 +1,112 @@ +using System; +using System.IO; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.OnnxRuntime.EP.WebGpu +{ + /// + /// Provides helper methods to locate the WebGPU plugin EP native library + /// and retrieve the EP name for registration with ONNX Runtime. + /// + public static class WebGpuEp + { + /// + /// Returns the path to the WebGPU plugin EP native library contained by this package. + /// Can be passed to OrtEnv.RegisterExecutionProviderLibrary(). + /// + /// Full path to the EP native library. + /// If the native library file does not exist at the expected path. + public static string GetLibraryPath() + { + string rootDir = GetNativeDirectory(); + string rid = GetRuntimeIdentifier(); + string libraryName = GetLibraryName(); + + // Probe the standard NuGet runtimes//native/ layout first, then fall back + // to the base directory for single-file/published layouts where native assets + // can land directly next to the managed assembly. + string[] candidates = + { + Path.Combine(rootDir, "runtimes", rid, "native", libraryName), + Path.Combine(rootDir, libraryName), + }; + + foreach (var candidate in candidates) + { + if (File.Exists(candidate)) + return Path.GetFullPath(candidate); + } + + throw new FileNotFoundException( + $"Did not find WebGPU EP library file. Probed: {string.Join(", ", candidates)}"); + } + + /// + /// Returns the names of the EPs created by the WebGPU plugin EP library. + /// Can be used to select an OrtEpDevice from those returned by OrtEnv.GetEpDevices(). + /// + /// Array of EP names. + public static string[] GetEpNames() + { + return new[] { GetEpName() }; + } + + /// + /// Returns the name of the one EP supported by this plugin EP library. + /// Convenience method for plugin EP packages that expose a single EP. + /// + /// The EP name string. + public static string GetEpName() + { + return "WebGpuExecutionProvider"; + } + + private static string GetNativeDirectory() + { + var assemblyDir = Path.GetDirectoryName(typeof(WebGpuEp).Assembly.Location); + + if (!string.IsNullOrEmpty(assemblyDir) && Directory.Exists(assemblyDir)) + return assemblyDir; + + return AppContext.BaseDirectory; + } + + private static string GetRuntimeIdentifier() + { + return GetOSTag() + "-" + GetArchTag(); + } + + private static string GetLibraryName() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return "onnxruntime_providers_webgpu.dll"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + return "libonnxruntime_providers_webgpu.so"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + return "libonnxruntime_providers_webgpu.dylib"; + + throw new PlatformNotSupportedException( + $"WebGPU plugin EP does not support OS platform: {RuntimeInformation.OSDescription}"); + } + + private static string GetOSTag() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) return "win"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) return "linux"; + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) return "osx"; + throw new PlatformNotSupportedException( + $"WebGPU plugin EP does not support OS platform: {RuntimeInformation.OSDescription}"); + } + + private static string GetArchTag() + { + return RuntimeInformation.ProcessArchitecture switch + { + Architecture.X64 => "x64", + Architecture.Arm64 => "arm64", + _ => throw new PlatformNotSupportedException( + $"WebGPU plugin EP does not support process architecture: {RuntimeInformation.ProcessArchitecture}"), + }; + } + } +} diff --git a/plugin-ep-webgpu/csharp/README.md b/plugin-ep-webgpu/csharp/README.md new file mode 100644 index 0000000000000..7a2b2041e364f --- /dev/null +++ b/plugin-ep-webgpu/csharp/README.md @@ -0,0 +1,140 @@ +# WebGPU Plugin EP — NuGet Packaging + +This directory contains the C# NuGet package project and test app for the WebGPU plugin Execution Provider. + +## Directory Structure + +``` +csharp/ +├── pack_nuget.py # Helper script to build the NuGet package +├── Microsoft.ML.OnnxRuntime.EP.WebGpu/ +│ ├── Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj # NuGet package project (netstandard2.0) +│ ├── WebGpuEp.cs # Helper class for native library resolution +│ └── README.md # Package readme (shipped inside .nupkg) +└── test/ + └── WebGpuEpNuGetTest/ + ├── WebGpuEpNuGetTest.csproj # Test console app (net8.0) + ├── Program.cs # Registers EP, runs inference, validates output + ├── mul.onnx # Test model (element-wise multiply) + └── generate_mul_model.py # Script to regenerate mul.onnx +``` + +## Prerequisites + +- .NET SDK 8.0 or later +- A built WebGPU plugin EP shared library + +## Building the NuGet Package + +Use `pack_nuget.py` to stage native binaries and run `dotnet pack`. The script copies everything into a staging +directory before building — the source tree is never modified. By default, an auto-cleaned temporary directory is used; +pass `--staging-dir` to use an explicit one (required when running with `--build-only` or `--pack-only`). + +At least one binary directory (or `--artifacts-dir` with matching subdirectories) must be provided. Platforms without +a binary directory are skipped. Run `python pack_nuget.py --help` for the full list of options and their defaults. + +### Pack with a local build (single platform) + +```powershell +cd plugin-ep-webgpu/csharp + +python pack_nuget.py --version 0.1.0-dev ` + --binary-dir-win-x64 +``` + +### Pack multiple platforms + +Each `--binary-dir-*` points at the directory containing that platform's already-built native binaries. In practice +the four binaries are produced on different machines and combined in CI; locally you'd typically only set the one(s) +you have available. + +```powershell +python pack_nuget.py --version 0.1.0-dev ` + --binary-dir-win-x64 ` + --binary-dir-win-arm64 ` + --binary-dir-linux-x64 ` + --binary-dir-macos-arm64 +``` + +## Versioning + +The package version is supplied to `pack_nuget.py` via `--version`. In the packaging pipeline, the release or +pre-release version is derived from [`plugin-ep-webgpu/VERSION_NUMBER`](../VERSION_NUMBER). + +## Inspecting the Package + +The `.nupkg` is a ZIP file. To verify its contents: + +```powershell +Expand-Archive nuget_output/Microsoft.ML.OnnxRuntime.EP.WebGpu.0.1.0-dev.nupkg ` + -DestinationPath nuget_output/inspect -Force + +Get-ChildItem nuget_output/inspect -Recurse | Select-Object FullName +``` + +Expected layout inside the package: + +``` +lib/netstandard2.0/Microsoft.ML.OnnxRuntime.EP.WebGpu.dll +runtimes/win-x64/native/onnxruntime_providers_webgpu.dll +runtimes/win-x64/native/dxil.dll +runtimes/win-x64/native/dxcompiler.dll +runtimes/win-arm64/native/... +runtimes/linux-x64/native/libonnxruntime_providers_webgpu.so +runtimes/osx-arm64/native/libonnxruntime_providers_webgpu.dylib +``` + +## Testing the Package + +The test app registers the WebGPU EP, creates a session, runs a simple Mul model, and validates the output. + +```powershell +# Point the test project's nuget.config at the pack output +$localFeed = (Resolve-Path nuget_output).Path +@" + + + + + + + + +"@ | Set-Content test/WebGpuEpNuGetTest/nuget.config + +# Build and run +dotnet run --project test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj --configuration Release +``` + +A successful run prints `PASSED: All outputs match expected values.` and exits with code 0. + +## Regenerating the Test Model + +```bash +python test/WebGpuEpNuGetTest/generate_mul_model.py +``` + +Requires the `onnx` Python package. + +## CI Pipeline + +The NuGet packaging is integrated into the WebGPU plugin pipeline: + +- **Pipeline:** `tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml` +- **Packaging stage:** `tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml` + +The CI stage downloads build artifacts from all enabled platform stages, invokes `pack_nuget.py`, ESRP-signs the +package, and runs the test app on a GPU agent. + +## Native Binaries Per Platform + +| RID | Required Files | +|---|---| +| `win-x64` | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| `win-arm64` | `onnxruntime_providers_webgpu.dll`, `dxil.dll`, `dxcompiler.dll` | +| `linux-x64` | `libonnxruntime_providers_webgpu.so` | +| `osx-arm64` | `libonnxruntime_providers_webgpu.dylib` | + +On Windows, `dxil.dll` and `dxcompiler.dll` are the DirectX Shader Compiler binaries downloaded from the +[DXC GitHub releases](https://github.com/microsoft/DirectXShaderCompiler/releases). The CI pipeline handles this +automatically. diff --git a/plugin-ep-webgpu/csharp/pack_nuget.py b/plugin-ep-webgpu/csharp/pack_nuget.py new file mode 100644 index 0000000000000..9a29d067a4034 --- /dev/null +++ b/plugin-ep-webgpu/csharp/pack_nuget.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +"""Build the Microsoft.ML.OnnxRuntime.EP.WebGpu NuGet package. + +Stages native binaries from build artifacts into the runtimes/ layout expected +by the .csproj and runs `dotnet pack` to produce the .nupkg / .snupkg files. + +Can be invoked locally or from CI. In CI, pass --artifacts-dir to point at the +downloaded pipeline artifacts. Locally, pass individual --binary-dir-* options. + +Examples +-------- +Local: pack win-x64 only from a local build: + + python pack_nuget.py --version 0.1.0-dev \\ + --binary-dir-win-x64 ../../build/webgpu.plugin/Release/Release + +CI: pack all platforms from downloaded artifacts: + + python pack_nuget.py --version $(PluginPackageVersion) \\ + --artifacts-dir $(Build.BinariesDirectory)/artifacts \\ + --output-dir $(Build.ArtifactStagingDirectory)/nuget +""" + +from __future__ import annotations + +import argparse +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + +# Platform name -> (RID, list of native binary filenames expected in the source dir). +PLATFORMS: dict[str, tuple[str, tuple[str, ...]]] = { + "win_x64": ("win-x64", ("onnxruntime_providers_webgpu.dll", "dxil.dll", "dxcompiler.dll")), + "win_arm64": ("win-arm64", ("onnxruntime_providers_webgpu.dll", "dxil.dll", "dxcompiler.dll")), + "linux_x64": ("linux-x64", ("libonnxruntime_providers_webgpu.so",)), + "macos_arm64": ("osx-arm64", ("libonnxruntime_providers_webgpu.dylib",)), +} + +SCRIPT_DIR = Path(__file__).resolve().parent +PROJECT_DIR = SCRIPT_DIR / "Microsoft.ML.OnnxRuntime.EP.WebGpu" +CSPROJ = PROJECT_DIR / "Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj" +MIN_ORT_VERSION_FILE = SCRIPT_DIR.parent / "MIN_ONNXRUNTIME_VERSION" + + +class PackError(RuntimeError): + """Raised for any user-actionable failure during packaging.""" + + +def parse_args() -> argparse.Namespace: + def _absolute_path(value: str) -> Path: + """argparse `type` converter: parse a string as an absolute Path.""" + return Path(value).resolve() + + p = argparse.ArgumentParser( + description="Build the Microsoft.ML.OnnxRuntime.EP.WebGpu NuGet package.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--version", required=True, help="Package version (e.g. 0.1.0-dev).") + p.add_argument( + "--output-dir", + type=_absolute_path, + default=(SCRIPT_DIR / "nuget_output").resolve(), + help="Directory for the .nupkg / .snupkg output (default: ./nuget_output).", + ) + p.add_argument("--configuration", default="Release", help="Build configuration (default: Release).") + + # CI mode: a single root containing per-platform subdirectories. + p.add_argument( + "--artifacts-dir", + type=_absolute_path, + help="CI mode: root containing /bin/ subdirectories for each platform.", + ) + + # Local mode: explicit per-platform binary directories. Each takes precedence over + # --artifacts-dir for that platform. + for name in PLATFORMS: + flag = f"--binary-dir-{name.replace('_', '-')}" + p.add_argument(flag, type=_absolute_path, dest=f"binary_dir_{name}", help=f"Path to {name} native binaries.") + + p.add_argument( + "--nuget-config", type=_absolute_path, help="Optional NuGet.config passed to dotnet via --configfile." + ) + p.add_argument( + "--staging-dir", + type=_absolute_path, + help=( + "Explicit staging directory. Required with --build-only / --pack-only " + "(caller owns its lifecycle). When omitted, an auto-cleaned temporary " + "directory is used for the full build+pack flow." + ), + ) + + phase = p.add_mutually_exclusive_group() + phase.add_argument( + "--build-only", + action="store_true", + help="Stage and build the managed DLL only; skip dotnet pack. Preserves the staging dir.", + ) + phase.add_argument( + "--pack-only", + action="store_true", + help="Skip staging/build and run dotnet pack against an existing staging directory.", + ) + + p.add_argument( + "--required-platforms", + default="", + help=( + "Comma-separated list of platforms that MUST be staged successfully. " + "When omitted, the script just requires at least one platform to be staged." + ), + ) + + return p.parse_args() + + +def parse_required_platforms(value: str) -> list[str]: + names = [tok.strip() for tok in value.split(",") if tok.strip()] + invalid = [n for n in names if n not in PLATFORMS] + if invalid: + raise PackError( + f"unknown platform(s) in --required-platforms: {', '.join(invalid)}. valid: {', '.join(PLATFORMS)}." + ) + return names + + +def stage_sources(staging_dir: Path) -> None: + """Copy project sources into staging, excluding bin/obj.""" + print(f"Staging project files to {staging_dir}") + if staging_dir.exists(): + shutil.rmtree(staging_dir) + shutil.copytree( + PROJECT_DIR, + staging_dir, + ignore=shutil.ignore_patterns("bin", "obj"), + ) + + +def resolve_platform_source( + name: str, + binary_dir_override: Path | None, + artifacts_dir: Path | None, + is_required: bool, +) -> Path | None: + """Return the source dir for a platform, or None to skip.""" + if binary_dir_override is not None: + return binary_dir_override + if artifacts_dir is not None: + candidate = artifacts_dir / name / "bin" + if candidate.is_dir(): + return candidate + if is_required: + raise PackError(f"required platform '{name}' artifact directory not found: {candidate}") + if is_required: + raise PackError( + f"required platform '{name}' has no binary directory " + f"(pass --binary-dir-{name.replace('_', '-')} or --artifacts-dir)." + ) + return None + + +def stage_binaries( + staging_dir: Path, + args: argparse.Namespace, + required_platforms: list[str], +) -> None: + staged: set[str] = set() + + for name, (rid, files) in PLATFORMS.items(): + binary_dir_override: Path | None = getattr(args, f"binary_dir_{name}") + is_required = name in required_platforms + source_dir = resolve_platform_source(name, binary_dir_override, args.artifacts_dir, is_required) + if source_dir is None: + print(f"Skipping {name} (no binary directory provided)") + continue + if not source_dir.is_dir(): + raise PackError(f"binary directory does not exist: {source_dir}") + + target_dir = staging_dir / "runtimes" / rid / "native" + target_dir.mkdir(parents=True, exist_ok=True) + + print(f"Staging {name} -> runtimes/{rid}/native/") + for filename in files: + src = source_dir / filename + if not src.is_file(): + raise PackError(f"expected binary not found: {src}") + shutil.copy2(src, target_dir / filename) + print(f" {filename}") + staged.add(name) + + if required_platforms: + missing = [n for n in required_platforms if n not in staged] + if missing: + raise PackError(f"required platforms not staged: {', '.join(missing)}") + elif not staged: + raise PackError("no platform binaries were staged. Provide at least one --binary-dir-* or --artifacts-dir.") + + print() + print("Runtimes layout:") + for path in sorted((staging_dir / "runtimes").rglob("*")): + print(f" {path}") + + +def dotnet_common_args( + staged_csproj: Path, + args: argparse.Namespace, + min_ort_version_file: Path, +) -> list[str]: + common = [ + str(staged_csproj), + "--configuration", + args.configuration, + f"-p:Version={args.version}", + f"-p:OnnxRuntimeMinVersionFile={min_ort_version_file}", + ] + if args.nuget_config: + common.extend(["--configfile", str(args.nuget_config)]) + print(f"Using NuGet.config: {args.nuget_config}") + return common + + +def do_build(staged_csproj: Path, staging_dir: Path, args: argparse.Namespace, min_ort_version_file: Path) -> None: + print() + print(f"Running dotnet build (Version={args.version}, Configuration={args.configuration})...") + cmd = ["dotnet", "build", *dotnet_common_args(staged_csproj, args, min_ort_version_file)] + print("+ " + " ".join(cmd)) + subprocess.run(cmd, check=True) + + # Note: "netstandard2.0" must match in Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj. + managed_dll = staging_dir / "bin" / args.configuration / "netstandard2.0" / "Microsoft.ML.OnnxRuntime.EP.WebGpu.dll" + if not managed_dll.is_file(): + raise PackError(f"managed DLL not found after build: {managed_dll}") + print() + print(f"Built managed DLL: {managed_dll}") + print("Staging directory preserved for subsequent --pack-only invocation.") + + +def do_pack( + staged_csproj: Path, + output_dir: Path, + args: argparse.Namespace, + min_ort_version_file: Path, +) -> None: + print() + print(f"Running dotnet pack (Version={args.version}, Configuration={args.configuration})...") + pack_args = [ + "dotnet", + "pack", + *dotnet_common_args(staged_csproj, args, min_ort_version_file), + "--output", + str(output_dir), + ] + if args.pack_only: + pack_args.append("--no-build") + print("+ " + " ".join(pack_args)) + subprocess.run(pack_args, check=True) + + print() + nupkgs = sorted(output_dir.glob("*.nupkg")) + if not nupkgs: + raise PackError(f"no .nupkg files found in {output_dir}") + for pkg in nupkgs: + print(f"Produced: {pkg.name} ({pkg.stat().st_size / (1024 * 1024):.2f} MB)") + for pkg in sorted(output_dir.glob("*.snupkg")): + print(f"Produced: {pkg.name} ({pkg.stat().st_size / (1024 * 1024):.2f} MB)") + + +def run_in_staging(args: argparse.Namespace, staging_dir: Path, min_ort_version_file: Path) -> None: + staged_csproj = staging_dir / "Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj" + output_dir: Path = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + required_platforms = parse_required_platforms(args.required_platforms) + + if args.pack_only: + if not staged_csproj.is_file(): + raise PackError(f"staged project not found at {staged_csproj}. Run with --build-only first.") + print(f"Reusing existing staging directory: {staging_dir}") + else: + stage_sources(staging_dir) + stage_binaries(staging_dir, args, required_platforms) + + if args.build_only: + do_build(staged_csproj, staging_dir, args, min_ort_version_file) + return + + do_pack(staged_csproj, output_dir, args, min_ort_version_file) + + print() + print(f"Done. Output: {output_dir}") + + +def run(args: argparse.Namespace) -> None: + if not CSPROJ.is_file(): + raise PackError(f"project file not found: {CSPROJ}") + if not MIN_ORT_VERSION_FILE.is_file(): + raise PackError(f"MIN_ONNXRUNTIME_VERSION file not found: {MIN_ORT_VERSION_FILE}") + if args.nuget_config and not args.nuget_config.is_file(): + raise PackError(f"NuGet.config not found: {args.nuget_config}") + + if (args.build_only or args.pack_only) and not args.staging_dir: + raise PackError("--staging-dir is required when using --build-only or --pack-only.") + + min_ort_version_file = MIN_ORT_VERSION_FILE.resolve() + + if args.staging_dir: + staging_dir: Path = args.staging_dir + staging_dir.mkdir(parents=True, exist_ok=True) + run_in_staging(args, staging_dir, min_ort_version_file) + return + + # Full build+pack flow with no caller-managed staging dir: use a temp dir that + # is cleaned up automatically (including on exception). + with tempfile.TemporaryDirectory(prefix="webgpu_pack_") as tmp: + run_in_staging(args, Path(tmp), min_ort_version_file) + + +def main() -> int: + args = parse_args() + try: + run(args) + except PackError as e: + print(f"error: {e}", file=sys.stderr) + return 1 + except subprocess.CalledProcessError as e: + cmd_name = e.cmd[0] if e.cmd else "subprocess" + print(f"error: {cmd_name} failed with exit code {e.returncode}", file=sys.stderr) + return e.returncode or 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/Program.cs b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/Program.cs new file mode 100644 index 0000000000000..f5d1f0628c831 --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/Program.cs @@ -0,0 +1,82 @@ +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.EP.WebGpu; + +class Program +{ + static int Main() + { + string epLibPath = WebGpuEp.GetLibraryPath(); + string epRegistrationName = "webgpu_ep_registration"; + string epName = WebGpuEp.GetEpName(); + + Console.WriteLine($"WebGPU EP library path: {epLibPath}"); + + var env = OrtEnv.Instance(); + env.RegisterExecutionProviderLibrary(epRegistrationName, epLibPath); + Console.WriteLine($"Registered EP library: {epLibPath}"); + + try + { + // Find the OrtEpDevice for the WebGPU EP + OrtEpDevice? epDevice = null; + foreach (var d in env.GetEpDevices()) + { + if (string.Equals(epName, d.EpName, StringComparison.Ordinal)) + { + epDevice = d; + break; + } + } + + if (epDevice == null) + { + Console.Error.WriteLine($"ERROR: Unable to find OrtEpDevice with name '{epName}'"); + return 1; + } + Console.WriteLine($"Found OrtEpDevice for EP: {epName}"); + + // Create session with WebGPU EP + using var sessionOptions = new SessionOptions(); + sessionOptions.AppendExecutionProvider(env, new[] { epDevice }, new Dictionary()); + sessionOptions.AddSessionConfigEntry("session.disable_cpu_ep_fallback", "1"); + + string inputModelPath = Path.Combine(AppContext.BaseDirectory, "mul.onnx"); + Console.WriteLine($"Loading model: {inputModelPath}"); + + using var session = new InferenceSession(inputModelPath, sessionOptions); + + // Run model: mul(x, y) = x * y + float[] inputData = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }; + using var inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, new long[] { 2, 3 }); + var inputValues = new List { inputOrtValue, inputOrtValue }.AsReadOnly(); + var inputNames = new List { "x", "y" }.AsReadOnly(); + using var runOptions = new RunOptions(); + + using var outputs = session.Run(runOptions, inputNames, inputValues, session.OutputNames); + + float[] expected = { 1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f }; + var actual = outputs[0].GetTensorDataAsSpan().ToArray(); + + Console.WriteLine($"Input: {string.Join(", ", inputData)}"); + Console.WriteLine($"Output: {string.Join(", ", actual)}"); + Console.WriteLine($"Expected: {string.Join(", ", expected)}"); + + // Validate output + for (int i = 0; i < expected.Length; i++) + { + if (Math.Abs(actual[i] - expected[i]) > 1e-5f) + { + Console.Error.WriteLine($"ERROR: Output mismatch at index {i}: expected {expected[i]}, got {actual[i]}"); + return 1; + } + } + + Console.WriteLine("PASSED: All outputs match expected values."); + return 0; + } + finally + { + env.UnregisterExecutionProviderLibrary(epRegistrationName); + } + } +} diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj new file mode 100644 index 0000000000000..9554161b1e978 --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj @@ -0,0 +1,34 @@ + + + + Exe + net8.0 + latest + enable + enable + + *-* + + + + + + + + + + PreserveNewest + + + + + + + diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/generate_mul_model.py b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/generate_mul_model.py new file mode 100644 index 0000000000000..c64b4b7ec96bc --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/generate_mul_model.py @@ -0,0 +1,25 @@ +"""Generate a simple Mul ONNX model for testing. + +Produces mul.onnx in the same directory as this script. +The model computes z = x * y (element-wise) for float32 tensors of shape [2, 3]. +""" + +import os + +from onnx import TensorProto, checker, helper, save + +X = helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3]) +Y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3]) +Z = helper.make_tensor_value_info("z", TensorProto.FLOAT, [2, 3]) + +mul_node = helper.make_node("Mul", inputs=["x", "y"], outputs=["z"]) + +graph = helper.make_graph([mul_node], "mul_graph", [X, Y], [Z]) +model = helper.make_model(graph, producer_name="onnxruntime-webgpu-ep-test") +model.opset_import[0].version = 13 + +checker.check_model(model) + +output_path = os.path.join(os.path.dirname(__file__), "mul.onnx") +save(model, output_path) +print(f"Saved {output_path}") diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/mul.onnx b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/mul.onnx new file mode 100644 index 0000000000000..6df01feb5cf58 --- /dev/null +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/mul.onnx @@ -0,0 +1,16 @@ + onnxruntime-webgpu-ep-test:Z + +x +yz"Mul mul_graphZ +x +  + +Z +y +  + +b +z +  + +B \ No newline at end of file diff --git a/plugin-ep-webgpu/python/README.md b/plugin-ep-webgpu/python/README.md index ac14a84a70f48..849105a439396 100644 --- a/plugin-ep-webgpu/python/README.md +++ b/plugin-ep-webgpu/python/README.md @@ -19,19 +19,13 @@ Wheels are built via `build_wheel.py`. Running `pip install` or `pip wheel` dire supported — the source tree contains `pyproject.toml.in` (a template), not a real `pyproject.toml`. ```bash -python build_wheel.py \ - --binary_dir \ - --version \ - --output_dir +python build_wheel.py --binary_dir --version --output_dir ``` Example: ```bash -python build_wheel.py \ - --binary_dir ./build/Release \ - --version 0.1.0.dev20260429 \ - --output_dir ./dist +python build_wheel.py --binary_dir ./build/Release --version 0.1.0.devYYYYMMDD --output_dir ./dist ``` The script combines the pre-built plugin EP binaries with the package source to produce a platform-specific wheel. @@ -44,7 +38,7 @@ Install the wheel and dependencies in a clean environment, then run the smoke te python -m venv test_venv source test_venv/bin/activate # or test_venv\Scripts\Activate.ps1 on Windows pip install onnx numpy -pip install dist/onnxruntime_ep_webgpu-*.whl # pulls in onnxruntime>=1.24.4 +pip install dist/onnxruntime_ep_webgpu-*.whl # pulls in the minimum compatible onnxruntime python test/test_webgpu_plugin_ep.py ``` diff --git a/plugin-ep-webgpu/python/build_wheel.py b/plugin-ep-webgpu/python/build_wheel.py index 8f855a5d2179b..b4357bcdfbe0f 100644 --- a/plugin-ep-webgpu/python/build_wheel.py +++ b/plugin-ep-webgpu/python/build_wheel.py @@ -86,6 +86,7 @@ def prepare_staging_dir(staging_dir: Path, binary_dir: Path, version: str): shutil.copytree(SCRIPT_DIR / "onnxruntime_ep_webgpu", staging_dir / "onnxruntime_ep_webgpu") # Copy plugin binaries into the package directory + # Note: The binaries are assumed to be directly under `binary_dir`. package_dir = staging_dir / "onnxruntime_ep_webgpu" copied = [] for pattern in BINARY_PATTERNS: diff --git a/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml b/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml index 7d9f7c24b3360..673452d8b110a 100644 --- a/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/plugin-webgpu-pipeline.yml @@ -46,7 +46,7 @@ parameters: type: string values: - release - - RC + # - RC # not implemented yet - dev default: dev diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml index 9ce494d4b3a36..12ee9ca68bb4e 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml @@ -71,7 +71,7 @@ stages: set -e -x python3 -m venv /build/test_venv source /build/test_venv/bin/activate - python3 -m pip install onnxruntime onnx numpy + python3 -m pip install onnx numpy wheel=\$(find /build/python_wheel -name 'onnxruntime_ep_webgpu-*.whl' | head -1) python3 -m pip install \"\$wheel\" python3 -u /onnxruntime_src/plugin-ep-webgpu/python/test/test_webgpu_plugin_ep.py diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml index 5ad4e170b2855..6dca5dd450fd0 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml @@ -30,7 +30,7 @@ stages: set -e -x python3 -m venv "$(Build.BinariesDirectory)/test_venv" source "$(Build.BinariesDirectory)/test_venv/bin/activate" - python3 -m pip install onnxruntime onnx numpy + python3 -m pip install onnx numpy wheel=$(find "$(Pipeline.Workspace)/build/webgpu_plugin_python_macos_arm64" -name "onnxruntime_ep_webgpu-*.whl" | head -1) python3 -m pip install "$wheel" python3 -u "$(Build.SourcesDirectory)/plugin-ep-webgpu/python/test/test_webgpu_plugin_ep.py" diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml new file mode 100644 index 0000000000000..93210533d2dc0 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-nuget-packaging-stage.yml @@ -0,0 +1,186 @@ +# NuGet packaging stage for WebGPU plugin EP. +# Downloads platform-specific build artifacts, packs them into a single multi-platform NuGet package, +# signs it, and runs a basic test. + +parameters: +- name: package_version + type: string + +- name: version_file + type: string + +- name: DoEsrp + type: boolean + default: true + +- name: platforms + type: object + default: + win_x64: false + win_arm64: false + linux_x64: false + macos_arm64: false + +stages: +- stage: NuGet_Packaging + displayName: 'NuGet Packaging' + dependsOn: + - ${{ if eq(parameters.platforms.win_x64, true) }}: + - Win_plugin_webgpu_x64_Build + - ${{ if eq(parameters.platforms.win_arm64, true) }}: + - Win_plugin_webgpu_arm64_Build + - ${{ if eq(parameters.platforms.linux_x64, true) }}: + - Linux_plugin_webgpu_x64_Build + - ${{ if eq(parameters.platforms.macos_arm64, true) }}: + - MacOS_plugin_webgpu_arm64_Build + jobs: + # ---------- Pack job ---------- + - job: NuGet_Pack + displayName: 'Pack NuGet' + timeoutInMinutes: 30 + workspace: + clean: all + pool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + templateContext: + outputs: + - output: pipelineArtifact + targetPath: '$(Build.ArtifactStagingDirectory)\nuget' + artifactName: webgpu_plugin_nuget + variables: + - template: ../templates/common-variables.yml + - name: WebGpuPackStagingDir + value: '$(Build.BinariesDirectory)\webgpu_pack_staging' + # Common arguments shared by the Build and Pack invocations of pack_nuget.py. + - name: WebGpuPackNuGetCommonArgs + value: >- + --version "$(PluginPackageVersion)" + --output-dir "$(Build.ArtifactStagingDirectory)\nuget" + --staging-dir "$(WebGpuPackStagingDir)" + --configuration Release + --nuget-config "$(Build.SourcesDirectory)\NuGet.config" + steps: + - checkout: self + clean: true + submodules: none + + - template: ../templates/setup-build-tools.yml + parameters: + host_cpu_arch: 'x64' + + - template: ../templates/set-nightly-build-option-variable-step.yml + + - template: ../templates/set-plugin-build-variables-step.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + + # Download platform artifacts + - ${{ if eq(parameters.platforms.win_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download win-x64 artifacts' + inputs: + artifactName: webgpu_plugin_win_x64 + targetPath: '$(Build.BinariesDirectory)\artifacts\win_x64' + + - ${{ if eq(parameters.platforms.win_arm64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download win-arm64 artifacts' + inputs: + artifactName: webgpu_plugin_win_arm64 + targetPath: '$(Build.BinariesDirectory)\artifacts\win_arm64' + + - ${{ if eq(parameters.platforms.linux_x64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download linux-x64 artifacts' + inputs: + artifactName: webgpu_plugin_linux_x64 + targetPath: '$(Build.BinariesDirectory)\artifacts\linux_x64' + + - ${{ if eq(parameters.platforms.macos_arm64, true) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download macos-arm64 artifacts' + inputs: + artifactName: webgpu_plugin_macos_arm64 + targetPath: '$(Build.BinariesDirectory)\artifacts\macos_arm64' + + # Compute the set of required platforms from the pipeline parameters and verify the + # corresponding artifact directories actually downloaded. This catches renamed/moved + # upstream artifacts loudly before any pack work, and feeds pack_nuget.py the same + # list so it fails fast if any required platform's binaries are missing. + - task: PythonScript@0 + displayName: 'Compute required platforms' + inputs: + scriptSource: inline + script: | + import os + import sys + + # The string literals below are filled in by ADO template expansion at queue + # time and resolve to a boolean value 'True' or 'False'. Compare case-insensitively. + platforms_enabled = { + "win_x64": "${{ parameters.platforms.win_x64 }}".lower() == "true", + "win_arm64": "${{ parameters.platforms.win_arm64 }}".lower() == "true", + "linux_x64": "${{ parameters.platforms.linux_x64 }}".lower() == "true", + "macos_arm64": "${{ parameters.platforms.macos_arm64 }}".lower() == "true", + } + expected = [name for name, enabled in platforms_enabled.items() if enabled] + + if not expected: + print("##vso[task.logissue type=error]No platforms enabled in 'platforms' parameter — nothing to pack.") + sys.exit(1) + + artifacts_dir = r"$(Build.BinariesDirectory)\artifacts" + missing = [ + f"{p} ({d})" + for p in expected + for d in [os.path.join(artifacts_dir, p, "bin")] + if not os.path.isdir(d) + ] + if missing: + print("##vso[task.logissue type=error]Expected artifact directories not found:") + for m in missing: + print(f"##vso[task.logissue type=error] {m}") + sys.exit(1) + + required = ",".join(expected) + print(f"Required platforms: {required}") + print(f"##vso[task.setvariable variable=WebGpuRequiredPlatforms]{required}") + + # Stage binaries and build the managed assembly (so it can be ESRP-signed before packing). + - task: PythonScript@0 + displayName: 'Build managed DLL' + inputs: + scriptSource: filePath + scriptPath: '$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\pack_nuget.py' + arguments: >- + $(WebGpuPackNuGetCommonArgs) + --artifacts-dir "$(Build.BinariesDirectory)\artifacts" + --required-platforms $(WebGpuRequiredPlatforms) + --build-only + + # ESRP-sign the managed DLL before it gets embedded in the .nupkg. + - template: ../templates/win-esrp-dll.yml + parameters: + FolderPath: '$(WebGpuPackStagingDir)' + Pattern: 'Microsoft.ML.OnnxRuntime.EP.WebGpu.dll' + DisplayName: 'ESRP - Sign managed DLL' + DoEsrp: ${{ parameters.DoEsrp }} + + # Pack the (now-signed) managed DLL plus native binaries into the .nupkg. + - task: PythonScript@0 + displayName: 'Pack NuGet package' + inputs: + scriptSource: filePath + scriptPath: '$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\pack_nuget.py' + arguments: >- + $(WebGpuPackNuGetCommonArgs) + --pack-only + + # ESRP sign + - template: ../templates/esrp_nuget.yml + parameters: + FolderPath: '$(Build.ArtifactStagingDirectory)\nuget' + DisplayName: 'ESRP - Sign NuGet package' + DoEsrp: ${{ parameters.DoEsrp }} diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml index 6777f207d67b9..996d6fa1af0a6 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-webgpu-packaging-stage.yml @@ -48,9 +48,7 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} # Windows ARM64 - # ARM64 build requires the x64 tblgen.exe (used during the build), which is not correctly - # generated in a cross build. So we require x64 to be built first and download tblgen.exe from it. - - ${{ if and(eq(parameters.build_windows_arm64, true), eq(parameters.build_windows_x64, true)) }}: + - ${{ if eq(parameters.build_windows_arm64, true) }}: - template: plugin-win-webgpu-stage.yml parameters: arch: 'arm64' @@ -74,13 +72,25 @@ stages: version_file: ${{ parameters.version_file }} cmake_build_type: ${{ parameters.cmake_build_type }} + # NuGet packaging (runs after all platform builds) + - template: plugin-webgpu-nuget-packaging-stage.yml + parameters: + package_version: ${{ parameters.package_version }} + version_file: ${{ parameters.version_file }} + DoEsrp: true + platforms: + win_x64: ${{ parameters.build_windows_x64 }} + win_arm64: ${{ parameters.build_windows_arm64 }} + linux_x64: ${{ parameters.build_linux_x64 }} + macos_arm64: ${{ parameters.build_macos_arm64 }} + # Create zip packages for Foundry Local consumption - stage: Package_Foundry_Local_WebGPU_Zips displayName: 'Package Foundry Local WebGPU Plugin-EP Zips' dependsOn: - ${{ if eq(parameters.build_windows_x64, true) }}: - Win_plugin_webgpu_x64_Build - - ${{ if and(eq(parameters.build_windows_arm64, true), eq(parameters.build_windows_x64, true)) }}: + - ${{ if eq(parameters.build_windows_arm64, true) }}: - Win_plugin_webgpu_arm64_Build - ${{ if eq(parameters.build_linux_x64, true) }}: - Linux_plugin_webgpu_x64_Build @@ -111,10 +121,7 @@ stages: artifactName: webgpu_plugin_win_x64 targetPath: $(Build.SourcesDirectory)/webgpu-plugin-win-x64 - # Windows ARM64 - # ARM64 build requires the x64 tblgen.exe (used during the build), which is not correctly - # generated in a cross build. So we require x64 to be built first and download tblgen.exe from it. - - ${{ if and(eq(parameters.build_windows_arm64, true), eq(parameters.build_windows_x64, true)) }}: + - ${{ if eq(parameters.build_windows_arm64, true) }}: - task: DownloadPipelineArtifact@2 displayName: 'Download webgpu_plugin_win_arm64' inputs: diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml index c774d42776afb..332d5b0224f37 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-stage.yml @@ -28,6 +28,7 @@ parameters: stages: - stage: Win_plugin_webgpu_${{ parameters.arch }}_Build ${{ if eq(parameters.arch, 'arm64') }}: + # The ARM64 build consumes the x64 tblgen.exe artifact published by the Windows x64 stage. dependsOn: Win_plugin_webgpu_x64_Build ${{ else }}: dependsOn: [] diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml index af29a62d69329..1494584ff98fd 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml @@ -31,29 +31,103 @@ stages: artifact: webgpu_plugin_python_win_${{ parameters.arch }} displayName: 'Download Python wheel' - - task: PowerShell@2 + - pwsh: | + $ErrorActionPreference = 'Stop' + + echo "creating test_venv" + python -m venv "$(Build.BinariesDirectory)\test_venv" + + echo "activating test_venv" + & "$(Build.BinariesDirectory)\test_venv\Scripts\Activate.ps1" + + echo "installing test dependencies" + python -m pip install onnx numpy + + $wheelDir = "$(Pipeline.Workspace)\build\webgpu_plugin_python_win_${{ parameters.arch }}" + $wheel = (Get-ChildItem "$wheelDir\onnxruntime_ep_webgpu-*.whl")[0] + echo "installing ${wheel}" + python -m pip install $wheel.FullName + + echo "running test_webgpu_plugin_ep.py" + python -u "$(Build.SourcesDirectory)\plugin-ep-webgpu\python\test\test_webgpu_plugin_ep.py" displayName: 'Install and test Python package' env: ORT_TEST_VERBOSE: $(System.Debug) - inputs: - targetType: inline - pwsh: true - script: | + + # NuGet package test (x64 only — the NuGet package is multi-platform but + # the test runs on a single Windows agent that exercises the WebGPU EP). + - ${{ if eq(parameters.arch, 'x64') }}: + - job: Win_plugin_webgpu_nuget_Test + timeoutInMinutes: 30 + workspace: + clean: all + pool: + name: onnxruntime-Win2022-VS2022-webgpu-A10 + os: windows + variables: + WebGpuTestProject: '$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\test\WebGpuEpNuGetTest\WebGpuEpNuGetTest.csproj' + steps: + - checkout: self + submodules: none + + - template: ../templates/setup-feeds-and-python-steps.yml + + # Download the NuGet package produced by the packaging pipeline run that + # triggered this pipeline (or that was selected at queue time). + - download: build + artifact: webgpu_plugin_nuget + displayName: 'Download NuGet package' + + # Set up local NuGet feed and extract the package version from the .nupkg filename + # so the test project can pin to it (instead of resolving via a floating version). + - pwsh: | $ErrorActionPreference = 'Stop' + $localFeedDir = "$(Build.BinariesDirectory)\local_feed" + New-Item -ItemType Directory -Path $localFeedDir -Force | Out-Null - echo "creating test_venv" - python -m venv "$(Build.BinariesDirectory)\test_venv" + # Locate the .nupkg. + $nupkg = Get-ChildItem "$(Pipeline.Workspace)\build\webgpu_plugin_nuget\Microsoft.ML.OnnxRuntime.EP.WebGpu.*.nupkg" | + Select-Object -First 1 + if (-not $nupkg) { + throw "No matching .nupkg found under $(Pipeline.Workspace)\build\webgpu_plugin_nuget" + } + Copy-Item $nupkg.FullName $localFeedDir -Force - echo "activating test_venv" - & "$(Build.BinariesDirectory)\test_venv\Scripts\Activate.ps1" + # Extract version from filename: Microsoft.ML.OnnxRuntime.EP.WebGpu..nupkg + # The version starts with a digit, which disambiguates from any future filename suffixes. + if ($nupkg.BaseName -notmatch '^Microsoft\.ML\.OnnxRuntime\.EP\.WebGpu\.(\d.*)$') { + throw "Could not extract version from .nupkg filename: $($nupkg.Name)" + } + $packageVersion = $Matches[1] + Write-Host "Detected package version: $packageVersion" + Write-Host "##vso[task.setvariable variable=OrtWebGpuPackageVersion]$packageVersion" - echo "installing onnxruntime onnx numpy" - python -m pip install onnxruntime onnx numpy + # Write a project-level nuget.config that adds ONLY the local feed. + # NuGet merges this with the repo-root NuGet.config. + $nugetConfig = "$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\test\WebGpuEpNuGetTest\nuget.config" + Set-Content -Path $nugetConfig -Encoding UTF8 -Value @" + + + + + + + "@ + Write-Host "Wrote project-level nuget.config with local feed: $localFeedDir" + Write-Host "Local feed contents:" + Get-ChildItem $localFeedDir | ForEach-Object { Write-Host " $($_.Name)" } + displayName: 'Set up local NuGet feed' - $wheelDir = "$(Pipeline.Workspace)\build\webgpu_plugin_python_win_${{ parameters.arch }}" - $wheel = (Get-ChildItem "$wheelDir\onnxruntime_ep_webgpu-*.whl")[0] - echo "installing ${wheel}" - python -m pip install $wheel.FullName + - pwsh: | + dotnet build ` + "$(WebGpuTestProject)" ` + --configuration Release ` + -p:OrtWebGpuPackageVersion=$(OrtWebGpuPackageVersion) + displayName: 'Build test project' - echo "running test_webgpu_plugin_ep.py" - python -u "$(Build.SourcesDirectory)\plugin-ep-webgpu\python\test\test_webgpu_plugin_ep.py" + - pwsh: | + dotnet run ` + --project "$(WebGpuTestProject)" ` + --configuration Release ` + --no-build + displayName: 'Run NuGet package test' From ee5158e73d019fcd672e26629b1d718587055589 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 5 May 2026 22:13:30 +0000 Subject: [PATCH 20/34] Fill CUDA Cast operator opset gap: extend registration from opset 23 to 25 (#27744) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Extends CUDA Cast kernel registration to cover opset 25 (latest ONNX spec). The existing non-versioned opset 23 registration is capped to VERSIONED (23, 24), and a new non-versioned opset 25 registration is added for all type specializations. **`cast_op.cc`**: - `REGISTER_KERNEL_TYPED(T)`: opset 23 → VERSIONED (23, 24), added non-versioned opset 25 - Renamed `REGISTER_KERNEL_TYPED_23` → `REGISTER_KERNEL_TYPED_23_TO_24` (VERSIONED) - Added `REGISTER_KERNEL_TYPED_25` macro (non-versioned) - Renamed `SPECIALIZE_IMPL_19_TO_23` → `SPECIALIZE_IMPL_19_TO_25`, covering Float8 types through opset 25 - Updated Float4E2M1x2 registration to use new versioned/non-versioned macros **`cuda_execution_provider.cc`**: - Forward declarations: all opset 23 Cast entries → VERSIONED (23, 24), added opset 25 non-versioned entries (all 16 types: 13 standard + 2 Float8 + 1 Float4) - `BuildKernelCreateInfo`: same pattern — capped 23 to (23, 24), added opset 25 block ### Motivation and Context CUDA Cast operator was registered up to opset 23, but ONNX spec defines Cast through opset 25. This gap can cause kernel lookup failures when running models exported at opset 25. Part of the broader CUDA opset gap-filling effort tracked in #27729. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: Copilot --- docs/OperatorKernels.md | 3 +- .../providers/cuda/cuda_execution_provider.cc | 104 +++++++--- .../core/providers/cuda/tensor/cast_op.cc | 42 +++- .../test/providers/cpu/tensor/cast_op_test.cc | 187 +++++++++++++++++- 4 files changed, 290 insertions(+), 46 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 28b282b25f8f6..05e8b072e130e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -680,7 +680,8 @@ The **OpSet Version** column uses the following notation: |||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| |||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| -|Cast|*in* input:**T1**
*out* output:**T2**|23+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[23, 24]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float4e2m1), tensor(float8e4m3fn), tensor(float8e5m2), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8b139c2d5514f..ffe255f4277c2 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1681,25 +1681,25 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, MLFloat16, Attention); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, BFloat16, Attention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, BFloat16, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, int64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint8_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint16_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint32_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, uint64_t, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, bool, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, float, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, double, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, MLFloat16, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, BFloat16, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, int64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint8_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint16_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint32_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint64_t, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, bool, Cast); #if !defined(DISABLE_FLOAT8_TYPES) -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E4M3FN, Cast); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float8E5M2, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E4M3FN, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E5M2, Cast); #endif #if !defined(DISABLE_FLOAT4_TYPES) -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Float4E2M1x2, Cast); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float4E2M1x2, Cast); #endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, uint8_t, float, DequantizeLinear); @@ -1769,6 +1769,26 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, T class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 24, 24, Unsqueeze); // Opset 25. +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, float, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, double, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, MLFloat16, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, BFloat16, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint8_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint16_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint32_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint64_t, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, bool, Cast); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E4M3FN, Cast); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E5M2, Cast); +#endif +#if !defined(DISABLE_FLOAT4_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float4E2M1x2, Cast); +#endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, ConstantOfShape); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, uint8_t, float, DequantizeLinear); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, int8_t, float, DequantizeLinear); @@ -2885,25 +2905,25 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif #if !defined(DISABLE_FLOAT4_TYPES) - BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2973,6 +2993,26 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 25 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif +#if !defined(DISABLE_FLOAT4_TYPES) + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cc b/onnxruntime/core/providers/cuda/tensor/cast_op.cc index 8f5c9202c1dba..2ed08e25d02d2 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cc @@ -90,10 +90,20 @@ const std::vector& CastOpTypeConstraints() { .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", CastOpTypeConstraints()), \ Cast); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 23, 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", CastOpTypeConstraints()), \ + Cast); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Cast, \ kOnnxDomain, \ - 23, \ + 25, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -389,11 +399,23 @@ SPECIALIZE_IMPL(BFloat16) .TypeConstraint("T2", CastOpTypeConstraints()), \ Cast); -#define REGISTER_KERNEL_TYPED_23(T, OutputTypeConstraints) \ +#define REGISTER_KERNEL_TYPED_23_TO_24(T, OutputTypeConstraints) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Cast, \ + kOnnxDomain, \ + 23, 24, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", OutputTypeConstraints), \ + Cast); + +#define REGISTER_KERNEL_TYPED_25(T, OutputTypeConstraints) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Cast, \ kOnnxDomain, \ - 23, \ + 25, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -403,18 +425,20 @@ SPECIALIZE_IMPL(BFloat16) #if !defined(DISABLE_FLOAT8_TYPES) -#define SPECIALIZE_IMPL_19_TO_23(T) \ - REGISTER_KERNEL_TYPED_19_TO_22(T) \ - REGISTER_KERNEL_TYPED_23(T, CastOpTypeConstraints()) \ +#define SPECIALIZE_IMPL_19_TO_25(T) \ + REGISTER_KERNEL_TYPED_19_TO_22(T) \ + REGISTER_KERNEL_TYPED_23_TO_24(T, CastOpTypeConstraints()) \ + REGISTER_KERNEL_TYPED_25(T, CastOpTypeConstraints()) \ template Status Cast::ComputeInternal(OpKernelContext* context) const; -SPECIALIZE_IMPL_19_TO_23(Float8E4M3FN) -SPECIALIZE_IMPL_19_TO_23(Float8E5M2) +SPECIALIZE_IMPL_19_TO_25(Float8E4M3FN) +SPECIALIZE_IMPL_19_TO_25(Float8E5M2) #endif #if !defined(DISABLE_FLOAT4_TYPES) -REGISTER_KERNEL_TYPED_23(Float4E2M1x2, {DataTypeImpl::GetTensorType()}) +REGISTER_KERNEL_TYPED_23_TO_24(Float4E2M1x2, {DataTypeImpl::GetTensorType()}) +REGISTER_KERNEL_TYPED_25(Float4E2M1x2, {DataTypeImpl::GetTensorType()}) template Status Cast::ComputeInternal(OpKernelContext* context) const; #endif diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index d5b6630668000..4481cf36554cd 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "boost/mp11.hpp" @@ -2635,7 +2636,8 @@ TEST(CastOpTest, Float8E4M3FNToInt2x4_OddShape) { template void CastOpTestFloatFloat4(std::vector shape, std::vector float_data, - bool is_fp4_input = false) { + bool is_fp4_input = false, + int opset = 23) { int num_pairs = static_cast(float_data.size()) / 2; int num_fp4_elements = static_cast((float_data.size() + 1) / 2); bool is_odd_count = (float_data.size() % 2 != 0); @@ -2653,7 +2655,7 @@ void CastOpTestFloatFloat4(std::vector shape, if (!is_fp4_input) { TestCastOp(gsl::make_span(float_data), gsl::make_span(fp4_data), shape, - OpTester::ExpectResult::kExpectSuccess, "", 23, Saturate::None, true); + OpTester::ExpectResult::kExpectSuccess, "", opset, Saturate::None, true); } else { std::vector casted_back_float; @@ -2668,7 +2670,7 @@ void CastOpTestFloatFloat4(std::vector shape, } TestCastOp(gsl::make_span(fp4_data), gsl::make_span(casted_back_float), shape, - OpTester::ExpectResult::kExpectSuccess, "", 23, Saturate::None, true); + OpTester::ExpectResult::kExpectSuccess, "", opset, Saturate::None, true); } } @@ -2732,8 +2734,185 @@ TEST(CastOpTest, Float4E2M1x2ToFloat) { } } +// Opset 25 tests for Float4 types on CUDA +TEST(CastOpTest, FloatToFloat4E2M1x2_Opset25) { + CastOpTestFloatFloat4({2, 2, 2}, + {std::numeric_limits::infinity(), + -std::numeric_limits::infinity(), + 7.f, -7.f, + 0.5f, -0.5f, + std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN()}, + false, 25); + + CastOpTestFloatFloat4({1, 3, 1}, + {0.256f, 0.987f, 43.8f}, + false, 25); +} + +TEST(CastOpTest, Float4E2M1x2ToFloat_Opset25) { + CastOpTestFloatFloat4({2, 2, 2}, + {0.5f, 7.34f, + 1.f, 1.5f, + 2.f, 3.f, + 4.f, 6.f}, + true, 25); + + CastOpTestFloatFloat4({1, 3, 1}, + {0.256f, 0.987f, 43.8f}, + true, 25); +} + #endif +// Opset 25 tests for standard types on CUDA. +// Verifies CUDA Cast kernel registration at opset 25 works for common type conversions. +#if defined(USE_CUDA) + +TEST(CastOpTest, StandardTypes_Opset25_Cuda) { + const std::vector shape{2, 3}; + + // float -> double + { + const std::vector input = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + const std::vector expected = {1.0, 2.5, -3.0, 0.0, 100.0, -0.5}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // double -> float + { + const std::vector input = {1.0, 2.5, -3.0, 0.0, 100.0, -0.5}; + const std::vector expected = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // float -> int32_t + { + const std::vector input = {1.0f, 2.9f, -3.0f, 0.0f, 100.0f, -0.5f}; + const std::vector expected = {1, 2, -3, 0, 100, 0}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // int32_t -> float + { + const std::vector input = {1, 2, -3, 0, 100, -7}; + const std::vector expected = {1.0f, 2.0f, -3.0f, 0.0f, 100.0f, -7.0f}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // float -> MLFloat16 + if (HasCudaEnvironment(530)) { + const std::vector input = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + const std::vector expected = CastedValues(gsl::make_span(input)); + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // MLFloat16 -> float + if (HasCudaEnvironment(530)) { + const std::vector input = CastedValues( + gsl::make_span(std::vector{1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f})); + const std::vector expected = {1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f}; + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // BFloat16 -> float + if (HasCudaEnvironment(800)) { + const std::vector input = CastedValues( + gsl::make_span(std::vector{1.0f, 2.5f, -3.0f, 0.0f, 100.0f, -0.5f})); + const std::vector expected = CastedValues(gsl::make_span(input)); + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // bool -> float + { + const bool input[] = {true, false, true, true, false, false}; + const gsl::span input_span(input); + const std::vector expected = {1.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f}; + TestCastOp(input_span, gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } +} + +#if !defined(DISABLE_FLOAT8_TYPES) + +TEST(CastOpTest, Float8_Opset25_Cuda) { + constexpr int min_cuda_architecture = 11080; + if (!HasCudaEnvironment(min_cuda_architecture)) { + return; + } + + const std::vector shape{2, 2, 2}; + const std::vector float_input = {NAN, -1.f, 0.0391877927f, 0.296140194f, + -0.120196559f, 5.0f, + -std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + // Float8E4M3FN: float -> Float8E4M3FN at opset 25 + { + std::vector output; + output.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + output.emplace_back(Float8E4M3FN(float_input[i], true)); + } + TestCastOp(gsl::make_span(float_input), gsl::make_span(output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::True, /*cuda_only=*/true); + } + + // Float8E5M2: float -> Float8E5M2 at opset 25 + { + std::vector output; + output.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + output.emplace_back(Float8E5M2(float_input[i], true)); + } + TestCastOp(gsl::make_span(float_input), gsl::make_span(output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::True, /*cuda_only=*/true); + } + + // Float8E4M3FN -> float at opset 25 + { + std::vector input; + input.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + input.emplace_back(Float8E4M3FN(float_input[i], true)); + } + std::vector expected; + expected.reserve(input.size()); + for (const auto& v : input) { + expected.push_back(v.ToFloat()); + } + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } + + // Float8E5M2 -> float at opset 25 + { + std::vector input; + input.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + input.emplace_back(Float8E5M2(float_input[i], true)); + } + std::vector expected; + expected.reserve(input.size()); + for (const auto& v : input) { + expected.push_back(v.ToFloat()); + } + TestCastOp(gsl::make_span(input), gsl::make_span(expected), shape, + OpTester::ExpectResult::kExpectSuccess, "", 25, Saturate::None, /*cuda_only=*/true); + } +} + +#endif // !defined(DISABLE_FLOAT8_TYPES) + +#endif // defined(USE_CUDA) + // Regression tests for sub-byte same-type cast (CopyCpuTensor heap overflow fix). // When src and dst types are the same, Cast::Compute calls CopyCpuTensor which must // use SizeInBytes() (not shape.Size() * DataType()->Size()) for the memcpy byte count. @@ -2835,7 +3014,7 @@ TEST(CastOpTest, UInt2x4ToUInt2x4_LargeShape) { // Direct CopyCpuTensor test with guaranteed distinct buffers to exercise the memcpy path. // This bypasses the MayInplace optimization that can alias input/output in OpTester. // Uses guard bytes after the valid buffer region to detect overflow deterministically -// without relying on ASan — the pre-fix code would overwrite these sentinel bytes. +// without relying on ASan; the pre-fix code would overwrite these sentinel bytes. TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) { constexpr uint8_t kGuardByte = 0xCD; constexpr size_t kGuardSize = 64; From 28bcc9cf304b3af4eece60700c09f2e2561ba655 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 5 May 2026 23:04:36 +0000 Subject: [PATCH 21/34] =?UTF-8?q?Fill=20CUDA=20opset=20gap=20for=20ReduceM?= =?UTF-8?q?ax=20and=20ReduceMin=20(18=20=E2=86=92=2020)=20(#27755)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Extends CUDA ReduceMax and ReduceMin kernel registrations from opset 18 to opset 20. - **`reduction_ops.cc`**: Added `REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED` macro for versioned ranges requiring `InputMemoryType(OrtMemTypeCPUInput, 1)`. Split both operators from 2-way (1–17, 18+) to 3-way (1–17, 18–19, 20+). - **`cuda_execution_provider.cc`**: Capped opset 18 forward declarations and `BuildKernelCreateInfo` entries to versioned 18–19. Added opset 20 non-versioned entries for both operators. Type coverage maintained as-is: ReduceMax (float, double, MLFloat16, int32_t, int64_t), ReduceMin adds int8_t, uint8_t. ### Motivation and Context ReduceMax and ReduceMin CUDA registrations stopped at opset 18; ONNX latest is opset 20. Models exported with opset 19–20 could fail to find a matching CUDA kernel for these ops. Follows the same pattern used in #27735 (TopK) and other opset gap PRs tracked in #27729. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu --- .../providers/cuda/cuda_execution_provider.cc | 86 +++++--- .../providers/cuda/reduction/reduction_ops.cc | 66 ++++-- .../cpu/reduction/reduction_ops_test.cc | 204 ++++++++++++++++++ 3 files changed, 318 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ffe255f4277c2..6cd906f5f1ea6 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -840,6 +840,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); @@ -1420,13 +1422,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1486,11 +1488,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, // Opset 18 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, Split); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, 18, float, Pad); @@ -1577,6 +1581,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, 21, float, GridSample); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMin); + // Opset 21. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, 22, double, Cast); @@ -2064,6 +2083,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2701,18 +2722,20 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { // Opset 18 BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2801,6 +2824,21 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // Opset 21 BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 127cfcc557fd5..a0a2f377d0c80 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -36,6 +36,16 @@ namespace cuda { (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); +#define REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(name, T, begin, end) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + begin, end, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ + name); + #define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) @@ -876,13 +886,27 @@ REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16) REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float) REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int64_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int8_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, uint8_t, 17, 18) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, MLFloat16, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, float, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, double, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, int32_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, int64_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, int8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMax, uint8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, MLFloat16, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, float, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, double, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, int32_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, int64_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, int8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMax, uint8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, MLFloat16, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, float, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, double, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, int32_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, int64_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, int8_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMax, uint8_t, 20) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, MLFloat16, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, float, 17, 18) @@ -890,13 +914,27 @@ REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, double, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, BFloat16, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, MLFloat16, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, float, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, double, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int32_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int64_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int8_t, 17, 18) -REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, uint8_t, 17, 18) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, MLFloat16, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, float, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, double, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, int32_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, int64_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, int8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_TYPED(ReduceMin, uint8_t, 1, 17) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, MLFloat16, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, float, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, double, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, int32_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, int64_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, int8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_RANGE_AXES_INPUT_TYPED(ReduceMin, uint8_t, 18, 19) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, MLFloat16, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, float, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, double, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, int32_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, int64_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, int8_t, 20) +REGISTER_KERNEL_VERSIONED_SINCE_TYPED(ReduceMin, uint8_t, 20) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, MLFloat16, 17, 18) REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, float, 17, 18) diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 52e8b55cb3b98..79617dc16e1f5 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -6318,5 +6318,209 @@ TEST(ReductionOpTest, ReduceSumSquare_NoopWithAxesNotProvided_ElementwiseSquare) test.ConfigEp(DefaultCpuExecutionProvider()).RunWithConfig(); } +// Opset 20 tests for ReduceMax and ReduceMin on CUDA. +// Verifies CUDA kernel registration at opset 20 works for all supported types. +#if defined(USE_CUDA) + +TEST(ReductionOpTest, ReduceMax_float_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4.0f, 8.0f, 12.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_double_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4.0, 8.0, 12.0}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_half_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + FloatsToMLFloat16s({1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f})); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, FloatsToMLFloat16s({4.0f, 8.0f, 12.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_int32_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4, 8, 12}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMax_int64_Opset20_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {4, 8, 12}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_float_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1.0f, 5.0f, 9.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_double_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1.0, 5.0, 9.0}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_half_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + FloatsToMLFloat16s({1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f})); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, FloatsToMLFloat16s({1.0f, 5.0f, 9.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_int32_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_int64_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_int8_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ReductionOpTest, ReduceMin_uint8_Opset20_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)1); + test.AddInput("data", {3, 2, 2}, + {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3, 1, 1}, {1, 5, 9}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test ReduceMax at opset 20 with keepdims=0 on CUDA +TEST(ReductionOpTest, ReduceMax_float_Opset20_NoKeepdims_Cuda) { + OpTester test("ReduceMax", 20); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3}, {4.0f, 8.0f, 12.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// Test ReduceMin at opset 20 with keepdims=0 on CUDA +TEST(ReductionOpTest, ReduceMin_float_Opset20_NoKeepdims_Cuda) { + OpTester test("ReduceMin", 20); + test.AddAttribute("keepdims", (int64_t)0); + test.AddInput("data", {3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f, 12.0f}); + test.AddInput("axes", {2}, {1, 2}); + test.AddOutput("reduced", {3}, {1.0f, 5.0f, 9.0f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +#endif // defined(USE_CUDA) + } // namespace test } // namespace onnxruntime From b8f21f1ee36917a3e61b6380b66bb41e11147613 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 5 May 2026 23:53:35 +0000 Subject: [PATCH 22/34] =?UTF-8?q?Fill=20RNN=20CUDA=20operator=20opset=20ga?= =?UTF-8?q?p=20(14=20=E2=86=92=2022)=20(#27743)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Extends RNN CUDA kernel registration from opset 14 to opset 22, following the standard opset gap-filling pattern: - **`rnn.cc`**: Cap existing opset 14 non-versioned kernel to versioned 14–21; add new non-versioned kernel at opset 22 - **`cuda_execution_provider.cc`**: Update forward declarations and `BuildKernelCreateInfo` entries to match (versioned 14–21 + non-versioned 22); remove duplicate GRU opset 22 entries introduced during merge - **`OperatorKernels.md`**: Update CUDA RNN entry to reflect three tiers: `[7,13]`, `[14,21]`, `22+` No behavioral changes — the operator implementation is identical across opset 14–22. This is a registration-only change. ### Motivation and Context RNN CUDA operator was registered at opset 14 while ONNX defines it through opset 22, causing models exported at newer opsets to fall back to CPU. Part of the broader CUDA EP opset gap effort tracked in #27729. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu --- docs/OperatorKernels.md | 3 +- .../providers/cuda/cuda_execution_provider.cc | 28 +++++---- onnxruntime/core/providers/cuda/rnn/rnn.cc | 20 +++++- .../test/providers/cpu/rnn/rnn_op_test.cc | 63 +++++++++++++++++++ 4 files changed, 101 insertions(+), 13 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 05e8b072e130e..ab2fd24fc6423 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -883,7 +883,8 @@ The **OpSet Version** column uses the following notation: |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |||[10, 12]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |RMSNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**|23+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| -|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| |||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| |RandomNormal|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |RandomNormalLike|*in* input:**T1**
*out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 6cd906f5f1ea6..2a61300a8e556 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1405,17 +1405,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Div); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Identity); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, RNN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, RNN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, RNN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, GRU); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, GRU); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, GRU); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Identity); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, RNN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, RNN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, RNN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( @@ -1691,6 +1691,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RNN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RNN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RNN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); @@ -2151,15 +2154,15 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2648,9 +2651,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2658,6 +2658,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index ed8be63679707..236aa5022fa80 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -24,11 +24,25 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, RNN_Input_Index::sequence_lens), \ RNN); +#define REGISTER_KERNEL_VERSIONED_TYPED_14(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + RNN, \ + kOnnxDomain, \ + 14, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, RNN_Input_Index::sequence_lens), \ + RNN); + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ RNN, \ kOnnxDomain, \ - 14, \ + 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -41,6 +55,10 @@ REGISTER_KERNEL_VERSIONED_TYPED(float); REGISTER_KERNEL_VERSIONED_TYPED(double); REGISTER_KERNEL_VERSIONED_TYPED(MLFloat16); +REGISTER_KERNEL_VERSIONED_TYPED_14(float); +REGISTER_KERNEL_VERSIONED_TYPED_14(double); +REGISTER_KERNEL_VERSIONED_TYPED_14(MLFloat16); + REGISTER_KERNEL_TYPED(float); REGISTER_KERNEL_TYPED(double); REGISTER_KERNEL_TYPED(MLFloat16); diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index 0dcf4f597d9c8..49bc10935c2c8 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -986,6 +986,69 @@ TEST(RNNTest, RNN_forward_sequence_lens_with_zero) { test.ConfigEp(std::move(cpu)).RunWithConfig(); } +TEST(RNNTest, RNN_ForwardDefaultActivations_OpSet22_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + // Simple forward RNN at opset 22 to verify CUDA registration. + int64_t seq_length = 2; + int batch_size = 1; + int64_t input_size = 2; + int64_t hidden_size = 3; + int num_directions = 1; + + std::vector X_data = {1.f, 2.f, 3.f, 4.f}; + std::vector X_dims = {seq_length, batch_size, input_size}; + + std::vector W_data = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f}; + std::vector W_dims = {num_directions, hidden_size, input_size}; + + std::vector R_data(num_directions * hidden_size * hidden_size, 0.1f); + std::vector R_dims = {num_directions, hidden_size, hidden_size}; + + // Y = tanh(X * W^T + H_prev * R^T), H_prev = 0 + // time_step 0: X=[1,2], W^T cols=[0.1,0.3,0.5; 0.2,0.4,0.6] + // h0 = tanh([0.1*1+0.2*2, 0.3*1+0.4*2, 0.5*1+0.6*2]) = tanh([0.5, 1.1, 1.7]) + float h0_0 = std::tanh(0.5f); + float h0_1 = std::tanh(1.1f); + float h0_2 = std::tanh(1.7f); + + // time_step 1: X=[3,4], h_prev = h0 + // h1 = tanh(X * W^T + h0 * R^T) + // X * W^T = [0.1*3+0.2*4, 0.3*3+0.4*4, 0.5*3+0.6*4] = [1.1, 2.5, 3.9] + // h0 * R^T (R=0.1 everywhere) = [0.1*(h0_0+h0_1+h0_2), ...] (same for each) + float h0_sum = h0_0 + h0_1 + h0_2; + float h1_0 = std::tanh(1.1f + 0.1f * h0_sum); + float h1_1 = std::tanh(2.5f + 0.1f * h0_sum); + float h1_2 = std::tanh(3.9f + 0.1f * h0_sum); + + std::vector Y_data = {h0_0, h0_1, h0_2, h1_0, h1_1, h1_2}; + std::vector Y_dims = {seq_length, num_directions, batch_size, hidden_size}; + + std::vector Y_h_data = {h1_0, h1_1, h1_2}; + std::vector Y_h_dims = {num_directions, batch_size, hidden_size}; + + OpTester test("RNN", 22); + test.AddShapeToTensorData(); + + test.AddAttribute>("activations", {"Tanh"}); + test.AddAttribute("direction", string("forward")); + test.AddAttribute("hidden_size", hidden_size); + + test.AddInput("X", X_dims, X_data); + test.AddInput("W", W_dims, W_data, true); + test.AddInput("R", R_dims, R_data, true); + + test.AddOutput("Y", Y_dims, Y_data); + test.AddOutput("Y_h", Y_h_dims, Y_h_data); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // Test reverse RNN with all-zero sequence_lens and non-zero initial_h. // The bug: reverse direction with sequence_lens=0 would return initial_h instead of zero-filling. TEST(RNNTest, RNN_reverse_sequence_lens_all_zero) { From 8aec1a5e650a1630fdca4e8c73d50fd53469366c Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 6 May 2026 01:25:50 +0000 Subject: [PATCH 23/34] Fill Reshape CUDA operator opset gap from 23 to 25 (#27742) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Extends the Reshape CUDA kernel registration from opset 23 to opset 25, following the same pattern used in #27728. - **`reshape.cc`**: Cap existing non-versioned opset 23 kernel → versioned (23, 24); add new non-versioned kernel at opset 25 - **`cuda_execution_provider.cc`**: Update forward declaration and `BuildKernelCreateInfo` for versioned (23, 24); add opset 25 entries - **`docs/OperatorKernels.md`**: Update Reshape CUDA EP entry from `23+` to `25+` and add `[23, 24]` versioned range row No functional changes to the kernel itself — the opset 25 schema is backward-compatible with opset 23. ### Motivation and Context Reshape is listed as a P1 gap in #27729 (CUDA max opset 23, ONNX latest opset 25). Models exported at opset 25 would fail to find a matching Reshape kernel on the CUDA EP. --- 🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. [Learn more about Advanced Security.](https://gh.io/cca-advanced-security) --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- docs/OperatorKernels.md | 3 ++- .../core/providers/cuda/cuda_execution_provider.cc | 6 ++++-- onnxruntime/core/providers/cuda/tensor/reshape.cc | 14 +++++++++++++- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ab2fd24fc6423..bc4602856b3bf 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -916,7 +916,8 @@ The **OpSet Version** column uses the following notation: |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|23+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|25+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| +|||[23, 24]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[19, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[14, 18]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 2a61300a8e556..ec44cfd3f7090 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1760,7 +1760,7 @@ class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E4M3FN, MLFloat16, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Float8E5M2, MLFloat16, QuantizeLinear); #endif -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 24, Reshape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, MLFloat16_MLFloat16, RMSNormalization); @@ -1848,6 +1848,7 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E4M3FN, MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Float8E5M2, MLFloat16, QuantizeLinear); #endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 25, Size); @@ -3006,7 +3007,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -3094,6 +3095,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index 36ee05e1e2b01..7bf3da4197ba9 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -89,7 +89,19 @@ std::unique_ptr FuncReshape( ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, - 23, + 25, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPUInput, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 23, 24, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypesIRv9()) From 3e217610dd3ad7b6db799df9ecaad568517ca7e9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 01:39:26 +0000 Subject: [PATCH 24/34] Bump brace-expansion in /js/react_native/e2e (#27894) Bumps and [brace-expansion](https://github.com/juliangruber/brace-expansion). These dependencies needed to be updated together. Updates `brace-expansion` from 1.1.11 to 1.1.13
Release notes

Sourced from brace-expansion's releases.

v1.1.12

  • pkg: publish on tag 1.x c460dbd
  • fmt ccb8ac6
  • Fix potential ReDoS Vulnerability or Inefficient Regular Expression (#65) c3c73c8

https://github.com/juliangruber/brace-expansion/compare/v1.1.11...v1.1.12

Commits

Updates `brace-expansion` from 2.0.1 to 2.0.3
Release notes

Sourced from brace-expansion's releases.

v1.1.12

  • pkg: publish on tag 1.x c460dbd
  • fmt ccb8ac6
  • Fix potential ReDoS Vulnerability or Inefficient Regular Expression (#65) c3c73c8

https://github.com/juliangruber/brace-expansion/compare/v1.1.11...v1.1.12

Commits

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/package-lock.json | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index 73d6e2a65f274..907e9cf72b59c 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -4757,7 +4757,9 @@ "license": "MIT" }, "node_modules/brace-expansion": { - "version": "1.1.11", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "license": "MIT", "dependencies": { "balanced-match": "^1.0.0", @@ -5763,7 +5765,9 @@ } }, "node_modules/detox/node_modules/brace-expansion": { - "version": "2.0.1", + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.3.tgz", + "integrity": "sha512-MCV/fYJEbqx68aE58kv2cA/kiky1G8vux3OR6/jbS+jIMe/6fJWa0DTzJU7dqijOWYwHi1t29FlfYI9uytqlpA==", "dev": true, "license": "MIT", "dependencies": { From 673c3320fca00be7d6029d2a32299059388e034b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 6 May 2026 06:19:34 +0000 Subject: [PATCH 25/34] Fill CUDA EP opset gaps for Round and Equal operators (#27754) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Caps existing non-versioned CUDA kernel registrations and adds new registrations at the latest ONNX opset: - **Round**: opset 11 (non-versioned) → versioned 11–21 + new opset 22 - **Equal**: opset 13 (non-versioned) → versioned 13–18 + new opset 19 Changes across three files: - `unary_elementwise_ops.cc` — `UNARY_OP_HFD(Round, 11)` → `UNARY_OP_VERSIONED_HFD` + `UNARY_OP_HFD` - `binary_elementwise_ops.cc` — `BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 13)` → versioned 13–18 + new 19 (same for `bool` typed registration) - `cuda_execution_provider.cc` — corresponding forward declarations and `BuildKernelCreateInfo` entries No type changes; both operators retain their existing CUDA type support at the new opsets. ### Motivation and Context Tracks with the ongoing effort to close ONNX opset coverage gaps in the CUDA execution provider (https://github.com/microsoft/onnxruntime/issues/27729). Without these registrations, models targeting opset 19+ (Equal) or 22+ (Round) fall back from CUDA to CPU. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu --- .../providers/cuda/cuda_execution_provider.cc | 66 ++++++---- .../cuda/math/binary_elementwise_ops.cc | 6 +- .../cuda/math/unary_elementwise_ops.cc | 3 +- .../cpu/math/element_wise_ops_test.cc | 121 ++++++++++++++++++ .../test/providers/cpu/math/round_test.cc | 51 +++++++- 5 files changed, 221 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ec44cfd3f7090..bc0a250a90493 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1073,9 +1073,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, float, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, double, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Round); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, float, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, double, Round); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 21, MLFloat16, Round); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); @@ -1192,14 +1192,14 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, E class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Min); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, bool, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, float, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, double, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, MLFloat16, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint32_t, Greater); @@ -1572,6 +1572,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Shape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, bool, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, int64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, uint64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, double, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, MLFloat16, Equal); // Opset 20 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); @@ -1698,6 +1706,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Round); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Round); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); @@ -2324,9 +2335,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2438,14 +2449,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2819,6 +2830,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 20 BuildKernelCreateInfo, @@ -2945,6 +2964,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index e4faa50d7acbc..babbb4b3ba672 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -579,8 +579,10 @@ Status LessOrEqual::ComputeInternal(OpKernelContext* context) const { return this->CompareMethod(context, &ImplT2_LessOrEqual); } -BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 13) -BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 13, bool) +BINARY_LOGICALOP_REGISTER_VERSIONED_UZILHFD(Equal, 13, 18) +BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 13, 18, bool) +BINARY_LOGICALOP_REGISTER_UZILHFD(Equal, 19) +BINARY_ELEMENTWISE_LOGICALOP_REGISTER_KERNEL_TYPED(Equal, 19, bool) BINARY_OP_REGISTER_VERSIONED_UZILHFD(Equal, 11, 12) BINARY_ELEMENTWISE_REGISTER_KERNEL_VERSIONED_TYPED(Equal, 11, 12, bool) BINARY_OP_REGISTER_VERSIONED_OIL(Equal, 7, 10) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 86a1b0f5b6102..a54b96da6c174 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -249,7 +249,8 @@ UNARY_OP_HFDX(Erf, 13) UNARY_OP_BWUZCSILHFDX(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) -UNARY_OP_HFD(Round, 11) +UNARY_OP_VERSIONED_HFD(Round, 11, 21) +UNARY_OP_HFD(Round, 22) UNARY_OP_HFD(Cos, 7) UNARY_OP_HFD(Sin, 7) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 283f20a4be9b0..11a4b373c53f1 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -3674,6 +3674,127 @@ TEST(MathOpTest, Equal_string) { test.Run(); } +#ifdef USE_CUDA +// Opset 19 tests for numeric types (CUDA EP) +TEST(MathOpTest, Equal_19_bool) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {false, true, false, true}); + test.AddInput("B", dims, {false, false, true, true}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_int32) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_int64) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_float) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1.0f, 0.0f, -1.0f, -1.0f}); + test.AddInput("B", dims, {1.0f, 1.0f, 2.0f, -1.0f}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_double) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {1.0, 0.0, -1.0, -1.0}); + test.AddInput("B", dims, {1.0, 1.0, 2.0, -1.0}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_float16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + std::vector dims{4}; + test.AddInput("A", dims, {MLFloat16(1.0f), MLFloat16(0.0f), MLFloat16(-1.0f), MLFloat16(-1.0f)}); + test.AddInput("B", dims, {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(-1.0f)}); + test.AddOutput("C", dims, {true, false, false, true}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MathOpTest, Equal_19_broadcastAB) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Equal", 19); + test.AddInput("A", {4, 2}, {1, 0, -1, -1, 1, 1, -1, 0}); + test.AddInput("B", {2}, {1, 1}); + test.AddOutput("C", {4, 2}, {true, false, false, false, true, true, false, false}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + #if defined(USE_DNNL) TEST(MathOpTest, Equal_bfloat16) { #ifdef USE_DNNL diff --git a/onnxruntime/test/providers/cpu/math/round_test.cc b/onnxruntime/test/providers/cpu/math/round_test.cc index 5df14ac079a63..48f96fe4f8494 100644 --- a/onnxruntime/test/providers/cpu/math/round_test.cc +++ b/onnxruntime/test/providers/cpu/math/round_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" #include "core/framework/data_types.h" #include "core/util/math.h" @@ -30,5 +31,53 @@ TEST(RoundTest, SimpleTestFloat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +#ifdef USE_CUDA +// Opset 22 tests +TEST(RoundTest, Round22_Float) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {0.9f, 2.5f, 2.3f, 1.5f, -4.5f}); + test.AddOutput("y", {5}, {1.0f, 2.0f, 2.0f, 2.0f, -4.0f}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoundTest, Round22_Double) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {0.9, 2.5, 2.3, 1.5, -4.5}); + test.AddOutput("y", {5}, {1.0, 2.0, 2.0, 2.0, -4.0}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(RoundTest, Round22_Float16) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + OpTester test("Round", 22, onnxruntime::kOnnxDomain); + test.AddInput("x", {5}, {MLFloat16(0.9f), MLFloat16(2.5f), MLFloat16(2.3f), MLFloat16(1.5f), MLFloat16(-4.5f)}); + test.AddOutput("y", {5}, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(2.0f), MLFloat16(-4.0f)}); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime From 3b007a6873c33591d6f527c78928852a0df609b9 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 7 May 2026 00:43:03 +0800 Subject: [PATCH 26/34] webgpu: Support QKV bias in FlashAttention for MultiHeadAttention (#28380) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Remove the `bias == nullptr` requirement from `CanApplyFlashAttention`, enabling FlashAttention for MultiHeadAttention nodes with QKV bias (e.g., whisper decoder). - Apply `TransferBSDToBNSH` to add bias and transpose Q/K/V to BNSH format before calling FlashAttention. - Handle cross-attention (only Q needs bias+transpose, K/V already BNSH from encoder) and self-attention (all Q/K/V need bias+transpose) separately. ## Motivation Whisper decoder's MultiHeadAttention nodes all have QKV bias, which previously forced them into the slower unfused attention path. Enabling FlashAttention for these nodes yields ~45% speedup on whisper-tiny-int4 (~92 → ~134 tokens/s). ## Test plan - [x] Existing MHA unit tests with bias data now exercise the FlashAttention path on WebGPU with Subgroups support - [x] whisper-tiny-int4 end-to-end: correct transcription at ~134 tps (vs ~92 tps baseline) - [x] clang-format passes - [x] D3D12 build succeeds --- .../contrib_ops/webgpu/bert/attention.cc | 2 +- .../webgpu/bert/flash_attention.cc | 4 +-- .../contrib_ops/webgpu/bert/flash_attention.h | 3 +- .../webgpu/bert/group_query_attention.cc | 2 +- .../webgpu/bert/multihead_attention.cc | 35 ++++++++++++++++++- 5 files changed, 38 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 755bd0c60452f..95a5c7c17bc3a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -726,7 +726,7 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) parameters.qkv_format_ = Q_K_V_BSNH; // Check if we can use flash attention - if (CanApplyFlashAttention(nullptr, parameters, context)) { + if (CanApplyFlashAttention(parameters, context)) { // FlashAttention supports Q_K_V_BSNH format directly return ApplyFlashAttention(&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr, nullptr, nullptr, nullptr, parameters, context, nullptr); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 58c7376895661..c288a82994e98 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -574,11 +574,9 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return Status::OK(); } -bool CanApplyFlashAttention(const Tensor* bias, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && - bias == nullptr && context.HasFeature(wgpu::FeatureName::Subgroups) && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index fc2843f6ea908..980ddc3a5373b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -191,8 +191,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr); -bool CanApplyFlashAttention(const Tensor* bias, - const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); // Split packed QKV with Q/K rotary embedding and copy KV cache fusion Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index fd72f751ee810..cdf88c2f225e8 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -257,7 +257,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(nullptr, temp_params, context); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context); } if (parameters.is_packed_qkv_ && do_rotary_) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index ed43e9b3653b0..2890afae02ab9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -104,7 +104,40 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& Tensor* output_qk = context.Output(3, output_qk_shape); if (output_qk == nullptr && // Flash attention does not output QK scores - CanApplyFlashAttention(bias, parameters, context)) { + CanApplyFlashAttention(parameters, context)) { + if (bias != nullptr) { + // Apply bias and transpose Q from BSD to BNSH before FlashAttention + TensorShapeVector q_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + Tensor Q = context.CreateGPUTensor(query->DataType(), TensorShape(q_dims)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q)); + + WebgpuAttentionParameters params_bnsh(parameters); + if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { + // Cross-attention: K/V are already BNSH, only Q needs bias+transpose + params_bnsh.qkv_format_ = Q_K_V_BNSH; + return ApplyFlashAttention(&Q, key, value, attention_bias, output, past_key, present_key, past_value, + present_value, params_bnsh, context); + } + + // Self-attention: K/V also need bias+transpose + TensorShapeVector k_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.head_size_}); + Tensor K = context.CreateGPUTensor(key->DataType(), TensorShape(k_dims)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.head_size_, key, bias, parameters.hidden_size_, &K)); + + TensorShapeVector v_dims({parameters.batch_size_, parameters.num_heads_, + parameters.kv_sequence_length_, parameters.v_head_size_}); + Tensor V = context.CreateGPUTensor(value->DataType(), TensorShape(v_dims)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_, + parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V)); + + params_bnsh.qkv_format_ = Q_K_V_BNSH; + return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, past_key, present_key, past_value, + present_value, params_bnsh, context); + } return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); } From dbc55dbb203f4c6e4d3b3c9d3193c38e6f262cbf Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 6 May 2026 10:42:32 -0700 Subject: [PATCH 27/34] Prevent double-free in OrtModelEditorApi ownership transfer (#28123) ### Description The OrtModelEditorApi C API functions (AddNodeToGraph, AddGraphToModel, SetGraphInputs/SetGraphOutputs) take raw pointers and wrap them in unique_ptr to transfer ownership. Without guards, callers can pass the same pointer twice or call Release after ownership transfer, causing double-free on destruction. ### Changes - **AddInitializerToGraph**: Copy OrtValue internally instead of taking raw pointer ownership. OrtValue uses shared_ptr for its data, so copying is cheap (refcount increment). The caller retains ownership and is responsible for releasing. This eliminates the double-free class entirely for initializers. - **AddNodeToGraph**: Add \owned_\ flag to ModelEditorNode to reject double-add, add null check - **AddGraphToModel**: Reject if model already has a graph, add null check for model. Add \owned_\ flag to ModelEditorGraph to reject same graph added to two models. - **SetGraphInputs/SetGraphOutputs**: Add \owned_\ flag to ModelEditorValueInfo to reject already-owned ValueInfos. Detect duplicate pointers in input arrays. Pre-allocate vector capacity before ownership-transfer loop for exception safety. - **ReleaseNode/ReleaseGraph/ReleaseValueInfo**: Check \owned_\ flag before deleting. If already owned by a graph/model, the release is a safe no-op. - **C++ wrapper**: Remove initializer.release() in AddInitializer to match copy semantics. - **Regression tests**: Tests covering ownership-transfer guard paths, release-after-ownership, and duplicate detection. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/session/onnxruntime_c_api.h | 53 +-- .../core/session/onnxruntime_cxx_api.h | 2 +- .../core/session/onnxruntime_cxx_inline.h | 6 +- onnxruntime/core/graph/graph.cc | 6 +- .../core/graph/model_editor_api_types.h | 7 +- onnxruntime/core/session/model_editor_api.h | 3 +- .../core/session/model_editor_c_api.cc | 141 +++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 21 ++ .../test/shared_lib/test_model_builder_api.cc | 325 +++++++++++++++++- 9 files changed, 518 insertions(+), 46 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c728428348b53..e400180931b04 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7696,7 +7696,7 @@ struct OrtModelEditorApi { /** \brief Set the inputs for the OrtGraph. * * Set the graph inputs. This will replace any existing inputs with the new values. - * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. + * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call OrtApi::ReleaseValueInfo. * * \param[in] graph The OrtGraph instance to update. * \param[in] inputs The input OrtValueInfo instances. @@ -7712,7 +7712,7 @@ struct OrtModelEditorApi { /** \brief Set the outputs for the OrtGraph. * * Set the graph outputs. This will replace any existing outputs with the new values. - * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. + * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call OrtApi::ReleaseValueInfo. * * \param[in] graph The OrtGraph instance to update. * \param[in] outputs The output OrtValueInfo instances. @@ -7727,27 +7727,30 @@ struct OrtModelEditorApi { /** \brief Add an initializer to the OrtGraph * - * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. + * ORT will copy the OrtValue wrapper internally. The caller retains ownership of the OrtValue and should + * release it with OrtApi::ReleaseValue when done. Note that the underlying data buffer is not copied. + * If the OrtValue was created with a user-provided buffer (e.g., OrtApi::CreateTensorWithDataAsOrtValue), + * that buffer must remain valid for the duration of the inference session. * * Two options: * * Allocated memory: - * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. + * Use OrtApi::CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. * Set `data_is_external` to false. * * Pre-existing memory: - * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue + * Use OrtApi::CreateTensorWithDataAsOrtValue or OrtApi::CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue * with a tensor that contains a pointer to the existing data. * Set `data_is_external` to true. * * The pointer must remain valid for the duration of the inference session. - * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session + * If using OrtApi::CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session * is released. - * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as + * If using OrtApi::CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as * soon as the OrtValue is no longer in use. * * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. - * For smaller tensors use CreateTensorAsOrtValue. + * For smaller tensors use OrtApi::CreateTensorAsOrtValue. * * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of @@ -7756,19 +7759,19 @@ struct OrtModelEditorApi { * * \param[in] graph The OrtGraph instance to update. * \param[in] name The value name for the initializer. - * \param[in] tensor The OrtValue instance containing the tensor data. - * \param[in] data_is_external Set to true if the data is external and should not be copied. + * \param[in] ort_value The OrtValue instance containing the tensor data. + * \param[in] data_is_external Set to true if the data is external and should not be serialized into the model. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.22. */ - ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, _In_ OrtValue* tensor, - bool data_is_external); + ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, + _In_ const OrtValue* ort_value, bool data_is_external); /** \brief Add an OrtNode to an OrtGraph * - * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. + * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call OrtApi::ReleaseNode. * * \param[in] graph The OrtGraph instance to update. * \param[in] node The OrtNode instance to add to the graph. @@ -7807,7 +7810,7 @@ struct OrtModelEditorApi { * * Add the graph to a model. This should be called once when creating a new model. * - * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. + * The OrtModel takes ownership of the OrtGraph and you should NOT call OrtApi::ReleaseGraph. * * \param[in] model The OrtModel instance to update. * \param[in] graph The OrtGraph instance to add to the model. @@ -7825,7 +7828,7 @@ struct OrtModelEditorApi { * and SetGraphOutputs must have been called. * This will validate the model, run optimizers, and prepare the session for inferencing. * - * ReleaseOrtModel must be called to free the OrtModel after session creation. + * OrtApi::ReleaseModel must be called to free the OrtModel after session creation. * * \param[in] env The OrtEnv instance. * \param[in] model The OrtModel instance. @@ -7845,13 +7848,13 @@ struct OrtModelEditorApi { * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the * model is finalized. * - * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. - * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. - * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel(). + * Add nodes and initializers to the OrtModel using AddNodeToGraph() and AddInitializerToGraph(). + * Graph inputs/outputs should be updated with SetGraphInputs() and SetGraphOutputs() as needed to reflect changes made * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. * - * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the - * session for inferencing by calling FinalizeModelEditorSession. + * Add the new information from the OrtModel to the original model using ApplyModelToSession(), and prepare the + * session for inferencing by calling FinalizeModelEditorSession(). * * \param{in} env The OrtEnv instance. * \param{in} model_path The path to the existing ONNX model to augment. @@ -7871,13 +7874,13 @@ struct OrtModelEditorApi { * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the * model is finalized. * - * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. - * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. - * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel(). + * Add nodes and initializers to the OrtModel using AddNodeToGraph() and AddInitializerToGraph(). + * Graph inputs/outputs should be updated with SetGraphInputs() and SetGraphOutputs() as needed to reflect changes made * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. * - * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the - * session for inferencing by calling FinalizeModelEditorSession. + * Add the new information from the OrtModel to the original model using ApplyModelToSession(), and prepare the + * session for inferencing by calling FinalizeModelEditorSession(). * * \param{in} env The OrtEnv instance. * \param{in} model_data The model data for the existing model to augment. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a19793f4c67d2..f6845914ed77e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3531,7 +3531,7 @@ struct GraphImpl : ConstGraphImpl { // & outputs); // ::SetOutputs(std::vector& outputs) { } template -inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { - // Graph takes ownership of `initializer` - // On error the ownership is not transferred. +inline void GraphImpl::AddInitializer(const std::string& name, const Value& initializer, bool data_is_external) { + // Graph copies the OrtValue internally. Caller retains ownership of initializer. ThrowOnError(GetModelEditorApi().AddInitializerToGraph(this->p_, name.c_str(), initializer, data_is_external)); - initializer.release(); } template diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 7da1c6936ff31..4487ccf62c0a2 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -6788,12 +6788,12 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati } }; - auto add_initializers = [this](const std::unordered_map>& initializers, + auto add_initializers = [this](const std::unordered_map& initializers, bool is_external) { for (auto& name_and_ortvalue : initializers) { // convert from OrtValue to TensorProto const std::string& name = name_and_ortvalue.first; - OrtValue& v = *name_and_ortvalue.second; + const OrtValue& v = name_and_ortvalue.second; ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); const Tensor& t = v.Get(); @@ -6814,7 +6814,7 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati offset, t.SizeInBytes(), tensor_proto); // add OrtValue to ortvalue_initializers_ to keep it alive and to store the deleter if provided. - ortvalue_initializers_.emplace(name, std::move(v)); + ortvalue_initializers_.emplace(name, v); } else { onnxruntime::utils::SetRawDataInTensorProto(tensor_proto, t.DataRaw(), t.SizeInBytes()); } diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 2c0f6d6174303..6fbd687545ab2 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -81,6 +81,7 @@ struct ModelEditorValueInfo : public OrtValueInfo { "OrtModelEditorApi does not support querying if a OrtValueInfo is defined in an outer scope."); } + bool owned_ = false; // true after ownership transferred to a graph std::string name; std::unique_ptr type_info; }; @@ -154,6 +155,7 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting the parent graph for OrtNode"); } + bool owned_ = false; // true after ownership transferred to a graph size_t id = 0; std::string operator_name; std::string domain_name; @@ -235,8 +237,9 @@ struct ModelEditorGraph : public OrtGraph { onnxruntime::InlinedVector> inputs; onnxruntime::InlinedVector> outputs; - std::unordered_map> initializers; - std::unordered_map> external_initializers; + std::unordered_map initializers; + std::unordered_map external_initializers; + bool owned_ = false; // true after ownership transferred to a model std::vector> nodes; std::string name = "ModelEditorGraph"; std::filesystem::path model_path; diff --git a/onnxruntime/core/session/model_editor_api.h b/onnxruntime/core/session/model_editor_api.h index be6da18de2a64..fdd574bb91f34 100644 --- a/onnxruntime/core/session/model_editor_api.h +++ b/onnxruntime/core/session/model_editor_api.h @@ -33,7 +33,8 @@ ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); -ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, + _In_ const OrtValue* ort_value, bool data_is_external); ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc index 487d5c818f9bc..91cd66f2d2191 100644 --- a/onnxruntime/core/session/model_editor_c_api.cc +++ b/onnxruntime/core/session/model_editor_c_api.cc @@ -3,8 +3,11 @@ #if !defined(ORT_MINIMAL_BUILD) +#include #include +#include "core/common/inlined_containers.h" + #include "core/framework/error_code_helper.h" #include "core/framework/ort_value.h" #include "core/framework/onnxruntime_typeinfo.h" @@ -105,6 +108,14 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateGraph, _Outptr_ OrtGraph** graph) { ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { API_IMPL_BEGIN + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + + if (inputs == nullptr && inputs_len != 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot be null when inputs_len is non-zero"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -112,7 +123,27 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, "Invalid OrtGraph variant for use in the OrtModelEditorApi"); } + // Check for duplicate pointers in the input array to prevent double-free + onnxruntime::InlinedHashSet seen; + for (size_t i = 0; i < inputs_len; ++i) { + if (inputs[i] == nullptr) { + continue; + } + if (!seen.insert(inputs[i]).second) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Duplicate OrtValueInfo pointer found in inputs array. " + "Each OrtValueInfo can only appear once."); + } + onnxruntime::ModelEditorValueInfo* vi = onnxruntime::ModelEditorValueInfo::ToInternal(inputs[i]); + if (vi != nullptr && vi->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This OrtValueInfo has already been added to a graph. " + "Each OrtValueInfo can only be added once."); + } + } + graph->inputs.clear(); + graph->inputs.reserve(inputs_len); for (size_t i = 0; i < inputs_len; ++i) { if (inputs[i] == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); @@ -125,6 +156,7 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, } graph->inputs.push_back(std::unique_ptr(input)); // take ownership + input->owned_ = true; inputs[i] = nullptr; } @@ -135,6 +167,14 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph, _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { API_IMPL_BEGIN + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + + if (outputs == nullptr && outputs_len != 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot be null when outputs_len is non-zero"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -142,7 +182,27 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph "Invalid OrtGraph variant for use in the OrtModelEditorApi"); } + // Check for duplicate pointers in the output array to prevent double-free + onnxruntime::InlinedHashSet seen; + for (size_t i = 0; i < outputs_len; ++i) { + if (outputs[i] == nullptr) { + continue; + } + if (!seen.insert(outputs[i]).second) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Duplicate OrtValueInfo pointer found in outputs array. " + "Each OrtValueInfo can only appear once."); + } + onnxruntime::ModelEditorValueInfo* vi = onnxruntime::ModelEditorValueInfo::ToInternal(outputs[i]); + if (vi != nullptr && vi->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This OrtValueInfo has already been added to a graph. " + "Each OrtValueInfo can only be added once."); + } + } + graph->outputs.clear(); + graph->outputs.reserve(outputs_len); for (size_t i = 0; i < outputs_len; ++i) { if (outputs[i] == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); @@ -155,6 +215,7 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph } graph->outputs.push_back(std::unique_ptr(output)); // take ownership + output->owned_ = true; outputs[i] = nullptr; } @@ -163,8 +224,20 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph } ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* ort_graph, _In_ const char* name, - _Inout_ OrtValue* tensor, bool data_is_external) { + _In_ const OrtValue* ort_value, bool data_is_external) { API_IMPL_BEGIN + if (name == nullptr || *name == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name cannot be null or empty string"); + } + + if (ort_value == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ort_value cannot be null"); + } + + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -172,19 +245,25 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* ort "Invalid OrtGraph variant for use in the OrtModelEditorApi"); } - if (!tensor->IsTensor()) { + if (!ort_value->IsTensor()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported."); } - if (!tensor->IsAllocated()) { + if (!ort_value->IsAllocated()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Tensor must be allocated."); } - const auto& t = tensor->Get(); + const auto& t = ort_value->Get(); if (t.Location().device.Type() != OrtDevice::CPU) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only CPU based tensors are currently supported."); } + // Reject duplicate name in either map + if (graph->initializers.count(name) != 0 || graph->external_initializers.count(name) != 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "An initializer with this name has already been added to the graph."); + } + if (data_is_external) { // enforce that an external initializer is not used if the data size is < 128 bytes. // the reason for this is to avoid potential shape inferencing errors if this initializer is providing an @@ -195,18 +274,26 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* ort "External initializer should only be used for data >= 128 bytes. " "Please use CreateTensorAsOrtValue instead."); } - - graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership - } else { - graph->initializers[name] = std::unique_ptr(tensor); // take ownership } + auto& m = data_is_external ? graph->external_initializers : graph->initializers; + auto [it, inserted] = m.emplace(name, *ort_value); + ORT_ENFORCE(inserted, "Unexpected duplicate name after validation. This is a bug."); + return nullptr; API_IMPL_END } ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* ort_graph, _Inout_ OrtNode* ort_node) { API_IMPL_BEGIN + if (ort_node == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "node cannot be null"); + } + + if (ort_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); if (graph == nullptr) { @@ -221,8 +308,19 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* ort_graph, "Invalid OrtNode variant for use in the OrtModelEditorApi"); } + // Reject if this node has already been added to a graph (prevents double-free) + if (node->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This node has already been added to a graph. " + "Each OrtNode can only be added once."); + } + node->id = graph->nodes.size(); - graph->nodes.push_back(std::unique_ptr(node)); // take ownership + if (graph->nodes.size() == graph->nodes.capacity()) { + graph->nodes.reserve(std::max(graph->nodes.capacity() * 2, size_t{1})); + } + graph->nodes.emplace_back(node); + node->owned_ = true; return nullptr; API_IMPL_END } @@ -246,11 +344,36 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateModel, ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph) { API_IMPL_BEGIN + if (model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "model cannot be null"); + } + if (graph == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); } + // Reject if model already has a graph (prevents double-free/UAF) + if (model->graph != nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Model already has a graph. Each OrtModel can only have one graph."); + } + + // Reject if this graph has already been added to a model (prevents double-free across models) + onnxruntime::ModelEditorGraph* me_graph = onnxruntime::ModelEditorGraph::ToInternal(graph); + if (me_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtGraph variant for use in the OrtModelEditorApi"); + } + + if (me_graph->owned_) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "This graph has already been added to a model. " + "Each OrtGraph can only be added once."); + } + model->graph = std::unique_ptr(graph); // take ownership + me_graph->owned_ = true; + return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5ee5f1486b137..2ac95d6e36466 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2752,14 +2752,35 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSes } ORT_API(void, OrtApis::ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info) { + if (value_info != nullptr) { + if (auto* me = onnxruntime::ModelEditorValueInfo::ToInternal(value_info); + me != nullptr && me->owned_) { + assert(false && "Releasing an OrtValueInfo that is owned by a graph"); + return; + } + } delete value_info; } ORT_API(void, OrtApis::ReleaseNode, _Frees_ptr_opt_ OrtNode* node) { + if (node != nullptr) { + if (auto* me = onnxruntime::ModelEditorNode::ToInternal(node); + me != nullptr && me->owned_) { + assert(false && "Releasing an OrtNode that is owned by a graph"); + return; + } + } delete node; } ORT_API(void, OrtApis::ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph) { + if (graph != nullptr) { + if (auto* me = onnxruntime::ModelEditorGraph::ToInternal(graph); + me != nullptr && me->owned_) { + assert(false && "Releasing an OrtGraph that is owned by a model"); + return; + } + } delete graph; } diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index c5ec376f7d0f5..0237ce773eda2 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -235,7 +235,7 @@ TEST(ModelEditorAPITest, Basic_CApi) { &y_tensor)); Ort::ThrowOnError(model_editor_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); - y_tensor = nullptr; // graph now owns + api.ReleaseValue(y_tensor); if (use_constant_node) { // Test that a Constant node is converted to an initializer @@ -1083,3 +1083,326 @@ TEST(ModelEditorCompileAPITest, EmbedModeWithBufferOutputSatisfiesValidation) { allocator->Free(output_buffer); } } + +// +// Regression tests for double-free / ownership-transfer bugs in OrtModelEditorApi. +// These test that the API rejects attempts to transfer ownership of the same object twice. +// + +TEST(ModelEditorAPITest, AddInitializerToGraph_DuplicateName_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // Create two small ORT-allocated tensors (< 128 bytes, so data_is_external = false) + std::vector dims = {2, 2}; + OrtAllocator* allocator = nullptr; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + + OrtValue* tensor1 = nullptr; + OrtValue* tensor2 = nullptr; + Ort::ThrowOnError(api.CreateTensorAsOrtValue(allocator, dims.data(), dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor1)); + Ort::ThrowOnError(api.CreateTensorAsOrtValue(allocator, dims.data(), dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor2)); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddInitializerToGraph(graph, "W", tensor1, false)); + + // Second add with same name should fail + Ort::Status status{model_editor_api.AddInitializerToGraph(graph, "W", tensor2, false)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + // Clean up — caller retains ownership under copy semantics + api.ReleaseValue(tensor1); + api.ReleaseValue(tensor2); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddInitializerToGraph_SamePointerDifferentName_Succeeds) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + std::vector dims = {2, 2}; + OrtAllocator* allocator = nullptr; + Ort::ThrowOnError(api.GetAllocatorWithDefaultOptions(&allocator)); + + OrtValue* tensor = nullptr; + Ort::ThrowOnError(api.CreateTensorAsOrtValue(allocator, dims.data(), dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &tensor)); + + // Both adds succeed — each creates an independent copy sharing the same underlying data + ASSERT_ORTSTATUS_OK(model_editor_api.AddInitializerToGraph(graph, "W1", tensor, false)); + ASSERT_ORTSTATUS_OK(model_editor_api.AddInitializerToGraph(graph, "W2", tensor, false)); + + // Caller retains ownership and releases + api.ReleaseValue(tensor); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddNodeToGraph_DuplicateNode_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + OrtNode* node = CreateNode(model_editor_api, "Relu", "relu1", {"X"}, {"Y"}); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddNodeToGraph(graph, node)); + + // Second add of same node should fail (prevents double-free) + Ort::Status status{model_editor_api.AddNodeToGraph(graph, node)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddGraphToModel_DuplicateGraph_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph1 = nullptr; + OrtGraph* graph2 = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph1)); + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph2)); + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + OrtModel* model = nullptr; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model)); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddGraphToModel(model, graph1)); + + // Second add should fail (model already has a graph) + Ort::Status status{model_editor_api.AddGraphToModel(model, graph2)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already has a graph")); + + // Clean up graph2 since ownership was NOT transferred + api.ReleaseGraph(graph2); + api.ReleaseModel(model); +} + +TEST(ModelEditorAPITest, SetGraphInputs_DuplicatePointer_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // Create a single OrtValueInfo + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", type_info, &value_info)); + api.ReleaseTypeInfo(type_info); + + // Pass the same pointer twice in the inputs array — should fail + std::vector inputs = {value_info, value_info}; + Ort::Status status{model_editor_api.SetGraphInputs(graph, inputs.data(), inputs.size())}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("Duplicate")); + + // Clean up — ownership was NOT transferred + api.ReleaseValueInfo(value_info); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, SetGraphOutputs_DuplicatePointer_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + // Create a single OrtValueInfo + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* value_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("Y", type_info, &value_info)); + api.ReleaseTypeInfo(type_info); + + // Pass the same pointer twice in the outputs array — should fail + std::vector outputs = {value_info, value_info}; + Ort::Status status{model_editor_api.SetGraphOutputs(graph, outputs.data(), outputs.size())}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("Duplicate")); + + // Clean up — ownership was NOT transferred + api.ReleaseValueInfo(value_info); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, AddNodeToGraph_NullGraph_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtNode* node = CreateNode(model_editor_api, "Relu", "relu1", {"X"}, {"Y"}); + + // Null graph should fail without crashing + Ort::Status status{model_editor_api.AddNodeToGraph(nullptr, node)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("null")); + + api.ReleaseNode(node); +} + +TEST(ModelEditorAPITest, AddGraphToModel_SameGraphTwoModels_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + OrtModel* model1 = nullptr; + OrtModel* model2 = nullptr; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model1)); + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model2)); + + // First add should succeed + ASSERT_ORTSTATUS_OK(model_editor_api.AddGraphToModel(model1, graph)); + + // Second add to different model should fail (graph already owned) + Ort::Status status{model_editor_api.AddGraphToModel(model2, graph)}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + // model2 doesn't own anything, model1 owns graph + api.ReleaseModel(model2); + api.ReleaseModel(model1); +} + +// Skipped in debug builds where the assert in Release functions would fire. +#ifdef NDEBUG +TEST(ModelEditorAPITest, ReleaseNode_AfterAddToGraph_IsNoOp) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + OrtNode* node = CreateNode(model_editor_api, "Relu", "relu1", {"X"}, {"Y"}); + + ASSERT_ORTSTATUS_OK(model_editor_api.AddNodeToGraph(graph, node)); + api.ReleaseNode(node); + api.ReleaseGraph(graph); +} + +TEST(ModelEditorAPITest, ReleaseGraph_AfterAddToModel_IsNoOp) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + OrtModel* model = nullptr; + Ort::ThrowOnError(model_editor_api.CreateModel(domain_names.data(), opset_versions.data(), + domain_names.size(), &model)); + + ASSERT_ORTSTATUS_OK(model_editor_api.AddGraphToModel(model, graph)); + api.ReleaseGraph(graph); + api.ReleaseModel(model); +} + +TEST(ModelEditorAPITest, ReleaseValueInfo_AfterSetGraphInputs_IsNoOp) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph)); + + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* x_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", type_info, &x_info)); + api.ReleaseTypeInfo(type_info); + + OrtValueInfo* saved_ptr = x_info; + std::vector inputs = {x_info}; + ASSERT_ORTSTATUS_OK(model_editor_api.SetGraphInputs(graph, inputs.data(), inputs.size())); + + api.ReleaseValueInfo(saved_ptr); + api.ReleaseGraph(graph); +} +#endif // NDEBUG + +TEST(ModelEditorAPITest, SetGraphInputs_AlreadyOwnedValueInfo_Fails) { + const auto& api = Ort::GetApi(); + const auto& model_editor_api = Ort::GetModelEditorApi(); + + OrtGraph* graph1 = nullptr; + OrtGraph* graph2 = nullptr; + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph1)); + Ort::ThrowOnError(model_editor_api.CreateGraph(&graph2)); + + // Create OrtValueInfo + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector dims = {3, 4}; + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, dims.data(), dims.size())); + + OrtTypeInfo* type_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateTensorTypeInfo(tensor_type_info, &type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); + + OrtValueInfo* x_info = nullptr; + Ort::ThrowOnError(model_editor_api.CreateValueInfo("X", type_info, &x_info)); + api.ReleaseTypeInfo(type_info); + + // Save the raw pointer before SetGraphInputs nulls out the array entry + OrtValueInfo* saved_ptr = x_info; + std::vector inputs = {x_info}; + ASSERT_ORTSTATUS_OK(model_editor_api.SetGraphInputs(graph1, inputs.data(), inputs.size())); + + // Try to add the already-owned ValueInfo to a second graph — should fail + std::vector inputs2 = {saved_ptr}; + Ort::Status status{model_editor_api.SetGraphInputs(graph2, inputs2.data(), inputs2.size())}; + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("already been added")); + + // graph1 owns x_info, graph2 is empty + api.ReleaseGraph(graph2); + api.ReleaseGraph(graph1); +} From 5f071fb9fcc0fd95d95ec9a81a8aad6866e16725 Mon Sep 17 00:00:00 2001 From: Jie Chen Date: Thu, 7 May 2026 03:03:03 +0800 Subject: [PATCH 28/34] Add M-tile loop with dispatch capping for Intel Xe2/3-LPG (#28250) - Wrap 8x16x16 MatMulNBits(SubgroupMatrix) kernel body in M-tile loop using uniforms.m_tiles_per_wg for tile assignment per workgroup - Cap dispatch_y on Xe2/3-LPG when M > 2k, with occupancy factor 16x - Non-Intel or small-M paths pass m_tiles_per_wg=1 (no behavior change) --- .../subgroup_matrix_matmul_nbits.cc | 26 ++- .../subgroup_matrix_matmul_nbits.h | 3 +- ..._matrix_matmul_nbits_8x16x16.wgsl.template | 169 +++++++++--------- 3 files changed, 113 insertions(+), 85 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index cdc0f1ded3e45..a14bf26e7c438 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -293,13 +293,31 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te const bool has_weight_idx = weight_index > 0 || has_weight_idx_indirect; SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; mul_program.SetWorkgroupSize(work_group_size); - mul_program.SetDispatchGroupSize( - (N + tile_size_b - 1) / tile_size_b, - (M + tile_size_a - 1) / tile_size_a, 1); + uint32_t dispatch_x = (N + tile_size_b - 1) / tile_size_b; + uint32_t num_m_tiles = (M + tile_size_a - 1) / tile_size_a; + uint32_t dispatch_y = num_m_tiles; + // For large M on Intel Xe, cap dispatch_y so each workgroup processes multiple + // M-tiles sequentially, reducing scheduling overhead. + if (M > 2048 && context.AdapterInfo().vendor == std::string_view{"intel"}) { + // Each XeCore has 4 XVE x 8 SIMD-32 hardware threads = 32 subgroups. + uint32_t hw_subgroups = 0; + if (context.AdapterInfo().architecture == std::string_view{"xe-3lpg"}) { + hw_subgroups = 384; // 12 XeCore x 32 + } else if (context.AdapterInfo().architecture == std::string_view{"xe-2lpg"}) { + hw_subgroups = 256; // 8 XeCore x 32 + } + if (hw_subgroups > 0) { + constexpr uint32_t kOccupancyFactor = 16; // empirically tuned on Xe2/Xe3 devices + uint32_t target_wgs = hw_subgroups * kOccupancyFactor / (work_group_size / 32); + dispatch_y = std::min(dispatch_y, (target_wgs + dispatch_x - 1) / dispatch_x); + } + } + uint32_t m_tiles_per_wg = (num_m_tiles + dispatch_y - 1) / dispatch_y; + mul_program.SetDispatchGroupSize(dispatch_x, dispatch_y, 1); mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kU32Components : 2 * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({{M}, {N}, {K}, {zero_blocks_per_col}, {weight_index}}) + .AddUniformVariables({{M}, {N}, {K}, {zero_blocks_per_col}, {weight_index}, {m_tiles_per_wg}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}) .CacheHint(nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect); if (has_zero_points) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index 810bda950b169..f4b0d10262de5 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -32,7 +32,8 @@ class SubgroupMatrixMatMulNBitsProgram final : public Program; - var sg_mat_c1: subgroup_matrix_result; - var sg_mat_c2: subgroup_matrix_result; - var sg_mat_c3: subgroup_matrix_result; - for (var k_idx: u32 = 0; k_idx < uniforms.K; k_idx += kTileK) { - // Load Phase - dequant_b_to_tile(global_base_b, k_idx, local_idx / 4, local_idx % 4); - workgroupBarrier(); - - for (var sg_mat_k_idx: u32 = 0; sg_mat_k_idx < kTileK; sg_mat_k_idx += kSgMatK) - { - // Load A from global memory (prepacked layout). - // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride - var sg_mat_a0: subgroup_matrix_left = - subgroupMatrixLoad>( - &input_a, sg_mat_offset_a, false, kSgMatK); - sg_mat_offset_a += kSgMatSizeLeft; - - // Load B from shared local memory. - // tile_b [kTileN, kTileK] is stored as column major. - var sg_mat_b0: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx, true, kTileK); - var sg_mat_b1: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + kSgMatStrideN, true, kTileK); - var sg_mat_b2: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + 2 * kSgMatStrideN, true, kTileK); - var sg_mat_b3: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + 3 * kSgMatStrideN, true, kTileK); - - // Compute Phase - // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate - sg_mat_c0 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b0, sg_mat_c0); - sg_mat_c1 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b1, sg_mat_c1); - sg_mat_c2 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b2, sg_mat_c2); - sg_mat_c3 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b3, sg_mat_c3); + let num_tiles_m = (uniforms.M + kTileM - 1) / kTileM; + + // Zero-initialized accumulator template (used to reset per M-tile iteration). + var sg_mat_zero: subgroup_matrix_result; + + // Sequential M-loop: each workgroup processes a contiguous block of M-tiles. + let m_start = workgroup_id.y * uniforms.m_tiles_per_wg; + let m_end = min(m_start + uniforms.m_tiles_per_wg, num_tiles_m); + for (var m_tile: u32 = m_start; m_tile < m_end; m_tile++) { + let global_base_a = m_tile * kTileM; + let sg_mat_idx = (m_tile * kSgMatCountM + sg_idx) * sg_mat_count_k; + + var sg_mat_offset_a = sg_mat_idx * kSgMatSizeLeft; + + var sg_mat_c0 = sg_mat_zero; + var sg_mat_c1 = sg_mat_zero; + var sg_mat_c2 = sg_mat_zero; + var sg_mat_c3 = sg_mat_zero; + for (var k_idx: u32 = 0; k_idx < uniforms.K; k_idx += kTileK) { + // Load Phase + dequant_b_to_tile(global_base_b, k_idx, local_idx / 4, local_idx % 4); + workgroupBarrier(); + + for (var sg_mat_k_idx: u32 = 0; sg_mat_k_idx < kTileK; sg_mat_k_idx += kSgMatK) + { + // Load A from global memory (prepacked layout). + // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride + var sg_mat_a0: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a, false, kSgMatK); + sg_mat_offset_a += kSgMatSizeLeft; + + // Load B from shared local memory. + // tile_b [kTileN, kTileK] is stored as column major. + var sg_mat_b0: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx, true, kTileK); + var sg_mat_b1: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + kSgMatStrideN, true, kTileK); + var sg_mat_b2: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + 2 * kSgMatStrideN, true, kTileK); + var sg_mat_b3: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + 3 * kSgMatStrideN, true, kTileK); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + sg_mat_c0 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b0, sg_mat_c0); + sg_mat_c1 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b1, sg_mat_c1); + sg_mat_c2 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b2, sg_mat_c2); + sg_mat_c3 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b3, sg_mat_c3); + } + workgroupBarrier(); } - workgroupBarrier(); - } - // Write out + // Write out #if has_bias - // Store results to scratch workgroup memory, then add bias and write to output. - // scratch layout: [kTileM, kTileN] row-major - let scratch_m_base = sg_idx * kSgMatM; - subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); - workgroupBarrier(); - - // 256 threads write 64x64 = 4096 elements. Each thread handles 16 elements. - // Thread mapping: m = local_idx / 4, n_base = (local_idx % 4) * 16 - let out_m = local_idx / 4; - let out_n_base = (local_idx % 4) * 16; - let global_m = global_base_a + out_m; - if (global_m < uniforms.M) { - let global_n_base = global_base_b + out_n_base; - let scratch_base = out_m * kTileN + out_n_base; - let out_base = global_m * uniforms.N + global_n_base; + // Store results to scratch workgroup memory, then add bias and write to output. + // scratch layout: [kTileM, kTileN] row-major + let scratch_m_base = sg_idx * kSgMatM; + subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); + workgroupBarrier(); + + // 256 threads write 64x64 = 4096 elements. Each thread handles 16 elements. + // Thread mapping: m = local_idx / 4, n_base = (local_idx % 4) * 16 + let out_m = local_idx / 4; + let out_n_base = (local_idx % 4) * 16; + let global_m = global_base_a + out_m; + if (global_m < uniforms.M) { + let global_n_base = global_base_b + out_n_base; + let scratch_base = out_m * kTileN + out_n_base; + let out_base = global_m * uniforms.N + global_n_base; #if has_weight_idx_indirect - let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; + let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; #elif has_weight_idx - let bias_offset = uniforms.weight_idx * uniforms.N; + let bias_offset = uniforms.weight_idx * uniforms.N; #else - const bias_offset: u32 = 0; + const bias_offset: u32 = 0; #endif - for (var i: u32 = 0; i < 16; i++) { - if (global_n_base + i < uniforms.N) { - let val = output_element_t(scratch[scratch_base + i]) - + bias[bias_offset + global_n_base + i]; - output.setByOffset(out_base + i, val); + for (var i: u32 = 0; i < 16; i++) { + if (global_n_base + i < uniforms.N) { + let val = output_element_t(scratch[scratch_base + i]) + + bias[bias_offset + global_n_base + i]; + output.setByOffset(out_base + i, val); + } } } - } #else - let sg_mat_offset_c = global_base_a * uniforms.N + global_base_b + sg_idx * kSgMatM * uniforms.N; - subgroupMatrixStore(&output, sg_mat_offset_c, sg_mat_c0, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + kSgMatN, sg_mat_c1, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + 2 * kSgMatN, sg_mat_c2, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + 3 * kSgMatN, sg_mat_c3, false, uniforms.N); + let sg_mat_offset_c = global_base_a * uniforms.N + global_base_b + sg_idx * kSgMatM * uniforms.N; + subgroupMatrixStore(&output, sg_mat_offset_c, sg_mat_c0, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + kSgMatN, sg_mat_c1, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + 2 * kSgMatN, sg_mat_c2, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + 3 * kSgMatN, sg_mat_c3, false, uniforms.N); #endif + } // end M-tile loop } // MAIN From 470977a461b517d3bcde89ef58ba6414d6afc86e Mon Sep 17 00:00:00 2001 From: Max Buckley Date: Wed, 6 May 2026 23:00:31 +0200 Subject: [PATCH 29/34] [CoreML EP] Support pre-opset-13 Split via 'split' attribute (#28270) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description The CoreML `SplitOpBuilder` previously gated `GetMinSupportedOpSet` at 13 because pre-13 `Split` carries split sizes via an INTS attribute rather than a second input. This PR lowers the gate to 1 and reads the attribute in both the MLProgram and NeuralNetwork emitters, so `Split` from any opset is supported on the CoreML EP. The validation in `IsOpSupportedImpl` mirrors the existing input-form rules — ≥2 outputs, sum of sizes equals the axis dim, all sizes positive, axis dim not dynamic. For the no-attribute / no-input case (legacy even-split) we also explicitly require the axis dim to be evenly divisible by `num_outputs`, since CoreML's `num_splits` requires that. This is a behavior change only for opset 2–12 graphs that were 100% rejected before, so no path that used to work regresses. ### Motivation DWPose `dw-ll_ucoco_384.onnx` (opset 11), a common pose-estimation model, has two `Split` nodes — one uneven (`split=[512, 512, 128]`) and one even (`split=[1, 1]`). Both fall back to CPU today, fragmenting the CoreML partition. | | Without this PR | With this PR | |---|---|---| | CoreML partitions | 3 | **1** | | Nodes on CoreML EP | 301 / 303 | **303 / 303** | ### Benchmark — M3 Max, MLProgram, batch 1, 1299-iter steady state | Metric | Without PR | With PR | Δ | |---|---|---|---| | Mean | 6.838 ms | 6.565 ms | −4.0% | | **StdDev** | **0.239 ms** | **0.170 ms** | **−29%** | | P50 | 6.810 ms | 6.545 ms | −3.9% | | P95 | 7.070 ms | 6.775 ms | −4.2% | | P99 | 7.330 ms | 6.928 ms | −5.5% | | P99.9 | 8.917 ms | 8.164 ms | −8.4% | | **Max** | **12.616 ms** | **10.360 ms** | **−17.9%** | Removing the two CPU↔CoreML round trips improves the tail far more than the median — useful for real-time perception pipelines where worst-case latency determines the frame budget. ### Tests Eight new tests in `onnxruntime/test/providers/coreml/coreml_basic_test.cc`, each exercising both the NeuralNetwork and MLProgram emitters and asserting full CoreML EP node assignment (no CPU fallback). **Pre-opset-13 attribute form (the new code path):** - `Split7UnevenAttribute` — opset 7 uneven `split=[4, 3, 2]`, covering the opset 7–10 range. - `Split11UnevenAttribute` — DWPose's pattern, `split=[4, 3, 2]`. - `Split11EvenAttribute` — uniform sizes via attribute. - `Split11NoAttributeEven` — falls through to the `num_splits = num_outputs` branch. **Post-opset-13 input form (parity with the existing, untouched path):** - `Split13UnevenInput` — `split` input `[4, 3, 2]`. - `Split13EvenInput` — uniform sizes via input. - `Split13NoSplitInputEven` — no `split` input, even-split fallback. **Negative coverage:** - `Split11ZeroSplitValueNotSupported` — verifies the attribute-form rejection of a non-positive entry; expects no CoreML assignment. All eight pass locally on macOS 26.3 / M3 Max. ### Motivation for upstreaming Most pre-2023 vision exports (DWPose, MMPose models, original YOLOv5/v7/v8, etc.) target ONNX opset 11/12 and use the `Split` attribute form. They currently lose any `Split` to CPU on the CoreML EP. This is a self-contained gap with a clean fix. --------- Co-authored-by: Claude Opus 4.7 (1M context) --- .../coreml/builders/impl/split_op_builder.cc | 50 +- .../providers/coreml/coreml_basic_test.cc | 571 ++++++++++++++++++ 2 files changed, 619 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 4ee9b54cebd16..875754138e408 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -23,8 +23,7 @@ class SplitOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; - // Split opset 13- uses "split" as attribute. Currently it's not supported. - int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } + int GetMinSupportedOpSet(const Node& /* node */) const override { return 1; } bool SupportsMLProgram() const override { return true; } }; @@ -56,6 +55,9 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return std::make_tuple(remainder, chunk_size); }; + // Pre-opset-13 'split' is an INTS attribute. If present, it overrides even splitting. + const auto split_attr = helper.GetInt64s("split"); + if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; std::unique_ptr split_op = model_builder.CreateOperation(node, "split"); @@ -68,6 +70,10 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto split_span = unpacked_tensor.DataAsSpan(); AddOperationInput(*split_op, "split_sizes", model_builder.AddConstant(split_op->type(), "split_sizes", split_span)); + } else if (split_attr) { + // pre-opset-13 'split' attribute + AddOperationInput(*split_op, "split_sizes", + model_builder.AddConstant(split_op->type(), "split_sizes", *split_attr)); } else if (node.SinceVersion() < 18) { int64_t num_outputs = narrow(node.OutputDefs().size()); AddOperationInput(*split_op, "num_splits", @@ -109,6 +115,11 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, for (const auto& split_size : split_span) { coreml_splitnd->add_splitsizes(split_size); } + } else if (split_attr) { + // pre-opset-13 'split' attribute + for (const auto& split_size : *split_attr) { + coreml_splitnd->add_splitsizes(split_size); + } } else if (node.SinceVersion() < 18) { int64_t num_outputs = narrow(node.OutputDefs().size()); coreml_splitnd->set_numsplits(num_outputs); @@ -166,6 +177,10 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar return false; } + if (split_dims_at_axis == -1) { + LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic."; + return false; + } Initializer unpacked_tensor(input_params.graph_viewer.GetGraph(), *splits_tensor, input_params.graph_viewer.ModelPath()); auto splits_span = unpacked_tensor.DataAsSpan(); @@ -182,10 +197,27 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar LOGS(logger, VERBOSE) << "Invalid value in 'splits' input."; return false; } + } else if (const auto split_attr = helper.GetInt64s("split"); split_attr) { + // pre-opset-13: 'split' is an INTS attribute. Validate the same way we + // validate the input form above. + if (split_attr->size() < 2) { + LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs."; + return false; + } if (split_dims_at_axis == -1) { LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic."; return false; } + int64_t sum_of_splits = std::accumulate(split_attr->begin(), split_attr->end(), int64_t{0}); + if (sum_of_splits != split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Mismatch between sum of 'split' attribute and split-axis size. Expected: " + << split_dims_at_axis << " Actual: " << sum_of_splits; + return false; + } + if (std::any_of(split_attr->begin(), split_attr->end(), [](int64_t v) { return v <= 0; })) { + LOGS(logger, VERBOSE) << "Invalid value in 'split' attribute (sizes must be positive)."; + return false; + } } else { if (node.SinceVersion() >= 18) { const auto num_outputs = helper.GetInt64("num_outputs"); @@ -205,6 +237,20 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar << num_outputs.value(); return false; } + } else if (node.OutputDefs().size() < 2) { + LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs."; + return false; + } else if (split_dims_at_axis == -1) { + // No 'split' attr or input: ONNX spec says split evenly, but we cannot + // verify divisibility without a known axis size. + LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic when 'split' is omitted."; + return false; + } else if (split_dims_at_axis % static_cast(node.OutputDefs().size()) != 0) { + // No 'split' attr or input: ONNX spec says split evenly. CoreML's + // num_splits requires the axis size be evenly divisible. + LOGS(logger, VERBOSE) << "Even split required when 'split' is omitted; axis size " + << split_dims_at_axis << " not divisible by num outputs " << node.OutputDefs().size(); + return false; } } return true; diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index f56c81d2e89de..61c3297a43118 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -1164,6 +1164,577 @@ TEST(CoreMLExecutionProviderTest, QuickGeluTestFp16) { #endif } +TEST(CoreMLExecutionProviderTest, Split11UnevenAttribute) { + // ai.onnx:Split-11 with `split` attribute carrying non-uniform sizes. + // This is the form used by DWPose (`dw-ll_ucoco_384.onnx`); without + // attribute support the node falls back to CPU and fragments the CoreML + // partition. + std::unordered_map domain_to_version{{kOnnxDomain, 11}}; + onnxruntime::Model model("split11_uneven_attribute", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + // Input X: {1, 9} float + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(9); + + // Outputs along axis=1 with split=[4, 3, 2]: {1,4}, {1,3}, {1,2} + auto make_output_type = [](int64_t split_size) { + ONNX_NAMESPACE::TypeProto t; + t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* s = t.mutable_tensor_type()->mutable_shape(); + s->add_dim()->set_dim_value(1); + s->add_dim()->set_dim_value(split_size); + return t; + }; + ONNX_NAMESPACE::TypeProto out0_type = make_output_type(4); + ONNX_NAMESPACE::TypeProto out1_type = make_output_type(3); + ONNX_NAMESPACE::TypeProto out2_type = make_output_type(2); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &out0_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &out1_type); + auto& out2_arg = graph.GetOrCreateNodeArg("Y2", &out2_type); + + auto& node = graph.AddNode("split11_uneven", "Split", "Split-11 with uneven 'split' attribute", + {&input_arg}, {&out0_arg, &out1_arg, &out2_arg}); + node.AddAttribute("axis", static_cast(1)); + node.AddAttribute("split", std::vector{4, 3, 2}); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 9}; + std::vector input_data = {0.5f, -1.0f, 2.25f, -3.5f, 4.0f, -0.125f, 7.5f, -8.0f, 0.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "Split11UnevenAttribute_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "Split11UnevenAttribute_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, Split11EvenAttribute) { + // Even sizes via attribute — exercises the split_sizes path with uniform + // values rather than the fall-through num_splits path. + std::unordered_map domain_to_version{{kOnnxDomain, 11}}; + onnxruntime::Model model("split11_even_attribute", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(6); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(3); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &output_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &output_type); + + auto& node = graph.AddNode("split11_even", "Split", "Split-11 with even 'split' attribute", + {&input_arg}, {&out0_arg, &out1_arg}); + node.AddAttribute("axis", static_cast(1)); + node.AddAttribute("split", std::vector{3, 3}); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 6}; + std::vector input_data = {1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "Split11EvenAttribute_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "Split11EvenAttribute_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, Split11NoAttributeEven) { + // No `split` attribute, axis size divides evenly: must fall through to the + // num_splits = num_outputs branch. + std::unordered_map domain_to_version{{kOnnxDomain, 11}}; + onnxruntime::Model model("split11_no_attribute_even", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(8); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(4); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &output_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &output_type); + + auto& node = graph.AddNode("split11_no_attr", "Split", "Split-11 with no 'split' attribute", + {&input_arg}, {&out0_arg, &out1_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 8}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "Split11NoAttributeEven_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "Split11NoAttributeEven_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, Split13UnevenInput) { + // Parity with Split11UnevenAttribute: same shapes and split sizes, but using + // the opset-13 input form ('split' as a constant initializer) instead of the + // pre-13 attribute form. Locks in that the existing input path still works. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("split13_uneven_input", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(9); + + auto make_output_type = [](int64_t split_size) { + ONNX_NAMESPACE::TypeProto t; + t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* s = t.mutable_tensor_type()->mutable_shape(); + s->add_dim()->set_dim_value(1); + s->add_dim()->set_dim_value(split_size); + return t; + }; + ONNX_NAMESPACE::TypeProto out0_type = make_output_type(4); + ONNX_NAMESPACE::TypeProto out1_type = make_output_type(3); + ONNX_NAMESPACE::TypeProto out2_type = make_output_type(2); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &out0_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &out1_type); + auto& out2_arg = graph.GetOrCreateNodeArg("Y2", &out2_type); + + ONNX_NAMESPACE::TensorProto split_init; + split_init.set_name("split_sizes"); + split_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_init.add_dims(3); + for (auto v : std::vector{4, 3, 2}) { + split_init.add_int64_data(v); + } + graph.AddInitializedTensor(split_init); + auto& split_arg = graph.GetOrCreateNodeArg("split_sizes", nullptr); + + auto& node = graph.AddNode("split13_uneven", "Split", "Split-13 with uneven 'split' input", + {&input_arg, &split_arg}, {&out0_arg, &out1_arg, &out2_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 9}; + std::vector input_data = {0.5f, -1.0f, 2.25f, -3.5f, 4.0f, -0.125f, 7.5f, -8.0f, 0.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "Split13UnevenInput_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "Split13UnevenInput_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, Split13EvenInput) { + // Parity with Split11EvenAttribute via the opset-13 input form. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("split13_even_input", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(6); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(3); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &output_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &output_type); + + ONNX_NAMESPACE::TensorProto split_init; + split_init.set_name("split_sizes"); + split_init.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_init.add_dims(2); + for (auto v : std::vector{3, 3}) { + split_init.add_int64_data(v); + } + graph.AddInitializedTensor(split_init); + auto& split_arg = graph.GetOrCreateNodeArg("split_sizes", nullptr); + + auto& node = graph.AddNode("split13_even", "Split", "Split-13 with even 'split' input", + {&input_arg, &split_arg}, {&out0_arg, &out1_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 6}; + std::vector input_data = {1.0f, -2.0f, 3.0f, -4.0f, 5.0f, -6.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "Split13EvenInput_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "Split13EvenInput_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, Split13NoSplitInputEven) { + // Parity with Split11NoAttributeEven: opset 13 with no 'split' input must + // fall through to the SinceVersion() < 18 even-split branch (num_splits = + // num_outputs) for both emitters. + std::unordered_map domain_to_version{{kOnnxDomain, 13}}; + onnxruntime::Model model("split13_no_split_input_even", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(8); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* output_shape = output_type.mutable_tensor_type()->mutable_shape(); + output_shape->add_dim()->set_dim_value(1); + output_shape->add_dim()->set_dim_value(4); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &output_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &output_type); + + auto& node = graph.AddNode("split13_no_split_input", "Split", "Split-13 with no 'split' input", + {&input_arg}, {&out0_arg, &out1_arg}); + node.AddAttribute("axis", static_cast(1)); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 8}; + std::vector input_data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "Split13NoSplitInputEven_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "Split13NoSplitInputEven_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, Split7UnevenAttribute) { + // Opset 7 (≤ 10) parity check. The builder advertises support from opset 1 + // and reads the 'split' attribute; the Split11* tests cover opset 11, this + // test covers the opset 7-10 range explicitly. + std::unordered_map domain_to_version{{kOnnxDomain, 7}}; + onnxruntime::Model model("split7_uneven_attribute", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(9); + + auto make_output_type = [](int64_t split_size) { + ONNX_NAMESPACE::TypeProto t; + t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* s = t.mutable_tensor_type()->mutable_shape(); + s->add_dim()->set_dim_value(1); + s->add_dim()->set_dim_value(split_size); + return t; + }; + ONNX_NAMESPACE::TypeProto out0_type = make_output_type(4); + ONNX_NAMESPACE::TypeProto out1_type = make_output_type(3); + ONNX_NAMESPACE::TypeProto out2_type = make_output_type(2); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &out0_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &out1_type); + auto& out2_arg = graph.GetOrCreateNodeArg("Y2", &out2_type); + + auto& node = graph.AddNode("split7_uneven", "Split", "Split-7 with uneven 'split' attribute", + {&input_arg}, {&out0_arg, &out1_arg, &out2_arg}); + node.AddAttribute("axis", static_cast(1)); + node.AddAttribute("split", std::vector{4, 3, 2}); + + ASSERT_STATUS_OK(graph.Resolve()); + +#if defined(__APPLE__) + std::vector dims = {1, 9}; + std::vector input_data = {0.5f, -1.0f, 2.25f, -3.5f, 4.0f, -0.125f, 7.5f, -8.0f, 0.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); + CreateMLValue(allocator, dims, input_data, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + + RunAndVerifyOutputsWithEP(model_span, "Split7UnevenAttribute_NN", + MakeCoreMLExecutionProvider(), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); + RunAndVerifyOutputsWithEP(model_span, "Split7UnevenAttribute_MLProgram", + MakeCoreMLExecutionProvider("MLProgram"), + feeds, + EPVerificationParams{ExpectedEPNodeAssignment::All}); +#else + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, Split11ZeroSplitValueNotSupported) { + // Negative: a zero entry in the 'split' attribute must be rejected so the + // node falls back to CPU. Sum still equals the axis size, so this exercises + // the non-positive value check specifically. + std::unordered_map domain_to_version{{kOnnxDomain, 11}}; + onnxruntime::Model model("split11_zero_split_value", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(9); + + auto make_output_type = [](int64_t split_size) { + ONNX_NAMESPACE::TypeProto t; + t.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* s = t.mutable_tensor_type()->mutable_shape(); + s->add_dim()->set_dim_value(1); + s->add_dim()->set_dim_value(split_size); + return t; + }; + ONNX_NAMESPACE::TypeProto out0_type = make_output_type(3); + ONNX_NAMESPACE::TypeProto out1_type = make_output_type(0); + ONNX_NAMESPACE::TypeProto out2_type = make_output_type(6); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &out0_type); + auto& out1_arg = graph.GetOrCreateNodeArg("Y1", &out1_type); + auto& out2_arg = graph.GetOrCreateNodeArg("Y2", &out2_type); + + auto& node = graph.AddNode("split11_zero", "Split", "Split-11 with a zero 'split' entry", + {&input_arg}, {&out0_arg, &out1_arg, &out2_arg}); + node.AddAttribute("axis", static_cast(1)); + node.AddAttribute("split", std::vector{3, 0, 6}); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); +} + +TEST(CoreMLExecutionProviderTest, Split11SingleOutputNotSupported) { + // Negative: a Split node with only 1 output. CoreML SplitND requires ≥2, + // so the attribute-form path's split_attr->size() < 2 check rejects it. + // ONNX schema allows variadic ≥1 outputs and CPU's Split kernel accepts + // a single output, so this case can be observed via partition assertion. + std::unordered_map domain_to_version{{kOnnxDomain, 11}}; + onnxruntime::Model model("split11_single_output", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto input_type; + input_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto* input_shape = input_type.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_value(1); + input_shape->add_dim()->set_dim_value(5); + + ONNX_NAMESPACE::TypeProto output_type; + output_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(5); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &input_type); + auto& out0_arg = graph.GetOrCreateNodeArg("Y0", &output_type); + + auto& node = graph.AddNode("split11_single_output", "Split", + "Split-11 with a single output", + {&input_arg}, {&out0_arg}); + node.AddAttribute("axis", static_cast(1)); + node.AddAttribute("split", std::vector{5}); + + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + gsl::span model_span{reinterpret_cast(model_data.data()), model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None); + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::None); +} + #endif // !(ORT_MINIMAL_BUILD) } // namespace test } // namespace onnxruntime From 19738c570937b280ce037bec2f01d2ddbcf81687 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 6 May 2026 14:45:16 -0700 Subject: [PATCH 30/34] [Plugin EP] Add OrtEp::OnSessionInitializationEnd() (#28319) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Port `IExecutionProvider::OnSessionInitializationEnd()` to the public plugin EP C API (`OrtEp` struct). This gives plugin EP authors an opportunity to synchronize and clean up temporary resources after session initialization is complete, reducing memory usage and ensuring the first inference run is fast. ### Changes - **`onnxruntime_ep_c_api.h`** — Added optional `OnSessionInitializationEnd` function pointer to end of `OrtEp` struct (`\since Version 1.27`). - **`ep_plugin_provider_interfaces.h/.cc`** — Added `PluginExecutionProvider::OnSessionInitializationEnd()` override that delegates to the `OrtEp` callback with version (`< 27`) and null guards. - **`ep_plugin_provider_test.cc`** — Added 4 unit tests covering: null callback fallback, success, error propagation, and old-version fallback. ### Motivation and Context `OnSessionInitializationEnd()` already exists on the internal `IExecutionProvider` (used by DML EP, for example). Plugin EPs built as shared libraries against the public C API had no way to receive this notification. This change closes that gap. ### Test Coverage | Test | Scenario | |------|----------| | `OnSessionInitializationEnd_NullCallback` | NULL pointer falls back to base class (OK) | | `OnSessionInitializationEnd_Success` | Callback returns OK | | `OnSessionInitializationEnd_Error` | Error status propagates | | `OnSessionInitializationEnd_OldVersionFallback` | Version < 27 bypasses callback | --- .../core/session/onnxruntime_ep_c_api.h | 16 +++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- .../ep_plugin_provider_interfaces.cc | 7 +++ .../plugin_ep/ep_plugin_provider_interfaces.h | 2 + .../test/framework/ep_plugin_provider_test.cc | 48 +++++++++++++++++++ 5 files changed, 74 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 62757812cd6e3..76fb7ce93b600 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2551,6 +2551,22 @@ struct OrtEp { * \since Version 1.26. */ ORT_API2_STATUS(GetAvailableResource, _In_ const OrtEp* this_ptr, _Out_ OrtResourceCount* available); + + /** \brief Called by ORT when session initialization is complete. + * + * This provides an opportunity for execution providers to optionally synchronize and + * clean up temporary resources to reduce memory usage and ensure the first inference run is fast. + * + * \param[in] this_ptr The OrtEp instance. + * + * \note Implementation of this function is optional. If set to NULL, ORT assumes no + * post-initialization work is needed and treats it as a no-op success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + */ + ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 2ac95d6e36466..f1454f4e74ab5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4953,7 +4953,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); -static_assert(offsetof(OrtApi, KernelInfoGetAttributeArray_string) / sizeof(void*) == 417, "Size of version 25 API cannot change"); +static_assert(offsetof(OrtApi, SetPerSessionThreadPoolCallbacks) / sizeof(void*) == 418, "Size of version 25 API cannot change"); // no additions in version 26 // So that nobody forgets to finish an API version, this check will serve as a reminder: diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index b94497853fea2..d8094fe68ea53 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -783,6 +783,13 @@ Status PluginExecutionProvider::OnRunEnd(bool sync_stream, const RunOptions& run return ToStatusAndRelease(ort_ep_->OnRunEnd(ort_ep_.get(), &run_options, sync_stream)); } +Status PluginExecutionProvider::OnSessionInitializationEnd() { + if (ort_ep_->ort_version_supported < 27 || ort_ep_->OnSessionInitializationEnd == nullptr) { + return Base::OnSessionInitializationEnd(); + } + return ToStatusAndRelease(ort_ep_->OnSessionInitializationEnd(ort_ep_.get())); +} + Status PluginExecutionProvider::Sync() const { if (ort_ep_->ort_version_supported < 25 || ort_ep_->Sync == nullptr) { return Base::Sync(); diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 8218571a8b1fe..ba84403dec8aa 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -117,6 +117,8 @@ class PluginExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const RunOptions& run_options) override; + Status OnSessionInitializationEnd() override; + Status Sync() const override; Status SetEpDynamicOptions(gsl::span keys, diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 5fe77d8c62e09..80b638314bad9 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -1513,4 +1513,52 @@ TEST(PluginExecutionProviderTest, GetAvailableResource_NullCallbackLeavesThresho EXPECT_FALSE(accountant->GetThreshold().has_value()); } +// OnSessionInitializationEnd is nullptr -> falls back to base class (returns OK). +TEST(PluginExecutionProviderTest, OnSessionInitializationEnd_NullCallback) { + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + ort_ep->OnSessionInitializationEnd = nullptr; + ASSERT_STATUS_OK(ep->OnSessionInitializationEnd()); +} + +// OnSessionInitializationEnd returns OK status. +TEST(PluginExecutionProviderTest, OnSessionInitializationEnd_Success) { + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + ort_ep->OnSessionInitializationEnd = [](OrtEp* /*this_ptr*/) noexcept -> OrtStatus* { + return nullptr; + }; + + ASSERT_STATUS_OK(ep->OnSessionInitializationEnd()); +} + +// OnSessionInitializationEnd returns an error status -> error propagates. +TEST(PluginExecutionProviderTest, OnSessionInitializationEnd_Error) { + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + ort_ep->OnSessionInitializationEnd = [](OrtEp* this_ptr) noexcept -> OrtStatus* { + auto* test_ep = static_cast(this_ptr); + return test_ep->ort_api->CreateStatus(ORT_RUNTIME_EXCEPTION, "cleanup failed"); + }; + + auto status = ep->OnSessionInitializationEnd(); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("cleanup failed")); +} + +// OnSessionInitializationEnd with old ort_version_supported -> falls back to base class even if pointer is set. +TEST(PluginExecutionProviderTest, OnSessionInitializationEnd_OldVersionFallback) { + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + ort_ep->OnSessionInitializationEnd = [](OrtEp* this_ptr) noexcept -> OrtStatus* { + auto* test_ep = static_cast(this_ptr); + return test_ep->ort_api->CreateStatus(ORT_RUNTIME_EXCEPTION, "should not be called"); + }; + + // Simulate an older EP version that doesn't have this field. + ort_ep->ort_version_supported = 26; + + ASSERT_STATUS_OK(ep->OnSessionInitializationEnd()); +} + } // namespace onnxruntime::test From e3c34da40639669f3dbb7ae95db0662afbec8cc9 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 6 May 2026 15:00:48 -0700 Subject: [PATCH 31/34] Refactor and modernize StringNormalizer. (#28320) This pull request refactors and modernizes the UTF-8 and wide character (wchar_t) string conversion logic in the string normalizer CPU kernel, replacing deprecated and complex code with new, platform-appropriate utilities. The changes improve code maintainability, portability, and performance, especially on non-Windows platforms, by introducing custom UTF-8 conversion routines and simplifying buffer management. The most important changes are: **UTF-8 and Wide Character Conversion Utilities:** * Added new UTF-8 <-> wchar_t conversion functions (`WideToUtf8RequiredSize`, `WideToUtf8`, `Utf8ToWide`, and `Utf8ToWideString`) for non-Windows platforms in `utf8_util.h`, avoiding deprecated `std::codecvt` and providing robust Unicode handling. * Updated `Utf8ConverterGeneric` in `string_normalizer.cc` to use these new utilities, greatly simplifying the code and removing legacy/deprecated conversion logic. **Code Simplification and Performance:** * Simplified buffer size estimation for conversions: now directly uses the UTF-8 string size as an upper bound for the wide buffer, removing the need for a full decode pass just to compute buffer sizes. * Improved comments and logic for case-insensitive filtering, clarifying why lowercasing is used and how conversions are managed for efficiency. [[1]](diffhunk://#diff-20cdc2200d64f7c8dba541825ed6de8e69c5aaf0c0ece6967d3613482d0aaf16L32-R39) [[2]](diffhunk://#diff-26d2562f008c04f6d64a9c805054957c6a888040bd0912d5c16a53ed05512ca8L614-R446) **Cleanup and Modernization:** * Removed all usage of deprecated `std::codecvt` and related workaround code, as well as unnecessary includes and platform-specific handling, resulting in cleaner and more maintainable code. [[1]](diffhunk://#diff-26d2562f008c04f6d64a9c805054957c6a888040bd0912d5c16a53ed05512ca8R8-L27) [[2]](diffhunk://#diff-26d2562f008c04f6d64a9c805054957c6a888040bd0912d5c16a53ed05512ca8L39-R57) [[3]](diffhunk://#diff-26d2562f008c04f6d64a9c805054957c6a888040bd0912d5c16a53ed05512ca8L419-L428) These changes collectively modernize the string normalization kernel, improve portability, and make the codebase easier to maintain. --- onnxruntime/core/common/utf8_util.h | 217 ++++++++- .../providers/cpu/text/string_normalizer.cc | 412 ++++++------------ .../providers/cpu/text/string_normalizer.h | 36 +- onnxruntime/test/common/utf8_util_test.cc | 394 +++++++++++++++++ .../cpu/text/string_normalizer_test.cc | 370 +++++++++++++++- 5 files changed, 1126 insertions(+), 303 deletions(-) diff --git a/onnxruntime/core/common/utf8_util.h b/onnxruntime/core/common/utf8_util.h index 583aaf0a47cf7..360d17327fd66 100644 --- a/onnxruntime/core/common/utf8_util.h +++ b/onnxruntime/core/common/utf8_util.h @@ -5,16 +5,19 @@ #include "core/common/common.h" +#include + namespace onnxruntime { namespace utf8_util { /// -/// Checks the extension bytes and returns a number of -/// bytes in the UTF-8 character +/// Classifies a UTF-8 lead byte by encoded length. +/// This is a structural prefix check only; full well-formedness validation +/// is handled by utf8_validate. /// -/// -/// result -/// false if the char len is greater than 4 otherwise true +/// lead byte candidate +/// decoded byte length +/// false if the byte does not match any 1-4 byte UTF-8 lead-byte prefix inline bool utf8_bytes(unsigned char ch, size_t& len) { if ((ch & 0x80) == 0) { len = 1; @@ -24,12 +27,11 @@ inline bool utf8_bytes(unsigned char ch, size_t& len) { len = 2; return true; } - unsigned int result = (ch & 0xF0); - if (result == 0xE0) { + if ((ch & 0xF0) == 0xE0) { len = 3; return true; } - if (result == 0xF0) { + if ((ch & 0xF8) == 0xF0) { len = 4; return true; } @@ -64,6 +66,11 @@ inline bool utf8_validate(const unsigned char* s, size_t len, size_t& utf8_chars case 1: break; case 2: { + // Reject overlong 2-byte sequences. Valid Unicode 2-byte encodings + // start at U+0080, so lead bytes 0xC0 and 0xC1 are invalid. + if (ch < 0xC2u) { + return false; + } if (++idx >= len || s[idx] < 0x80u || s[idx] > 0xBFu) { return false; } @@ -147,5 +154,199 @@ inline bool utf8_validate(const unsigned char* s, size_t len, size_t& utf8_chars return true; } +// UTF-8 <-> wchar_t conversion utilities for non-Windows builds. +// These helpers operate on one wchar_t code unit per Unicode scalar value. +// They are fully Unicode-correct on platforms where wchar_t stores scalar values +// directly, which is commonly the case for 32-bit wchar_t builds such as Linux, +// macOS, and wasm. +// They do not implement UTF-16 surrogate-pair handling, so non-Windows builds +// with 16-bit wchar_t cannot represent supplementary-plane characters correctly +// via these helpers. +// On Windows, use the Win32 MultiByteToWideChar/WideCharToMultiByte APIs instead. +#ifndef _WIN32 + +static_assert(sizeof(wchar_t) >= 4, + "Non-Windows UTF-8/wchar_t conversion helpers require wchar_t to be at least 32 bits."); + +/// Compute the number of UTF-8 bytes required to encode a wide string. +inline size_t WideToUtf8RequiredSize(const std::wstring& wstr) { + size_t result = 0; + for (wchar_t wc : wstr) { + char32_t cp = static_cast(wc); + if (cp >= 0xD800 && cp <= 0xDFFF) { + ORT_THROW("Invalid Unicode surrogate codepoint U+", std::hex, static_cast(cp)); + } + if (cp <= 0x7F) { + result += 1; + } else if (cp <= 0x7FF) { + result += 2; + } else if (cp <= 0xFFFF) { + result += 3; + } else if (cp <= 0x10FFFF) { + result += 4; + } else { + ORT_THROW("Invalid Unicode codepoint U+", std::hex, static_cast(cp)); + } + } + return result; +} + +/// Convert a wide string to UTF-8, writing into a pre-allocated std::string. +/// The string is resized to the actual number of bytes written. +inline Status WideToUtf8(const std::wstring& wstr, std::string& str) { + if (wstr.empty()) { + str.clear(); + return Status::OK(); + } + + char* dest = str.data(); + char* dest_end = dest + str.size(); + + for (wchar_t wc : wstr) { + char32_t cp = static_cast(wc); + if (cp >= 0xD800 && cp <= 0xDFFF) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Invalid Unicode surrogate codepoint during UTF-8 conversion"); + } + if (cp <= 0x7F) { + const size_t remaining = static_cast(dest_end - dest); + if (remaining < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Destination buffer too small for UTF-8 conversion"); + } + *dest++ = static_cast(cp); + } else if (cp <= 0x7FF) { + const size_t remaining = static_cast(dest_end - dest); + if (remaining < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Destination buffer too small for UTF-8 conversion"); + } + *dest++ = static_cast(0xC0 | (cp >> 6)); + *dest++ = static_cast(0x80 | (cp & 0x3F)); + } else if (cp <= 0xFFFF) { + const size_t remaining = static_cast(dest_end - dest); + if (remaining < 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Destination buffer too small for UTF-8 conversion"); + } + *dest++ = static_cast(0xE0 | (cp >> 12)); + *dest++ = static_cast(0x80 | ((cp >> 6) & 0x3F)); + *dest++ = static_cast(0x80 | (cp & 0x3F)); + } else if (cp <= 0x10FFFF) { + const size_t remaining = static_cast(dest_end - dest); + if (remaining < 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Destination buffer too small for UTF-8 conversion"); + } + *dest++ = static_cast(0xF0 | (cp >> 18)); + *dest++ = static_cast(0x80 | ((cp >> 12) & 0x3F)); + *dest++ = static_cast(0x80 | ((cp >> 6) & 0x3F)); + *dest++ = static_cast(0x80 | (cp & 0x3F)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Invalid Unicode codepoint during UTF-8 conversion"); + } + } + + str.resize(static_cast(dest - str.data())); + return Status::OK(); +} + +/// Convert a UTF-8 string to a wide string, writing into a pre-allocated std::wstring. +/// The wstring is resized to the actual number of characters written. +/// Validates continuation bytes and rejects overlong encodings and surrogates. +inline Status Utf8ToWide(const std::string& str, std::wstring& wstr) { + if (str.empty()) { + wstr.clear(); + return Status::OK(); + } + + if (wstr.size() < str.size()) { + wstr.resize(str.size()); + } + + const auto* src = reinterpret_cast(str.data()); + const auto* src_end = src + str.size(); + wchar_t* dest = wstr.data(); + + while (src < src_end) { + char32_t cp = 0; + size_t byte_len = 0; + + if ((*src & 0x80) == 0) { + cp = *src; + byte_len = 1; + } else if ((*src & 0xE0) == 0xC0) { + byte_len = 2; + if (static_cast(src_end - src) < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Truncated UTF-8 sequence"); + } + if ((src[1] & 0xC0) != 0x80) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid UTF-8 continuation byte"); + } + cp = (static_cast(src[0] & 0x1F) << 6) | + static_cast(src[1] & 0x3F); + // Reject overlong encoding (must be >= 0x80 for 2-byte) + if (cp < 0x80) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Overlong UTF-8 encoding"); + } + } else if ((*src & 0xF0) == 0xE0) { + byte_len = 3; + if (static_cast(src_end - src) < 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Truncated UTF-8 sequence"); + } + if ((src[1] & 0xC0) != 0x80 || (src[2] & 0xC0) != 0x80) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid UTF-8 continuation byte"); + } + cp = (static_cast(src[0] & 0x0F) << 12) | + (static_cast(src[1] & 0x3F) << 6) | + static_cast(src[2] & 0x3F); + // Reject overlong encoding (must be >= 0x800 for 3-byte) + if (cp < 0x800) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Overlong UTF-8 encoding"); + } + // Reject UTF-16 surrogates (U+D800..U+DFFF) + if (cp >= 0xD800 && cp <= 0xDFFF) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid UTF-8: surrogate codepoint"); + } + } else if ((*src & 0xF8) == 0xF0) { + byte_len = 4; + if (static_cast(src_end - src) < 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Truncated UTF-8 sequence"); + } + if ((src[1] & 0xC0) != 0x80 || (src[2] & 0xC0) != 0x80 || (src[3] & 0xC0) != 0x80) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid UTF-8 continuation byte"); + } + cp = (static_cast(src[0] & 0x07) << 18) | + (static_cast(src[1] & 0x3F) << 12) | + (static_cast(src[2] & 0x3F) << 6) | + static_cast(src[3] & 0x3F); + // Reject overlong encoding (must be >= 0x10000 for 4-byte) + if (cp < 0x10000) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Overlong UTF-8 encoding"); + } + // Reject codepoints beyond Unicode range + if (cp > 0x10FFFF) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid UTF-8: codepoint beyond U+10FFFF"); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid UTF-8 lead byte"); + } + + *dest++ = static_cast(cp); + src += byte_len; + } + + wstr.resize(static_cast(dest - wstr.data())); + return Status::OK(); +} + +/// Convenience: convert UTF-8 string to wstring (throws on error). +inline std::wstring Utf8ToWideString(const std::string& s) { + // UTF-8 byte count is an upper bound on wchar_t count + std::wstring result; + result.resize(s.size()); + ORT_THROW_IF_ERROR(Utf8ToWide(s, result)); + return result; +} + +#endif // !_WIN32 + } // namespace utf8_util } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/text/string_normalizer.cc b/onnxruntime/core/providers/cpu/text/string_normalizer.cc index 8722c335a96a2..92bfd7ebdff62 100644 --- a/onnxruntime/core/providers/cpu/text/string_normalizer.cc +++ b/onnxruntime/core/providers/cpu/text/string_normalizer.cc @@ -5,26 +5,17 @@ #include "string_normalizer.h" #include "core/common/common.h" +#include "core/common/utf8_util.h" #include "core/framework/tensor.h" -// Used below HAS_DEPRECATED_DECLARATIONS -#include "onnxruntime_config.h" -#ifdef _MSC_VER +#ifdef _WIN32 #include #include -#endif // _MSC_VER +#endif // _WIN32 -#include #include #include -#if defined(__GNUC__) -// Allow deprecated-declarations warning - std::codecvt_utf8 is deprecatedd -#if defined(HAS_DEPRECATED_DECLARATIONS) -#pragma GCC diagnostic warning "-Wdeprecated-declarations" -#endif // defined(HAS_DEPRECATED_DECLARATIONS) -#endif // defined(__GNUC__) - namespace onnxruntime { ONNX_CPU_OPERATOR_KERNEL( @@ -36,235 +27,39 @@ ONNX_CPU_OPERATOR_KERNEL( namespace string_normalizer { -// codecvt_utf8 is deprecated, we will want to replace it with our class +#ifndef _WIN32 +// Thin wrapper around the common utf8_util functions, providing the same interface +// as Utf8ConverterWindows so the code below can use either via the Utf8Converter alias. class Utf8ConverterGeneric { public: size_t ComputeRequiredSizeToUtf8(const std::wstring& wstr) const { - if (wstr.empty()) { - return 0; - } - - size_t result = 0; - std::mbstate_t state = std::mbstate_t(); - - const wchar_t* src = wstr.data(); - const wchar_t* src_end = src + wstr.length(); - - char dummy_dest[128] = {0}; - - char* char_next = dummy_dest; - const wchar_t* wchar_next = src; - - size_t converted = 0; - - std::codecvt_base::result ret_code = std::codecvt_base::ok; - - // Continue while we exhaust the sequence - while (converted < wstr.length()) { - ret_code = converter_.out(state, - wchar_next, - src_end, - wchar_next, - std::begin(dummy_dest), - std::end(dummy_dest), - char_next); - result += (char_next - dummy_dest); - converted = (wchar_next - src); - - if (ret_code != std::codecvt_base::partial && - ret_code != std::codecvt_base::ok) { - break; - } - } - - ORT_ENFORCE(ret_code != std::codecvt_base::noconv, "Conversion is expected"); - - if (ret_code != std::codecvt_base::ok) { - ORT_THROW("Failed to compute size for UTF-8. Converted only first: ", - converted, " codepoints out of: ", wstr.length()); - } - - return result; + return utf8_util::WideToUtf8RequiredSize(wstr); } - // We assume the caller pre-allocated the correct length Status ConvertToUtf8(const std::wstring& wstr, std::string& str) const { - if (wstr.empty()) { - str.clear(); - return Status::OK(); - } - - std::mbstate_t state = std::mbstate_t(); - - const wchar_t* src = wstr.data(); - const wchar_t* src_end = src + wstr.length(); - - char* dest = str.data(); - char* dest_end = dest + str.length(); - - char* char_next = dest; - const wchar_t* wchar_next = src; - - std::codecvt_base::result ret_code = converter_.out(state, - src, - src_end, - wchar_next, - dest, - dest_end, - char_next); - - if (ret_code != std::codecvt_base::ok) { - size_t converted = narrow(wchar_next - wstr.data()); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to convert to UTF-8. Converted only first: ", - converted, " codepoints out of: ", wstr.length()); - } - - str.resize(char_next - dest); - - return Status::OK(); + return utf8_util::WideToUtf8(wstr, str); } Status ComputeRequiredSizeToWideChar(const std::string& str, size_t& wchars) { - if (str.empty()) { - wchars = 0; - return Status::OK(); - } - - size_t result = 0; - std::mbstate_t state = std::mbstate_t(); - - const char* src = str.data(); - const char* src_end = src + str.length(); - - wchar_t dummy_dest[128] = {0}; - const char* char_next = src; - wchar_t* wchar_next = dummy_dest; - - size_t converted = 0; - - std::codecvt_base::result ret_code = std::codecvt_base::ok; - while (converted < str.length()) { - ret_code = converter_.in(state, - char_next, - src_end, - char_next, - std::begin(dummy_dest), - std::end(dummy_dest), - wchar_next); - result += (wchar_next - dummy_dest); - converted = (char_next - src); - - if (ret_code != std::codecvt_base::partial && - ret_code != std::codecvt_base::ok) { - break; - } - } - - ORT_ENFORCE(ret_code != std::codecvt_base::noconv, "Conversion is expected"); - - if (ret_code != std::codecvt_base::ok) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Failed to compute buffer size for wchar_t. Converted only first: ", - converted, " bytes out of: ", str.length(), - " Source: ", src); - } - - wchars = result; + // UTF-8 byte count is an upper bound on wchar_t count; use it directly. + wchars = str.size(); return Status::OK(); } - // We assume the destination buffer is preallocated correctly Status ConvertToWideChar(const std::string& str, std::wstring& wstr) { - if (str.empty()) { - // Preserve the buffer for re-use, just set size to 0 - wstr.clear(); - return Status::OK(); - } - - std::mbstate_t state = std::mbstate_t(); - const char* src = str.data(); - const char* src_end = src + str.length(); - - wchar_t* dest = wstr.data(); - wchar_t* dest_end = dest + wstr.length(); - - const char* char_next = src; - wchar_t* wchar_next = dest; - - std::codecvt_base::result ret_code = converter_.in(state, - src, - src_end, - char_next, - dest, - dest_end, - wchar_next); - - if (ret_code != std::codecvt_base::ok) { - size_t converted = narrow(char_next - str.data()); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to convert to wchar_t. Converted only first: ", - converted, " bytes out of: ", str.length(), - " Source: ", src); - } - - wstr.resize(wchar_next - dest); - - return Status::OK(); + return utf8_util::Utf8ToWide(str, wstr); } std::wstring from_bytes(const std::string& s) { - std::wstring result; - - size_t wchars = 0; - ORT_THROW_IF_ERROR(ComputeRequiredSizeToWideChar(s, wchars)); - - result.resize(wchars); - ORT_THROW_IF_ERROR(ConvertToWideChar(s, result)); - return result; + return utf8_util::Utf8ToWideString(s); } - - private: - std::codecvt_utf8 converter_; }; +#endif // !_WIN32 // We need to specialize for MS as there is // a std::locale creation bug that affects different // environments in a different way -#ifdef _MSC_VER - -class Locale { - public: - explicit Locale(const std::string& name) - : loc_(nullptr) { - loc_ = _create_locale(LC_CTYPE, name.c_str()); - if (loc_ == nullptr) { - ORT_THROW("Failed to construct locale with name:", - name, ":", ":Please, install necessary language-pack-XX and configure locales"); - } - } - - ~Locale() { - if (loc_ != nullptr) { - _free_locale(loc_); - } - } - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Locale); - - void ChangeCase(StringNormalizer::CaseAction caseaction, - std::wstring& wstr) const { - assert(caseaction != StringNormalizer::NONE); - if (caseaction == StringNormalizer::LOWER) { - std::transform(wstr.begin(), wstr.end(), wstr.begin(), - [this](wchar_t ch) { return ::_towlower_l(ch, loc_); }); - } else { - std::transform(wstr.begin(), wstr.end(), wstr.begin(), - [this](wchar_t ch) { return ::_towupper_l(ch, loc_); }); - } - } - - private: - _locale_t loc_; -}; +#ifdef _WIN32 class Utf8ConverterWindows { public: @@ -382,50 +177,10 @@ const std::string default_locale("en-US"); using Utf8Converter = Utf8ConverterWindows; -#else // _MSC_VER - -class Locale { - public: - explicit Locale(const std::string& name) { - ORT_TRY { - loc_ = std::locale(name.c_str()); - } - ORT_CATCH(const std::runtime_error& e) { - ORT_HANDLE_EXCEPTION([&]() { - ORT_THROW("Failed to construct locale with name:", - name, ":", e.what(), ":Please, install necessary language-pack-XX and configure locales"); - }); - } - } - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Locale); - - void ChangeCase(StringNormalizer::CaseAction caseaction, - std::wstring& wstr) const { - assert(caseaction != StringNormalizer::NONE); - if (caseaction == StringNormalizer::LOWER) { - std::transform(wstr.begin(), wstr.end(), wstr.begin(), - [this](wchar_t ch) { return std::tolower(ch, loc_); }); - } else { - std::transform(wstr.begin(), wstr.end(), wstr.begin(), - [this](wchar_t ch) { return std::toupper(ch, loc_); }); - } - } - - private: - std::locale loc_; -}; - -#if defined(__APPLE__) || defined(__ANDROID__) - -using Utf8Converter = Utf8ConverterGeneric; - -#else +#else // _WIN32 using Utf8Converter = Utf8ConverterGeneric; -#endif - #if defined(__APPLE__) #include #if TARGET_OS_IPHONE || TARGET_OS_SIMULATOR @@ -437,11 +192,67 @@ const std::string default_locale("en_US.UTF-8"); // Other kinds of Apple Platfo const std::string default_locale("en_US.UTF-8"); // All non-MS and not Apple #endif -#endif // _MSC_VER +#endif // _WIN32 } // namespace string_normalizer using namespace string_normalizer; +#ifdef _WIN32 + +StringNormalizer::Locale::Locale(const std::string& name) { + loc_ = _create_locale(LC_CTYPE, name.c_str()); + if (loc_ == nullptr) { + ORT_THROW("Failed to construct locale with name:", + name, ":", ":Please, install necessary language-pack-XX and configure locales"); + } +} + +StringNormalizer::Locale::~Locale() { + if (loc_ != nullptr) { + _free_locale(loc_); + } +} + +void StringNormalizer::Locale::ChangeCase(CaseAction caseaction, std::wstring& wstr) const { + assert(caseaction != NONE); + if (caseaction == LOWER) { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [this](wchar_t ch) { return ::_towlower_l(ch, loc_); }); + } else { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [this](wchar_t ch) { return ::_towupper_l(ch, loc_); }); + } +} + +#else + +StringNormalizer::Locale::Locale(const std::string& name) { + ORT_TRY { + loc_ = std::locale(name.c_str()); + } + ORT_CATCH(const std::runtime_error& e) { + ORT_HANDLE_EXCEPTION([&]() { + ORT_THROW("Failed to construct locale with name:", + name, ":", e.what(), ":Please, install necessary language-pack-XX and configure locales"); + }); + } +} + +StringNormalizer::Locale::~Locale() = default; + +void StringNormalizer::Locale::ChangeCase(CaseAction caseaction, std::wstring& wstr) const { + assert(caseaction != NONE); + if (caseaction == LOWER) { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [this](wchar_t ch) { return std::tolower(ch, loc_); }); + } else { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [this](wchar_t ch) { return std::toupper(ch, loc_); }); + } +} + +#endif + StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { int64_t iscasesensitive = 0; Status status = info.GetAttr("is_case_sensitive", &iscasesensitive); @@ -461,21 +272,26 @@ StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(false, "attribute case_change_action has invalid value"); } - locale_name_ = info.GetAttrOrDefault("locale", default_locale); + const std::string locale_name = info.GetAttrOrDefault("locale", default_locale); std::vector stop_words = info.GetAttrsOrDefault("stopwords"); + const bool needs_runtime_locale = case_change_action_ != NONE || (!is_case_sensitive_ && !stop_words.empty()); + if (needs_runtime_locale) { + locale_.emplace(locale_name); + } + if (is_case_sensitive_) { stopwords_.reserve(stop_words.size()); for (std::string& s : stop_words) { stopwords_.insert(std::move(s)); } - } else { - Locale locale(locale_name_); + } else if (!stop_words.empty()) { + assert(locale_.has_value()); Utf8Converter converter; wstopwords_.reserve(stop_words.size()); for (std::string& s : stop_words) { std::wstring wstr = converter.from_bytes(s); - locale.ChangeCase(compare_caseaction_, wstr); + locale_->ChangeCase(compare_caseaction_, wstr); wstopwords_.insert(std::move(wstr)); } } @@ -508,6 +324,13 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { "Input dimensions are either[C > 0] or [1][C > 0] allowed"); } + auto validate_utf8 = [](const std::string& value) { + size_t utf8_chars = 0; + ORT_RETURN_IF_NOT(utf8_util::utf8_validate(reinterpret_cast(value.data()), value.size(), utf8_chars), + "Input strings must be valid UTF-8"); + return Status::OK(); + }; + // Special case, no filtering and no case change if (case_change_action_ == NONE && ((is_case_sensitive_ && stopwords_.empty()) || @@ -515,7 +338,10 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { output_shape.push_back(C); auto output_tensor = ctx->Output(0, output_shape); auto const output_data = output_tensor->MutableData(); - std::copy(input_span.begin(), input_span.end(), output_data); + for (size_t i = 0, lim = input_span.size(); i < lim; ++i) { + ORT_RETURN_IF_ERROR(validate_utf8(input_span[i])); + output_data[i] = input_span[i]; + } return Status::OK(); } @@ -525,31 +351,39 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { // to widechar, lowercase it and then compare. Case-insensitive comparison is complicated // for UTF-8 and requires additional dependency. - Locale locale(locale_name_); Utf8Converter converter; + const Locale* locale = locale_ ? &*locale_ : nullptr; + + // Determine whether we need wchar conversion at all. + // We need it if: (a) case change is requested, or (b) case-insensitive filtering. + const bool needs_wchar = (case_change_action_ != NONE) || !is_case_sensitive_; - // Compute the largest widestring buffer needed. size_t max_wide_buffer_len = 0; - for (const auto& s : input_span) { - size_t wchars = 0; - // Checks for invalid UTF-8 characters on Windows - ORT_RETURN_IF_ERROR(converter.ComputeRequiredSizeToWideChar(s, wchars)); - max_wide_buffer_len = std::max(max_wide_buffer_len, wchars); + if (needs_wchar) { + // UTF-8 byte count is an upper bound on wchar_t count: each codepoint requires + // at least 1 byte but produces exactly 1 wchar_t (UTF-32) or at most 2 (UTF-16). + // This avoids a full UTF-8 decode pass just to compute buffer sizes. + for (const auto& s : input_span) { + max_wide_buffer_len = std::max(max_wide_buffer_len, s.size()); + } } // Reuse reserved space std::wstring wchar_buffer; - wchar_buffer.reserve(max_wide_buffer_len); + if (needs_wchar) { + wchar_buffer.reserve(max_wide_buffer_len); + } // Output everything and change case as required auto output_no_filtering = [&](const TensorShape& output_shape) { - auto output_tensor = ctx->Output(0, output_shape); - auto const output_data = output_tensor->MutableData(); + auto* output_tensor = ctx->Output(0, output_shape); + auto* output_data = output_tensor->MutableData(); for (size_t i = 0, lim = input_span.size(); i < lim; ++i) { const std::string& s = input_span[i]; wchar_buffer.resize(max_wide_buffer_len); ORT_RETURN_IF_ERROR(converter.ConvertToWideChar(s, wchar_buffer)); - locale.ChangeCase(case_change_action_, wchar_buffer); + assert(locale != nullptr); + locale->ChangeCase(case_change_action_, wchar_buffer); auto& dest = output_data[i]; size_t utf8_buffer_len = converter.ComputeRequiredSizeToUtf8(wchar_buffer); @@ -560,14 +394,15 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { }; auto output_filtered = [&](const TensorShape& output_shape, gsl::span filtered_indices) { - auto output_tensor = ctx->Output(0, output_shape); - auto output_data = output_tensor->MutableData(); + auto* output_tensor = ctx->Output(0, output_shape); + auto* output_data = output_tensor->MutableData(); for (size_t i : filtered_indices) { const std::string& s = input_span[i]; if (case_change_action_ != NONE) { wchar_buffer.resize(max_wide_buffer_len); ORT_RETURN_IF_ERROR(converter.ConvertToWideChar(s, wchar_buffer)); - locale.ChangeCase(case_change_action_, wchar_buffer); + assert(locale != nullptr); + locale->ChangeCase(case_change_action_, wchar_buffer); auto& dest = *output_data++; size_t utf8_buffer_len = converter.ComputeRequiredSizeToUtf8(wchar_buffer); @@ -588,19 +423,18 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { output_shape.push_back(C); status = output_no_filtering(output_shape); } else { - // we need to filter + // Case-sensitive filtering: direct string compare, no wchar needed for comparison. InlinedVector filtered_strings_indices; filtered_strings_indices.reserve(input_span.size()); for (size_t i = 0, lim = input_span.size(); i < lim; ++i) { const std::string& s = input_span[i]; + ORT_RETURN_IF_ERROR(validate_utf8(s)); if (stopwords_.count(s) == 0) { filtered_strings_indices.push_back(i); } } - // According to the spec, if all strings are filtered out - // the output must have a shape of {1} with a single empty string. const int64_t filtered_count = std::max(1, narrow(filtered_strings_indices.size())); output_shape.push_back(filtered_count); status = output_filtered(output_shape, filtered_strings_indices); @@ -611,23 +445,23 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const { output_shape.push_back(C); status = output_no_filtering(output_shape); } else { - // Case insensitive filtering is performed by converting the input strings - // to compare_caseaction_. For that we convert to wchar_t UNICODE. - // Otherwise, we need to pull ICU library on all platforms. + // Case insensitive filtering: convert to wchar_t and lowercase for comparison. + // Re-conversion during output is cheaper than caching N wide strings (each requiring + // a heap allocation), especially under multi-threaded contention for the allocator lock. InlinedVector filtered_strings_indices; filtered_strings_indices.reserve(input_span.size()); + for (size_t i = 0, lim = input_span.size(); i < lim; ++i) { const std::string& s = input_span[i]; wchar_buffer.resize(max_wide_buffer_len); ORT_RETURN_IF_ERROR(converter.ConvertToWideChar(s, wchar_buffer)); - locale.ChangeCase(compare_caseaction_, wchar_buffer); + assert(locale != nullptr); + locale->ChangeCase(compare_caseaction_, wchar_buffer); if (wstopwords_.count(wchar_buffer) == 0) { filtered_strings_indices.push_back(i); } } - // According to the spec, if all strings are filtered out - // the output must have a shape of {1} with a single empty string. const int64_t filtered_count = std::max(1, narrow(filtered_strings_indices.size())); output_shape.push_back(filtered_count); status = output_filtered(output_shape, filtered_strings_indices); diff --git a/onnxruntime/core/providers/cpu/text/string_normalizer.h b/onnxruntime/core/providers/cpu/text/string_normalizer.h index 4e66a66b00893..1852233b90289 100644 --- a/onnxruntime/core/providers/cpu/text/string_normalizer.h +++ b/onnxruntime/core/providers/cpu/text/string_normalizer.h @@ -9,8 +9,13 @@ #include "core/framework/op_kernel.h" #include +#include #include +#ifdef _WIN32 +#include +#endif + namespace onnxruntime { class StringNormalizer : public OpKernel { @@ -27,12 +32,37 @@ class StringNormalizer : public OpKernel { Status Compute(OpKernelContext* ctx) const override; private: + class Locale { + public: + explicit Locale(const std::string& name); + ~Locale(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Locale); + + void ChangeCase(CaseAction caseaction, std::wstring& wstr) const; + + private: +#ifdef _WIN32 + _locale_t loc_{nullptr}; +#else + std::locale loc_; +#endif + }; + bool is_case_sensitive_{true}; CaseAction case_change_action_{NONE}; - // Set this to lower because some characters do not have upper case. - // used for case-insensitive compare + // Hardcoded to LOWER for case-insensitive stopword comparison. + // Lowercase is used here as a practical fit for the current per-character + // std::transform-based implementation: + // - Some characters have no uppercase form or uppercase to multiple characters + // (e.g., ß -> SS), which this implementation cannot handle because it + // transforms one wchar_t at a time. + // - Unicode casing can be locale-, context-, and length-dependent, so this + // should not be interpreted as full Unicode case folding. + // The ideal approach would be Unicode case folding (ICU), but that's not + // warranted for this operator. CaseAction compare_caseaction_{LOWER}; - std::string locale_name_; + std::optional locale_; // Either if these are populated but not both InlinedHashSet stopwords_; InlinedHashSet wstopwords_; diff --git a/onnxruntime/test/common/utf8_util_test.cc b/onnxruntime/test/common/utf8_util_test.cc index 775d53090328d..c21ecd9920934 100644 --- a/onnxruntime/test/common/utf8_util_test.cc +++ b/onnxruntime/test/common/utf8_util_test.cc @@ -41,5 +41,399 @@ TEST(Utf8UtilTest, Validate) { } } +// --- utf8_bytes tests --- + +TEST(Utf8UtilTest, Utf8Bytes_Ascii) { + using namespace utf8_util; + size_t len = 0; + // All ASCII bytes (0x00-0x7F) should be 1-byte + EXPECT_TRUE(utf8_bytes(0x00, len)); + EXPECT_EQ(1U, len); + EXPECT_TRUE(utf8_bytes('A', len)); + EXPECT_EQ(1U, len); + EXPECT_TRUE(utf8_bytes(0x7F, len)); + EXPECT_EQ(1U, len); +} + +TEST(Utf8UtilTest, Utf8Bytes_TwoByte) { + using namespace utf8_util; + size_t len = 0; + // 0xC0-0xDF share the 2-byte lead-byte prefix. + // Full well-formedness checks happen in utf8_validate. + EXPECT_TRUE(utf8_bytes(0xC0, len)); + EXPECT_EQ(2U, len); + EXPECT_TRUE(utf8_bytes(0xC2, len)); + EXPECT_EQ(2U, len); + EXPECT_TRUE(utf8_bytes(0xDF, len)); + EXPECT_EQ(2U, len); +} + +TEST(Utf8UtilTest, Utf8Bytes_ThreeByte) { + using namespace utf8_util; + size_t len = 0; + // 0xE0-0xEF are 3-byte lead bytes + EXPECT_TRUE(utf8_bytes(0xE0, len)); + EXPECT_EQ(3U, len); + EXPECT_TRUE(utf8_bytes(0xED, len)); + EXPECT_EQ(3U, len); + EXPECT_TRUE(utf8_bytes(0xEF, len)); + EXPECT_EQ(3U, len); +} + +TEST(Utf8UtilTest, Utf8Bytes_FourByte) { + using namespace utf8_util; + size_t len = 0; + // 0xF0-0xF7 share the 4-byte lead-byte prefix. + // Full well-formedness checks happen in utf8_validate. + EXPECT_TRUE(utf8_bytes(0xF0, len)); + EXPECT_EQ(4U, len); + EXPECT_TRUE(utf8_bytes(0xF4, len)); + EXPECT_EQ(4U, len); + EXPECT_TRUE(utf8_bytes(0xF7, len)); + EXPECT_EQ(4U, len); +} + +TEST(Utf8UtilTest, Utf8Bytes_Invalid) { + using namespace utf8_util; + size_t len = 99; + // Continuation bytes (0x80-0xBF) are not valid lead bytes + EXPECT_FALSE(utf8_bytes(0x80, len)); + EXPECT_FALSE(utf8_bytes(0xBF, len)); + // 0xF8-0xFF are invalid (would be 5+ byte sequences) + EXPECT_FALSE(utf8_bytes(0xF8, len)); + EXPECT_FALSE(utf8_bytes(0xF9, len)); + EXPECT_FALSE(utf8_bytes(0xFC, len)); + EXPECT_FALSE(utf8_bytes(0xFE, len)); + EXPECT_FALSE(utf8_bytes(0xFF, len)); +} + +// --- utf8_len tests --- + +TEST(Utf8UtilTest, Utf8Len_Empty) { + using namespace utf8_util; + size_t len = 99; + EXPECT_TRUE(utf8_len(reinterpret_cast(""), 0, len)); + EXPECT_EQ(0U, len); +} + +TEST(Utf8UtilTest, Utf8Len_Ascii) { + using namespace utf8_util; + size_t len = 0; + const char* s = "Hello"; + EXPECT_TRUE(utf8_len(reinterpret_cast(s), 5, len)); + EXPECT_EQ(5U, len); +} + +TEST(Utf8UtilTest, Utf8Len_Multibyte) { + using namespace utf8_util; + size_t len = 0; + // "café" = 'c' 'a' 'f' U+00E9(2 bytes) = 5 bytes, 4 chars + const char* s = "caf\xc3\xa9"; + EXPECT_TRUE(utf8_len(reinterpret_cast(s), 5, len)); + EXPECT_EQ(4U, len); +} + +TEST(Utf8UtilTest, Utf8Len_ThreeByteChars) { + using namespace utf8_util; + size_t len = 0; + // U+4E16 (世) = 0xE4 0xB8 0x96, U+754C (界) = 0xE7 0x95 0x8C + const char* s = "\xe4\xb8\x96\xe7\x95\x8c"; // "世界" + EXPECT_TRUE(utf8_len(reinterpret_cast(s), 6, len)); + EXPECT_EQ(2U, len); +} + +TEST(Utf8UtilTest, Utf8Len_FourByteChars) { + using namespace utf8_util; + size_t len = 0; + // U+1F600 (😀) = 0xF0 0x9F 0x98 0x80 + const char* s = "\xf0\x9f\x98\x80"; + EXPECT_TRUE(utf8_len(reinterpret_cast(s), 4, len)); + EXPECT_EQ(1U, len); +} + +TEST(Utf8UtilTest, Utf8Len_Mixed) { + using namespace utf8_util; + size_t len = 0; + // "A" (1) + U+00F1 (2) + U+4E16 (3) + U+1F600 (4) = 10 bytes, 4 chars + const char* s = "A\xc3\xb1\xe4\xb8\x96\xf0\x9f\x98\x80"; + EXPECT_TRUE(utf8_len(reinterpret_cast(s), 10, len)); + EXPECT_EQ(4U, len); +} + +TEST(Utf8UtilTest, Utf8Len_InvalidLeadByte) { + using namespace utf8_util; + size_t len = 0; + // 0xF8 is invalid lead byte + const char* s = "\xf8\x80\x80\x80"; + EXPECT_FALSE(utf8_len(reinterpret_cast(s), 4, len)); +} + +TEST(Utf8UtilTest, Utf8Len_Truncated) { + using namespace utf8_util; + size_t len = 0; + // 2-byte sequence but only 1 byte available + const char* s = "\xc3"; + EXPECT_FALSE(utf8_len(reinterpret_cast(s), 1, len)); +} + +// --- utf8_validate additional tests --- + +TEST(Utf8UtilTest, Validate_EmptyString) { + using namespace utf8_util; + size_t chars = 99; + EXPECT_TRUE(utf8_validate(reinterpret_cast(""), 0, chars)); + EXPECT_EQ(0U, chars); +} + +TEST(Utf8UtilTest, Validate_MultiCharString) { + using namespace utf8_util; + size_t chars = 0; + // "Héllo" = 'H' U+00E9(2b) 'l' 'l' 'o' = 6 bytes, 5 chars + const char* s = "H\xc3\xa9llo"; + EXPECT_TRUE(utf8_validate(reinterpret_cast(s), 6, chars)); + EXPECT_EQ(5U, chars); +} + +TEST(Utf8UtilTest, Validate_OverlongTwoByte) { + using namespace utf8_util; + size_t chars = 0; + // Overlong encoding of U+0000: 0xC0 0x80 (should be rejected) + const char* s = "\xc0\x80"; + EXPECT_FALSE(utf8_validate(reinterpret_cast(s), 2, chars)); +} + +TEST(Utf8UtilTest, Validate_OverlongTwoByteLeadByteC1) { + using namespace utf8_util; + size_t chars = 0; + const char* s = "\xc1\xbf"; + EXPECT_FALSE(utf8_validate(reinterpret_cast(s), 2, chars)); +} + +TEST(Utf8UtilTest, Validate_SurrogatePair) { + using namespace utf8_util; + size_t chars = 0; + // U+D800 encoded as 3-byte: 0xED 0xA0 0x80 (invalid surrogate) + const char* s = "\xed\xa0\x80"; + EXPECT_FALSE(utf8_validate(reinterpret_cast(s), 3, chars)); +} + +TEST(Utf8UtilTest, Validate_MaxCodepoint) { + using namespace utf8_util; + size_t chars = 0; + // U+10FFFF = 0xF4 0x8F 0xBF 0xBF (valid, max Unicode codepoint) + const char* s = "\xf4\x8f\xbf\xbf"; + EXPECT_TRUE(utf8_validate(reinterpret_cast(s), 4, chars)); + EXPECT_EQ(1U, chars); +} + +TEST(Utf8UtilTest, Validate_BeyondMaxCodepoint) { + using namespace utf8_util; + size_t chars = 0; + // U+110000 = 0xF4 0x90 0x80 0x80 (invalid, beyond U+10FFFF) + const char* s = "\xf4\x90\x80\x80"; + EXPECT_FALSE(utf8_validate(reinterpret_cast(s), 4, chars)); +} + +TEST(Utf8UtilTest, Validate_FourByteLeadByteAboveUnicodeRange) { + using namespace utf8_util; + size_t chars = 0; + const char* s = "\xf7\xbf\xbf\xbf"; + EXPECT_FALSE(utf8_validate(reinterpret_cast(s), 4, chars)); +} + +TEST(Utf8UtilTest, Validate_ContinuationByteAlone) { + using namespace utf8_util; + size_t chars = 0; + // A lone continuation byte + const char* s = "\x80"; + EXPECT_FALSE(utf8_validate(reinterpret_cast(s), 1, chars)); +} + +// --- Non-Windows conversion tests --- +#ifndef _WIN32 + +using namespace utf8_util; + +TEST(Utf8UtilTest, WideToUtf8RequiredSize_Ascii) { + std::wstring ws = L"Hello"; + EXPECT_EQ(5U, WideToUtf8RequiredSize(ws)); +} + +TEST(Utf8UtilTest, WideToUtf8RequiredSize_Multibyte) { + // U+00E9 -> 2 bytes, U+4E16 -> 3 bytes, U+1F600 -> 4 bytes + std::wstring ws; + ws += static_cast(0x00E9); // 2 bytes + ws += static_cast(0x4E16); // 3 bytes + ws += static_cast(0x1F600); // 4 bytes + EXPECT_EQ(9U, WideToUtf8RequiredSize(ws)); +} + +TEST(Utf8UtilTest, WideToUtf8_RoundTrip_Ascii) { + std::wstring ws = L"Hello World"; + std::string result; + result.resize(WideToUtf8RequiredSize(ws)); + ASSERT_TRUE(WideToUtf8(ws, result).IsOK()); + EXPECT_EQ("Hello World", result); +} + +TEST(Utf8UtilTest, WideToUtf8_BufferTooSmall) { + std::wstring ws; + ws += static_cast(0x00E9); // 2 bytes in UTF-8 + std::string result; + result.resize(1); + EXPECT_FALSE(WideToUtf8(ws, result).IsOK()); +} + +TEST(Utf8UtilTest, WideToUtf8_EmptyDestinationBuffer) { + std::wstring ws = L"A"; + std::string result; + EXPECT_FALSE(WideToUtf8(ws, result).IsOK()); +} + +TEST(Utf8UtilTest, WideToUtf8_ThreeByteBufferTooSmall) { + std::wstring ws; + ws += static_cast(0x4E16); // 3 bytes in UTF-8 + std::string result; + result.resize(2); + EXPECT_FALSE(WideToUtf8(ws, result).IsOK()); +} + +TEST(Utf8UtilTest, WideToUtf8_FourByteBufferTooSmall) { + std::wstring ws; + ws += static_cast(0x1F600); // 4 bytes in UTF-8 + std::string result; + result.resize(3); + EXPECT_FALSE(WideToUtf8(ws, result).IsOK()); +} + +TEST(Utf8UtilTest, WideToUtf8_RoundTrip_Multibyte) { + // Build wide string with various codepoints + std::wstring ws; + ws += static_cast(0x00E9); // é + ws += static_cast(0x4E16); // 世 + ws += static_cast(0x1F600); // 😀 + + std::string utf8; + utf8.resize(WideToUtf8RequiredSize(ws)); + ASSERT_TRUE(WideToUtf8(ws, utf8).IsOK()); + + // Verify via round-trip + std::wstring back; + back.resize(utf8.size()); + ASSERT_TRUE(Utf8ToWide(utf8, back).IsOK()); + EXPECT_EQ(ws, back); +} + +TEST(Utf8UtilTest, WideToUtf8_Empty) { + std::wstring ws; + std::string result = "notempty"; + ASSERT_TRUE(WideToUtf8(ws, result).IsOK()); + EXPECT_TRUE(result.empty()); +} + +TEST(Utf8UtilTest, Utf8ToWide_Empty) { + std::string s; + std::wstring result = L"notempty"; + ASSERT_TRUE(Utf8ToWide(s, result).IsOK()); + EXPECT_TRUE(result.empty()); +} + +TEST(Utf8UtilTest, Utf8ToWide_Ascii) { + std::string s = "ABC"; + std::wstring result; + result.resize(s.size()); + ASSERT_TRUE(Utf8ToWide(s, result).IsOK()); + EXPECT_EQ(L"ABC", result); +} + +TEST(Utf8UtilTest, Utf8ToWide_AutoResizeDestination) { + std::string s = "ABC"; + std::wstring result; + ASSERT_TRUE(Utf8ToWide(s, result).IsOK()); + EXPECT_EQ(L"ABC", result); +} + +TEST(Utf8UtilTest, Utf8ToWide_TruncatedSequence) { + // 3-byte sequence missing last byte + std::string s = "\xe4\xb8"; + std::wstring result; + result.resize(s.size()); + EXPECT_FALSE(Utf8ToWide(s, result).IsOK()); +} + +TEST(Utf8UtilTest, Utf8ToWide_InvalidContinuationByte) { + // 2-byte lead 0xC3 followed by non-continuation 0x28 + std::string s = "\xc3\x28"; + std::wstring result; + result.resize(s.size()); + EXPECT_FALSE(Utf8ToWide(s, result).IsOK()); +} + +TEST(Utf8UtilTest, Utf8ToWide_OverlongEncoding) { + // Overlong 2-byte for U+002F ('/') = 0xC0 0xAF + std::string s = "\xc0\xaf"; + std::wstring result; + result.resize(s.size()); + EXPECT_FALSE(Utf8ToWide(s, result).IsOK()); +} + +TEST(Utf8UtilTest, Utf8ToWide_SurrogateCodepoint) { + // U+D800 as 3-byte UTF-8: 0xED 0xA0 0x80 + std::string s = "\xed\xa0\x80"; + std::wstring result; + result.resize(s.size()); + EXPECT_FALSE(Utf8ToWide(s, result).IsOK()); +} + +TEST(Utf8UtilTest, Utf8ToWide_BeyondUnicode) { + // U+110000: 0xF4 0x90 0x80 0x80 + std::string s = "\xf4\x90\x80\x80"; + std::wstring result; + result.resize(s.size()); + EXPECT_FALSE(Utf8ToWide(s, result).IsOK()); +} + +TEST(Utf8UtilTest, Utf8ToWide_InvalidLeadByte) { + // 0xF8 is not a valid UTF-8 lead byte + std::string s = "\xf8\x80\x80\x80\x80"; + std::wstring result; + result.resize(s.size()); + EXPECT_FALSE(Utf8ToWide(s, result).IsOK()); +} + +TEST(Utf8UtilTest, Utf8ToWideString_ValidInput) { + std::string s = "caf\xc3\xa9"; // "café" + std::wstring result = Utf8ToWideString(s); + EXPECT_EQ(4U, result.size()); + EXPECT_EQ(static_cast('c'), result[0]); + EXPECT_EQ(static_cast('a'), result[1]); + EXPECT_EQ(static_cast('f'), result[2]); + EXPECT_EQ(static_cast(0x00E9), result[3]); +} + +#if !defined(ORT_NO_EXCEPTIONS) +TEST(Utf8UtilTest, Utf8ToWideString_InvalidInput) { + // Should throw on invalid UTF-8 + std::string s = "\xc0\xaf"; + EXPECT_THROW(Utf8ToWideString(s), OnnxRuntimeException); +} + +TEST(Utf8UtilTest, WideToUtf8RequiredSize_SurrogateCodepoint) { + std::wstring ws; + ws += static_cast(0xD800); + EXPECT_THROW(WideToUtf8RequiredSize(ws), OnnxRuntimeException); +} +#endif // !defined(ORT_NO_EXCEPTIONS) + +TEST(Utf8UtilTest, WideToUtf8_SurrogateCodepoint) { + std::wstring ws; + ws += static_cast(0xD800); + std::string result; + result.resize(4); + EXPECT_FALSE(WideToUtf8(ws, result).IsOK()); +} + +#endif // !_WIN32 + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/text/string_normalizer_test.cc b/onnxruntime/test/providers/cpu/text/string_normalizer_test.cc index 724fdb078e2fd..229d23926c97f 100644 --- a/onnxruntime/test/providers/cpu/text/string_normalizer_test.cc +++ b/onnxruntime/test/providers/cpu/text/string_normalizer_test.cc @@ -76,6 +76,7 @@ TEST(ContribOpTest, StringNormalizerSensitiveFilterOutNoCase) { test.Run(OpTester::ExpectResult::kExpectSuccess); } +#ifndef ORT_IOS TEST(ContribOpTest, StringNormalizerSensitiveFilterOutLower) { // - casesensitive approach // - filter out monday @@ -211,9 +212,6 @@ TEST(ContribOpTest, StringNormalizerSensitiveFilterOutUpperEmptyCase) { test.Run(OpTester::ExpectResult::kExpectSuccess); } -// Fails on iOS because necessary locales are not installed -// MacOS runs fine. -#ifndef ORT_IOS TEST(ContribOpTest, StringNormalizerSensitiveFilterOutUpperSameOutput) { // Empty output case // - casesensitive approach @@ -232,5 +230,371 @@ TEST(ContribOpTest, StringNormalizerSensitiveFilterOutUpperSameOutput) { } #endif +// ============================================================ +// Additional tests for coverage gaps +// ============================================================ + +#ifndef ORT_IOS +TEST(ContribOpTest, StringNormalizerDefaultIsCaseSensitiveIsFalse) { + // Omit is_case_sensitive and rely on the schema default of false. + OpTester test("StringNormalizer", opset_ver, domain); + test.AddAttribute("stopwords", std::vector{"monday"}); + std::vector dims{3}; + std::vector input = {"Monday", "Tuesday", "Wednesday"}; + test.AddInput("T", dims, input); + + std::vector output = {"Tuesday", "Wednesday"}; + test.AddOutput("Y", {2}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInsensitiveFilterOutLower) { + // Case-insensitive filtering + LOWER case change. + // This exercises the can_reuse_wide fast path (compare_caseaction_ == LOWER == case_change_action_). + // Tests French (accented), German (umlaut/eszett), Russian, Chinese. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "LOWER", false, {"Понедельник", "Besançon"}, test_locale); + std::vector dims{6}; + std::vector input = {"ПОНЕДЕЛЬНИК", // matches "Понедельник" case-insensitively + "BESANÇON", // matches "Besançon" case-insensitively + "École élémentaire", + "mit freundlichen grüßen", + "中文", + "Tuesday"}; + test.AddInput("T", dims, input); + + std::vector output = {"école élémentaire", + "mit freundlichen grüßen", // ß stays ß when lowercased + "中文", // Chinese has no case + "tuesday"}; + test.AddOutput("Y", {4}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInsensitiveFilterOutNone) { + // Case-insensitive filtering + NO case change. + // Strings matching stopwords are removed; survivors keep original case. + // Tests that Cyrillic and accented Latin stopwords match case-insensitively. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", false, {"понедельник", "école élémentaire"}, test_locale); + std::vector dims{5}; + std::vector input = {"Понедельник", // matches "понедельник" + "École Élémentaire", // matches "école élémentaire" + "Besançon", + "中文", + "Thursday"}; + test.AddInput("T", dims, input); + + // Filtered strings are removed; survivors keep original case + std::vector output = {"Besançon", "中文", "Thursday"}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInsensitiveNoStopwordsLower) { + // Case-insensitive, no stopwords, LOWER case change. + // Exercises output_no_filtering path with multilingual input. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "LOWER", false, {}, test_locale); + std::vector dims{5}; + std::vector input = {"BESANÇON", + "ÉCOLE ÉLÉMENTAIRE", + "ПОНЕДЕЛЬНИК", + "MIT FREUNDLICHEN GRÜßEN", + "中文"}; + test.AddInput("T", dims, input); + + std::vector output = {"besançon", + "école élémentaire", + "понедельник", + "mit freundlichen grüßen", + "中文"}; + test.AddOutput("Y", {5}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInsensitiveFilterUpperMultilingual) { + // Case-insensitive filtering + UPPER case change (case_change != compare_caseaction_). + // Exercises the output_filtered_with_wide fallback (cannot reuse cached lowercase wide forms). + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", false, {"besançon", "中文"}, test_locale); + std::vector dims{5}; + std::vector input = {"Besançon", // matches "besançon" + "École élémentaire", + "Понедельник", + "mit freundlichen grüßen", + "中文"}; // matches "中文" (no case, exact match) + test.AddInput("T", dims, input); + + std::vector output = {"ÉCOLE ÉLÉMENTAIRE", + "ПОНЕДЕЛЬНИК", + // Eszett behavior differs by platform +#ifdef __wasm__ + "MIT FREUNDLICHEN GRÜẞEN" +#else + "MIT FREUNDLICHEN GRÜßEN" +#endif + }; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerEmptyStringInInput) { + // Input contains empty strings — should not crash or produce invalid output. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {}, test_locale); + std::vector dims{3}; + std::vector input = {"hello", "", "world"}; + test.AddInput("T", dims, input); + + std::vector output = {"HELLO", "", "WORLD"}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerSingleElement) { + // Single-element input tensor with multi-byte UTF-8. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "LOWER", true, {}, test_locale); + std::vector dims{1}; + std::vector input = {"ÉCOLE"}; + test.AddInput("T", dims, input); + + std::vector output = {"école"}; + test.AddOutput("Y", {1}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInsensitiveAllFilteredOutMultilingual) { + // Case-insensitive: all strings match stopwords → output is [1] with empty string. + // Uses Cyrillic and Chinese stopwords. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", false, {"понедельник", "中文", "grüßen"}, test_locale); + std::vector dims{3}; + std::vector input = {"ПОНЕДЕЛЬНИК", "中文", "Grüßen"}; + test.AddInput("T", dims, input); + + std::vector output{""}; + test.AddOutput("Y", {1}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInsensitiveMixedCaseStopwords) { + // Stopwords given in mixed case with accented characters should still match. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", false, {"ПОНЕДЕЛЬНИК", "École Élémentaire"}, test_locale); + std::vector dims{4}; + std::vector input = {"понедельник", // matches "ПОНЕДЕЛЬНИК" + "école élémentaire", // matches "École Élémentaire" + "Besançon", + "中文"}; + test.AddInput("T", dims, input); + + std::vector output = {"Besançon", "中文"}; + test.AddOutput("Y", {2}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizer2DInputWithFilteringMultilingual) { + // 2D shape [1, C] with filtering using multilingual input. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "LOWER", true, {"Понедельник"}, test_locale); + std::vector dims{1, 4}; + std::vector input = {"Понедельник", "BESANÇON", "中文", "ÉCOLE"}; + test.AddInput("T", dims, input); + + std::vector output = {"besançon", "中文", "école"}; + test.AddOutput("Y", {1, 3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} +#endif + +TEST(ContribOpTest, StringNormalizer2DInputAllFilteredOut) { + // 2D shape [1, C] with all filtered → output shape [1, 1] with empty string. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {"中文", "Понедельник"}, test_locale); + std::vector dims{1, 2}; + std::vector input = {"中文", "Понедельник"}; + test.AddInput("T", dims, input); + + std::vector output{""}; + test.AddOutput("Y", {1, 1}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInvalidDimensions3D) { + // Input with 3 dimensions -> bypass shape metadata so the kernel validation path runs. + OpTester test("StringNormalizer", opset_ver, domain); + test.AddShapeToTensorData(false); + InitTestAttr(test, "NONE", true, {}, test_locale); + std::vector dims{1, 1, 2}; + std::vector input = {"hello", "world"}; + test.AddInput("T", dims, input); + test.AddOutput("Y", {1, 1, 2}, input); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Input dimensions are either[C > 0] or [1][C > 0] allowed"); +} + +TEST(ContribOpTest, StringNormalizerInvalidDimensions2DFirstNotOne) { + // 2D input with first dim != 1 -> bypass shape metadata so the kernel validation path runs. + OpTester test("StringNormalizer", opset_ver, domain); + test.AddShapeToTensorData(false); + InitTestAttr(test, "NONE", true, {}, test_locale); + std::vector dims{2, 2}; + std::vector input = {"a", "b", "c", "d"}; + test.AddInput("T", dims, input); + test.AddOutput("Y", {2, 2}, input); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Input dimensions are either[C > 0] or [1][C > 0] allowed"); +} + +TEST(ContribOpTest, StringNormalizerEmpty1DInputRejectedForCompatibility) { + // Preserve current ORT behavior for empty 1D input. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, test_locale); + std::vector dims{0}; + std::vector input{}; + test.AddInput("T", dims, input); + test.AddOutput("Y", {0}, input); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Single dimension value must be greater than 0"); +} + +TEST(ContribOpTest, StringNormalizerEmpty2DInputRejectedForCompatibility) { + // Preserve current ORT behavior for empty [1, 0] input. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, test_locale); + std::vector dims{1, 0}; + std::vector input{}; + test.AddInput("T", dims, input); + test.AddOutput("Y", {1, 0}, input); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Input dimensions are either[C > 0] or [1][C > 0] allowed"); +} + +TEST(ContribOpTest, StringNormalizerInvalidCaseChangeAction) { + // Invalid case_change_action should be rejected during kernel construction. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "TITLE", true, {}, test_locale); + std::vector dims{1}; + std::vector input{"hello"}; + test.AddInput("T", dims, input); + test.AddOutput("Y", dims, input); + test.Run(OpTester::ExpectResult::kExpectFailure, + "attribute case_change_action has invalid value"); +} + +TEST(ContribOpTest, StringNormalizerInvalidLocale) { + // Invalid locale should be rejected when locale-sensitive processing is required + // on platforms that validate locale names. On wasm, locale construction accepts + // arbitrary names, so this path succeeds and UPPER still applies. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {}, "ort_invalid_locale_for_test"); + std::vector dims{1}; + std::vector input{"hello"}; + test.AddInput("T", dims, input); +#ifdef __wasm__ + test.AddOutput("Y", dims, std::vector{"HELLO"}); + test.Run(OpTester::ExpectResult::kExpectSuccess); +#else + test.AddOutput("Y", dims, input); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Failed to construct locale with name:"); +#endif +} + +TEST(ContribOpTest, StringNormalizerInvalidLocaleIgnoredWhenUnused) { + // Invalid locale should not matter when the runtime stays on the UTF-8 passthrough fast path. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", false, {}, "ort_invalid_locale_for_test"); + std::vector dims{1}; + std::vector input{"hello"}; + test.AddInput("T", dims, input); + test.AddOutput("Y", dims, input); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerPassthroughRejectsInvalidUtf8) { + // Byte-only passthrough path should still validate UTF-8 input. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, test_locale); + std::vector dims{1}; + std::vector input{std::string("\xF0\x28\x8C\x28", 4)}; + test.AddInput("T", dims, input); + test.AddOutput("Y", dims, input); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Input strings must be valid UTF-8"); +} + +TEST(ContribOpTest, StringNormalizerSensitiveFilteringRejectsInvalidUtf8) { + // Case-sensitive filtering path should validate UTF-8 even when no wchar conversion is needed. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {"keep"}, test_locale); + std::vector dims{2}; + std::vector input{"keep", std::string("\xF0\x28\x8C\x28", 4)}; + test.AddInput("T", dims, input); + test.AddOutput("Y", {1}, std::vector{std::string("\xF0\x28\x8C\x28", 4)}); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Input strings must be valid UTF-8"); +} + +#ifndef ORT_IOS +TEST(ContribOpTest, StringNormalizerGermanEszettLower) { + // German Eszett (ß) lowercasing: ß should remain ß. + // This tests the converter and case logic with the problematic German character. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "LOWER", true, {}, test_locale); + std::vector dims{2}; + std::vector input = {"GRÜßEN", "STRAßE"}; + test.AddInput("T", dims, input); + + std::vector output = {"grüßen", "straße"}; + test.AddOutput("Y", {2}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerInsensitiveGermanEszettFilter) { + // Case-insensitive filtering with German Eszett in stopwords. + // "grüßen" lowercased stays "grüßen", should match stopword "grüßen". + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", false, {"grüßen"}, test_locale); + std::vector dims{3}; + std::vector input = {"Grüßen", "Straße", "中文"}; + test.AddInput("T", dims, input); + + // "Grüßen" lowercased → "grüßen" → matches stopword + std::vector output = {"Straße", "中文"}; + test.AddOutput("Y", {2}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + +TEST(ContribOpTest, StringNormalizerCyrillicCaseChange) { + // Full Cyrillic case conversion test. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {}, test_locale); + std::vector dims{3}; + std::vector input = {"понедельник", "Вторник", "среда"}; + test.AddInput("T", dims, input); + + std::vector output = {"ПОНЕДЕЛЬНИК", "ВТОРНИК", "СРЕДА"}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} +#endif + +TEST(ContribOpTest, StringNormalizerNoStopwordsNoCaseChange) { + // No stopwords, NONE case change → pure passthrough (fast path). + // Tests with multilingual content to ensure passthrough preserves bytes exactly. + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, test_locale); + std::vector dims{4}; + std::vector input = {"Besançon", "Понедельник", "中文", "grüßen"}; + test.AddInput("T", dims, input); + + std::vector output = {"Besançon", "Понедельник", "中文", "grüßen"}; + test.AddOutput("Y", {4}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); +} + } // namespace test } // namespace onnxruntime From bf76a0b72b785d6c46780a49c3216106365bba1f Mon Sep 17 00:00:00 2001 From: Rishi Dave <62260675+Rishi-Dave@users.noreply.github.com> Date: Thu, 7 May 2026 10:14:10 -0700 Subject: [PATCH 32/34] feat(quantization): add calibration cache to quantize_static (#28221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Add an optional `calibration_cache_path` parameter to `quantize_static()` so users can save and reload the calibration result (`TensorsData`) across runs. - Avoids re-running the expensive calibration inference pass when iterating on post-calibration options such as `nodes_to_exclude`, `activation_type`, or `weight_type`. - Cache format is JSON, mirroring the encoder already used by `write_calibration_table` — no new serialization surface area. ## Motivation Fixes #21908. Users commonly re-run `quantize_static` multiple times on the same model and calibration dataset while varying the set of excluded nodes or the quant types, to trade off accuracy vs. speed. Today, every call repeats the full calibration inference loop even though the calibration result is identical, which is costly on large calibration datasets. There was no supported way to persist the computed tensor ranges — `write_calibration_table` writes a lossy table (drops histogram data) and has no paired reader. This PR closes that gap. ## Changes - `python/tools/quantization/calibrate.py`: - Add `TensorData.from_dict` and `TensorsData.from_dict` classmethods (inverse of existing `to_dict`). - Add module-level `_CalibrationCacheEncoder(json.JSONEncoder)`, `save_tensors_data(tensors, path)`, and `load_tensors_data(path)`. The encoder handles `TensorData`/`TensorsData`/`np.ndarray`/`CalibrationMethod`/numpy scalars. Writes are atomic (tmp file + `os.replace`) and auto-create parent directories. - `python/tools/quantization/quantize.py`: - `quantize_static` gains `calibration_cache_path: str | Path | None = None`. If the path exists, calibration is skipped and ranges are loaded from the cache. If the path is new, calibration runs and the result is saved. Raises `ValueError` if the cached `calibration_method` does not match the caller's `calibrate_method`. - `calibration_data_reader` becomes optional; at least one of it or an existing cache must be provided, else `ValueError`. - `python/tools/quantization/__init__.py`: export `TensorData`, `TensorsData`, `save_tensors_data`, `load_tensors_data`. - Tests: new `TestCalibrationCache` in `test/python/quantization/test_calibration.py` covering MinMax roundtrip, Entropy roundtrip (with histogram), missing-path error, parent-dir auto-creation, numpy scalar `bins` handling, method-mismatch guard, end-to-end `quantize_static` cache hit/miss, and `ValueError` when neither reader nor cache is provided. ## Test Plan - `python -m pytest onnxruntime/test/python/quantization/test_calibration.py::TestCalibrationCache -v` - `python -m pytest onnxruntime/test/python/quantization/test_calibration.py::TestCalibrateMinMaxCalibrator -v` (regression) - `lintrunner -a` on changed files: clean. ## Backward Compatibility `calibration_data_reader` changes from required-positional to optional-keyword. Existing call sites — whether positional or keyword — continue to work unchanged. The new behavior is only engaged when `calibration_cache_path` is provided. --- .../python/tools/quantization/__init__.py | 4 + .../python/tools/quantization/calibrate.py | 86 ++++++ .../python/tools/quantization/quant_utils.py | 17 +- .../python/tools/quantization/quantize.py | 139 ++++++--- .../python/quantization/test_calibration.py | 284 +++++++++++++++++- 5 files changed, 475 insertions(+), 55 deletions(-) diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index ac99de348f612..50b0bd08ae360 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -3,7 +3,11 @@ CalibrationDataReader, CalibrationMethod, MinMaxCalibrater, + TensorData, + TensorsData, create_calibrator, + load_tensors_data, + save_tensors_data, ) from .qdq_quantizer import QDQQuantizer # noqa: F401 from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 05a5b0873d93d..305804661cf64 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -5,9 +5,12 @@ # license information. # -------------------------------------------------------------------------- import abc +import contextlib import copy import itertools +import json import os +import tempfile import uuid from collections.abc import Sequence from enum import Enum @@ -98,6 +101,21 @@ def to_dict(self): data["CLS"] = self.__class__.__name__ return data + @classmethod + def from_dict(cls, d: dict) -> "TensorData": + """Reconstruct a TensorData from a dict produced by to_dict().""" + kwargs = {} + for k, v in d.items(): + if k == "CLS": + continue + value = v + if isinstance(value, dict) and value.get("CLS") == "numpy.array": + value = np.array(value["data"], dtype=np.dtype(value["dtype"])) + elif k in cls._floats and isinstance(value, (int, float)): + value = np.array(value, dtype=np.float32) + kwargs[k] = value + return cls(**kwargs) + class TensorsData: def __init__(self, calibration_method, data: dict[str, TensorData | tuple]): @@ -150,6 +168,18 @@ def to_dict(self): } return data + @classmethod + def from_dict(cls, d: dict) -> "TensorsData": + """Reconstruct a TensorsData from a dict produced by to_dict().""" + method_val = d["calibration_method"] + if isinstance(method_val, dict) and method_val.get("CLS") == "CalibrationMethod": + name = method_val["value"].split(".")[-1] + method = CalibrationMethod[name] + else: + method = method_val + reconstructed = {k: TensorData.from_dict(v) for k, v in d["data"].items()} + return cls(method, reconstructed) + class CalibrationMethod(Enum): MinMax = 0 @@ -184,6 +214,62 @@ def set_range(self, start_index: int, end_index: int): raise NotImplementedError +class CalibrationCacheEncoder(json.JSONEncoder): + """Shared JSON encoder for calibration caches. + + Handles numpy ndarrays and numpy scalar types (integer/floating) so + calibration JSON output is consistent across ``save_tensors_data`` and + ``quant_utils.write_calibration_table``. + """ + + def default(self, obj): + if isinstance(obj, (TensorData, TensorsData)): + return obj.to_dict() + if isinstance(obj, np.ndarray): + return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} + if isinstance(obj, CalibrationMethod): + return {"CLS": obj.__class__.__name__, "value": str(obj)} + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + return json.JSONEncoder.default(self, obj) + + +def save_tensors_data(tensors_data: "TensorsData", path: "str | Path", *, smooth_quant: bool = False) -> None: + """Serialize calibration tensor ranges to a JSON file at *path*. + + :param smooth_quant: whether the producing run used SmoothQuant. Stored in + the cache so a later load can detect a mismatch and recompute. + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_name = tempfile.mkstemp(dir=path.parent, prefix=".calibcache_", suffix=".tmp") + try: + with os.fdopen(fd, "w") as f: + payload = tensors_data.to_dict() + payload["smooth_quant"] = smooth_quant + json.dump(payload, f, cls=CalibrationCacheEncoder) + f.flush() + os.replace(tmp_name, path) + except BaseException: + with contextlib.suppress(FileNotFoundError): + os.unlink(tmp_name) + raise + + +def load_tensors_data(path: "str | Path") -> "TensorsData": + """Load calibration tensor ranges from a JSON file written by save_tensors_data().""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Calibration cache not found: {path}") + if not path.is_file(): + raise ValueError(f"Calibration cache path is not a file: {path}") + with path.open("r") as f: + d = json.load(f) + return TensorsData.from_dict(d) + + class CalibraterBase: def __init__( self, diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 0ce1e1a0d75de..c8deb0d3e395a 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -796,21 +796,14 @@ def write_calibration_table(calibration_cache, dir="."): import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue # noqa: PLC0415 import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable # noqa: PLC0415 - from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData # noqa: PLC0415 + + # Use the shared encoder from calibrate.py so write_calibration_table and + # save_tensors_data produce identical JSON for numpy scalar/array values. + from onnxruntime.quantization.calibrate import CalibrationCacheEncoder # noqa: PLC0415 logging.info(f"calibration cache: {calibration_cache}") - class MyEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (TensorData, TensorsData)): - return obj.to_dict() - if isinstance(obj, np.ndarray): - return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} - if isinstance(obj, CalibrationMethod): - return {"CLS": obj.__class__.__name__, "value": str(obj)} - return json.JSONEncoder.default(self, obj) - - json_data = json.dumps(calibration_cache, cls=MyEncoder) + json_data = json.dumps(calibration_cache, cls=CalibrationCacheEncoder) with open(os.path.join(dir, "calibration.json"), "w") as file: file.write(json_data) # use `json.loads` to do the reverse diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index b8b239b85e7ad..d6b2ecb2b17ed 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -6,6 +6,7 @@ from __future__ import annotations import copy +import json import logging import tempfile from collections.abc import Callable @@ -14,7 +15,14 @@ import onnx -from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator +from .calibrate import ( + CalibrationDataReader, + CalibrationMethod, + TensorsData, + create_calibrator, + load_tensors_data, + save_tensors_data, +) from .onnx_quantizer import ONNXQuantizer from .qdq_quantizer import QDQQuantizer from .quant_utils import ( @@ -479,7 +487,7 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua def quantize_static( model_input: str | Path | onnx.ModelProto, model_output: str | Path, - calibration_data_reader: CalibrationDataReader, + calibration_data_reader: CalibrationDataReader | None = None, quant_format=QuantFormat.QDQ, op_types_to_quantize=None, per_channel=False, @@ -492,6 +500,7 @@ def quantize_static( calibrate_method=CalibrationMethod.MinMax, calibration_providers=None, extra_options=None, + calibration_cache_path: str | Path | None = None, ): """ Given an onnx model and calibration data reader, create a quantized onnx model and save it into a file @@ -506,7 +515,13 @@ def quantize_static( model_output: file path of quantized model calibration_data_reader: a calibration data reader. It enumerates calibration data and generates inputs for the - original model. + original model. May be None if calibration_cache_path points to an + existing cache file. + calibration_cache_path: optional path to a JSON calibration cache. If + the file already exists, calibration inference is skipped and the + cached tensor ranges are loaded instead. If the file does not yet + exist, calibration runs normally and the result is saved to this + path for future reuse. quant_format: QuantFormat{QOperator, QDQ}. QOperator format quantizes the model with quantized operators directly. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. @@ -673,6 +688,11 @@ def quantize_static( } if extra_options.get("SmoothQuant", False): + if calibration_data_reader is None: + raise ValueError( + "SmoothQuant requires a non-None calibration_data_reader; the calibration cache " + "stores per-tensor ranges only and cannot drive the SmoothQuant transform." + ) import importlib # noqa: PLC0415 try: @@ -704,48 +724,83 @@ def inc_dataloader(): if is_model_updated: model = updated_model - with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: - if is_model_updated: - # Update model_input and avoid to use the original one - model_input = copy.deepcopy(model) - - if isinstance(model_input, onnx.ModelProto): - output_path = Path(quant_tmp_dir).joinpath("model_input.onnx").as_posix() - onnx.save_model( - model_input, - output_path, - save_as_external_data=True, + _cache_path = Path(calibration_cache_path) if calibration_cache_path is not None else None + if _cache_path is not None and _cache_path.exists() and not _cache_path.is_file(): + raise ValueError(f"calibration_cache_path is not a file: {_cache_path}") + _cache_hit = _cache_path is not None and _cache_path.is_file() + _smooth_quant = bool(extra_options.get("SmoothQuant", False)) + + if _cache_hit: + with _cache_path.open("r") as _f: + _raw = json.load(_f) + _cached_sq = bool(_raw.get("smooth_quant", False)) + if _cached_sq != _smooth_quant: + logging.warning( + "Calibration cache at %s was produced with smooth_quant=%s; " + "current run uses smooth_quant=%s. Recomputing ranges and overwriting cache.", + _cache_path, + _cached_sq, + _smooth_quant, ) - model_input = output_path - - calibrator = create_calibrator( - Path(model_input), - op_types_to_quantize, - augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(), - calibrate_method=calibrate_method, - use_external_data_format=use_external_data_format, - providers=calibration_providers, - extra_options=calib_extra_options, - ) - - stride = extra_options.get("CalibStridedMinMax", None) - if stride: - total_data_size = len(calibration_data_reader) - if total_data_size % stride != 0: - raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).") - - for start in range(0, total_data_size, stride): - end_index = start + stride - calibration_data_reader.set_range(start_index=start, end_index=end_index) - calibrator.collect_data(calibration_data_reader) + _cache_hit = False else: - calibrator.collect_data(calibration_data_reader) - tensors_range = calibrator.compute_data() - if not isinstance(tensors_range, TensorsData): - raise TypeError( - f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}." + tensors_range = load_tensors_data(_cache_path) + if tensors_range.calibration_method != calibrate_method: + raise ValueError( + f"Calibration cache at {_cache_path} was produced with " + f"{tensors_range.calibration_method}, but quantize_static was called " + f"with calibrate_method={calibrate_method}. Delete the cache or " + f"pass a matching calibrate_method." + ) + + if not _cache_hit: + if calibration_data_reader is None: + raise ValueError("Either calibration_data_reader or an existing calibration_cache_path must be provided.") + with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: + if is_model_updated: + # Update model_input and avoid to use the original one + model_input = copy.deepcopy(model) + + if isinstance(model_input, onnx.ModelProto): + output_path = Path(quant_tmp_dir).joinpath("model_input.onnx").as_posix() + onnx.save_model( + model_input, + output_path, + save_as_external_data=True, + ) + model_input = output_path + + calibrator = create_calibrator( + Path(model_input), + op_types_to_quantize, + augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(), + calibrate_method=calibrate_method, + use_external_data_format=use_external_data_format, + providers=calibration_providers, + extra_options=calib_extra_options, ) - del calibrator + + stride = extra_options.get("CalibStridedMinMax", None) + if stride: + total_data_size = len(calibration_data_reader) + if total_data_size % stride != 0: + raise ValueError(f"Total data size ({total_data_size}) is not divisible by stride size ({stride}).") + + for start in range(0, total_data_size, stride): + end_index = start + stride + calibration_data_reader.set_range(start_index=start, end_index=end_index) + calibrator.collect_data(calibration_data_reader) + else: + calibrator.collect_data(calibration_data_reader) + tensors_range = calibrator.compute_data() + if not isinstance(tensors_range, TensorsData): + raise TypeError( + f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}." + ) + del calibrator + + if _cache_path is not None: + save_tensors_data(tensors_range, _cache_path, smooth_quant=_smooth_quant) check_static_quant_arguments(quant_format, activation_type, weight_type) diff --git a/onnxruntime/test/python/quantization/test_calibration.py b/onnxruntime/test/python/quantization/test_calibration.py index 60c5f9d404258..7afb77f9dfab2 100644 --- a/onnxruntime/test/python/quantization/test_calibration.py +++ b/onnxruntime/test/python/quantization/test_calibration.py @@ -5,6 +5,8 @@ # license information. # -------------------------------------------------------------------------- +import json +import logging import tempfile import unittest from pathlib import Path @@ -14,7 +16,16 @@ from onnx import TensorProto, helper, numpy_helper import onnxruntime -from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod, create_calibrator +from onnxruntime.quantization import quantize_static +from onnxruntime.quantization.calibrate import ( + CalibrationDataReader, + CalibrationMethod, + TensorData, + TensorsData, + create_calibrator, + load_tensors_data, + save_tensors_data, +) def generate_input_initializer(tensor_shape, tensor_dtype, input_name): @@ -528,5 +539,276 @@ def test_compute_data_per_channel(self): np.testing.assert_equal(min_max, tensors_range[output_name].range_value) +class TestCalibrationCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_dir = tempfile.TemporaryDirectory(prefix="test_calibration_cache.") + + @classmethod + def tearDownClass(cls): + cls._tmp_dir.cleanup() + + def _make_simple_model(self, path): + """Build a tiny Conv+Relu model for end-to-end cache tests.""" + vi_input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 3]) + vi_output = helper.make_tensor_value_info("X6", TensorProto.FLOAT, [1, 3, 1, 3]) + w1 = generate_input_initializer([3, 3, 1, 1], np.float32, "W1") + b1 = generate_input_initializer([3], np.float32, "B1") + conv_node = helper.make_node("Conv", ["input", "W1", "B1"], ["X2"], name="Conv1") + relu_node = helper.make_node("Relu", ["X2"], ["X6"], name="Relu1") + graph = helper.make_graph([conv_node, relu_node], "cache_test_graph", [vi_input], [vi_output]) + graph.initializer.add().CopyFrom(w1) + graph.initializer.add().CopyFrom(b1) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + onnx.save(model, path) + + def test_save_load_tensors_data_minmax_roundtrip(self): + td = TensorsData( + CalibrationMethod.MinMax, + {"x": TensorData(lowest=np.array(-1.0, dtype=np.float32), highest=np.array(2.0, dtype=np.float32))}, + ) + cache_path = Path(self._tmp_dir.name) / "minmax_cache.json" + save_tensors_data(td, cache_path) + self.assertTrue(cache_path.exists()) + + loaded = load_tensors_data(cache_path) + self.assertEqual(loaded.calibration_method, CalibrationMethod.MinMax) + self.assertEqual(list(loaded.keys()), ["x"]) + lo, hi = loaded["x"].range_value + np.testing.assert_array_equal(lo, np.array(-1.0, dtype=np.float32)) + np.testing.assert_array_equal(hi, np.array(2.0, dtype=np.float32)) + self.assertEqual(lo.shape, ()) + self.assertEqual(hi.shape, ()) + + def test_save_load_tensors_data_entropy_roundtrip(self): + hist = np.array([1.0, 2.0, 3.0], dtype=np.float32) + hist_edges = np.array([0.0, 1.0, 2.0, 3.0], dtype=np.float32) + td = TensorsData( + CalibrationMethod.Entropy, + { + "y": TensorData( + lowest=np.array(-0.5, dtype=np.float32), + highest=np.array(0.5, dtype=np.float32), + hist=hist, + hist_edges=hist_edges, + ) + }, + ) + cache_path = Path(self._tmp_dir.name) / "entropy_cache.json" + save_tensors_data(td, cache_path) + + loaded = load_tensors_data(cache_path) + self.assertEqual(loaded.calibration_method, CalibrationMethod.Entropy) + lo, hi = loaded["y"].range_value + np.testing.assert_array_almost_equal(lo, np.array(-0.5, dtype=np.float32)) + np.testing.assert_array_almost_equal(hi, np.array(0.5, dtype=np.float32)) + np.testing.assert_array_almost_equal(loaded["y"].hist, hist) + np.testing.assert_array_almost_equal(loaded["y"].hist_edges, hist_edges) + + def test_load_tensors_data_invalid_path(self): + bogus = Path(self._tmp_dir.name) / "does_not_exist.json" + with self.assertRaises(FileNotFoundError): + load_tensors_data(bogus) + + def test_quantize_static_calibration_cache_path(self): + model_path = Path(self._tmp_dir.name) / "tiny_model.onnx" + self._make_simple_model(str(model_path)) + + cache_path = Path(self._tmp_dir.name) / "quant_cache.json" + out1_path = Path(self._tmp_dir.name) / "quantized1.onnx" + out2_path = Path(self._tmp_dir.name) / "quantized2.onnx" + + # First call: calibration_data_reader provided, cache written + data_reader = TestDataReader() + quantize_static( + str(model_path), + str(out1_path), + calibration_data_reader=data_reader, + calibration_cache_path=cache_path, + ) + self.assertTrue(cache_path.exists()) + td1 = load_tensors_data(cache_path) + + # Second call: no data_reader, load from cache + quantize_static( + str(model_path), + str(out2_path), + calibration_data_reader=None, + calibration_cache_path=cache_path, + ) + self.assertTrue(out2_path.exists()) + td2 = load_tensors_data(cache_path) + self.assertEqual(td1.calibration_method, td2.calibration_method) + + def test_quantize_static_no_reader_no_cache_raises(self): + model_path = Path(self._tmp_dir.name) / "tiny_model2.onnx" + self._make_simple_model(str(model_path)) + out_path = Path(self._tmp_dir.name) / "quantized_err.onnx" + + with self.assertRaises(ValueError): + quantize_static(str(model_path), str(out_path), calibration_data_reader=None) + + def test_save_tensors_data_creates_parent_dir(self): + nested_path = Path(self._tmp_dir.name) / "nested" / "dir" / "cache.json" + td = TensorsData( + CalibrationMethod.MinMax, + {"x": TensorData(lowest=np.array(-1.0, dtype=np.float32), highest=np.array(1.0, dtype=np.float32))}, + ) + save_tensors_data(td, nested_path) + self.assertTrue(nested_path.exists()) + + def test_save_tensors_data_handles_scalar_bins(self): + td = TensorsData( + CalibrationMethod.Entropy, + { + "z": TensorData( + lowest=np.array(0.0, dtype=np.float32), + highest=np.array(1.0, dtype=np.float32), + hist=np.array([1, 2], dtype=np.int64), + bins=np.int64(5), + ) + }, + ) + cache_path = Path(self._tmp_dir.name) / "scalar_bins_cache.json" + save_tensors_data(td, cache_path) + loaded = load_tensors_data(cache_path) + self.assertEqual(loaded["z"].bins, 5) + + def test_load_tensors_data_method_mismatch_raises(self): + model_path = Path(self._tmp_dir.name) / "tiny_mismatch.onnx" + self._make_simple_model(str(model_path)) + cache_path = Path(self._tmp_dir.name) / "mismatch_cache.json" + out_path = Path(self._tmp_dir.name) / "quantized_mismatch.onnx" + + data_reader = TestDataReader() + quantize_static( + str(model_path), + str(out_path), + calibration_data_reader=data_reader, + calibrate_method=CalibrationMethod.MinMax, + calibration_cache_path=cache_path, + ) + + with self.assertRaises(ValueError): + quantize_static( + str(model_path), + str(out_path), + calibration_data_reader=None, + calibrate_method=CalibrationMethod.Entropy, + calibration_cache_path=cache_path, + ) + + def test_save_tensors_data_writes_smooth_quant_field(self): + """save_tensors_data persists the smooth_quant flag in the JSON payload.""" + td = TensorsData( + CalibrationMethod.MinMax, + {"x": TensorData(lowest=np.array(-1.0, dtype=np.float32), highest=np.array(1.0, dtype=np.float32))}, + ) + for flag in (False, True): + cache_path = Path(self._tmp_dir.name) / f"sq_{flag}_cache.json" + save_tensors_data(td, cache_path, smooth_quant=flag) + with cache_path.open("r") as f: + raw = json.load(f) + self.assertIn("smooth_quant", raw) + self.assertEqual(raw["smooth_quant"], flag) + + def test_smooth_quant_mismatch_triggers_recompute(self): + """Cache produced with smooth_quant=True must not be used for a smooth_quant=False run.""" + model_path = Path(self._tmp_dir.name) / "sq_mismatch_model.onnx" + self._make_simple_model(str(model_path)) + cache_path = Path(self._tmp_dir.name) / "sq_mismatch_cache.json" + out1_path = Path(self._tmp_dir.name) / "sq_mismatch_out1.onnx" + + # Write a cache that claims smooth_quant=True by injecting the field directly. + td = TensorsData( + CalibrationMethod.MinMax, + {"x": TensorData(lowest=np.array(-1.0, dtype=np.float32), highest=np.array(1.0, dtype=np.float32))}, + ) + save_tensors_data(td, cache_path, smooth_quant=True) + with cache_path.open("r") as f: + self.assertEqual(json.load(f)["smooth_quant"], True) + + # Run with smooth_quant=False (default): the cache must be treated as a miss. + # We supply a real data_reader so recompute can proceed. + data_reader = TestDataReader() + with self.assertLogs("root", level=logging.WARNING) as log_cm: + quantize_static( + str(model_path), + str(out1_path), + calibration_data_reader=data_reader, + calibration_cache_path=cache_path, + # SmoothQuant not set -> defaults to False + ) + # At least one WARNING about the mismatch must have been emitted. + self.assertTrue( + any("smooth_quant" in msg for msg in log_cm.output), + msg=f"Expected smooth_quant warning; got: {log_cm.output}", + ) + # The rewritten cache must now have smooth_quant=False. + with cache_path.open("r") as f: + self.assertEqual(json.load(f)["smooth_quant"], False) + + def test_smooth_quant_match_produces_cache_hit(self): + """Cache with smooth_quant=False is reused when the run also uses smooth_quant=False.""" + model_path = Path(self._tmp_dir.name) / "sq_hit_model.onnx" + self._make_simple_model(str(model_path)) + cache_path = Path(self._tmp_dir.name) / "sq_hit_cache.json" + out1_path = Path(self._tmp_dir.name) / "sq_hit_out1.onnx" + out2_path = Path(self._tmp_dir.name) / "sq_hit_out2.onnx" + + # First run: write cache with smooth_quant=False (the default). + data_reader = TestDataReader() + quantize_static( + str(model_path), + str(out1_path), + calibration_data_reader=data_reader, + calibration_cache_path=cache_path, + ) + with cache_path.open("r") as f: + self.assertEqual(json.load(f)["smooth_quant"], False) + + # Second run: no data_reader, cache should be a hit (smooth_quant matches). + quantize_static( + str(model_path), + str(out2_path), + calibration_data_reader=None, + calibration_cache_path=cache_path, + ) + self.assertTrue(out2_path.exists()) + + def test_old_cache_without_smooth_quant_field_treated_as_false(self): + """A legacy cache without a smooth_quant key is assumed smooth_quant=False.""" + model_path = Path(self._tmp_dir.name) / "legacy_sq_model.onnx" + self._make_simple_model(str(model_path)) + cache_path = Path(self._tmp_dir.name) / "legacy_sq_cache.json" + out1_path = Path(self._tmp_dir.name) / "legacy_sq_out1.onnx" + out2_path = Path(self._tmp_dir.name) / "legacy_sq_out2.onnx" + + # First run: populate a real cache against the actual model so tensor names match. + data_reader = TestDataReader() + quantize_static( + str(model_path), + str(out1_path), + calibration_data_reader=data_reader, + calibration_cache_path=cache_path, + ) + + # Strip the smooth_quant field to simulate a legacy cache file. + with cache_path.open("r") as f: + raw = json.load(f) + raw.pop("smooth_quant", None) + with cache_path.open("w") as f: + json.dump(raw, f) + + # A run with smooth_quant=False (default) must treat the legacy cache as a hit (no recompute needed). + quantize_static( + str(model_path), + str(out2_path), + calibration_data_reader=None, + calibration_cache_path=cache_path, + ) + self.assertTrue(out2_path.exists()) + + if __name__ == "__main__": unittest.main() From 0a341b096271492fd02b9fc939ca7f9a9c54c466 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 7 May 2026 11:31:46 -0700 Subject: [PATCH 33/34] [WebGPU plugin EP packaging] Remove explicit ORT package dependency (#28384) ### Description This pull request refactors how the minimum required ONNX Runtime (ORT) version is handled and communicated for the WebGPU plugin EP packages (both Python and C#). Instead of declaring a hard dependency on a specific ORT version in package manifests, the minimum compatible version is now injected into package READMEs at build time, and ORT version compatibility will be checked at runtime. The packaging scripts are updated to use a shared template utility, and the CI/test setup is adjusted accordingly. The most important changes are: **Minimum ORT Version Handling and Documentation:** * The Python and C# plugin EP packages no longer declare a hard dependency on a specific ONNX Runtime package version in their manifests (`pyproject.toml`, `.csproj`). Instead, the minimum required ORT version is injected into the package `README.md` during packaging, and users are instructed to install or reference the correct ORT version themselves. **Packaging and Build Script Refactoring:** * A new shared utility, `_packaging_utils.py`, is added for template-based file generation, and both Python (`build_wheel.py`) and C# (`pack_nuget.py`) packaging scripts are updated to use this helper for injecting the minimum ORT version into `README.md`. Duplicate template code is removed. * The C# packaging script and project file are refactored to remove the mechanism for passing the minimum ORT version via MSBuild properties, since the dependency is now documented rather than enforced at build/package time. **Test and CI Adjustments:** * The Python test pipelines for Linux, macOS, and Windows are updated to explicitly install `onnxruntime` before testing the plugin EP, since it is no longer a transitive dependency. These changes improve clarity for users about required dependencies, centralize version management, and simplify the packaging scripts. NOTE: The WebGPU EP ORT version compatibility check does not actually use the value in `./plugin-ep-webgpu/MIN_ONNXRUNTIME_VERSION` yet. That can be done in another PR. ### Motivation and Context There are multiple packages that can provide an ORT library. Even though the basic, CPU EP-only ORT package would typically be sufficient, we don't want to mandate usage of a specific package flavor. --- plugin-ep-webgpu/README.md | 4 +- plugin-ep-webgpu/_packaging_utils.py | 44 ++++++++++++++++ .../Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj | 22 -------- .../README.md | 8 +++ plugin-ep-webgpu/csharp/pack_nuget.py | 36 ++++++++++--- .../WebGpuEpNuGetTest.csproj | 10 +++- plugin-ep-webgpu/python/build_wheel.py | 51 +++++-------------- .../python/onnxruntime_ep_webgpu/README.md | 12 ++++- plugin-ep-webgpu/python/pyproject.toml.in | 3 -- .../stages/plugin-linux-webgpu-test-stage.yml | 17 +++++-- .../stages/plugin-mac-webgpu-test-stage.yml | 10 +++- .../stages/plugin-win-webgpu-test-stage.yml | 14 ++++- 12 files changed, 149 insertions(+), 82 deletions(-) create mode 100644 plugin-ep-webgpu/_packaging_utils.py diff --git a/plugin-ep-webgpu/README.md b/plugin-ep-webgpu/README.md index 889fef10ae5e1..2b66c06eeefa5 100644 --- a/plugin-ep-webgpu/README.md +++ b/plugin-ep-webgpu/README.md @@ -11,7 +11,9 @@ For more information about plugin EPs, see the documentation [here](https://onnx final package version (release, dev) from this via [`tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml`](../tools/ci_build/github/azure-pipelines/templates/set-plugin-build-variables-step.yml). - [`MIN_ONNXRUNTIME_VERSION`](MIN_ONNXRUNTIME_VERSION) — Minimum compatible core `onnxruntime` version. Single source - of truth shared by all packages built from this directory. + of truth shared by all packages built from this directory. The packages do not declare a hard dependency on a + specific ONNX Runtime package; instead, this version string is injected into each package's README at build/pack + time, and the native plugin EP code validates compatibility at registration time. - [`python/`](python/) — Sources and build script for the `onnxruntime-ep-webgpu` Python wheel. See [`python/README.md`](python/README.md) for build and test instructions. - [`csharp/`](csharp/) — Sources and packaging script for the `Microsoft.ML.OnnxRuntime.EP.WebGpu` NuGet package. See diff --git a/plugin-ep-webgpu/_packaging_utils.py b/plugin-ep-webgpu/_packaging_utils.py new file mode 100644 index 0000000000000..201b3342ff39c --- /dev/null +++ b/plugin-ep-webgpu/_packaging_utils.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +"""Shared utilities for the WebGPU plugin EP packaging scripts. Not a public API.""" + +from __future__ import annotations + +import re +from pathlib import Path + +# Matches "@var@" template variables. +_TEMPLATE_VARIABLE_PATTERN = re.compile(r"@(\w+)@") + + +def gen_file_from_template( + template_file: Path, output_file: Path, variable_substitutions: dict[str, str], strict: bool = True +) -> None: + """Generate a file from a template by substituting "@var@" markers with provided values. + + If `strict` is True, raises ValueError when the set of "@var@" names found in the template + does not match the keys of `variable_substitutions`. + + Note: substituted values are inserted verbatim with no awareness of the target file's syntax. + The caller is responsible for any quoting/escaping required by the target format. + """ + content = template_file.read_text(encoding="utf-8") + + variables_in_file: set[str] = set() + + def replace(match: re.Match[str]) -> str: + name = match.group(1) + variables_in_file.add(name) + return variable_substitutions.get(name, match.group(0)) + + content = _TEMPLATE_VARIABLE_PATTERN.sub(replace, content) + + if strict and variables_in_file != variable_substitutions.keys(): + provided = set(variable_substitutions.keys()) + raise ValueError( + f"Template variables and substitution keys do not match for {template_file}. " + f"Only in template: {sorted(variables_in_file - provided)}. " + f"Only in substitutions: {sorted(provided - variables_in_file)}." + ) + + output_file.write_text(content, encoding="utf-8") diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj index 94be6bec6ea46..58860c46b9c16 100644 --- a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj @@ -26,28 +26,6 @@ snupkg
- - - $(MSBuildThisFileDirectory)..\..\MIN_ONNXRUNTIME_VERSION - $([System.IO.File]::ReadAllText('$(OnnxRuntimeMinVersionFile)').Trim()) - - - - - - - - - - - diff --git a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md index f4a717b8836d5..6b92dc909784f 100644 --- a/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md +++ b/plugin-ep-webgpu/csharp/Microsoft.ML.OnnxRuntime.EP.WebGpu/README.md @@ -2,6 +2,14 @@ WebGPU plugin Execution Provider for [ONNX Runtime](https://github.com/microsoft/onnxruntime). +### Prerequisites + +This package provides the WebGPU plugin EP only. Your project must separately reference an ONNX Runtime +core package (e.g. `Microsoft.ML.OnnxRuntime`) of version `@min_onnxruntime_version@` or later. + +If the referenced ONNX Runtime is incompatible, the plugin EP will report an error when its library is +registered. + ### Usage ```csharp diff --git a/plugin-ep-webgpu/csharp/pack_nuget.py b/plugin-ep-webgpu/csharp/pack_nuget.py index 9a29d067a4034..b1ce61c0480e2 100644 --- a/plugin-ep-webgpu/csharp/pack_nuget.py +++ b/plugin-ep-webgpu/csharp/pack_nuget.py @@ -45,6 +45,10 @@ CSPROJ = PROJECT_DIR / "Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj" MIN_ORT_VERSION_FILE = SCRIPT_DIR.parent / "MIN_ONNXRUNTIME_VERSION" +# Import the shared template helper from _packaging_utils.py in the parent directory. +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from _packaging_utils import gen_file_from_template # noqa: E402 (path setup must precede import) + class PackError(RuntimeError): """Raised for any user-actionable failure during packaging.""" @@ -208,14 +212,12 @@ def stage_binaries( def dotnet_common_args( staged_csproj: Path, args: argparse.Namespace, - min_ort_version_file: Path, ) -> list[str]: common = [ str(staged_csproj), "--configuration", args.configuration, f"-p:Version={args.version}", - f"-p:OnnxRuntimeMinVersionFile={min_ort_version_file}", ] if args.nuget_config: common.extend(["--configfile", str(args.nuget_config)]) @@ -223,10 +225,10 @@ def dotnet_common_args( return common -def do_build(staged_csproj: Path, staging_dir: Path, args: argparse.Namespace, min_ort_version_file: Path) -> None: +def do_build(staged_csproj: Path, staging_dir: Path, args: argparse.Namespace) -> None: print() print(f"Running dotnet build (Version={args.version}, Configuration={args.configuration})...") - cmd = ["dotnet", "build", *dotnet_common_args(staged_csproj, args, min_ort_version_file)] + cmd = ["dotnet", "build", *dotnet_common_args(staged_csproj, args)] print("+ " + " ".join(cmd)) subprocess.run(cmd, check=True) @@ -243,14 +245,13 @@ def do_pack( staged_csproj: Path, output_dir: Path, args: argparse.Namespace, - min_ort_version_file: Path, ) -> None: print() print(f"Running dotnet pack (Version={args.version}, Configuration={args.configuration})...") pack_args = [ "dotnet", "pack", - *dotnet_common_args(staged_csproj, args, min_ort_version_file), + *dotnet_common_args(staged_csproj, args), "--output", str(output_dir), ] @@ -269,6 +270,21 @@ def do_pack( print(f"Produced: {pkg.name} ({pkg.stat().st_size / (1024 * 1024):.2f} MB)") +def render_readme(staging_dir: Path, min_ort_version: str) -> None: + """Substitute the minimum ORT version into the staged README in place.""" + readme = staging_dir / "README.md" + if not readme.is_file(): + raise PackError(f"staged README not found: {readme}") + try: + gen_file_from_template( + readme, + readme, + {"min_onnxruntime_version": min_ort_version}, + ) + except ValueError as e: + raise PackError(str(e)) from e + + def run_in_staging(args: argparse.Namespace, staging_dir: Path, min_ort_version_file: Path) -> None: staged_csproj = staging_dir / "Microsoft.ML.OnnxRuntime.EP.WebGpu.csproj" output_dir: Path = args.output_dir @@ -282,12 +298,16 @@ def run_in_staging(args: argparse.Namespace, staging_dir: Path, min_ort_version_ else: stage_sources(staging_dir) stage_binaries(staging_dir, args, required_platforms) + min_ort_version = min_ort_version_file.read_text(encoding="utf-8").strip() + if not min_ort_version: + raise PackError(f"{min_ort_version_file} is empty") + render_readme(staging_dir, min_ort_version) if args.build_only: - do_build(staged_csproj, staging_dir, args, min_ort_version_file) + do_build(staged_csproj, staging_dir, args) return - do_pack(staged_csproj, output_dir, args, min_ort_version_file) + do_pack(staged_csproj, output_dir, args) print() print(f"Done. Output: {output_dir}") diff --git a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj index 9554161b1e978..6162c9a33e81c 100644 --- a/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj +++ b/plugin-ep-webgpu/csharp/test/WebGpuEpNuGetTest/WebGpuEpNuGetTest.csproj @@ -13,10 +13,18 @@ a single-package local feed. --> *-* + + *-* - + + diff --git a/plugin-ep-webgpu/python/build_wheel.py b/plugin-ep-webgpu/python/build_wheel.py index b4357bcdfbe0f..6f19b88838bf9 100644 --- a/plugin-ep-webgpu/python/build_wheel.py +++ b/plugin-ep-webgpu/python/build_wheel.py @@ -10,7 +10,6 @@ import argparse import platform -import re import shutil import subprocess import sys @@ -20,41 +19,9 @@ SCRIPT_DIR = Path(__file__).parent MIN_ONNXRUNTIME_VERSION_FILE = SCRIPT_DIR.parent / "MIN_ONNXRUNTIME_VERSION" -# Matches "@var@" template variables. -_TEMPLATE_VARIABLE_PATTERN = re.compile(r"@(\w+)@") - - -def gen_file_from_template( - template_file: Path, output_file: Path, variable_substitutions: dict[str, str], strict: bool = True -) -> None: - """Generate a file from a template by substituting "@var@" markers with provided values. - - If `strict` is True, raises ValueError when the set of "@var@" names found in the template - does not match the keys of `variable_substitutions`. - - Note: substituted values are inserted verbatim with no awareness of the target file's syntax. - The caller is responsible for any quoting/escaping required by the target format. - """ - content = template_file.read_text(encoding="utf-8") - - variables_in_file: set[str] = set() - - def replace(match: re.Match[str]) -> str: - name = match.group(1) - variables_in_file.add(name) - return variable_substitutions.get(name, match.group(0)) - - content = _TEMPLATE_VARIABLE_PATTERN.sub(replace, content) - - if strict and variables_in_file != variable_substitutions.keys(): - provided = set(variable_substitutions.keys()) - raise ValueError( - f"Template variables and substitution keys do not match for {template_file}. " - f"Only in template: {sorted(variables_in_file - provided)}. " - f"Only in substitutions: {sorted(provided - variables_in_file)}." - ) - - output_file.write_text(content, encoding="utf-8") +# Import the shared template helper from _packaging_utils.py in the parent directory. +sys.path.insert(0, str(SCRIPT_DIR.parent)) +from _packaging_utils import gen_file_from_template # noqa: E402, I001 (path setup must precede import) # Patterns for binaries to include in the package @@ -98,15 +65,23 @@ def prepare_staging_dir(staging_dir: Path, binary_dir: Path, version: str): if not copied: raise FileNotFoundError(f"No plugin binaries found in {binary_dir}. Looked for: {BINARY_PATTERNS}") - # Render pyproject.toml from its template + # Substitute the minimum ORT version into the staged README in place. min_ort_version = MIN_ONNXRUNTIME_VERSION_FILE.read_text(encoding="utf-8").strip() if not min_ort_version: raise ValueError(f"{MIN_ONNXRUNTIME_VERSION_FILE} is empty") + staged_readme = package_dir / "README.md" + gen_file_from_template( + staged_readme, + staged_readme, + {"min_onnxruntime_version": min_ort_version}, + ) + + # Render pyproject.toml from its template gen_file_from_template( SCRIPT_DIR / "pyproject.toml.in", staging_dir / "pyproject.toml", - {"version": version, "min_onnxruntime_version": min_ort_version}, + {"version": version}, ) diff --git a/plugin-ep-webgpu/python/onnxruntime_ep_webgpu/README.md b/plugin-ep-webgpu/python/onnxruntime_ep_webgpu/README.md index 3200f0dd08ff0..3d28d23a96172 100644 --- a/plugin-ep-webgpu/python/onnxruntime_ep_webgpu/README.md +++ b/plugin-ep-webgpu/python/onnxruntime_ep_webgpu/README.md @@ -1,10 +1,20 @@ # ONNX Runtime WebGPU Plugin Execution Provider -WebGPU Execution Provider plugin for ONNX Runtime. Install alongside `onnxruntime` to enable WebGPU acceleration. +WebGPU Execution Provider plugin for ONNX Runtime. Install alongside `onnxruntime` to enable WebGPU +acceleration. + +## Prerequisites + +This package provides the WebGPU plugin EP only. You must separately install an ONNX Runtime package +(e.g. `onnxruntime`) of version `@min_onnxruntime_version@` or later. + +If the installed ONNX Runtime is incompatible, the plugin EP will report an error when its library is +registered. ## Installation ```bash +pip install "onnxruntime>=@min_onnxruntime_version@" pip install onnxruntime-ep-webgpu ``` diff --git a/plugin-ep-webgpu/python/pyproject.toml.in b/plugin-ep-webgpu/python/pyproject.toml.in index 9eee3235f0294..83ce01f38d1c8 100644 --- a/plugin-ep-webgpu/python/pyproject.toml.in +++ b/plugin-ep-webgpu/python/pyproject.toml.in @@ -9,9 +9,6 @@ description = "ONNX Runtime WebGPU Plugin Execution Provider" readme = "onnxruntime_ep_webgpu/README.md" license = {text = "MIT"} requires-python = ">=3.11" -dependencies = [ - "onnxruntime>=@min_onnxruntime_version@", -] [tool.setuptools.packages.find] include = ["onnxruntime_ep_webgpu*"] diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml index 12ee9ca68bb4e..771d92fe2c314 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-linux-webgpu-test-stage.yml @@ -53,12 +53,18 @@ stages: - script: | set -e -x + # Pin Vulkan to SwiftShader (software Vulkan, built from source in # the Docker image) so the test does not require a GPU agent. # Keeping these env vars at `docker run` time (rather than baking # them into the image) leaves the image reusable for a potential # future real-GPU test job. swiftshader_icd=/opt/swiftshader/vk_swiftshader_icd.json + + min_ort_version_file="$(Build.SourcesDirectory)/plugin-ep-webgpu/MIN_ONNXRUNTIME_VERSION" + min_ort_version=$(tr -d '\r\n' < "$min_ort_version_file") + echo "using minimum onnxruntime version ${min_ort_version} from ${min_ort_version_file}" + docker run --rm \ --volume "$(Build.SourcesDirectory):/onnxruntime_src" \ --volume "$(Build.BinariesDirectory):/build" \ @@ -66,14 +72,15 @@ stages: --env "VK_ICD_FILENAMES=${swiftshader_icd}" \ --env "VK_DRIVER_FILES=${swiftshader_icd}" \ --env "ORT_TEST_VERBOSE=$(System.Debug)" \ + --env "MIN_ORT_VERSION=${min_ort_version}" \ onnxruntimewebgpuplugin \ - /bin/bash -c " + /bin/bash -c ' set -e -x python3 -m venv /build/test_venv source /build/test_venv/bin/activate - python3 -m pip install onnx numpy - wheel=\$(find /build/python_wheel -name 'onnxruntime_ep_webgpu-*.whl' | head -1) - python3 -m pip install \"\$wheel\" + python3 -m pip install onnx "onnxruntime==${MIN_ORT_VERSION}" numpy + wheel=$(find /build/python_wheel -name "onnxruntime_ep_webgpu-*.whl" | head -1) + python3 -m pip install "$wheel" python3 -u /onnxruntime_src/plugin-ep-webgpu/python/test/test_webgpu_plugin_ep.py - " + ' displayName: 'Install and test Python package' diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml index 6dca5dd450fd0..70d14d184120b 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-mac-webgpu-test-stage.yml @@ -28,11 +28,19 @@ stages: - script: | set -e -x + python3 -m venv "$(Build.BinariesDirectory)/test_venv" source "$(Build.BinariesDirectory)/test_venv/bin/activate" - python3 -m pip install onnx numpy + + min_ort_version_file="$(Build.SourcesDirectory)/plugin-ep-webgpu/MIN_ONNXRUNTIME_VERSION" + min_ort_version=$(tr -d '\r\n' < "$min_ort_version_file") + echo "using minimum onnxruntime version ${min_ort_version} from ${min_ort_version_file}" + + python3 -m pip install onnx "onnxruntime==${min_ort_version}" numpy + wheel=$(find "$(Pipeline.Workspace)/build/webgpu_plugin_python_macos_arm64" -name "onnxruntime_ep_webgpu-*.whl" | head -1) python3 -m pip install "$wheel" + python3 -u "$(Build.SourcesDirectory)/plugin-ep-webgpu/python/test/test_webgpu_plugin_ep.py" displayName: 'Install and test Python package' env: diff --git a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml index 1494584ff98fd..997cf528483fb 100644 --- a/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/plugin-win-webgpu-test-stage.yml @@ -40,8 +40,12 @@ stages: echo "activating test_venv" & "$(Build.BinariesDirectory)\test_venv\Scripts\Activate.ps1" + $minOrtVersionFile = "$(Build.SourcesDirectory)\plugin-ep-webgpu\MIN_ONNXRUNTIME_VERSION" + $minOrtVersion = (Get-Content $minOrtVersionFile -Raw).Trim() + echo "using minimum onnxruntime version $minOrtVersion from $minOrtVersionFile" + echo "installing test dependencies" - python -m pip install onnx numpy + python -m pip install onnx "onnxruntime==$minOrtVersion" numpy $wheelDir = "$(Pipeline.Workspace)\build\webgpu_plugin_python_win_${{ parameters.arch }}" $wheel = (Get-ChildItem "$wheelDir\onnxruntime_ep_webgpu-*.whl")[0] @@ -102,6 +106,11 @@ stages: Write-Host "Detected package version: $packageVersion" Write-Host "##vso[task.setvariable variable=OrtWebGpuPackageVersion]$packageVersion" + $minOrtVersionFile = "$(Build.SourcesDirectory)\plugin-ep-webgpu\MIN_ONNXRUNTIME_VERSION" + $minOrtVersion = (Get-Content $minOrtVersionFile -Raw).Trim() + Write-Host "Using minimum core ORT version: $minOrtVersion from $minOrtVersionFile" + Write-Host "##vso[task.setvariable variable=OrtCoreTestVersion]$minOrtVersion" + # Write a project-level nuget.config that adds ONLY the local feed. # NuGet merges this with the repo-root NuGet.config. $nugetConfig = "$(Build.SourcesDirectory)\plugin-ep-webgpu\csharp\test\WebGpuEpNuGetTest\nuget.config" @@ -122,7 +131,8 @@ stages: dotnet build ` "$(WebGpuTestProject)" ` --configuration Release ` - -p:OrtWebGpuPackageVersion=$(OrtWebGpuPackageVersion) + -p:OrtWebGpuPackageVersion=$(OrtWebGpuPackageVersion) ` + -p:OrtCoreTestVersion=$(OrtCoreTestVersion) displayName: 'Build test project' - pwsh: | From ec55d3cf336d3a5cdbd03cecb10748a4ee4397f0 Mon Sep 17 00:00:00 2001 From: umangb-09 Date: Fri, 8 May 2026 00:43:43 +0530 Subject: [PATCH 34/34] Fix Subgraph_t issues with TRT RTX ver 1.5.x (#28361) Fixed a build break issue on ORT with NV EP for RTX 1.5.0.97 version --- .../core/providers/nv_tensorrt_rtx/nv_execution_provider.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index a436f2f424947..7382ec6e19f19 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -138,9 +138,9 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { */ using ShapeRangesMap = std::unordered_map>>>; -// SubGraph_t and SubGraphCollection_t were defined in NvOnnxParser.h up to TRT-RTX 1.5.x -// but removed in 1.6.0. Define them here for 1.6+ so the provider owns these ORT-internal types. -#if TRT_MINOR_RTX >= 6 +// SubGraph_t and SubGraphCollection_t were removed from NvOnnxParser.h starting in TRT-RTX 1.5.0.99. +// Define them here so the provider owns these ORT-internal types for any SDK that no longer ships them. +#if TRT_MAJOR_RTX >= 2 || (TRT_MAJOR_RTX == 1 && ((TRT_MINOR_RTX == 5 && TRT_BUILD_RTX >= 99) || TRT_MINOR_RTX >= 6)) using SubGraph_t = std::pair, bool>; using SubGraphCollection_t = std::vector; #endif