From 04f744032cf21e182b2d03391d64fa2a970a3aea Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 14 May 2026 14:08:25 -0700 Subject: [PATCH 01/11] fix(security): add SafeInt overflow protection in Expand and constant folding output size limit (#28055) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Harden the constant folding optimizer and the `Expand` CPU kernel against integer overflow attacks from crafted ONNX models. **Problem:** The `Expand::Compute()` kernel performs cumulative dimension multiplications (`input_count *= input_dim`, `output_count *= output_dim`) using raw `int64_t` arithmetic. When triggered during constant folding at `CreateSession()` time via a crafted model with extreme shape values, signed integer overflow can produce corrupted values used for buffer offset calculations and `memcpy` lengths, creating a potential out-of-bounds write. The downstream SafeInt check in the allocator catches overflow only when the total byte count wraps, but carefully chosen dimensions can make the overflowed value appear valid. Additionally, the constant folding optimizer has no output size budget — any deterministic node with constant inputs is eligible for constant folding regardless of output size, enabling memory exhaustion attacks at model load time. ### Key Changes **1. SafeInt-protected arithmetic in `expand.cc`** Wraps all dimension accumulation and offset/length calculations with `SafeInt` or `SafeInt` to catch overflow before it can corrupt buffer arithmetic: | Location | Before | After | |---|---|---| | Accumulator loop (L97-98) | `input_count *= input_dim` | `SafeInt(input_count) * input_dim` | | Accumulator loop (L109) | `last_dim_size *= expand_dim_size[...]` | `SafeInt(last_dim_size) * ...` | | copy_byte (L116) | `copy_len * sizeof(T)` | `SafeInt(copy_len) * sizeof(T)` | | input_offset (L122) | `i * copy_len` | `SafeInt(i) * copy_len` | | output_offset (L126) | `output_offset += current_count * ...` | `SafeInt(output_offset) + SafeInt(current_count) * ...` | **2. Constant folding output size limit in `constant_folding.cc`** - **Pre-execution check**: `EstimateNodeOutputSizeInBytes()` uses shape inference results with SafeInt-protected arithmetic to estimate total output bytes. Nodes exceeding the limit are skipped. - **Post-execution check**: After `kernel->Compute()`, actual output `SizeInBytes()` is verified against the limit (catches cases where shape inference couldn't determine output size). - **Exception isolation**: `kernel->Compute()` is wrapped in `try/catch` so that SafeInt overflow exceptions from individual nodes skip the node rather than aborting the entire optimization pass. - **Configurable limit**: New session option `optimization.constant_folding_max_output_size_in_bytes` (default: 1 GB, `"0"` to disable). **3. Session option** New key `kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes` in `onnxruntime_session_options_config_keys.h`. ### Motivation and Context This addresses a security vulnerability where a malicious ONNX model can cause signed integer overflow in the Expand kernel during constant folding at model load time (`CreateSession()`), potentially leading to out-of-bounds memory writes. The constant folding size limit provides defense-in-depth against memory exhaustion attacks from untrusted models. ### Testing - `ConstantFoldingOutputSizeLimit` — Verifies 4 MB Expand is blocked at 1 MB limit, allowed at 8 MB limit. - `ConstantFoldingDefaultLimitBlocksLargeExpand` — Verifies 1 GB ConstantOfShape is blocked at 512 MB limit. - `ConstantFoldingSmallOutputAllowed` — Verifies small Expand (64 bytes) is still folded normally. - `ConstantFoldingExpandOverflowDimsSkipped` — Verifies Expand with `[2^32, 2^32]` dimensions (int64 overflow) is gracefully skipped during constant folding. --- .../onnxruntime_session_options_config_keys.h | 12 + .../core/optimizer/constant_folding.cc | 231 +++++++++++++++++- .../core/providers/cpu/tensor/expand.cc | 14 +- .../test/optimizer/graph_transform_test.cc | 195 +++++++++++++++ 4 files changed, 443 insertions(+), 9 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 24557bb81bce3..9d61165927d8c 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -124,6 +124,18 @@ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimiz // Default is an empty string which means no optimizers are disabled. static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers"; +// Maximum total output size in bytes that the constant folding optimizer is allowed to produce per node. +// Prevents malicious models from causing excessive memory allocation during optimization. +// If the estimated or actual output size of a constant-foldable node exceeds this limit, the node will +// not be constant folded and will instead be executed at runtime. +// +// Option values: +// - A positive integer (as string): Maximum allowed output size in bytes per constant-folded node. +// Default is "1073741824" (1 GB). +// - "0": Disable the size limit (not recommended for untrusted models). +static const char* const kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes = + "optimization.constant_folding_max_output_size_in_bytes"; + // It controls whether to run graph optimizations in loop or not. // // "0": disable. Graph Optimization Loop is disabled. diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index cb6d65342bc54..1b8bb57d6a74e 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -2,15 +2,18 @@ // Licensed under the MIT License. #include +#include #include "core/optimizer/constant_folding.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" -#include "core/optimizer/utils.h" #include "core/framework/op_kernel.h" #include "core/framework/tensorprotoutils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/common/safeint.h" +#include "core/common/parse_string.h" using namespace onnxruntime::common; @@ -140,6 +143,154 @@ static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Log return status; } +// Default maximum output size per constant-folded node: 1 GB. +// This prevents malicious models from causing excessive memory allocation during optimization. +static constexpr int64_t kDefaultConstantFoldingMaxOutputSizeInBytes = 1024 * 1024 * 1024; + +static size_t GetElementSizeForConstantFolding(ONNX_NAMESPACE::TensorProto_DataType elem_type) { + const size_t element_size = utils::GetElementSizeOfTensor(elem_type); + if (element_size != 0) { + return element_size; + } + + // String tensors allocate storage for std::string slots even though the payload size is variable. + return elem_type == ONNX_NAMESPACE::TensorProto_DataType_STRING ? sizeof(std::string) : 0; +} + +static int64_t EstimateTensorElementCount(const ONNX_NAMESPACE::TensorShapeProto& shape) { + SafeInt num_elements = 1; + for (int i = 0; i < shape.dim_size(); ++i) { + const auto& dim = shape.dim(i); + if (!utils::HasDimValue(dim)) { + return -1; // Symbolic dimension + } + int64_t dim_value = dim.dim_value(); + if (dim_value < 0) { + return -1; // Invalid dimension + } + num_elements *= dim_value; + } + + return num_elements; +} + +static int64_t EstimateTensorSizeInBytes(const NodeArg& node_arg) { + const auto* type_proto = node_arg.TypeAsProto(); + if (type_proto == nullptr || !utils::HasTensorType(*type_proto)) { + return -1; // Cannot estimate non-tensor or unknown types + } + + const auto* shape = node_arg.Shape(); + if (shape == nullptr) { + return -1; // Unknown shape + } + + auto elem_type = static_cast( + type_proto->tensor_type().elem_type()); + size_t element_size = GetElementSizeForConstantFolding(elem_type); + if (element_size == 0) { + return -1; // Unknown element type + } + + int64_t num_elements = EstimateTensorElementCount(*shape); + if (num_elements < 0) { + return -1; + } + + return SafeInt(num_elements) * static_cast(element_size); +} + +static int64_t EstimateUniqueOutputSizeInBytes(const Node& node) { + const auto& input_defs = node.InputDefs(); + if (input_defs.empty() || input_defs[0] == nullptr) { + return -1; + } + + const int64_t input_num_elements = input_defs[0]->Shape() != nullptr + ? EstimateTensorElementCount(*input_defs[0]->Shape()) + : -1; + if (input_num_elements < 0) { + return -1; + } + + const auto* input_type_proto = input_defs[0]->TypeAsProto(); + if (input_type_proto == nullptr || !utils::HasTensorType(*input_type_proto)) { + return -1; + } + + auto input_elem_type = static_cast( + input_type_proto->tensor_type().elem_type()); + const size_t input_element_size = GetElementSizeForConstantFolding(input_elem_type); + if (input_element_size == 0) { + return -1; + } + + SafeInt total_size = 0; + const auto& output_defs = node.OutputDefs(); + for (size_t output_idx = 0; output_idx < output_defs.size(); ++output_idx) { + const auto* output_def = output_defs[output_idx]; + if (!output_def->Exists()) { + continue; + } + + const size_t element_size = output_idx == 0 ? input_element_size : sizeof(int64_t); + total_size += SafeInt(input_num_elements) * static_cast(element_size); + } + + return total_size; +} + +static int64_t EstimateIdentityOutputSizeInBytes(const Node& node) { + const auto& input_defs = node.InputDefs(); + if (input_defs.empty() || input_defs[0] == nullptr) { + return -1; + } + + return EstimateTensorSizeInBytes(*input_defs[0]); +} + +// Estimate the total output size in bytes for a node using shape inference results. +// Returns -1 if the output size cannot be estimated (e.g., unknown shapes or types). +static int64_t EstimateNodeOutputSizeInBytes(const Node& node) { + if (node.OpType() == "Identity" && node.Domain().empty()) { + return EstimateIdentityOutputSizeInBytes(node); + } + + if (node.OpType() == "Unique" && node.Domain().empty()) { + return EstimateUniqueOutputSizeInBytes(node); + } + + SafeInt total_size = 0; + for (const auto* output_def : node.OutputDefs()) { + if (!output_def->Exists()) { + continue; + } + + const int64_t output_size = EstimateTensorSizeInBytes(*output_def); + if (output_size < 0) { + return -1; + } + + total_size += output_size; + } + + return total_size; +} + +// Get the configured max output size from session options, or use the default. +static int64_t GetConstantFoldingMaxOutputSize(const ConfigOptions& config_options) { + std::string max_size_str = config_options.GetConfigOrDefault( + kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes, + std::to_string(kDefaultConstantFoldingMaxOutputSizeInBytes)); + + int64_t max_size = 0; + if (!TryParseStringWithClassicLocale(max_size_str, max_size) || max_size < 0) { + max_size = kDefaultConstantFoldingMaxOutputSizeInBytes; + } + + return max_size; +} + Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { bool have_updated_nodes = false; GraphViewer graph_viewer(graph); @@ -151,6 +302,8 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, }; #endif + const int64_t max_output_size = GetConstantFoldingMaxOutputSize(config_options_); + for (NodeIndex i : order) { auto* node = graph.GetNode(i); if (!node || !AllowConstantFolding(*node)) { @@ -233,6 +386,35 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } } + // Check if the estimated output size exceeds the configured limit. + // This prevents malicious models from causing excessive memory allocation during constant folding. + if (max_output_size > 0) { + int64_t estimated_size = -1; + try { + estimated_size = EstimateNodeOutputSizeInBytes(*node); + } catch (const std::exception&) { + // SafeInt overflow means the size is astronomically large - definitely skip + LOGS(logger, WARNING) << "Integer overflow while estimating output size of " + << node->OpType() << " node '" << node->Name() + << "'. Skipping constant folding for this node."; + continue; + } + + if (estimated_size > max_output_size) { + LOGS(logger, WARNING) << "Skipping constant folding for " << node->OpType() + << " node '" << node->Name() + << "' because estimated output size (" << estimated_size + << " bytes) exceeds the limit (" << max_output_size << " bytes)."; + continue; + } + if (estimated_size < 0) { + LOGS(logger, INFO) << "Skipping constant folding for " << node->OpType() + << " node '" << node->Name() + << "' because output size could not be estimated before execution."; + continue; + } + } + #if !defined(DISABLE_SPARSE_TENSORS) // Create execution frame for executing constant nodes. OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, @@ -312,7 +494,25 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, #pragma warning(disable : 6387) #endif OpKernelContext op_kernel_context(&frame, kernel.get(), /*stream*/ nullptr, nullptr, logger); - ORT_RETURN_IF_ERROR(kernel->Compute(&op_kernel_context)); + + // Skip the current node if Compute fails so one bad constant-fold candidate does not abort + // the entire constant folding pass. + Status compute_status = Status::OK(); + try { + compute_status = kernel->Compute(&op_kernel_context); + } catch (const std::exception& ex) { + LOGS(logger, WARNING) << "Exception during constant folding of " << node->OpType() + << " node '" << node->Name() << "': " << ex.what() + << ". Skipping constant folding for this node."; + continue; + } + + if (!compute_status.IsOK()) { + LOGS(logger, WARNING) << "Failure during constant folding of " << node->OpType() + << " node '" << node->Name() << "': " << compute_status.ErrorMessage() + << ". Skipping constant folding for this node."; + continue; + } #ifdef _WIN32 #pragma warning(pop) #endif @@ -320,6 +520,33 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, std::vector fetches; ORT_RETURN_IF_ERROR(frame.GetOutputs(fetches)); + // Post-execution size check: verify actual output sizes don't exceed the limit. + // This catches cases where pre-execution shape inference couldn't determine the output size. + if (max_output_size > 0) { + SafeInt actual_total_size = 0; + bool size_exceeded = false; + try { + for (size_t fetch_idx = 0; fetch_idx < fetches.size(); ++fetch_idx) { + if (fetches[fetch_idx].IsAllocated() && fetches[fetch_idx].IsTensor()) { + const auto& tensor = fetches[fetch_idx].Get(); + actual_total_size += tensor.SizeInBytes(); + } + } + size_exceeded = actual_total_size > max_output_size; + } catch (const std::exception&) { + // SafeInt overflow means total size is astronomically large + size_exceeded = true; + } + + if (size_exceeded) { + LOGS(logger, WARNING) << "Skipping constant folding for " << node->OpType() + << " node '" << node->Name() + << "' because actual output size exceeds the limit (" + << max_output_size << " bytes)."; + continue; + } + } + // Go over all output node args and substitute them with the newly computed tensors, which will be // added to the graph as initializers. ORT_ENFORCE(fetches.size() == fetch_to_output_idx.size()); diff --git a/onnxruntime/core/providers/cpu/tensor/expand.cc b/onnxruntime/core/providers/cpu/tensor/expand.cc index b0c636281bc7a..6d299282f3e60 100644 --- a/onnxruntime/core/providers/cpu/tensor/expand.cc +++ b/onnxruntime/core/providers/cpu/tensor/expand.cc @@ -94,8 +94,8 @@ Status Expand::Compute(OpKernelContext* context) const { auto input_dim = input_dims_iter > -1 ? input_dims[input_dims_iter] : 1; auto output_dim = output_dims[output_dims_iter]; - input_count *= input_dim; - output_count *= output_dim; + input_count = SafeInt(input_count) * input_dim; + output_count = SafeInt(output_count) * output_dim; if (0 == input_count || 0 == output_count) { return Status::OK(); @@ -106,26 +106,26 @@ Status Expand::Compute(OpKernelContext* context) const { input_dim_group[onnxruntime::narrow(dim_group_start)] = input_count; output_dim_group[onnxruntime::narrow(dim_group_start)] = output_count; expand_dim_size[onnxruntime::narrow(dim_group_start)] = output_count / input_count / last_dim_size; - last_dim_size *= expand_dim_size[onnxruntime::narrow(dim_group_start)]; + last_dim_size = SafeInt(last_dim_size) * expand_dim_size[onnxruntime::narrow(dim_group_start)]; } } auto distribute_count = input_dim_group[onnxruntime::narrow(dim_group_start)] / input_dim_group[SafeInt(max_dims_size) - 1]; std::vector output_offsets(onnxruntime::narrow(distribute_count), 0); int64_t copy_len = input_dim_group[SafeInt(max_dims_size) - 1]; - auto copy_byte = copy_len * sizeof(T); + size_t copy_byte = SafeInt(copy_len) * sizeof(T); auto distribute_fn = [&](ptrdiff_t i_start, ptrdiff_t i_end) { for (auto i = i_start; i < i_end; i++) { - auto input_offset = i * copy_len; + int64_t input_offset = SafeInt(i) * copy_len; int64_t output_offset = 0; for (auto j = dim_group_start + 1, remains = input_offset; j < max_dims_size; ++j) { auto current_count = remains / input_dim_group[onnxruntime::narrow(j)]; - output_offset += current_count * output_dim_group[onnxruntime::narrow(j)]; + output_offset = SafeInt(output_offset) + SafeInt(current_count) * output_dim_group[onnxruntime::narrow(j)]; remains = remains % input_dim_group[onnxruntime::narrow(j)]; } // for j - memcpy(output_data + output_offset, input_data + input_offset, onnxruntime::narrow(copy_byte)); + memcpy(output_data + output_offset, input_data + input_offset, copy_byte); output_offsets[onnxruntime::narrow(i)] = output_offset; } // for i }; // distribute_fn diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1357ede82017a..591740be0263d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1521,6 +1521,201 @@ TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddl ASSERT_TRUE(dest_edges.find(2) != dest_edges.end()); } +// Test that constant folding respects the output size limit and skips nodes +// whose output would exceed it. This is a security measure against malicious +// models that could cause memory exhaustion during optimization. +TEST_F(GraphTransformationTests, ConstantFoldingOutputSizeLimit) { + // Build a model with an Expand node: scalar input [1.0] expanded by shape [1024, 1024]. + // Output = 1024*1024 * 4 bytes = 4 MB of float data. + // With a 1 MB limit, this should NOT be constant folded. + // With a 8 MB limit, this SHOULD be constant folded. + + auto build_model = [&](ModelTestBuilder& builder) { + auto* input_data = builder.MakeInitializer({1}, {1.0f}); + auto* shape_data = builder.MakeInitializer({2}, {1024, 1024}); + auto* output_arg = builder.MakeOutput(); + + builder.AddNode("Expand", {input_data, shape_data}, {output_arg}); + }; + + // Test 1: With a 1 MB limit, the Expand node should NOT be folded (output is ~4 MB). + { + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + // Expand should remain because output is too large + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 1); + return Status::OK(); + }; + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + ConfigOptions config_options; + // Set limit to 1 MB + ASSERT_STATUS_OK(config_options.AddConfigEntry( + kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes, "1048576")); + + ASSERT_STATUS_OK(TestGraphTransformer(build_model, 14, *logger_, + std::make_unique(*e.get(), false, config_options), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + + // Test 2: With an 8 MB limit, the Expand node SHOULD be folded (output is ~4 MB). + { + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + // Expand should be folded since output is within limit + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 0); + return Status::OK(); + }; + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + ConfigOptions config_options; + // Set limit to 8 MB + ASSERT_STATUS_OK(config_options.AddConfigEntry( + kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes, "8388608")); + + ASSERT_STATUS_OK(TestGraphTransformer(build_model, 14, *logger_, + std::make_unique(*e.get(), false, config_options), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } +} + +// Test that an explicitly configured constant folding output-size limit blocks +// folding a very large ConstantOfShape output. +TEST_F(GraphTransformationTests, ConstantFoldingConfiguredLimitBlocksLargeConstantOfShape) { + // Build a model with a ConstantOfShape node producing a huge output. + // Shape = [16384, 16384] = 268M elements * 4 bytes = 1 GB. + // Use an explicit 512 MB limit so the 1 GB output is not folded. + + auto build_model = [&](ModelTestBuilder& builder) { + auto* shape_data = builder.MakeInitializer({2}, {16384, 16384}); + auto* output_arg = builder.MakeOutput(); + + auto& node = builder.AddNode("ConstantOfShape", {shape_data}, {output_arg}); + // Default value is float 0.0 + ONNX_NAMESPACE::AttributeProto value_attr; + value_attr.set_name("value"); + value_attr.set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR); + auto* tensor = value_attr.mutable_t(); + tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor->add_dims(1); + tensor->add_float_data(0.0f); + node.AddAttributeProto(std::move(value_attr)); + }; + + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["ConstantOfShape"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + // ConstantOfShape should remain because output is too large (1 GB > 512 MB limit) + TEST_RETURN_IF_NOT(op_to_count["ConstantOfShape"] == 1); + return Status::OK(); + }; + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + ConfigOptions config_options; + // Set limit to 512 MB so the 1 GB output is blocked + ASSERT_STATUS_OK(config_options.AddConfigEntry( + kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes, "536870912")); + + ASSERT_STATUS_OK(TestGraphTransformer(build_model, 14, *logger_, + std::make_unique(*e.get(), false, config_options), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); +} + +// Test that small constant folding still works with the size limit. +TEST_F(GraphTransformationTests, ConstantFoldingSmallOutputAllowed) { + // Build a model with a small Expand: scalar -> [4, 4] = 16 * 4 = 64 bytes. + // This is well within even a small limit and should be folded. + + auto build_model = [&](ModelTestBuilder& builder) { + auto* input_data = builder.MakeInitializer({1}, {42.0f}); + auto* shape_data = builder.MakeInitializer({2}, {4, 4}); + auto* output_arg = builder.MakeOutput(); + + builder.AddNode("Expand", {input_data, shape_data}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + // Small Expand should be constant folded + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 0); + return Status::OK(); + }; + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + const ConfigOptions empty_config_options; + + ASSERT_STATUS_OK(TestGraphTransformer(build_model, 14, *logger_, + std::make_unique(*e.get(), false, empty_config_options), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); +} + +// Test that constant folding gracefully handles an Expand node whose output shape +// dimensions would cause integer overflow. This simulates the attack vector where +// a malicious model embeds constant initializers with extreme shape values, causing +// kernel Compute() to execute during graph optimization. The SafeInt-protected +// arithmetic in Expand::Compute (or TensorShape overflow) should be caught by the +// try/catch around Compute, and the node should NOT be constant folded. +TEST_F(GraphTransformationTests, ConstantFoldingExpandOverflowDimsSkipped) { + constexpr int64_t kLargeDim = int64_t(1) << 32; // 4294967296 + + auto build_model = [&](ModelTestBuilder& builder) { + auto* input_data = builder.MakeInitializer({1}, {1.0f}); + auto* shape_data = builder.MakeInitializer({2}, {kLargeDim, kLargeDim}); + auto* output_arg = builder.MakeOutput(); + + builder.AddNode("Expand", {input_data, shape_data}, {output_arg}); + }; + + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) -> Status { + auto op_to_count = CountOpsInGraph(graph); + // Expand should remain because the overflow prevents constant folding. + TEST_RETURN_IF_NOT(op_to_count["Expand"] == 1); + return Status::OK(); + }; + + std::unique_ptr e = + std::make_unique(CPUExecutionProviderInfo()); + const ConfigOptions empty_config_options; + + ASSERT_STATUS_OK(TestGraphTransformer(build_model, 14, *logger_, + std::make_unique(*e.get(), false, empty_config_options), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx"; From 70dcc0e22c2c2adabe694be37f39ecc1f80a251d Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 14 May 2026 14:32:47 -0700 Subject: [PATCH 02/11] Improve error reporting for pre-allocated outputs with a wrong shape (#28481) This pull request improves the handling of pre-allocated output buffers in ONNX Runtime, especially for models with dynamic output shapes. The changes ensure that when a user provides an output buffer whose shape does not match the computed output shape, the library returns a clear error message. Additionally, the error handling and testing around this scenario are strengthened. The most important changes are: **Pre-allocated Output Buffer Shape Validation:** * Enhanced the logic in `IExecutionFrame::GetOrCreateNodeOutputMLValue` to check if the shape of a pre-allocated output OrtValue matches the computed output shape. If there is a mismatch (typically due to dynamic shapes), the code now returns an explicit `INVALID_ARGUMENT` error with a detailed message, guiding the user to fix their usage. **API and Error Handling Improvements:** * Updated `OpKernelContext::OutputMLValue` to throw an exception with the detailed error message if output OrtValue allocation fails, ensuring that shape mismatches are surfaced clearly to the caller. * Added a catch block for `OnnxRuntimeException` in `sequential_executor.cc` to convert exceptions into proper `Status` objects, improving robustness and error propagation. **Testing and Regression Coverage:** * Added a comprehensive regression test (`ExecutionFrameTestInit.FetchWithMismatchedDynamicShapes`) to verify correct handling of pre-allocated outputs with mismatched shapes, including both error and success cases. This test covers scenarios where output buffers are reused across runs with different dynamic shapes, ensuring the new logic works as intended. --- onnxruntime/core/framework/execution_frame.cc | 53 ++++++- onnxruntime/core/framework/op_kernel.cc | 4 +- .../core/framework/sequential_executor.cc | 5 + .../test/framework/execution_frame_test.cc | 135 ++++++++++++++++++ 4 files changed, 188 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 8030690e7c92d..59efae597ceb2 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -154,19 +154,58 @@ Status IExecutionFrame::GetOrCreateNodeOutputMLValue(const int output_index, int p_ort_value = &all_values_[ort_value_idx]; if (p_ort_value->IsAllocated()) { - // already allocated. verify shape matches if tensor. + // The OrtValue at this index is already allocated. This happens when the caller provides + // a pre-allocated OrtValue as an output for Run(). IExecutionFrame::Init() populates + // all_values_ with caller-provided outputs (fetches) before any kernel executes. + // Each NodeArg in an ONNX graph has exactly one producer, so at this point no kernel + // in the current run could have written here — the only source is a value placed + // during Init(). + // + // When the shapes match, we reuse the caller's buffer (zero-cost for repeated runs + // with the same shapes). + // + // When they differ, it means the caller supplied an output OrtValue whose shape does + // not match what the kernel computed for this run. This typically happens when the + // caller reuses the output OrtValue array across Run() calls with different input + // shapes on a model with dynamic dimensions. The caller should either supply + // unallocated output OrtValues or ensure the pre-allocated shape matches. + // + // Pre-run validation (ValidateInputsOutputs in inference_session.cc) catches + // structural mismatches (element type, rank, fixed dimensions) before execution + // begins. Only dynamic dimension differences reach this point, since the actual + // shape is only known once the kernel computes it. + bool shape_matched = true; + if (p_ort_value->IsTensor()) { + ORT_RETURN_IF_NOT(shape != nullptr, "shape must not be null for tensor output that is already allocated"); const Tensor& tensor = p_ort_value->Get(); - ORT_ENFORCE(shape && tensor.Shape() == *shape, - "OrtValue shape verification failed. Current shape:", tensor.Shape(), - " Requested shape:", shape ? shape->ToString() : "null"); + shape_matched = (tensor.Shape() == *shape); } else if (p_ort_value->IsSparseTensor()) { #if !defined(DISABLE_SPARSE_TENSORS) + ORT_RETURN_IF_NOT(shape != nullptr, "shape must not be null for sparse tensor output that is already allocated"); const SparseTensor& sp_tensor = p_ort_value->Get(); - ORT_ENFORCE(shape && sp_tensor.DenseShape() == *shape, - "OrtValue shape verification failed. Current shape:", sp_tensor.DenseShape(), - " Requested shape:", shape ? shape->ToString() : "null"); + shape_matched = (sp_tensor.DenseShape() == *shape); +#endif + } + + if (!shape_matched) { + const TensorShape& existing_shape = p_ort_value->IsTensor() + ? p_ort_value->Get().Shape() +#if !defined(DISABLE_SPARSE_TENSORS) + : p_ort_value->Get().DenseShape(); +#else + : *shape; // unreachable, but satisfies compiler #endif + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "The output OrtValue provided for output '", + node.OutputDefs()[output_index]->Name(), + "' of node '", node.Name(), + "' (", node.OpType(), ") has shape ", existing_shape, + " but the computed output shape for this run is ", *shape, + ". When calling Run() with pre-allocated output OrtValues on a model " + "with dynamic output shapes, either supply unallocated output OrtValues " + "or ensure the pre-allocated shapes match the expected output shapes " + "for each run."); } } else { // shape is nullptr for traditional ML output values diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index 212ce9c5069ea..287127a2a2ec9 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -80,7 +80,7 @@ OrtValue* OpKernelContext::OutputMLValue(int index, const TensorShape& shape) { OrtValue* p_ml_value = nullptr; Status status = execution_frame_->GetOrCreateNodeOutputMLValue(index, GetOutputArgIndex(index), &shape, p_ml_value, kernel_->Node()); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + ORT_THROW_IF_ERROR(status); return p_ml_value; } @@ -126,7 +126,7 @@ OrtValue* OpKernelContext::GetOrCreateOutputMLValue(int index) { auto output_arg_index = GetOutputArgIndex(index); OrtValue* value = nullptr; auto status = execution_frame_->GetOrCreateNodeOutputMLValue(index, output_arg_index, nullptr, value, kernel_->Node()); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + ORT_THROW_IF_ERROR(status); return value; } diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 2103183f8f452..c95d2d0d5ab8e 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -594,6 +594,11 @@ onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx, #endif #endif } + ORT_CATCH(const OnnxRuntimeException& ort_ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ort_ex.Category(), ort_ex.Code(), ort_ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 4808d2bfb447a..bbfe22a2d4bc7 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -559,6 +559,141 @@ TEST(ExecutionFrameTestInit, InitializerAsOutput) { } } +// Test that when a caller provides pre-allocated output OrtValues whose shapes don't match +// the computed output shapes, ORT returns a clear INVALID_ARGUMENT error with an actionable +// message. This is the scenario described in GitHub issue #28359. +// +// The caller's copy of the old output remains valid (OrtValue uses shared_ptr internally). +// Pre-run validation (ValidateInputsOutputs) catches structural mismatches (wrong type, rank, +// fixed dims). What remains at kernel execution time is purely dynamic dimension differences. +TEST(ExecutionFrameTestInit, FetchWithMismatchedDynamicShapes) { + // Regression test for https://github.com/microsoft/onnxruntime/issues/28359 + // Verifies that Run() returns a clear INVALID_ARGUMENT error when the caller provides + // a pre-allocated output OrtValue whose shape doesn't match the kernel's computed shape, + // and that pre-allocated outputs with matching shapes are used correctly. + SessionOptions so; + so.enable_mem_pattern = true; + + InferenceSession session(so, GetEnvironment()); + + // Use Relu which preserves shape: output shape == input shape + onnxruntime::Model model("dynamic_shape_test", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), + {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + // dynamic shape: {N, M} + auto* input_shape = float_tensor.mutable_tensor_type()->mutable_shape(); + input_shape->add_dim()->set_dim_param("N"); + input_shape->add_dim()->set_dim_param("M"); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &float_tensor); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &float_tensor); + graph.AddNode("relu", "Relu", "relu", {&input_arg}, {&output_arg}); + graph.SetInputs({&input_arg}); + graph.SetOutputs({&output_arg}); + ASSERT_STATUS_OK(graph.Resolve()); + + std::string serialized; + ASSERT_TRUE(model.ToProto().SerializeToString(&serialized)); + std::istringstream model_stream(serialized); + ASSERT_STATUS_OK(session.Load(model_stream)); + ASSERT_STATUS_OK(session.Initialize()); + + RunOptions ro; + auto allocator = test::AllocatorManager::Instance().GetAllocator(CPU); + + // Run 1: pre-allocate the output buffer with the correct shape {2, 3}. + // The kernel should write into this buffer directly. + std::vector input_data_1(6, 1.0f); + OrtValue input_1; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({2, 3}), input_data_1.data(), + allocator->Info(), input_1); + + OrtValue preallocated_output; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({2, 3}), allocator, preallocated_output); + const void* preallocated_buffer = preallocated_output.Get().DataRaw(); + + std::vector results = {preallocated_output}; + ASSERT_STATUS_OK(session.Run(ro, + AsSpan({std::string("X")}), AsSpan({input_1}), + AsSpan({std::string("Y")}), &results, nullptr)); + + ASSERT_EQ(results.size(), 1u); + ASSERT_TRUE(results[0].IsTensor()); + EXPECT_EQ(results[0].Get().Shape(), TensorShape({2, 3})); + + // The output should be in the pre-allocated buffer since shapes matched + EXPECT_EQ(results[0].Get().DataRaw(), preallocated_buffer); + + // Verify Run 1 output correctness (Relu of all 1.0f = all 1.0f) + auto run1_data = results[0].Get().DataAsSpan(); + for (float v : run1_data) { + EXPECT_EQ(v, 1.0f); + } + + // Run 2: different input shape {4, 5}, but results still contains the {2,3} output + // from Run 1. Run() should fail with INVALID_ARGUMENT because the pre-allocated + // output shape doesn't match the computed output shape. + std::vector input_data_2(20, 2.0f); + OrtValue input_2; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 5}), input_data_2.data(), + allocator->Info(), input_2); + + auto status = session.Run(ro, + AsSpan({std::string("X")}), AsSpan({input_2}), + AsSpan({std::string("Y")}), &results, nullptr); + + ASSERT_FALSE(status.IsOK()); + ASSERT_EQ(status.Code(), common::StatusCode::INVALID_ARGUMENT); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("pre-allocated")); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("{2,3}")); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("{4,5}")); + + // Run 3: clear the output and retry — should succeed + results = {}; + ASSERT_STATUS_OK(session.Run(ro, + AsSpan({std::string("X")}), AsSpan({input_2}), + AsSpan({std::string("Y")}), &results, nullptr)); + + ASSERT_EQ(results.size(), 1u); + ASSERT_TRUE(results[0].IsTensor()); + EXPECT_EQ(results[0].Get().Shape(), TensorShape({4, 5})); + + // Verify Run 3 output correctness (Relu of all 2.0f = all 2.0f) + auto run3_data = results[0].Get().DataAsSpan(); + for (float v : run3_data) { + EXPECT_EQ(v, 2.0f); + } + + // Run 4: same shape {4, 5} with results carrying over — shapes match, should reuse + const void* run3_buffer = results[0].Get().DataRaw(); + + std::vector input_data_4(20, 3.0f); + OrtValue input_4; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 5}), input_data_4.data(), + allocator->Info(), input_4); + + ASSERT_STATUS_OK(session.Run(ro, + AsSpan({std::string("X")}), AsSpan({input_4}), + AsSpan({std::string("Y")}), &results, nullptr)); + + ASSERT_EQ(results.size(), 1u); + ASSERT_TRUE(results[0].IsTensor()); + EXPECT_EQ(results[0].Get().Shape(), TensorShape({4, 5})); + + // Same shape — buffer should be reused + EXPECT_EQ(results[0].Get().DataRaw(), run3_buffer); + + // Verify Run 4 output correctness (Relu of all 3.0f = all 3.0f) + auto run4_data = results[0].Get().DataAsSpan(); + for (float v : run4_data) { + EXPECT_EQ(v, 3.0f); + } +} + #if !defined(DISABLE_SPARSE_TENSORS) TEST(ExecutionFrameTestInit, SparseInitializerAsOutput) { constexpr std::array dense_shape{3, 3}; From a8ba94a147eea82b49409fadcf183a9be45e60f0 Mon Sep 17 00:00:00 2001 From: Orlaith Monahan Date: Thu, 14 May 2026 23:02:20 +0100 Subject: [PATCH 03/11] [MLAS] Add an NHWC implementation of convolution to avoid transposes (#26834) * Modification to the CPU EP to specify channels_last when data format is NWHC * Added a FusedNhwcConv kernel * Implementation of the kernel in mlas * Added compiler guards so it is inly used with KleidiAi (for now, can be removed if needed) * Added unittests ### Description Currently OnnxRT supports NCHW as a default datalayout. For optimisations and kernels that operate better in NHWC layout, or where the datalayout is NHWC in the first place Transposes are added around the layers. This patch seeks to eliminate them in cases of convolutions where it would cause a performance decrease. ### Motivation and Context KleidiAi specific implementation of this feature. Only supports convolutions, DepthWise to follow. Currently a little strict with the filters as a result. --------- Signed-off-by: Orlaith Monahan Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cmake/CMakeLists.txt | 1 + cmake/onnxruntime_unittests.cmake | 15 +- docs/ContribOperators.md | 13 +- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 6 + onnxruntime/contrib_ops/cpu/fused_conv.cc | 13 + .../kernel_type_str_resolver_utils.cc | 696 +++++++++--------- .../kernel_type_str_resolver_utils.h | 7 - .../graph/contrib_ops/nhwc_schema_defs.cc | 28 +- onnxruntime/core/mlas/inc/mlas.h | 19 + onnxruntime/core/mlas/lib/convolve.cpp | 83 ++- .../mlas/lib/kleidiai/convolve_kleidiai.cpp | 262 +++---- .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 23 +- onnxruntime/core/mlas/lib/mlasi.h | 2 + .../core/optimizer/conv_activation_fusion.cc | 5 +- .../core/optimizer/conv_add_act_fusion.cc | 17 +- .../core/optimizer/graph_transformer_utils.cc | 4 +- ...out_transformation_potentially_added_ops.h | 1 + .../core/optimizer/nchwc_transformer.cc | 17 + .../core/optimizer/nhwc_transformer.cc | 314 +++++++- onnxruntime/core/optimizer/nhwc_transformer.h | 9 +- onnxruntime/core/providers/cpu/nn/conv.cc | 268 ++++++- onnxruntime/core/providers/cpu/nn/conv.h | 36 +- onnxruntime/core/util/math_cpu.cc | 1 + .../test/contrib_ops/fused_conv_test.cc | 139 ++++ .../kernel_type_str_resolver_utils_test.cc | 45 +- .../test/framework/ort_model_only_test.cc | 36 +- .../internal_testing_tests.cc | 72 +- onnxruntime/test/mlas/bench/bench_sconv.cpp | 2 + .../test/mlas/bench/bench_transcendental.cpp | 2 + onnxruntime/test/mlas/unittest/test_conv2d.h | 172 +---- .../test/optimizer/conv_add_act_test.cc | 3 +- .../fuse_initializers_transformer_test.cc | 9 + .../test/optimizer/nchwc_optimizer_test.cc | 31 + .../test/optimizer/nhwc_transformer_test.cc | 404 +++++++++- .../optimizer/transpose_optimizer_test.cc | 60 ++ 35 files changed, 2036 insertions(+), 779 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 83d1751e55543..dde6d44919092 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -638,6 +638,7 @@ else() check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE) check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS) check_cxx_compiler_flag(-Wcharacter-conversion HAS_CHARACTER_CONVERSION) + check_cxx_compiler_flag(-Wno-error=character-conversion HAS_NO_ERROR_CHARACTER_CONVERSION) check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION) check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 7f361aa63921f..a061858fa068f 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -50,6 +50,13 @@ function(filter_test_srcs test_srcs_var) endfunction() set(disabled_warnings) + +function(onnxruntime_disable_gtest_character_conversion_as_error target_name) + if (HAS_NO_ERROR_CHARACTER_CONVERSION) + target_compile_options(${target_name} PRIVATE "$<$:-Wno-error=character-conversion>") + endif() +endfunction() + function(AddTest) cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS;TEST_ARGS" ${ARGN}) list(REMOVE_DUPLICATES _UT_SOURCES) @@ -170,9 +177,7 @@ function(AddTest) if (${HAS_NOERROR}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") endif() - if (${HAS_CHARACTER_CONVERSION}) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=character-conversion>") - endif() + onnxruntime_disable_gtest_character_conversion_as_error(${_UT_TARGET}) endif() set(TEST_ARGS ${_UT_TEST_ARGS}) @@ -847,9 +852,7 @@ if(MSVC) "$<$>:/wd6326>") else() target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) - if (HAS_CHARACTER_CONVERSION) - target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Wno-error=character-conversion>") - endif() + onnxruntime_disable_gtest_character_conversion_as_error(onnxruntime_test_utils) endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS}) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index fb6a4eb22a872..ca072113e832d 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3569,7 +3569,6 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.NhwcFusedConv** NhwcFusedConv is a Conv operator with optional activation and add operators fused in. - Only has fp16 implementation as of 2023/04/15. #### Version @@ -3600,26 +3599,26 @@ This version of the operator has been available since version 1 of the 'com.micr
X : T
-
+
Input activation tensor in channels-last layout. For 2D convolution this is [N, H, W, C], where N is batch size, H/W are spatial dimensions, and C is the number of input channels.
W : T
-
+
Convolution weight tensor in the standard ONNX Conv filter layout [M, C/group, kH, kW], where M is the number of output channels.
B (optional) : T
-
+
Optional 1D bias tensor of shape [M].
Z (optional) : T
-
Tensor to be added to the output, must be the same shape and format as the output tensor.
+
Optional residual/add tensor in the same channels-last layout and shape as the output tensor Y. For 2D convolution this is [N, out_H, out_W, M].
#### Outputs
Y : T
-
+
Output tensor in channels-last layout. For 2D convolution this is [N, out_H, out_W, M], where M is the number of output channels.
#### Type Constraints
-
T : tensor(float16)
+
T : tensor(float16), tensor(float)
Constrain input and output types to float tensors
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index cc652ed52ee72..0749457f5a182 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -20,6 +20,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); +#ifdef USE_KLEIDIAI +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NhwcFusedConv); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); #if !defined(DISABLE_GENERATION_OPS) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); @@ -313,6 +316,9 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef USE_KLEIDIAI + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, #if !defined(DISABLE_GENERATION_OPS) BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/fused_conv.cc b/onnxruntime/contrib_ops/cpu/fused_conv.cc index 5374222dbabcc..76451232421da 100644 --- a/onnxruntime/contrib_ops/cpu/fused_conv.cc +++ b/onnxruntime/contrib_ops/cpu/fused_conv.cc @@ -26,5 +26,18 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( .TypeConstraint("T", DataTypeImpl::GetTensorType()), FusedConvFloat); +#ifdef USE_KLEIDIAI +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + NhwcFusedConv, + 1, + float, + KernelDefBuilder() + // Allow the optional "sum" input (index 3) to be reused as the output buffer (index 0), + // consistent with the FusedConv kernel registration. + .MayInplace(3, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + FusedConvFloat); +#endif + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc index 65ad60f478d78..27312a2f0ed37 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.cc @@ -9,7 +9,6 @@ #include "core/common/common.h" #include "core/flatbuffers/schema/ort.fbs.h" -#include "core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h" namespace onnxruntime::kernel_type_str_resolver_utils { @@ -17,10 +16,6 @@ static constexpr auto* kStandaloneKernelTypeStrResolverFileIdentifier = "ktsr"; #if !defined(ORT_MINIMAL_BUILD) -gsl::span GetLayoutTransformationRequiredOpIdentifiers() { - return kLayoutTransformationPotentiallyAddedOps; -} - Status SaveKernelTypeStrResolverToBuffer(const KernelTypeStrResolver& kernel_type_str_resolver, flatbuffers::DetachedBuffer& buffer, gsl::span& buffer_span) { flatbuffers::FlatBufferBuilder builder; @@ -53,368 +48,377 @@ Status AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(KernelTypeStrRe // clang-format off constexpr uint8_t kLayoutTransformationRequiredOpsKernelTypeStrResolverBytes[] = { 0x10, 0x00, 0x00, 0x00, 0x6b, 0x74, 0x73, 0x72, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0x7c, 0x15, 0x00, 0x00, - 0xf8, 0x0a, 0x00, 0x00, 0x64, 0x04, 0x00, 0x00, 0x44, 0x14, 0x00, 0x00, 0xc0, 0x09, 0x00, 0x00, - 0xbc, 0x14, 0x00, 0x00, 0x78, 0x0b, 0x00, 0x00, 0x08, 0x0c, 0x00, 0x00, 0xbc, 0x13, 0x00, 0x00, - 0x44, 0x08, 0x00, 0x00, 0x94, 0x0d, 0x00, 0x00, 0xd8, 0x0d, 0x00, 0x00, 0xa0, 0x07, 0x00, 0x00, - 0xe8, 0x06, 0x00, 0x00, 0x2c, 0x0a, 0x00, 0x00, 0x34, 0x0d, 0x00, 0x00, 0xdc, 0x07, 0x00, 0x00, - 0x70, 0x12, 0x00, 0x00, 0x88, 0x06, 0x00, 0x00, 0x50, 0x0e, 0x00, 0x00, 0x4c, 0x01, 0x00, 0x00, - 0xa0, 0x0c, 0x00, 0x00, 0x20, 0x11, 0x00, 0x00, 0x8c, 0x10, 0x00, 0x00, 0x38, 0x02, 0x00, 0x00, - 0x88, 0x0f, 0x00, 0x00, 0x3c, 0x0f, 0x00, 0x00, 0x64, 0x05, 0x00, 0x00, 0x84, 0x11, 0x00, 0x00, - 0xf4, 0x06, 0x00, 0x00, 0xf0, 0x08, 0x00, 0x00, 0x10, 0x0c, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, - 0x0c, 0x13, 0x00, 0x00, 0xc8, 0x0d, 0x00, 0x00, 0x44, 0x08, 0x00, 0x00, 0x20, 0x0a, 0x00, 0x00, - 0xd4, 0x11, 0x00, 0x00, 0x84, 0x08, 0x00, 0x00, 0xe8, 0x05, 0x00, 0x00, 0x8c, 0x15, 0x00, 0x00, - 0xd8, 0x0f, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x80, 0x01, 0x00, 0x00, 0x68, 0x05, 0x00, 0x00, - 0x84, 0x0e, 0x00, 0x00, 0x8c, 0x04, 0x00, 0x00, 0x30, 0x04, 0x00, 0x00, 0x68, 0x02, 0x00, 0x00, - 0x44, 0x12, 0x00, 0x00, 0x74, 0xea, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, - 0xa0, 0xea, 0xff, 0xff, 0x54, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xf0, 0xea, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xbc, 0xea, 0xff, 0xff, - 0x54, 0x15, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xaa, 0xea, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa4, 0xea, 0xff, 0xff, - 0xe0, 0xea, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x35, 0x00, 0x08, 0xeb, 0xff, 0xff, 0x08, 0x15, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00, 0x3c, 0x02, 0x00, 0x00, + 0x88, 0x0a, 0x00, 0x00, 0x60, 0x15, 0x00, 0x00, 0x6c, 0x13, 0x00, 0x00, 0x64, 0x11, 0x00, 0x00, + 0x7c, 0x04, 0x00, 0x00, 0xe4, 0x0e, 0x00, 0x00, 0xdc, 0x13, 0x00, 0x00, 0xb0, 0x02, 0x00, 0x00, + 0x10, 0x0b, 0x00, 0x00, 0xf8, 0x14, 0x00, 0x00, 0x40, 0x06, 0x00, 0x00, 0xa0, 0x08, 0x00, 0x00, + 0xa0, 0x10, 0x00, 0x00, 0x08, 0x0a, 0x00, 0x00, 0xb4, 0x01, 0x00, 0x00, 0xe4, 0x04, 0x00, 0x00, + 0xb0, 0x0b, 0x00, 0x00, 0x40, 0x10, 0x00, 0x00, 0x14, 0x0e, 0x00, 0x00, 0xb0, 0x03, 0x00, 0x00, + 0xe4, 0x02, 0x00, 0x00, 0x34, 0x0c, 0x00, 0x00, 0x18, 0x12, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, + 0x2c, 0x0f, 0x00, 0x00, 0x08, 0x05, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x48, 0x14, 0x00, 0x00, + 0x44, 0x05, 0x00, 0x00, 0x70, 0x15, 0x00, 0x00, 0xa4, 0x0f, 0x00, 0x00, 0xe8, 0x07, 0x00, 0x00, + 0x70, 0x09, 0x00, 0x00, 0x1c, 0x01, 0x00, 0x00, 0x9c, 0x10, 0x00, 0x00, 0xb0, 0x0b, 0x00, 0x00, + 0xd8, 0x13, 0x00, 0x00, 0x1c, 0x03, 0x00, 0x00, 0x84, 0x05, 0x00, 0x00, 0x08, 0x09, 0x00, 0x00, + 0x50, 0x0d, 0x00, 0x00, 0x10, 0x06, 0x00, 0x00, 0x60, 0x12, 0x00, 0x00, 0x58, 0x11, 0x00, 0x00, + 0xd4, 0x0c, 0x00, 0x00, 0x64, 0x08, 0x00, 0x00, 0x4c, 0x0c, 0x00, 0x00, 0xdc, 0x0a, 0x00, 0x00, + 0x60, 0x06, 0x00, 0x00, 0x9c, 0x15, 0x00, 0x00, 0xfc, 0xe9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x20, 0xea, 0xff, 0xff, + 0x34, 0x15, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0e, 0xea, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x48, 0xea, 0xff, 0xff, + 0x44, 0xea, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x32, 0x34, 0x00, 0x00, 0x78, 0xea, 0xff, 0xff, 0xa4, 0x15, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x54, 0xea, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x94, 0xea, 0xff, 0xff, 0x50, 0x15, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xb0, 0xea, 0xff, 0xff, 0xac, 0xea, 0xff, 0xff, 0x40, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xf6, 0xea, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xf0, 0xea, 0xff, 0xff, 0x2c, 0xeb, 0xff, 0xff, - 0xc8, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x7c, 0xeb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x48, 0xeb, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, - 0x34, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, - 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x7c, 0xeb, 0xff, 0xff, - 0x64, 0x13, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x58, 0xeb, 0xff, 0xff, 0x94, 0xeb, 0xff, 0xff, 0xf0, 0x0c, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe4, 0xeb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0xb0, 0xeb, 0xff, 0xff, 0x0c, 0x13, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9e, 0xeb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x0c, 0xec, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xd8, 0xeb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, - 0x33, 0x00, 0x00, 0x00, 0x04, 0xec, 0xff, 0xff, 0xf0, 0x0f, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x54, 0xec, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x20, 0xec, 0xff, 0xff, 0xf0, 0x13, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x0e, 0xec, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x08, 0xec, 0xff, 0xff, 0x44, 0xec, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, - 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, - 0x65, 0x61, 0x72, 0x3a, 0x32, 0x34, 0x00, 0x00, 0x78, 0xec, 0xff, 0xff, 0x94, 0x12, 0x00, 0x00, + 0x9a, 0xea, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x94, 0xea, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0xd4, 0xea, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0xfc, 0xea, 0xff, 0xff, 0x58, 0x14, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x66, 0xec, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xd4, 0xec, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0xa0, 0xec, 0xff, 0xff, 0x40, 0x12, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x7c, 0xec, 0xff, 0xff, 0xb8, 0xec, 0xff, 0xff, 0x04, 0x12, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0xed, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0xd4, 0xec, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x07, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, - 0x58, 0x00, 0x00, 0x00, 0xb8, 0x00, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00, - 0x1b, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, - 0x74, 0x3a, 0x51, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, - 0x20, 0xed, 0xff, 0xff, 0x9c, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x74, 0xed, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, - 0x7c, 0xed, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, 0x48, 0xed, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa0, 0xed, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x6c, 0xed, 0xff, 0xff, - 0x74, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xc0, 0xed, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x54, 0xed, 0xff, 0xff, - 0x90, 0xed, 0xff, 0xff, 0xd8, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xe0, 0xed, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xac, 0xed, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0xee, 0xff, 0xff, - 0x04, 0x00, 0x00, 0x00, 0xd4, 0xed, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x24, 0xee, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, - 0xf0, 0xed, 0xff, 0xff, 0x1c, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xde, 0xed, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x4c, 0xee, 0xff, 0xff, 0x07, 0x00, 0x00, 0x00, 0x18, 0xee, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, - 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, - 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x4c, 0xee, 0xff, 0xff, - 0x94, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xa0, 0xee, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x34, 0xee, 0xff, 0xff, - 0x70, 0xee, 0xff, 0xff, 0x4c, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x5e, 0xee, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xcc, 0xee, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x98, 0xee, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, - 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, - 0x77, 0x63, 0x4d, 0x61, 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0xcc, 0xee, 0xff, 0xff, - 0x44, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xba, 0xee, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb4, 0xee, 0xff, 0xff, - 0xf0, 0xee, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x30, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, - 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, - 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, - 0x30, 0xef, 0xff, 0xff, 0xb0, 0x0f, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x84, 0xef, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0x18, 0xef, 0xff, 0xff, 0x54, 0xef, 0xff, 0xff, 0x68, 0x0f, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xef, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xb0, 0xef, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x7c, 0xef, 0xff, 0xff, + 0xea, 0xea, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x24, 0xeb, 0xff, 0xff, 0x20, 0xeb, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x32, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x48, 0xeb, 0xff, 0xff, 0xf8, 0x0e, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x36, 0xeb, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x70, 0xeb, 0xff, 0xff, 0x6c, 0xeb, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, + 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, + 0xa4, 0xeb, 0xff, 0xff, 0x60, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x8e, 0xeb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xc0, 0xeb, 0xff, 0xff, + 0x8c, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x9c, 0xeb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xdc, 0xeb, 0xff, 0xff, 0x78, 0x13, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xbc, 0xeb, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x04, 0xec, 0xff, 0xff, 0x00, 0xec, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, + 0x31, 0x31, 0x00, 0x00, 0x28, 0xec, 0xff, 0xff, 0x38, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0xec, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x44, 0xec, 0xff, 0xff, 0x10, 0x13, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0xec, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x6c, 0xec, 0xff, 0xff, 0x68, 0xec, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x31, 0x39, 0x00, 0x00, 0x98, 0xec, 0xff, 0xff, 0x84, 0x13, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x86, 0xec, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x80, 0xec, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xc0, 0xec, 0xff, 0xff, + 0x24, 0x13, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa0, 0xec, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xe8, 0xec, 0xff, 0xff, + 0xe4, 0xec, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x32, 0x35, 0x00, 0x00, 0x00, 0x0c, 0xed, 0xff, 0xff, 0x48, 0x12, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xfa, 0xec, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x34, 0xed, 0xff, 0xff, 0x30, 0xed, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, + 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, + 0x64, 0xed, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x40, 0xed, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x80, 0xed, 0xff, 0xff, + 0x9c, 0x12, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6e, 0xed, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x68, 0xed, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0xa8, 0xed, 0xff, 0xff, 0x3c, 0x12, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc4, 0xed, 0xff, 0xff, 0xc0, 0xed, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, + 0x64, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, + 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x34, + 0x00, 0x00, 0x00, 0x00, 0xf8, 0xed, 0xff, 0xff, 0xf4, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe2, 0xed, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x14, 0xee, 0xff, 0xff, 0xd0, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf4, 0xed, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x3c, 0xee, 0xff, 0xff, 0x38, 0xee, 0xff, 0xff, 0xe4, 0x11, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0xee, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x54, 0xee, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x3a, 0x32, 0x33, 0x00, 0x00, 0x00, 0x00, 0x7c, 0xee, 0xff, 0xff, 0xc4, 0x0b, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6a, 0xee, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa4, 0xee, 0xff, 0xff, 0xa0, 0xee, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, - 0xa0, 0xef, 0xff, 0xff, 0x70, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x8e, 0xef, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x88, 0xef, 0xff, 0xff, 0xc4, 0xef, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x34, 0x00, 0x00, 0x00, - 0xf0, 0xef, 0xff, 0xff, 0x20, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xde, 0xef, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xd8, 0xef, 0xff, 0xff, 0x14, 0xf0, 0xff, 0xff, 0xe0, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x64, 0xf0, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x30, 0xf0, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, - 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x58, 0xf0, 0xff, 0xff, 0xb8, 0x0f, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x00, 0x00, + 0xc4, 0xee, 0xff, 0xff, 0x90, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb2, 0xee, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xec, 0xee, 0xff, 0xff, 0xe8, 0xee, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x10, 0xef, 0xff, 0xff, + 0x6c, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xec, 0xee, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x2c, 0xef, 0xff, 0xff, 0x28, 0x10, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x46, 0xf0, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0xf0, 0xff, 0xff, 0x7c, 0xf0, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x32, 0x35, - 0x00, 0x00, 0x00, 0x00, 0xa4, 0xf0, 0xff, 0xff, 0xec, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x92, 0xf0, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x8c, 0xf0, 0xff, 0xff, 0xc8, 0xf0, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, - 0xf0, 0xf0, 0xff, 0xff, 0xa0, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xde, 0xf0, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xd8, 0xf0, 0xff, 0xff, 0x14, 0xf1, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x3c, 0xf1, 0xff, 0xff, - 0xd4, 0x0e, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x2a, 0xf1, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x24, 0xf1, 0xff, 0xff, - 0x60, 0xf1, 0xff, 0xff, 0x94, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xb0, 0xf1, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x7c, 0xf1, 0xff, 0xff, + 0x1a, 0xef, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x54, 0xef, 0xff, 0xff, 0x50, 0xef, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, - 0x00, 0x00, 0x00, 0x00, 0xa4, 0xf1, 0xff, 0xff, 0xec, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x92, 0xf1, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x8c, 0xf1, 0xff, 0xff, 0xc8, 0xf1, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x78, 0xef, 0xff, 0xff, 0xdc, 0x0f, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x66, 0xef, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xa0, 0xef, 0xff, 0xff, 0x9c, 0xef, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x32, 0x33, 0x00, 0x00, 0x00, 0x00, - 0xf0, 0xf1, 0xff, 0xff, 0xa0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xde, 0xf1, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xd8, 0xf1, 0xff, 0xff, 0x14, 0xf2, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, - 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x3c, 0xf2, 0xff, 0xff, - 0xa0, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x8c, 0xf2, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x58, 0xf2, 0xff, 0xff, 0xb8, 0x0d, 0x00, 0x00, + 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, + 0xc4, 0xef, 0xff, 0xff, 0x90, 0x0f, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb2, 0xef, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xec, 0xef, 0xff, 0xff, 0xe8, 0xef, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, + 0x14, 0xf0, 0xff, 0xff, 0x68, 0x0f, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xf0, 0xef, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x30, 0xf0, 0xff, 0xff, + 0x24, 0x0f, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x1e, 0xf0, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x58, 0xf0, 0xff, 0xff, + 0x54, 0xf0, 0xff, 0xff, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x98, 0x00, 0x00, 0x00, 0xb8, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0xd4, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x4c, + 0x69, 0x6e, 0x65, 0x61, 0x72, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, 0xa0, 0xf0, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x54, 0x34, 0x00, 0x00, 0x84, 0xf0, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, + 0xc4, 0xf0, 0xff, 0xff, 0x50, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xa0, 0xf0, 0xff, 0xff, 0x06, 0x00, 0x00, 0x00, 0xe0, 0xf0, 0xff, 0xff, + 0x6c, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xbc, 0xf0, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xfc, 0xf0, 0xff, 0xff, 0xe8, 0x0e, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x46, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0xf2, 0xff, 0xff, 0x7c, 0xf2, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, - 0x31, 0x00, 0x00, 0x00, 0xa4, 0xf2, 0xff, 0xff, 0x6c, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x92, 0xf2, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x8c, 0xf2, 0xff, 0xff, 0xc8, 0xf2, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, 0x35, 0x00, 0x00, 0x00, - 0xf0, 0xf2, 0xff, 0xff, 0x20, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xde, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xd8, 0xf2, 0xff, 0xff, 0x14, 0xf3, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xdc, 0xf0, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x24, 0xf1, 0xff, 0xff, 0x20, 0xf1, 0xff, 0xff, + 0xfc, 0x0e, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0xf1, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0x08, 0xf1, 0xff, 0xff, + 0x03, 0x00, 0x00, 0x00, 0x48, 0xf1, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x77, 0x5f, 0x73, 0x63, + 0x61, 0x6c, 0x65, 0x00, 0x30, 0xf1, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x70, 0xf1, 0xff, 0xff, + 0x7c, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x5e, 0xf1, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x58, 0xf1, 0xff, 0xff, + 0x07, 0x00, 0x00, 0x00, 0x98, 0xf1, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x33, 0x00, 0x3c, 0xf3, 0xff, 0xff, - 0xb8, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x8c, 0xf3, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x58, 0xf3, 0xff, 0xff, 0xb8, 0x0c, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x35, 0x00, 0xc0, 0xf1, 0xff, 0xff, + 0xbc, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x9c, 0xf1, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xdc, 0xf1, 0xff, 0xff, 0x78, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x46, 0xf3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x40, 0xf3, 0xff, 0xff, 0x7c, 0xf3, 0xff, 0xff, - 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, - 0x64, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, - 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x33, - 0x00, 0x00, 0x00, 0x00, 0xb4, 0xf3, 0xff, 0xff, 0x58, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9e, 0xf3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xd0, 0xf3, 0xff, 0xff, 0x10, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x24, 0xf4, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0xb8, 0xf3, 0xff, 0xff, 0xf4, 0xf3, 0xff, 0xff, 0xc8, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x44, 0xf4, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x10, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0x38, 0xf4, 0xff, 0xff, 0x58, 0x08, 0x00, 0x00, + 0xca, 0xf1, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x04, 0xf2, 0xff, 0xff, 0x00, 0xf2, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x34, + 0x00, 0x00, 0x00, 0x00, 0x28, 0xf2, 0xff, 0xff, 0x18, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x16, 0xf2, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x50, 0xf2, 0xff, 0xff, 0x4c, 0xf2, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, + 0x74, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, + 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x8c, 0xf2, 0xff, 0xff, 0x90, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x26, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x20, 0xf4, 0xff, 0xff, 0x5c, 0xf4, 0xff, 0xff, + 0x7a, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x74, 0xf2, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xb4, 0xf2, 0xff, 0xff, 0x30, 0x0d, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x94, 0xf2, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0xdc, 0xf2, 0xff, 0xff, 0xd8, 0xf2, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, + 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, 0x00, 0xf3, 0xff, 0xff, + 0x54, 0x0c, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xee, 0xf2, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x28, 0xf3, 0xff, 0xff, + 0x24, 0xf3, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, + 0x73, 0x65, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x00, 0x4c, 0xf3, 0xff, 0xff, 0x08, 0x0c, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x3a, 0xf3, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x74, 0xf3, 0xff, 0xff, 0x70, 0xf3, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, - 0x33, 0x00, 0x00, 0x00, 0x84, 0xf4, 0xff, 0xff, 0x8c, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x72, 0xf4, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x6c, 0xf4, 0xff, 0xff, 0xa8, 0xf4, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x39, + 0x00, 0x00, 0x00, 0x00, 0x98, 0xf3, 0xff, 0xff, 0xa8, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x86, 0xf3, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xc0, 0xf3, 0xff, 0xff, 0xbc, 0xf3, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, - 0xe0, 0xf4, 0xff, 0xff, 0xb0, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xca, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xfc, 0xf4, 0xff, 0xff, - 0x14, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x50, 0xf5, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xe4, 0xf4, 0xff, 0xff, - 0x20, 0xf5, 0xff, 0xff, 0x48, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x70, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x3c, 0xf5, 0xff, 0xff, - 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, - 0x64, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, - 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x35, - 0x00, 0x00, 0x00, 0x00, 0x74, 0xf5, 0xff, 0xff, 0x6c, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc8, 0xf5, 0xff, 0xff, - 0x02, 0x00, 0x00, 0x00, 0x5c, 0xf5, 0xff, 0xff, 0x98, 0xf5, 0xff, 0xff, 0x74, 0x09, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x82, 0xf5, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xb4, 0xf5, 0xff, 0xff, 0x08, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0xd0, 0xf5, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, - 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0xf8, 0xf5, 0xff, 0xff, 0xe4, 0x07, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0xf6, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x14, 0xf6, 0xff, 0xff, 0xfc, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0xf6, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xfc, 0xf5, 0xff, 0xff, 0x38, 0xf6, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x34, 0x00, - 0x60, 0xf6, 0xff, 0xff, 0x94, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xb0, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x7c, 0xf6, 0xff, 0xff, - 0x94, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x6a, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x64, 0xf6, 0xff, 0xff, - 0xa0, 0xf6, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x20, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, - 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, - 0xd0, 0xf6, 0xff, 0xff, 0x10, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x24, 0xf7, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0xb8, 0xf6, 0xff, 0xff, 0xf4, 0xf6, 0xff, 0xff, 0xc8, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe2, 0xf6, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x50, 0xf7, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x1c, 0xf7, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x32, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x44, 0xf7, 0xff, 0xff, 0x4c, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x32, 0xf7, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x2c, 0xf7, 0xff, 0xff, 0x68, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, 0x8c, 0xf7, 0xff, 0xff, - 0x84, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x7a, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x74, 0xf7, 0xff, 0xff, - 0xb0, 0xf7, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xf4, 0xf3, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, 0xe6, 0xf3, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x18, 0xf4, 0xff, 0xff, 0x3c, 0x0b, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xf3, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x40, 0xf4, 0xff, 0xff, 0x3c, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x24, 0xf4, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x64, 0xf4, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x34, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, + 0x68, 0x65, 0x72, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x8c, 0xf4, 0xff, 0xff, 0xd4, 0x08, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x68, 0xf4, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0xa8, 0xf4, 0xff, 0xff, 0xac, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x96, 0xf4, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xd0, 0xf4, 0xff, 0xff, 0xcc, 0xf4, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x1b, 0x00, 0x00, 0x00, + 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, + 0x77, 0x63, 0x4d, 0x61, 0x78, 0x50, 0x6f, 0x6f, 0x6c, 0x3a, 0x31, 0x00, 0x00, 0xf5, 0xff, 0xff, + 0x54, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xee, 0xf4, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x28, 0xf5, 0xff, 0xff, + 0x24, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x00, 0xd8, 0xf7, 0xff, 0xff, 0x38, 0x08, 0x00, 0x00, + 0x79, 0x3a, 0x32, 0x34, 0x00, 0x00, 0x00, 0x00, 0x4c, 0xf5, 0xff, 0xff, 0xf4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xc6, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xc0, 0xf7, 0xff, 0xff, 0xfc, 0xf7, 0xff, 0xff, + 0x3a, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x74, 0xf5, 0xff, 0xff, 0x70, 0xf5, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, - 0x33, 0x00, 0x00, 0x00, 0x24, 0xf8, 0xff, 0xff, 0xec, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x12, 0xf8, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x0c, 0xf8, 0xff, 0xff, 0x48, 0xf8, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, - 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, - 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x7c, 0xf8, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd8, 0xf8, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0xa4, 0xf8, 0xff, 0xff, 0x18, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x92, 0xf8, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x00, 0xf9, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xcc, 0xf8, 0xff, 0xff, - 0x14, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xa8, 0xf8, 0xff, 0xff, 0xe4, 0xf8, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x35, 0x00, 0x00, 0x00, - 0x10, 0xf9, 0xff, 0xff, 0x00, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfe, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0xf8, 0xf8, 0xff, 0xff, 0x34, 0xf9, 0xff, 0xff, 0xc0, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x84, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0x50, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x14, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, - 0x3a, 0x31, 0x00, 0x00, 0x74, 0xf9, 0xff, 0xff, 0x9c, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x62, 0xf9, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x5c, 0xf9, 0xff, 0xff, 0x98, 0xf9, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, - 0x34, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, - 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x35, 0x00, 0x00, 0xcc, 0xf9, 0xff, 0xff, - 0x14, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xa8, 0xf9, 0xff, 0xff, 0xe4, 0xf9, 0xff, 0xff, 0x28, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd2, 0xf9, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x40, 0xfa, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x0c, 0xfa, 0xff, 0xff, - 0xb0, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x5c, 0xfa, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x28, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, - 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, - 0x33, 0x00, 0x00, 0x00, 0x54, 0xfa, 0xff, 0xff, 0xbc, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xfa, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x3c, 0xfa, 0xff, 0xff, 0x78, 0xfa, 0xff, 0xff, 0x7c, 0x01, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc8, 0xfa, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x94, 0xfa, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, + 0x33, 0x00, 0x00, 0x00, 0x98, 0xf5, 0xff, 0xff, 0xbc, 0x09, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x86, 0xf5, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xc0, 0xf5, 0xff, 0xff, 0xbc, 0xf5, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, - 0x65, 0x61, 0x72, 0x3a, 0x32, 0x33, 0x00, 0x00, 0xc8, 0xfa, 0xff, 0xff, 0xf4, 0x03, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x18, 0xfb, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0xe4, 0xfa, 0xff, 0xff, 0x28, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd2, 0xfa, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x40, 0xfb, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x0c, 0xfb, 0xff, 0xff, - 0xd4, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xe8, 0xfa, 0xff, 0xff, 0x24, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, - 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, - 0x32, 0x31, 0x00, 0x00, 0x54, 0xfb, 0xff, 0xff, 0x68, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x42, 0xfb, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0xb0, 0xfb, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x7c, 0xfb, 0xff, 0xff, - 0x64, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xd0, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x64, 0xfb, 0xff, 0xff, - 0xa0, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, - 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0xc8, 0xfb, 0xff, 0xff, 0x48, 0x04, 0x00, 0x00, + 0x65, 0x61, 0x72, 0x3a, 0x32, 0x31, 0x00, 0x00, 0xec, 0xf5, 0xff, 0xff, 0xf8, 0x09, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xcc, 0xf5, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x14, 0xf6, 0xff, 0xff, 0x10, 0xf6, 0xff, 0xff, + 0x0c, 0x0a, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xfe, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xf8, 0xf5, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x38, 0xf6, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x1d, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x4e, 0x68, 0x77, 0x63, 0x46, 0x75, + 0x73, 0x65, 0x64, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x70, 0xf6, 0xff, 0xff, + 0xe4, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6a, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x64, 0xf6, 0xff, 0xff, 0x03, 0x00, 0x00, 0x00, + 0x6c, 0xf6, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x74, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xbc, 0xf6, 0xff, 0xff, 0xb8, 0xf6, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x35, 0x00, 0x00, 0x00, + 0xe4, 0xf6, 0xff, 0xff, 0x98, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc0, 0xf6, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0xf7, 0xff, 0xff, + 0x54, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xee, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x28, 0xf7, 0xff, 0xff, + 0x24, 0xf7, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, + 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x00, 0x00, 0x50, 0xf7, 0xff, 0xff, + 0x2c, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x2c, 0xf7, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x6c, 0xf7, 0xff, 0xff, 0xe8, 0x07, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x5a, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x94, 0xf7, 0xff, 0xff, 0x90, 0xf7, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, + 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, + 0xc4, 0xf7, 0xff, 0xff, 0x58, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb2, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xac, 0xf7, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xec, 0xf7, 0xff, 0xff, 0xf8, 0x07, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0xf8, 0xff, 0xff, + 0x04, 0xf8, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x79, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, + 0xec, 0xf7, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x2c, 0xf8, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, + 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x35, 0x00, 0x00, 0x00, 0x00, + 0x64, 0xf8, 0xff, 0xff, 0xb8, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x40, 0xf8, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x80, 0xf8, 0xff, 0xff, + 0x6c, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6a, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x9c, 0xf8, 0xff, 0xff, 0x48, 0x07, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xb6, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb0, 0xfb, 0xff, 0xff, 0xec, 0xfb, 0xff, 0xff, - 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x61, 0x78, 0x65, 0x73, - 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x48, 0xfc, 0xff, 0xff, - 0x01, 0x00, 0x00, 0x00, 0x14, 0xfc, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x7c, 0xf8, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0xc4, 0xf8, 0xff, 0xff, 0xc0, 0xf8, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x3a, 0x51, 0x75, 0x61, + 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x35, 0x00, 0x00, + 0xf4, 0xf8, 0xff, 0xff, 0xf8, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe2, 0xf8, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xdc, 0xf8, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x1c, 0xf9, 0xff, 0xff, 0xc8, 0x06, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x38, 0xf9, 0xff, 0xff, + 0x34, 0xf9, 0xff, 0xff, 0xe8, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x10, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x50, 0xf9, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, + 0x3a, 0x32, 0x34, 0x00, 0x78, 0xf9, 0xff, 0xff, 0x04, 0x06, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x54, 0xf9, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x94, 0xf9, 0xff, 0xff, 0xc0, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x82, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xbc, 0xf9, 0xff, 0xff, 0xb8, 0xf9, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x32, 0x35, 0x00, 0x00, 0x00, 0x00, 0xe0, 0xf9, 0xff, 0xff, + 0x60, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xce, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x08, 0xfa, 0xff, 0xff, + 0x04, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x3a, 0x31, 0x36, 0x00, 0x00, 0x00, 0x00, 0x2c, 0xfa, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x22, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x5c, 0xfa, 0xff, 0xff, 0x58, 0xfa, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, - 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, 0x34, 0x00, 0x00, 0x00, 0x3c, 0xfc, 0xff, 0xff, + 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x80, 0xfa, 0xff, 0xff, + 0xd4, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6e, 0xfa, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa8, 0xfa, 0xff, 0xff, + 0xa4, 0xfa, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x44, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, + 0x72, 0x3a, 0x32, 0x33, 0x00, 0x00, 0x00, 0x00, 0xdc, 0xfa, 0xff, 0xff, 0x10, 0x01, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc6, 0xfa, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xf8, 0xfa, 0xff, 0xff, 0xec, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd8, 0xfa, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x20, 0xfb, 0xff, 0xff, 0x1c, 0xfb, 0xff, 0xff, 0x00, 0x05, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf8, 0xfa, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x38, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x34, 0x00, 0x00, 0x00, + 0x64, 0xfb, 0xff, 0xff, 0x18, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x40, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x80, 0xfb, 0xff, 0xff, 0xd4, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x2a, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x24, 0xfc, 0xff, 0xff, - 0x60, 0xfc, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x3a, 0x32, 0x34, 0x00, 0x00, 0x00, 0x00, 0x88, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x7e, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x78, 0xfc, 0xff, 0xff, 0xb4, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, - 0x63, 0x6f, 0x6d, 0x2e, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, - 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, - 0xf0, 0xfc, 0xff, 0xff, 0xf0, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x44, 0xfd, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, - 0xd8, 0xfc, 0xff, 0xff, 0x14, 0xfd, 0xff, 0xff, 0xa8, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0xfd, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x70, 0xfd, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x3c, 0xfd, 0xff, 0xff, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x31, - 0x00, 0x00, 0x00, 0x00, 0x64, 0xfd, 0xff, 0xff, 0xac, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0xfd, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x01, 0x4c, 0xfd, 0xff, 0xff, 0x88, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, - 0x0a, 0x00, 0x00, 0x00, 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x31, 0x00, 0x00, - 0xb0, 0xfd, 0xff, 0xff, 0x60, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9e, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x98, 0xfd, 0xff, 0xff, 0xd4, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x30, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xfc, 0xfd, 0xff, 0xff, - 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, - 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, - 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, - 0x30, 0xfe, 0xff, 0xff, 0x8c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1e, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, - 0x8c, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x58, 0xfe, 0xff, 0xff, 0x88, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x6e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xa8, 0xfb, 0xff, 0xff, + 0xa4, 0xfb, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x70, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, + 0x32, 0x33, 0x00, 0x00, 0xd8, 0xfb, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x54, 0x33, 0x00, 0x00, 0xce, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xc8, 0xfb, 0xff, 0xff, + 0x02, 0x00, 0x00, 0x00, 0x08, 0xfc, 0xff, 0xff, 0x14, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe4, 0xfb, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x24, 0xfc, 0xff, 0xff, 0xc0, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x40, 0xfc, 0xff, 0xff, 0x3c, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0d, 0x00, 0x00, 0x00, 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, + 0x33, 0x00, 0x00, 0x00, 0x68, 0xfc, 0xff, 0xff, 0x14, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x44, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x84, 0xfc, 0xff, 0xff, 0xd0, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x72, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xac, 0xfc, 0xff, 0xff, 0xa8, 0xfc, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, + 0x72, 0x3a, 0x32, 0x31, 0x00, 0x00, 0x00, 0x00, 0xdc, 0xfc, 0xff, 0xff, 0x40, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xac, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, 0x40, 0xfe, 0xff, 0xff, 0x7c, 0xfe, 0xff, 0xff, - 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, - 0x24, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, - 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x32, 0x34, - 0x00, 0x00, 0x00, 0x00, 0xb4, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x32, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0c, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xd8, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x34, 0xff, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, - 0xc8, 0xfe, 0xff, 0xff, 0x04, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x54, 0x33, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xf6, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x28, 0xff, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, - 0x48, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, - 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x30, 0x00, 0x00, 0x00, 0x00, - 0x60, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, - 0x78, 0x5f, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0xbc, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x88, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x79, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x7a, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xac, 0xff, 0xff, 0xff, - 0x64, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x9c, 0xff, 0xff, 0xff, 0xd8, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, - 0x3a, 0x55, 0x6e, 0x73, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x31, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x1c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xca, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xc4, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x04, 0xfd, 0xff, 0xff, 0xe0, 0x02, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xe4, 0xfc, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x2c, 0xfd, 0xff, 0xff, 0x28, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, + 0x3a, 0x47, 0x61, 0x74, 0x68, 0x65, 0x72, 0x3a, 0x31, 0x00, 0x00, 0x00, 0x50, 0xfd, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x54, 0x69, 0x6e, 0x64, 0x00, 0x00, 0x00, 0x00, 0x38, 0xfd, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x78, 0xfd, 0xff, 0xff, 0xdc, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x66, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0xa0, 0xfd, 0xff, 0xff, 0x9c, 0xfd, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, + 0x3a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73, 0x65, 0x3a, 0x32, 0x34, 0x00, 0x00, 0x00, + 0xc4, 0xfd, 0xff, 0xff, 0x90, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb2, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0xec, 0xfd, 0xff, 0xff, 0xe8, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x38, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x31, 0x33, 0x00, 0x10, 0xfe, 0xff, 0xff, + 0x44, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xfe, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x38, 0xfe, 0xff, 0xff, + 0x34, 0xfe, 0xff, 0xff, 0x48, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x10, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x50, 0xfe, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x3a, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x3a, 0x31, 0x00, + 0x74, 0xfe, 0xff, 0xff, 0xe0, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x62, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, + 0x9c, 0xfe, 0xff, 0xff, 0x98, 0xfe, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, + 0x3a, 0x44, 0x65, 0x71, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, + 0x72, 0x3a, 0x31, 0x39, 0x00, 0x00, 0x00, 0x00, 0xcc, 0xfe, 0xff, 0xff, 0x50, 0x01, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xba, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0xb4, 0xfe, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0xf4, 0xfe, 0xff, 0xff, 0xf0, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd4, 0xfe, 0xff, 0xff, 0x02, 0x00, 0x00, 0x00, + 0x1c, 0xff, 0xff, 0xff, 0x18, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x3a, 0x53, 0x71, 0x75, 0x65, 0x65, 0x7a, 0x65, 0x3a, 0x32, 0x33, 0x00, 0x40, 0xff, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x36, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x01, 0x70, 0xff, 0xff, 0xff, 0x6c, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x61, 0x78, 0x65, 0x73, 0x00, 0x00, 0x00, 0x00, 0x54, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x94, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x1e, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6d, 0x2e, + 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x73, 0x6f, 0x66, 0x74, 0x3a, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, + 0x7a, 0x65, 0x4c, 0x69, 0x6e, 0x65, 0x61, 0x72, 0x3a, 0x31, 0x00, 0x00, 0xd0, 0xff, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x31, 0x00, 0x00, 0xb8, 0xff, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x54, 0x32, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, }; // clang-format on diff --git a/onnxruntime/core/framework/kernel_type_str_resolver_utils.h b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h index 5daab7c1159be..488970890a1a0 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver_utils.h +++ b/onnxruntime/core/framework/kernel_type_str_resolver_utils.h @@ -8,8 +8,6 @@ #include #include "core/common/status.h" #include "core/framework/kernel_type_str_resolver.h" -#include "core/graph/op_identifier.h" - namespace flatbuffers { class DetachedBuffer; } @@ -18,11 +16,6 @@ namespace onnxruntime::kernel_type_str_resolver_utils { #if !defined(ORT_MINIMAL_BUILD) -/** - * Gets the ops that the layout transformation may potentially add. - */ -gsl::span GetLayoutTransformationRequiredOpIdentifiers(); - /** * Saves `kernel_type_str_resolver` to a byte buffer owned by `buffer` and referenced by `buffer_span`. */ diff --git a/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc b/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc index 8fe3a4d5f3b6f..cb015a3a3c500 100644 --- a/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc @@ -388,7 +388,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(NhwcFusedConv, 1, OpSchema() .SetDoc(R"DOC( NhwcFusedConv is a Conv operator with optional activation and add operators fused in. -Only has fp16 implementation as of 2023/04/15. )DOC") .Attr("auto_pad", "", AttributeProto::STRING, std::string("NOTSET")) .Attr("kernel_shape", "", AttributeProto::INTS, OPTIONAL_VALUE) @@ -398,12 +397,27 @@ Only has fp16 implementation as of 2023/04/15. .Attr("group", "", AttributeProto::INT, static_cast(1)) .Attr("activation", "", AttributeProto::STRING, OPTIONAL_VALUE) .Attr("activation_params", "", AttributeProto::FLOATS, OPTIONAL_VALUE) - .Input(0, "X", "", "T") - .Input(1, "W", "", "T") - .Input(2, "B", "", "T", OpSchema::Optional) - .Input(3, "Z", "Tensor to be added to the output, must be the same shape and format as the output tensor.", "T", OpSchema::Optional) - .Output(0, "Y", "", "T") - .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float tensors") + .Input(0, "X", + "Input activation tensor in channels-last layout. For 2D convolution this is " + "[N, H, W, C], where N is batch size, H/W are spatial dimensions, and C is " + "the number of input channels.", + "T") + .Input(1, "W", + "Convolution weight tensor in the standard ONNX Conv filter layout " + "[M, C/group, kH, kW], where M is the number of output channels.", + "T") + .Input(2, "B", + "Optional 1D bias tensor of shape [M].", + "T", OpSchema::Optional) + .Input(3, "Z", + "Optional residual/add tensor in the same channels-last layout and shape as " + "the output tensor Y. For 2D convolution this is [N, out_H, out_W, M].", + "T", OpSchema::Optional) + .Output(0, "Y", + "Output tensor in channels-last layout. For 2D convolution this is " + "[N, out_H, out_W, M], where M is the number of output channels.", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input and output types to float tensors") .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); convPoolShapeInferenceNhwc(ctx, true, false, 0, 1); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 218aaef0c8f4b..ddb9daa5e244b 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -892,6 +892,7 @@ struct MLAS_CONV_PARAMETERS { size_t BatchCount; size_t GroupCount; size_t InputChannels; + bool ChannelsLast; size_t InputShape[3]; size_t KernelShape[3]; size_t DilationShape[3]; @@ -906,6 +907,9 @@ struct MLAS_CONV_PARAMETERS { MLAS_CONV_ALGORITHM Algorithm; ptrdiff_t ThreadCount; const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig = nullptr; + const void* PackedFilter = nullptr; + size_t PackedFilterGroupStride = 0; + bool FilterIsPacked = false; union { struct { CBLAS_TRANSPOSE TransB; @@ -932,9 +936,24 @@ MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, size_t FilterCount, const MLAS_ACTIVATION* Activation, size_t* WorkingBufferSize, + bool ChannelsLast, float Beta, MLAS_THREADPOOL* ThreadPool); +bool +MLASCALL +MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + const size_t* InputShape, + const size_t* KernelShape, + const size_t* DilationShape, + const size_t* Padding, + const size_t* StrideShape, + size_t FilterCount, + float Beta); + void MLASCALL MlasConv( diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index ba2d98f46e960..696314f20267e 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -1324,6 +1324,85 @@ Return Value: // Chance of arithmetic overflow could be reduced #pragma warning(disable : 26451) #endif + +namespace { + +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) +static constexpr size_t ComputeChannelsLastDilatedKernelSize(size_t dilation, size_t kernel) { + return (dilation * kernel) - (dilation - 1); +} + +static constexpr size_t ComputeChannelsLastConvOutSize(size_t input, size_t kernel, size_t padding, size_t stride) { + if (stride > 0 && (input + 2 * padding) >= kernel) { + return (((input - kernel) + (2 * padding)) / stride) + 1; + } + + return 0; +} +#endif + +} // namespace + +bool +MLASCALL +MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + const size_t* InputShape, + const size_t* KernelShape, + const size_t* DilationShape, + const size_t* Padding, + const size_t* StrideShape, + size_t FilterCount, + float Beta) +{ +#if !defined(USE_KLEIDIAI) || !defined(MLAS_TARGET_ARM64) + MLAS_UNREFERENCED_PARAMETER(Dimensions); + MLAS_UNREFERENCED_PARAMETER(BatchCount); + MLAS_UNREFERENCED_PARAMETER(GroupCount); + MLAS_UNREFERENCED_PARAMETER(InputShape); + MLAS_UNREFERENCED_PARAMETER(KernelShape); + MLAS_UNREFERENCED_PARAMETER(DilationShape); + MLAS_UNREFERENCED_PARAMETER(Padding); + MLAS_UNREFERENCED_PARAMETER(StrideShape); + MLAS_UNREFERENCED_PARAMETER(FilterCount); + MLAS_UNREFERENCED_PARAMETER(Beta); + return false; +#else + // Channels-last float convolution is only implemented by the KleidiAI + // override. The generic MLAS convolution path assumes NCHW layout. + if (GetMlasPlatform().MlasConvPrepareOverride == nullptr || + GetMlasPlatform().MlasConvOverride == nullptr) { + return false; + } + + if (Dimensions != 2 || BatchCount != 1 || GroupCount != 1 || Beta != 0.0f) { + return false; + } + + if (Padding[0] != Padding[2] || Padding[1] != Padding[3]) { + return false; + } + + const size_t output_h = + ComputeChannelsLastConvOutSize(InputShape[0], ComputeChannelsLastDilatedKernelSize(DilationShape[0], KernelShape[0]), + Padding[0], StrideShape[0]); + const size_t output_w = + ComputeChannelsLastConvOutSize(InputShape[1], ComputeChannelsLastDilatedKernelSize(DilationShape[1], KernelShape[1]), + Padding[1], StrideShape[1]); + if (output_h == 0 || output_w == 0) { + return false; + } + + if (FilterCount <= 1 || KernelShape[0] < 3 || KernelShape[1] < 3) { + return false; + } + + return true; +#endif +} + void MLASCALL MlasConvPrepare( @@ -1341,6 +1420,7 @@ MlasConvPrepare( size_t FilterCount, const MLAS_ACTIVATION* Activation, size_t* WorkingBufferSize, + bool ChannelsLast, float Beta, MLAS_THREADPOOL* ThreadPool ) @@ -1399,7 +1479,7 @@ Return Value: if (GetMlasPlatform().MlasConvPrepareOverride != nullptr && GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels, InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount, - Activation, WorkingBufferSize, Beta, ThreadPool)){ + Activation, WorkingBufferSize, ChannelsLast, Beta, ThreadPool)){ return; } // @@ -1410,6 +1490,7 @@ Return Value: Parameters->BatchCount = BatchCount; Parameters->GroupCount = GroupCount; Parameters->InputChannels = InputChannels; + Parameters->ChannelsLast = ChannelsLast; Parameters->FilterCount = FilterCount; Parameters->Beta = Beta; diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index a37dcb37a17ac..cca4f5a19c417 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -24,20 +24,6 @@ const KaiF32IMatmulKernel& imatmul_conv = GetKleidiAIF32IMatmulUKernel(); -// Right-hand-side (weights) cache key -struct RhsCacheKey { - size_t co, ci, kh, kw, dilationh, dilationw; - size_t weights_hash; - - bool operator==(const RhsCacheKey& other) const { - return co == other.co && ci == other.ci && - kh == other.kh && kw == other.kw && - dilationh == other.dilationh && dilationw == other.dilationw && - weights_hash == other.weights_hash; - } -}; - - // Left-hand-side (input indirection) cache key struct LhsCacheKey { size_t ci, ih, iw; @@ -53,90 +39,29 @@ struct LhsCacheKey { } }; -// Derived from 2^32 * (sqrt(5) - 1) / 2 ≈ 0.6180339887 (reciprocal of the golden ratio) -// Based on Knuth's multiplicative hashing method -constexpr size_t HASH_GOLDEN_RATIO_CONST = 0x9e3779b9; - -size_t HashWeights(const float* data, size_t count = 16) { - size_t h = 0; - for (size_t i = 0; i < count; ++i) { - h ^= std::hash()(data[i]) + HASH_GOLDEN_RATIO_CONST + (h << 6) + (h >> 2); - } - return h; -} - namespace std { // Specialize hash type for cache keys and do it within namespace std. // Doing this allows standard containers like std::unordered_map to find // the appropriate hash function via template specialization, as ADL // (argument-dependent lookup) does not apply to std::hash. - template<> - struct hash { - size_t operator()(const RhsCacheKey& k) const { - return k.weights_hash ^ - (std::hash()(k.co) << 1) ^ - (std::hash()(k.ci) << 2) ^ - (std::hash()(k.kh) << 3) ^ - (std::hash()(k.kw) << 4) ^ - (std::hash()(k.dilationh) << 5) ^ - (std::hash()(k.dilationw) << 6); - } - }; - template<> struct hash { size_t operator()(const LhsCacheKey& k) const { - return (std::hash()(k.ci) << 1) ^ - (std::hash()(k.ih) << 2) ^ - (std::hash()(k.iw) << 3) ^ - (std::hash()(k.padding) << 4) ^ - (std::hash()(k.sh) << 5) ^ - (std::hash()(k.sw) << 6) ^ - (std::hash()(k.kh) << 7) ^ - (std::hash()(k.kw) << 8) ^ - (std::hash()(k.dilationh) << 9) ^ - (std::hash()(k.dilationw) << 10); + return std::hash()(k.ci) ^ + (std::hash()(k.ih) << 1) ^ + (std::hash()(k.iw) << 2) ^ + (std::hash()(k.padding) << 3) ^ + (std::hash()(k.sh) << 4) ^ + (std::hash()(k.sw) << 5) ^ + (std::hash()(k.kh) << 6) ^ + (std::hash()(k.kw) << 7) ^ + (std::hash()(k.dilationh) << 8) ^ + (std::hash()(k.dilationw) << 9); } }; } -namespace { - -using LhsPtrsCache = std::unordered_map>; - -thread_local std::unordered_map lhs_ptrs_cache_by_pad; -thread_local const float* last_pad_ptr = nullptr; - -size_t LhsPtrsCacheEntryCount() { - size_t count = 0; - for (const auto& cache_group : lhs_ptrs_cache_by_pad) { - count += cache_group.second.size(); - } - return count; -} - -void ClearLhsPtrsCache() { - lhs_ptrs_cache_by_pad.clear(); - last_pad_ptr = nullptr; -} - -} // namespace - -#if defined(MLAS_ENABLE_TEST_HOOKS) -size_t -MLASCALL -ArmKleidiAI::MlasConvLhsCacheEntryCountForTest() { - return LhsPtrsCacheEntryCount(); -} - -void -MLASCALL -ArmKleidiAI::MlasConvClearLhsCacheForTest() { - ClearLhsPtrsCache(); -} -#endif - static constexpr size_t ComputeKernelSize(const size_t D, const size_t K) { // D - dilation size @@ -184,29 +109,21 @@ static size_t ComputeMlasWorkingBufferSize(const size_t co, } static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { - - //functional checks - logically can the conv be performed - if ((Parameters->Dimensions != 2) || - (Parameters->BatchCount != 1) || - (Parameters->Beta != 0.f) || - (Parameters->Padding[0] != Parameters->Padding[1]) || - (Parameters->Padding[0] != Parameters->Padding[2]) || - (Parameters->Padding[0] != Parameters->Padding[3]) || - (ComputeConvOutSize(Parameters->InputShape[0], - ComputeKernelSize(Parameters->DilationShape[0],Parameters->KernelShape[0]), - Parameters->Padding[0], Parameters->StrideShape[0]) * - ComputeConvOutSize(Parameters->InputShape[1], - ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]), - Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) { - KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on functional checks."); + if (!MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + Parameters->Dimensions, + Parameters->BatchCount, + Parameters->GroupCount, + Parameters->InputShape, + Parameters->KernelShape, + Parameters->DilationShape, + Parameters->Padding, + Parameters->StrideShape, + Parameters->FilterCount, + Parameters->Beta)) { + KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on shared capability checks."); return false; } - auto N = Parameters->FilterCount; - if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { - KLEIDIAI_DEBUG_LOG("CheckCapabilitiesSme returning false on optimization checks."); - return false; - } return true; } @@ -365,52 +282,65 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci }); } -static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const size_t ci, - const size_t kh, const size_t kw, - const size_t dilationh, const size_t dilationw, - const float* weights, const float* bias, - MLAS_THREADPOOL* ThreadPool) -{ - // Cache of prepacked kai rhs weights and biases. thread_local to prevent interference from parallel sessions. - thread_local std::unordered_map> rhs_cache; +size_t +MLASCALL +ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackWSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape) { + const auto d_kh = ComputeKernelSize(static_cast(DilationShape[0]), static_cast(KernelShape[0])); + const auto d_kw = ComputeKernelSize(static_cast(DilationShape[1]), static_cast(KernelShape[1])); + return kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(FilterCount, d_kh * d_kw, + InputChannels); +} + +void +MLASCALL +ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackW( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + size_t GroupCount, + const float* Filter, + const float* Bias, + void* PackedFilter, + size_t PackedFilterGroupStride, + MLAS_THREADPOOL* ThreadPool) { + const size_t kh = static_cast(KernelShape[0]); + const size_t kw = static_cast(KernelShape[1]); + const size_t dilationh = static_cast(DilationShape[0]); + const size_t dilationw = static_cast(DilationShape[1]); + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); - RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) }; + for (size_t group_idx = 0; group_idx < GroupCount; ++group_idx) { + const float* weights = Filter + group_idx * FilterCount * InputChannels * kh * kw; + const float* bias = Bias ? Bias + group_idx * FilterCount : nullptr; + auto* packed_group = reinterpret_cast(PackedFilter) + group_idx * PackedFilterGroupStride; - auto found = rhs_cache.find(key); - if (found != rhs_cache.end()) { - return found->second; - } else { // prepare mlas filter weights for kai rhs packing // dilated nhwc format - auto nhwc = NChwToNhwc(co, ci, kh, kw, weights, dilationh, dilationw, true, ThreadPool); - - - //dilation, axis swap (n x k -> k x n) where n == co, k == d_kh x d_kw x ci - const auto d_kh = ComputeKernelSize(dilationh,kh); - const auto d_kw = ComputeKernelSize(dilationw,kw); + auto nhwc = NChwToNhwc(FilterCount, InputChannels, kh, kw, weights, dilationh, dilationw, true, ThreadPool); //t_weights[d_kh][d_kw][ci][co] = nhwc[co][d_kh][d_kw][ci] - auto t_weights = Transpose4D({co,d_kh,d_kw,ci},&nhwc[0],{1,2,3,0}); - - const auto packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(co,d_kh*d_kw,ci); - auto packed = std::shared_ptr(new std::byte[packed_size], std::default_delete()); - - rhs_cache[key] = packed; + auto t_weights = Transpose4D({FilterCount, d_kh, d_kw, InputChannels}, &nhwc[0], {1, 2, 3, 0}); std::vector bias_copy; if (bias) { - bias_copy.assign(bias, bias + co); + bias_copy.assign(bias, bias + FilterCount); } else { - bias_copy.resize(co, 0.0f); + bias_copy.resize(FilterCount, 0.0f); } KLEIDIAI_KERNEL_LOG("kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme" - << " N=" << co << " k_chunk_count=" << (d_kh*d_kw) << " k_chunk_length=" << ci << " rhs_stride_row=" << (co * sizeof(float))); + << " N=" << FilterCount << " k_chunk_count=" << (d_kh * d_kw) + << " k_chunk_length=" << InputChannels + << " rhs_stride_row=" << (FilterCount * sizeof(float))); kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( - co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get() - ); - - return packed; + FilterCount, d_kh * d_kw, InputChannels, FilterCount * sizeof(float), &t_weights[0], bias_copy.data(), + packed_group); } } @@ -491,6 +421,7 @@ static std::shared_ptr LhsPtrFill(const size_t ci, const size_t i static std::unique_ptr LhsPackImageDataSme(const size_t ci, const size_t ih, const size_t iw, const size_t kh, const size_t kw, const size_t sh, const size_t sw, const size_t padding, const float* in, + bool input_is_channels_last, MLAS_THREADPOOL* ThreadPool) { size_t padsize = 256; @@ -517,21 +448,31 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s const auto lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m,kh*kw,ci); auto lhs = std::make_unique(lhs_size); - auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); + std::unique_ptr nhwc_holder; + const float* activation_src = nullptr; + if (input_is_channels_last) { + activation_src = in; + } else { + nhwc_holder = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); + activation_src = nhwc_holder.get(); + } // Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions. // // Entries include pointers to the pad buffer for out-of-bounds pixels, so we must not reuse entries after the // pad buffer is reallocated. To avoid clearing the entire cache, we group caches by pad buffer identity and // invalidate only the old group when the pad buffer moves. + using LhsPtrsCache = std::unordered_map>; + thread_local std::unordered_map lhs_ptrs_cache_by_pad; + // If pad_ptr moved (vector reallocation), drop only the old group to avoid accumulating unreachable entries. + thread_local const float* last_pad_ptr = nullptr; const float* cur_pad_ptr = pad_ptr.data(); if (last_pad_ptr != nullptr && last_pad_ptr != cur_pad_ptr) { lhs_ptrs_cache_by_pad.erase(last_pad_ptr); } last_pad_ptr = cur_pad_ptr; - // LhsPtrFill stores geometry offsets only; the current input base is supplied when packing. LhsCacheKey key = { ci, ih, iw, padding, sh, sw, @@ -549,7 +490,7 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s lhs_ptrs_cache[key] = lhs_ptrs; } - MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], &nhwc[0], &pad_ptr[0]); + MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], activation_src, &pad_ptr[0]); return lhs; } @@ -568,9 +509,12 @@ static void ConvolveSme(const size_t co, //channels out const size_t groups, //number of filter groups const float* weights, //kernel weights [co,ci,ih,iw] const float* bias, //kernel biases + const std::byte* packed_rhs, + const size_t packed_rhs_group_stride, const float* in, //in image data float* out, //out image data float* tmp_mlas_aligned, //intermediate buffer if we need to perform a transpose + bool input_is_channels_last, MLAS_THREADPOOL* ThreadPool) { //RhsPackWeightsBiasSme() - to perform dilation increases kernel size and masks unused weights @@ -608,18 +552,27 @@ static void ConvolveSme(const size_t co, //channels out for (size_t g = 0; g < groups; ++g) { - auto result{out}; - //do we require a post matmul transpose ? - //output is m x n or image_data x co or hw x co - //MLAS require it as n x m (or co x hw), transpose required - if (co > 1) { - //intermediate buffer required, pre-transpose - //Note: because we are calling MlasTranspose() need to ensure we use a MLAS aligned buffer + auto result = out; + const bool need_transpose = (!input_is_channels_last) && (co > 1); + if (need_transpose) { result = tmp_mlas_aligned; } - auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool); - auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool); + auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, input_is_channels_last, ThreadPool); + const std::byte* rhs_data = packed_rhs ? packed_rhs + g * packed_rhs_group_stride : nullptr; + std::unique_ptr rhs_storage; + if (rhs_data == nullptr) { + const std::array kernel_shape{static_cast(kh), static_cast(kw)}; + const std::array dilation_shape{static_cast(dilationh), static_cast(dilationw)}; + const size_t packed_size = + ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackWSize(co, ci, kernel_shape.data(), + dilation_shape.data()); + rhs_storage = std::make_unique(packed_size); + ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackW(co, ci, kernel_shape.data(), dilation_shape.data(), 1, + weights, bias, rhs_storage.get(), packed_size, + ThreadPool); + rhs_data = rhs_storage.get(); + } MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [&](ptrdiff_t tid) { //compute B,M,N index from iteration index @@ -632,7 +585,7 @@ static void ConvolveSme(const size_t co, //channels out imatmul_conv.ukernel.get_rhs_packed_offset(NIdx * n_step, d_kh * d_kw, ci); auto BTile = reinterpret_cast( - reinterpret_cast(rhs.get()) + rhs_packed_offset + rhs_data + rhs_packed_offset ); // Get lhs tile, A @@ -660,7 +613,7 @@ static void ConvolveSme(const size_t co, //channels out ); }); - if (result == tmp_mlas_aligned) { + if (need_transpose) { //Note: this could be absorbed into post conv activation MlasTranspose(tmp_mlas_aligned, out, m, co, ThreadPool); } @@ -689,6 +642,7 @@ ArmKleidiAI::MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, size_t FilterCount, const MLAS_ACTIVATION* Activation, size_t* WorkingBufferSize, + bool ChannelsLast, float Beta, MLAS_THREADPOOL* ThreadPool) { @@ -708,6 +662,7 @@ ArmKleidiAI::MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, Parameters->BatchCount = BatchCount; Parameters->GroupCount = GroupCount; Parameters->InputChannels = InputChannels; + Parameters->ChannelsLast = ChannelsLast; Parameters->FilterCount = FilterCount; Parameters->Beta = Beta; @@ -779,7 +734,10 @@ ArmKleidiAI::MlasConv( Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation Parameters->Padding[0], // image padding Parameters->GroupCount, // filter groups - Filter, Bias, Input, Output, WorkingBuffer, ThreadPool); + Filter, Bias, + reinterpret_cast(Parameters->PackedFilter), + Parameters->PackedFilterGroupStride, + Input, Output, WorkingBuffer, Parameters->ChannelsLast, ThreadPool); MlasActivation(Parameters->Activation, Output, nullptr, Parameters->FilterCount, Parameters->OutputSize, Parameters->OutputSize); diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 1d620388bb5f0..3c9f398ece887 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -185,6 +185,7 @@ MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, size_t FilterCount, const MLAS_ACTIVATION* Activation, size_t* WorkingBufferSize, + bool ChannelsLast, float Beta, MLAS_THREADPOOL* ThreadPool); @@ -200,15 +201,29 @@ MlasConv( MLAS_THREADPOOL* ThreadPool ); -#if defined(MLAS_ENABLE_TEST_HOOKS) size_t MLASCALL -MlasConvLhsCacheEntryCountForTest(); +MlasConvSymmetricChannelsLast2DFloatPackWSize( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape + ); void MLASCALL -MlasConvClearLhsCacheForTest(); -#endif +MlasConvSymmetricChannelsLast2DFloatPackW( + size_t FilterCount, + size_t InputChannels, + const int64_t* KernelShape, + const int64_t* DilationShape, + size_t GroupCount, + const float* Filter, + const float* Bias, + void* PackedFilter, + size_t PackedFilterGroupStride, + MLAS_THREADPOOL* ThreadPool + ); } /*++ diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 592733f0dd1f9..1193c1c5bbe27 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -853,6 +853,7 @@ void size_t FilterCount, const MLAS_ACTIVATION* Activation, size_t* WorkingBufferSize, + bool ChannelsLast, float Beta, MLAS_THREADPOOL* ThreadPool ); @@ -873,6 +874,7 @@ bool size_t FilterCount, const MLAS_ACTIVATION* Activation, size_t* WorkingBufferSize, + bool ChannelsLast, float Beta, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 44a20428a09e0..35f3683b77764 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -140,9 +140,12 @@ class FuseConvActivationAction : public ReplaceWithNew { return "FusedConv"; } } else if (domain == kMSDomain) { - if (op_type == "NhwcConv") { + if (op_type == "NhwcConv" || op_type == "NhwcFusedConv") { return "NhwcFusedConv"; } + if (op_type == "FusedConv") { + return "FusedConv"; + } } else if (domain == kMSInternalNHWCDomain) { if (op_type == "Conv") { return "Conv"; diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 45441d20a4112..d51ea894725fa 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -211,7 +211,22 @@ class FuseConvAddActivationAction : public ReplaceWithNew { private: std::string OpType(const RuntimeState& runtimeState) const override { - return (runtimeState.selected_nodes.Target().OpType() == "Conv") ? "FusedConv" : "NhwcFusedConv"; + const auto& target = runtimeState.selected_nodes.Target(); + const auto* channels_last_attr = graph_utils::GetNodeAttribute(target, "channels_last"); + const bool channels_last = channels_last_attr != nullptr && channels_last_attr->i() != 0; + const std::string& op_type = target.OpType(); + + // If channels_last is set, use NHWC fused convolution regardless of original op type. + if (channels_last) { + return "NhwcFusedConv"; + } + + // Without channels_last, convert Conv to FusedConv, and leave other op types unchanged. + if (op_type == "Conv") { + return "FusedConv"; + } + + return op_type; } std::string Domain(const RuntimeState&) const override { return kMSDomain; } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 282261cab0e58..702f20a96dccf 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -463,7 +463,7 @@ InlinedVector> GenerateTransformers( auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), - logger); + logger, session_options.config_options); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } @@ -566,7 +566,7 @@ InlinedVector> GenerateTransformersForMinimalB AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), - logger); + logger, session_options.config_options); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h index 81eb9f59eeada..3e3a010728086 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h @@ -70,6 +70,7 @@ inline constexpr std::array kLayoutTransformationPotentiallyAddedOps = { #if !defined(DISABLE_CONTRIB_OPS) // kMSDomain ops OpIdentifierWithStringViews{kMSDomain, "DequantizeLinear", 1}, + OpIdentifierWithStringViews{kMSDomain, "NhwcFusedConv", 1}, OpIdentifierWithStringViews{kMSDomain, "NhwcMaxPool", 1}, OpIdentifierWithStringViews{kMSDomain, "QLinearConv", 1}, OpIdentifierWithStringViews{kMSDomain, "QuantizeLinear", 1}, diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index b9366ff0abae8..a971a058f43b7 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -324,6 +324,18 @@ void NchwcTransformerImpl::ConvPoolShapeInference(const Node& node, void NchwcTransformerImpl::TransformConv(Node& node) { auto& input_defs = node.MutableInputDefs(); auto& output_defs = node.MutableOutputDefs(); + NchwcArgument* nchwc_sum_input = nullptr; + + // The internal NCHWc Conv kernel can consume an optional fused Sum input, but it expects + // that tensor to already be in NCHWc layout. Only allow transforming a pre-existing + // FusedConv(X, W, B, Sum) when the Sum input already has a tracked NCHWc variant that + // can be wired through directly. + if (node.OpType() == "FusedConv" && input_defs.size() >= 4 && input_defs[3] != nullptr && input_defs[3]->Exists()) { + nchwc_sum_input = LookupNchwcArgument(input_defs[3]); + if (nchwc_sum_input == nullptr) { + return; + } + } // Require that the weights tensor be static. const ONNX_NAMESPACE::TensorProto* conv_W_tensor_proto = nullptr; @@ -490,6 +502,11 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_node.MutableInputDefs()[2] = nchwc_conv_B_arg; } + if (nchwc_sum_input != nullptr) { + nchwc_node.MutableInputDefs()[3] = nchwc_sum_input->nchwc_arg_; + nchwc_sum_input->remaining_original_uses_--; + } + NchwcArgument::Shape output_shape(output_defs[0]); if (do_reorder_input) { diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index cd654991c92d5..6c0717865b135 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -2,7 +2,12 @@ // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#include +#include #include +#include +#include "core/common/cpuid_info.h" +#include "core/graph/constants.h" #include "core/mlas/inc/mlas.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" @@ -11,6 +16,8 @@ #include "core/optimizer/layout_transformation/layout_transformation.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/optimizer/transpose_optimization/ort_transpose_optimization.h" +#include "core/providers/common.h" +#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::common; @@ -21,6 +28,221 @@ namespace onnxruntime { using namespace layout_transformation; +#ifdef USE_KLEIDIAI +namespace { + +bool TryGetDimValueAsSizeT(const ONNX_NAMESPACE::TensorShapeProto& shape, int index, size_t& value) { + if (shape.dim_size() <= index || !shape.dim(index).has_dim_value()) { + return false; + } + + const int64_t dim_value = shape.dim(index).dim_value(); + if (dim_value < 0) { + return false; + } + + value = narrow(dim_value); + return true; +} + +bool TryReadPositiveInts(const std::vector& values, std::array& out) { + if (values.size() != out.size()) { + return false; + } + + for (size_t i = 0; i < out.size(); ++i) { + if (values[i] <= 0) { + return false; + } + + out[i] = narrow(values[i]); + } + + return true; +} + +bool TryReadPositiveOrZeroInts(const std::vector& values, std::array& out) { + if (values.size() != out.size()) { + return false; + } + + for (size_t i = 0; i < out.size(); ++i) { + if (values[i] < 0) { + return false; + } + + out[i] = narrow(values[i]); + } + + return true; +} + +bool TryParseAutoPadType(std::string_view value, AutoPadType& auto_pad_type) { + if (value.empty() || value == "NOTSET") { + auto_pad_type = AutoPadType::NOTSET; + return true; + } + + if (value == "VALID") { + auto_pad_type = AutoPadType::VALID; + return true; + } + + if (value == "SAME_UPPER") { + auto_pad_type = AutoPadType::SAME_UPPER; + return true; + } + + if (value == "SAME_LOWER") { + auto_pad_type = AutoPadType::SAME_LOWER; + return true; + } + + return false; +} + +bool TryComputeFloatNhwcPads(const api::NodeRef& node, + const std::array& input_shape, + const std::array& kernel_shape, + const std::array& strides, + const std::array& dilations, + std::array& pads) { + for (size_t i = 0; i < 2; ++i) { + if (kernel_shape[i] == 0 || strides[i] == 0 || dilations[i] == 0) { + return false; + } + } + + const auto auto_pad_value = node.GetAttributeString("auto_pad"); + AutoPadType auto_pad = AutoPadType::NOTSET; + if (!TryParseAutoPadType(auto_pad_value.value_or("NOTSET"), auto_pad)) { + return false; + } + + if (auto_pad == AutoPadType::NOTSET) { + const auto pads_opt = node.GetAttributeInts("pads"); + if (!pads_opt.has_value()) { + pads.fill(0); + return true; + } + + return TryReadPositiveOrZeroInts(*pads_opt, pads); + } + + std::array pads_int64{}; + for (size_t i = 0; i < 2; ++i) { + int64_t pad_head = 0; + int64_t pad_tail = 0; + int64_t out_dim = 0; + const auto status = ComputePadAndOutputShape( + narrow(input_shape[i]), + narrow(strides[i]), + narrow(kernel_shape[i]), + narrow(dilations[i]), + auto_pad, + pad_head, + pad_tail, + out_dim, + /*force_symmetric_auto_padding*/ false); + if (!status.IsOK() || pad_head < 0 || pad_tail < 0 || out_dim < 0) { + return false; + } + + pads_int64[i] = pad_head; + pads_int64[i + 2] = pad_tail; + } + + for (size_t i = 0; i < pads.size(); ++i) { + pads[i] = narrow(pads_int64[i]); + } + + return true; +} + +bool FloatNhwcWrapperFilter(const onnx_transpose_optimization::api::GraphRef& graph, + onnx_transpose_optimization::api::NodeRef& node) { + auto& base_node = NodeFromApiNode(node); + + ORT_UNUSED_PARAMETER(graph); +#if !defined(MLAS_TARGET_ARM64) + return false; +#else + if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { + return false; + } + + if (base_node.InputDefs().size() < 2) { + return false; + } + + const auto* input_shape = base_node.InputDefs()[0]->Shape(); + if (input_shape == nullptr || input_shape->dim_size() != 4) { + return false; + } + + const auto* weight_shape = base_node.InputDefs()[1]->Shape(); + if (weight_shape == nullptr || weight_shape->dim_size() != 4) { + return false; + } + + const auto inputs = node.Inputs(); + if (base_node.OpType() == "FusedConv" && inputs.size() > 3 && !inputs[3].empty()) { + return false; + } + + const auto group = node.GetAttributeInt("group").value_or(1); + if (group != 1) { + return false; + } + + std::array input_spatial_shape{}; + std::array kernel_spatial_shape{}; + std::array dilations{1, 1}; + std::array strides{1, 1}; + std::array pads{}; + size_t batch_count = 0; + size_t filter_count = 0; + + if (!TryGetDimValueAsSizeT(*input_shape, 0, batch_count) || + !TryGetDimValueAsSizeT(*input_shape, 2, input_spatial_shape[0]) || + !TryGetDimValueAsSizeT(*input_shape, 3, input_spatial_shape[1]) || + !TryGetDimValueAsSizeT(*weight_shape, 0, filter_count) || + !TryGetDimValueAsSizeT(*weight_shape, 2, kernel_spatial_shape[0]) || + !TryGetDimValueAsSizeT(*weight_shape, 3, kernel_spatial_shape[1])) { + return false; + } + + const auto dilations_opt = node.GetAttributeInts("dilations"); + if (dilations_opt.has_value() && !TryReadPositiveInts(*dilations_opt, dilations)) { + return false; + } + + const auto strides_opt = node.GetAttributeInts("strides"); + if (strides_opt.has_value() && !TryReadPositiveInts(*strides_opt, strides)) { + return false; + } + + if (!TryComputeFloatNhwcPads(node, input_spatial_shape, kernel_spatial_shape, strides, dilations, pads)) { + return false; + } + + return MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + /*Dimensions*/ 2, + batch_count, + /*GroupCount*/ 1, + input_spatial_shape.data(), + kernel_spatial_shape.data(), + dilations.data(), + pads.data(), + strides.data(), + filter_count, + /*Beta*/ 0.0f); +#endif +} + +} // namespace +#endif + static inline const OpTransformInfo* NhwcConvLookup( const OpTransformMap& conv_table, @@ -41,18 +263,29 @@ NhwcConvLookup( if (iter == conv_table.end()) { return nullptr; } + + if (iter->second.filter_ != nullptr) { + if (!iter->second.filter_(graph, node)) { + return nullptr; + } + } + return &(iter->second); } NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry, - const logging::Logger& logger) noexcept + const logging::Logger& logger, + const ConfigOptions& config_options) noexcept : GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) { if (!cpu_kernel_registry) { // This is a CPU op nodes optimizer, not useful if cpu EP is not available. return; } + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config{}; + SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config, config_options); + // // Constructing a mapping table from operators to be transformed to their target. // Make sure that the new nodes we are about to create during graph transformation, @@ -71,10 +304,10 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, kernel_create_info = nullptr; conv_table_.emplace( OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::INT8), - OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true}); + OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true, false}); conv_table_.emplace( OpIdInfo("QLinearConv", kMSDomain, api::DataType::INT8), - OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true}); + OpTransformInfo{qconv_int8.op_type_, qconv_int8.domain_, qconv_int8.version_, true, false}); } } @@ -90,10 +323,10 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, kernel_create_info = nullptr; conv_table_.emplace( OpIdInfo("QLinearConv", kOnnxDomain, api::DataType::UINT8), - OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true}); + OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true, false}); conv_table_.emplace( OpIdInfo("QLinearConv", kMSDomain, api::DataType::UINT8), - OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true}); + OpTransformInfo{qconv_uint8.op_type_, qconv_uint8.domain_, qconv_uint8.version_, true, false}); } } @@ -108,14 +341,61 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; + const auto filter = [](const api::GraphRef&, api::NodeRef& node) { + const auto dilations_opt = node.GetAttributeInts("dilations"); + if (dilations_opt.has_value()) { + const auto& dilations = dilations_opt.value(); + if ((dilations.size() >= 1 && dilations[0] != 1) || + (dilations.size() >= 2 && dilations[1] != 1)) { + return false; + } + } + + const auto group_opt = node.GetAttributeInt("group"); + if (group_opt.has_value() && group_opt.value() != 1) { + return false; + } + + return true; + }; + conv_table_.emplace( OpIdInfo("Conv", kOnnxDomain, api::DataType::FLOAT16), - OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false}); + OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false, false, filter}); conv_table_.emplace( OpIdInfo("FusedConv", kMSDomain, api::DataType::FLOAT16), - OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false}); + OpTransformInfo{nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, nhwc_conv_fp16.version_, false, true, filter}); + } + } + +#ifdef USE_KLEIDIAI + // KleidiAI specific block for NhwcFusedConvolutions + if (mlas_backend_kernel_selector_config.use_kleidiai) { + // F32 Conv -> F32 NHWC Conv + OpKernelRegistryId nhwc_conv_fp32{ + "NhwcFusedConv", kMSDomain, 1, {{"T", {DataTypeImpl::GetTensorType()}}}}; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = cpu_kernel_registry->TryFindKernel( + kCpuExecutionProvider, nhwc_conv_fp32.op_type_, nhwc_conv_fp32.domain_, + nhwc_conv_fp32.version_, nhwc_conv_fp32.type_constraints_, logger, &kernel_create_info); + + if (status.IsOK() && kernel_create_info != nullptr) { + kernel_create_info = nullptr; + + const auto filter = [](const api::GraphRef& graph, api::NodeRef& node) { + return FloatNhwcWrapperFilter(graph, node); + }; + + conv_table_.emplace( + OpIdInfo("Conv", kOnnxDomain, api::DataType::FLOAT), + OpTransformInfo{nhwc_conv_fp32.op_type_, nhwc_conv_fp32.domain_, nhwc_conv_fp32.version_, false, true, filter}); + conv_table_.emplace( + OpIdInfo("FusedConv", kMSDomain, api::DataType::FLOAT), + OpTransformInfo{nhwc_conv_fp32.op_type_, nhwc_conv_fp32.domain_, nhwc_conv_fp32.version_, false, true, filter}); } } +#endif { // fp16 MaxPool -> fp16 nhwc MaxPool @@ -130,7 +410,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, kernel_create_info = nullptr; conv_table_.emplace( OpIdInfo("MaxPool", kOnnxDomain, api::DataType::FLOAT16), - OpTransformInfo{nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, nhwc_maxpool_fp16.version_, false}); + OpTransformInfo{nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, nhwc_maxpool_fp16.version_, false, false}); } } @@ -147,7 +427,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, kernel_create_info = nullptr; conv_table_.emplace( OpIdInfo("AveragePool", kOnnxDomain, api::DataType::FLOAT16), - OpTransformInfo{nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, nhwc_avgpool_fp16.version_, false}); + OpTransformInfo{nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, nhwc_avgpool_fp16.version_, false, false}); } } @@ -164,7 +444,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, kernel_create_info = nullptr; conv_table_.emplace( OpIdInfo("GlobalAveragePool", kOnnxDomain, api::DataType::FLOAT16), - OpTransformInfo{nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, nhwc_gavgpool_fp16.version_, false}); + OpTransformInfo{nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, nhwc_gavgpool_fp16.version_, false, false}); } } }; @@ -214,10 +494,22 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (transform->has_channels_last_attrib_) { node->SetAttributeInt("channels_last", 1); } + size_t rank = shape->dim_size(); std::vector input_perm = ChannelFirstToLastPerm(rank); std::vector output_perm = ChannelLastToFirstPerm(rank); - WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); + const auto inputs = node->Inputs(); + std::vector*> input_perms(inputs.size(), nullptr); + if (!inputs.empty()) { + input_perms[0] = &input_perm; + } + // Some transformed operators require the optional fused Sum (Z) input at index 3 + // to be converted alongside the activation tensor. + if (transform->transpose_fused_sum_input_ && inputs.size() > 3 && !inputs[3].empty()) { + input_perms[3] = &input_perm; + } + + WrapTransposesAroundNode(*api_graph, *node, input_perms, {&output_perm}); // Replace the operator if needed if (node->Domain() != transform->domain_ || diff --git a/onnxruntime/core/optimizer/nhwc_transformer.h b/onnxruntime/core/optimizer/nhwc_transformer.h index c65f851fdab9d..4755ceb316fef 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.h +++ b/onnxruntime/core/optimizer/nhwc_transformer.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/common/common.h" +#include "core/framework/config_options.h" #include "core/framework/execution_provider.h" #include "core/framework/kernel_registry.h" #include "core/optimizer/graph_transformer.h" @@ -54,10 +56,15 @@ class OpIdHash { * @brief Information needed for operator layout transformation */ struct OpTransformInfo { + using FilterFn = std::function; + const std::string optype_; const std::string domain_; const int version_; const bool has_channels_last_attrib_; + const bool transpose_fused_sum_input_; + const FilterFn filter_{nullptr}; }; using OpTransformMap = std::unordered_map; @@ -76,7 +83,7 @@ class NhwcTransformer : public GraphTransformer { private: public: explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry, - const logging::Logger& logger) noexcept; + const logging::Logger& logger, const ConfigOptions& config_options) noexcept; /** * @brief Usually called right after constructor, it shows whether diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index e2db5ac238f52..87ce1b05caae2 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -15,15 +15,53 @@ */ /* Modifications Copyright (c) Microsoft. */ +#include + #include "core/providers/cpu/nn/conv.h" #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/util/math_cpuonly.h" +#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" +#endif + namespace onnxruntime { using ConvPadVector = ConvAttributes::ConvPadVector; +namespace { + +template +void ConvertNHWCToNCHW(const T* src, T* dst, + int64_t n, int64_t c, int64_t h, int64_t w, + concurrency::ThreadPool* thread_pool) { + const size_t n_count = narrow(n); + const size_t c_count = narrow(c); + const size_t hw = narrow(SafeInt(h) * w); + for (size_t n_idx = 0; n_idx < n_count; ++n_idx) { + const size_t n_src_offset = SafeInt(SafeInt(n_idx) * hw) * c_count; + const size_t n_dst_offset = SafeInt(SafeInt(n_idx) * c_count) * hw; + MlasTranspose(src + n_src_offset, dst + n_dst_offset, hw, c_count, thread_pool); + } +} + +template +void ConvertNCHWToNHWC(const T* src, T* dst, + int64_t n, int64_t c, int64_t h, int64_t w, + concurrency::ThreadPool* thread_pool) { + const size_t n_count = narrow(n); + const size_t c_count = narrow(c); + const size_t hw = narrow(SafeInt(h) * w); + for (size_t n_idx = 0; n_idx < n_count; ++n_idx) { + const size_t n_src_offset = SafeInt(SafeInt(n_idx) * c_count) * hw; + const size_t n_dst_offset = SafeInt(SafeInt(n_idx) * hw) * c_count; + MlasTranspose(src + n_src_offset, dst + n_dst_offset, c_count, hw, thread_pool); + } +} + +} // namespace + template Status Conv::Compute(OpKernelContext* context) const { const auto* X = context->Input(0); @@ -153,20 +191,77 @@ Status Conv::Compute(OpKernelContext* context) const { return Status::OK(); } +#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) +Status Conv::EnsurePackedChannelsLastFilter(concurrency::ThreadPool* thread_pool, + size_t filter_count_per_group, + size_t input_channels_per_group, + const TensorShapeVector& kernel_shape, + const TensorShapeVector& dilations) const { + if (!can_cache_packed_filter_) { + return Status::OK(); + } + + std::call_once(packed_filter_once_, [&] { + packed_filter_status_ = Status::OK(); + + auto alloc = Info().GetAllocator(OrtMemTypeDefault); + if (alloc == nullptr) { + packed_filter_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to get allocator for cached KleidiAI packed filter."); + return; + } + + packed_filter_group_stride_ = + ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackWSize(filter_count_per_group, + input_channels_per_group, + kernel_shape.data(), + dilations.data()); + if (packed_filter_group_stride_ == 0) { + packed_filter_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to get KleidiAI packed filter size."); + return; + } + + const size_t packed_filter_size = + packed_filter_group_stride_ * onnxruntime::narrow(conv_attrs_.group); + packed_filter_ = IAllocator::MakeUniquePtr(alloc, packed_filter_size, true); + memset(packed_filter_.get(), 0, packed_filter_size); + + ArmKleidiAI::MlasConvSymmetricChannelsLast2DFloatPackW(filter_count_per_group, + input_channels_per_group, + kernel_shape.data(), + dilations.data(), + onnxruntime::narrow(conv_attrs_.group), + constant_filter_tensor_->Data(), + constant_bias_tensor_ ? constant_bias_tensor_->Data() : nullptr, + packed_filter_.get(), + packed_filter_group_stride_, + thread_pool); + }); + + return packed_filter_status_; +} +#endif + Status Conv::Compute(OpKernelContext* context) const { size_t num_inputs = OpKernel::Node().InputDefs().size(); const Tensor* X = context->Input(0); const Tensor* W = context->Input(1); const Tensor* B = num_inputs >= 3 ? context->Input(2) : nullptr; const Tensor* Sum = num_inputs >= 4 ? context->Input(3) : nullptr; + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X->Shape(), W->Shape(), channels_last_)); const int64_t N = X->Shape()[0]; - const int64_t C = X->Shape()[1]; + // If channels_last_ we should get the back dim for channels instead of [1] + const int64_t C = channels_last_ ? X->Shape().GetDims().back() : X->Shape()[1]; const int64_t M = W->Shape()[0]; - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); - // kernel_shape is an optional attribute and has to be inferred from W if not provided TensorShapeVector kernel_shape; ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); + const size_t kernel_rank = kernel_shape.size(); + + if (channels_last_) { + ORT_RETURN_IF_NOT(kernel_rank == 2, "Conv with channels_last layout currently supports 2D kernels."); + } ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { @@ -182,12 +277,14 @@ Status Conv::Compute(OpKernelContext* context) const { } TensorShapeVector Y_dims({N, M}); - TensorShape input_shape = X->Shape().Slice(2); + TensorShape input_shape = channels_last_ ? X->Shape().Slice(1, 3) : X->Shape().Slice(2); ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); + if (channels_last_) { + Y_dims = {Y_dims[0], Y_dims[2], Y_dims[3], Y_dims[1]}; + } Tensor* Y = context->Output(0, TensorShape(Y_dims)); - TensorShape output_shape = Y->Shape().Slice(2); + TensorShape output_shape = channels_last_ ? TensorShape(Y_dims).Slice(1, 3) : Y->Shape().Slice(2); - // Bail out early if one of the dimensions is zero. if (Y->Shape().Size() == 0) { return Status::OK(); } @@ -198,23 +295,75 @@ Status Conv::Compute(OpKernelContext* context) const { auto Xdata = X->DataAsSpan(); const auto* Bdata = B != nullptr ? B->Data() : nullptr; auto Ydata = Y->MutableDataAsSpan(); - // Check for the optional Conv/Sum fusion. + concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); + + const bool wants_channels_last = channels_last_; + const bool sum_present = Sum != nullptr; + std::array input_shape_size_t{}; + std::array kernel_shape_size_t{}; + std::array dilations_size_t{}; + std::array pads_size_t{}; + std::array strides_size_t{}; + if (wants_channels_last) { + ORT_RETURN_IF_NOT(input_shape.NumDimensions() == 2, "Nhwc Conv fast-path expects 2D input shape."); + for (size_t i = 0; i < 2; ++i) { + input_shape_size_t[i] = narrow(input_shape[i]); + kernel_shape_size_t[i] = narrow(kernel_shape[i]); + dilations_size_t[i] = narrow(dilations[i]); + strides_size_t[i] = narrow(strides[i]); + pads_size_t[i] = narrow(pads[i]); + pads_size_t[i + 2] = narrow(pads[i + 2]); + } + } + const bool nhwc_fastpath = + wants_channels_last && !sum_present && + MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + kernel_rank, + narrow(N), + narrow(conv_attrs_.group), + input_shape_size_t.data(), + kernel_shape_size_t.data(), + dilations_size_t.data(), + pads_size_t.data(), + strides_size_t.data(), + narrow(M / conv_attrs_.group), + /*Beta*/ 0.0f); + +#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) + if (nhwc_fastpath && can_cache_packed_filter_) { + ORT_RETURN_IF_ERROR(EnsurePackedChannelsLastFilter(thread_pool, + narrow(M / conv_attrs_.group), + narrow(C / conv_attrs_.group), + kernel_shape, + dilations)); + } +#endif + + const bool manual_sum = wants_channels_last && !nhwc_fastpath && sum_present; + MLAS_ACTIVATION pre_sum_activation = activation_; + if (manual_sum) { + pre_sum_activation.ActivationKind = MlasIdentityActivation; + } + + const float* sum_manual_data = nullptr; + float Beta = 0.0f; - if (Sum != nullptr) { + if (sum_present) { const auto& sum_shape = Sum->Shape(); ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match"); - // If the output was not allocated inplace with the sum tensor, then copy here. - auto sum_data = Sum->DataAsSpan(); - if (Ydata.data() != sum_data.data()) { - gsl::copy(sum_data, Ydata); + if (manual_sum) { + sum_manual_data = Sum->Data(); + } else { + auto sum_span = Sum->DataAsSpan(); + if (Ydata.data() != sum_span.data()) { + gsl::copy(sum_span, Ydata); + } + Beta = 1.0f; } - Beta = 1.0f; } - const size_t kernel_rank = kernel_shape.size(); - concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool(); if (kernel_rank >= 1 && kernel_rank <= 3) { - MLAS_CONV_PARAMETERS Parameters; + MLAS_CONV_PARAMETERS Parameters{}; Parameters.BackendKernelSelectorConfig = &mlas_backend_kernel_selector_config_; size_t WorkingBufferSize; @@ -230,22 +379,90 @@ Status Conv::Compute(OpKernelContext* context) const { strides.data(), output_shape.GetDims().data(), narrow(M / conv_attrs_.group), - &activation_, + manual_sum ? &pre_sum_activation : &activation_, &WorkingBufferSize, - Beta, + nhwc_fastpath, + nhwc_fastpath ? 0.0f : Beta, thread_pool); - auto* working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * SafeInt(WorkingBufferSize)) - : nullptr; - BufferUniquePtr working_buffer(working_data, BufferDeleter(std::move(alloc))); +#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) + if (nhwc_fastpath && packed_filter_ != nullptr) { + Parameters.FilterIsPacked = true; + Parameters.PackedFilter = packed_filter_.get(); + Parameters.PackedFilterGroupStride = packed_filter_group_stride_; + } +#endif + + float* working_data = nullptr; + BufferUniquePtr working_buffer; + if (WorkingBufferSize > 0) { + working_data = static_cast(alloc->Alloc(sizeof(float) * SafeInt(WorkingBufferSize))); + working_buffer = BufferUniquePtr(working_data, BufferDeleter(alloc)); + } + + float* output_compute = Ydata.data(); + BufferUniquePtr output_temp; + if (wants_channels_last && !nhwc_fastpath) { + const SafeInt output_compute_size = + SafeInt(Y->Shape()[0]) * SafeInt(M) * + SafeInt(output_shape[0]) * SafeInt(output_shape[1]); + float* temp_output = static_cast(alloc->Alloc(sizeof(float) * output_compute_size)); + output_temp = BufferUniquePtr(temp_output, BufferDeleter(alloc)); + output_compute = temp_output; + } + + const float* input_compute = Xdata.data(); + BufferUniquePtr input_temp; + if (wants_channels_last && !nhwc_fastpath) { + ORT_RETURN_IF_NOT(X->Shape().NumDimensions() == 4, "Nhwc fallback expects 4D input."); + const auto& x_dims = X->Shape().GetDims(); + const int64_t input_n = x_dims[0]; + const int64_t input_h = x_dims[1]; + const int64_t input_w = x_dims[2]; + const int64_t input_c = x_dims[3]; + const SafeInt input_elements = SafeInt(X->Shape().Size()); + float* temp_input = static_cast(alloc->Alloc(sizeof(float) * input_elements)); + input_temp = BufferUniquePtr(temp_input, BufferDeleter(alloc)); + ConvertNHWCToNCHW(X->Data(), temp_input, + input_n, input_c, input_h, input_w, thread_pool); + input_compute = temp_input; + } MlasConv(&Parameters, - Xdata.data(), + input_compute, W->Data(), Bdata, - static_cast(working_buffer.get()), - Ydata.data(), + working_data, + output_compute, thread_pool); + + if (wants_channels_last && !nhwc_fastpath) { + const auto& y_dims = Y->Shape().GetDims(); + ORT_RETURN_IF_NOT(y_dims.size() == 4, "Nhwc fallback expects 4D output."); + if (manual_sum) { + const SafeInt output_elements = SafeInt(Y->Shape().Size()); + float* sum_nchw = static_cast(alloc->Alloc(sizeof(float) * output_elements)); + BufferUniquePtr sum_nchw_buffer(sum_nchw, BufferDeleter(alloc)); + ConvertNHWCToNCHW(sum_manual_data, + sum_nchw, + y_dims[0], y_dims[3], y_dims[1], y_dims[2], thread_pool); + + auto output_span = gsl::make_span(output_compute, static_cast(output_elements)); + auto sum_span = gsl::make_span(sum_nchw, static_cast(output_elements)); + for (size_t i = 0; i < output_span.size(); ++i) { + output_span[i] += sum_span[i]; + } + + const auto activation_rows = narrow(SafeInt(y_dims[0]) * y_dims[3]); + const auto activation_cols = narrow(output_shape.Size()); + MlasActivation(&activation_, output_compute, nullptr, activation_rows, + activation_cols, activation_cols); + } + + ConvertNCHWToNHWC(output_compute, + Ydata.data(), + y_dims[0], y_dims[3], y_dims[1], y_dims[2], thread_pool); + } } else { const int64_t input_image_size = input_shape.Size(); const int64_t output_image_size = output_shape.Size(); @@ -287,7 +504,8 @@ Status Conv::Compute(OpKernelContext* context) const { &mlas_backend_kernel_selector_config_); } - MlasActivation(&activation_, Ydata.data(), Bdata, narrow(M), narrow(output_image_size), narrow(output_image_size)); + MlasActivation(&activation_, Ydata.data(), Bdata, narrow(M), + narrow(output_image_size), narrow(output_image_size)); Xdata = Xdata.subspan(X_offset * conv_attrs_.group); Ydata = Ydata.subspan(Y_offset * conv_attrs_.group); diff --git a/onnxruntime/core/providers/cpu/nn/conv.h b/onnxruntime/core/providers/cpu/nn/conv.h index 8992af9792d13..1cbe417cdbd96 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.h +++ b/onnxruntime/core/providers/cpu/nn/conv.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/framework/op_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" @@ -25,9 +27,23 @@ class Conv : public OpKernel { template <> class Conv : public OpKernel { public: - Conv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { + Conv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info), channels_last_(info.GetKernelDef().OpName() == "NhwcFusedConv") { activation_.ActivationKind = MlasIdentityActivation; SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions()); + +#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) + if (channels_last_) { + const auto& input_defs = info.node().InputDefs(); + const bool has_bias_input = input_defs.size() >= 3 && input_defs[2] != nullptr; + info.TryGetConstantInput(1, &constant_filter_tensor_); + if (has_bias_input) { + info.TryGetConstantInput(2, &constant_bias_tensor_); + } + + can_cache_packed_filter_ = + constant_filter_tensor_ != nullptr && (!has_bias_input || constant_bias_tensor_ != nullptr); + } +#endif } Status Compute(OpKernelContext* context) const override; @@ -38,6 +54,24 @@ class Conv : public OpKernel { MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_; ConvAttributes conv_attrs_; + bool channels_last_{false}; + +#if defined(USE_KLEIDIAI) && defined(__aarch64__) && defined(__linux__) + private: + Status EnsurePackedChannelsLastFilter(concurrency::ThreadPool* thread_pool, + size_t filter_count_per_group, + size_t input_channels_per_group, + const TensorShapeVector& kernel_shape, + const TensorShapeVector& dilations) const; + + const Tensor* constant_filter_tensor_{nullptr}; + const Tensor* constant_bias_tensor_{nullptr}; + bool can_cache_packed_filter_{false}; + mutable std::once_flag packed_filter_once_; + mutable Status packed_filter_status_; + mutable IAllocatorUniquePtr packed_filter_; + mutable size_t packed_filter_group_stride_{0}; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 3c7dc53b1098f..608d30c12b587 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -845,6 +845,7 @@ void Im2col::operator()( template struct Im2col; template struct Im2col; template struct Im2col; +template struct Im2col; template <> void Col2im(const float* data_col, int64_t channels, int64_t height, diff --git a/onnxruntime/test/contrib_ops/fused_conv_test.cc b/onnxruntime/test/contrib_ops/fused_conv_test.cc index 9df222db43501..7bfacb996526f 100644 --- a/onnxruntime/test/contrib_ops/fused_conv_test.cc +++ b/onnxruntime/test/contrib_ops/fused_conv_test.cc @@ -120,6 +120,58 @@ void RunConvOp(const ConvOpAndTestAttributes& attributes, disable_cpu, disable_cuda, disable_webgpu, use_float16, weight_is_initializer); } +#ifdef USE_KLEIDIAI +void TestNhwcFusedConvFloatOp(const ConvOpAndTestAttributes& attributes, + const vector>& inputs, + const vector>& input_shapes, + const std::initializer_list& expected_output, + const vector& expected_output_shape, + bool weight_is_initializer = false) { + auto cpu_ep = DefaultCpuExecutionProvider(); + if (cpu_ep == nullptr) { + return; + } + + OpTester test("NhwcFusedConv", 1, onnxruntime::kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + test.AddAttribute("activation", attributes.activation); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + if (!attributes.pads.empty()) { + test.AddAttribute("pads", attributes.pads); + } else { + test.AddAttribute("auto_pad", attributes.auto_pad); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + if (!attributes.activation_parameters.empty()) { + test.AddAttribute("activation_params", attributes.activation_parameters); + } + + const char* szNames[] = {"X", "W", "B", "Z"}; + test.AddInput(szNames[0], input_shapes[0], inputs[0]); + test.AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer); + if (inputs.size() >= 3) { + test.AddInput(szNames[2], input_shapes[2], inputs[2]); + } + if (inputs.size() >= 4) { + test.AddInput(szNames[3], input_shapes[3], inputs[3]); + } + test.AddOutput("Y", expected_output_shape, expected_output); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cpu_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + TEST(FusedConvTest, Conv2D_HardSigmoid) { ConvOpAndTestAttributes attrs = { "", // auto_pad @@ -235,6 +287,93 @@ TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) { RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, false, true, true); } +#ifdef USE_KLEIDIAI +TEST(FusedConvTest, Cpu_NhwcConv2D_Bias_Z_Relu) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + "Relu" // activation + }; + + vector X = {1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f}; + vector X_shape = {1, 3, 3, 1}; + vector W = {1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f}; + vector W_shape = {2, 1, 2, 2}; + vector Y_shape = {1, 2, 2, 2}; + vector B = {1.0f, -1.0f}; + vector B_shape = {2}; + vector Z = {-1.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 1.0f}; + vector Z_shape = {1, 2, 2, 2}; + auto expected_vals = {12.0f, 11.0f, 17.0f, 15.0f, 25.0f, 23.0f, 29.0f, 28.0f}; + TestNhwcFusedConvFloatOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape); + TestNhwcFusedConvFloatOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, true); +} + +TEST(FusedConvTest, Cpu_NhwcConv2D_Z_Relu_Batch2) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{1, 1}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + "Relu" // activation + }; + + vector X = {1.0f, 2.0f, 3.0f, 4.0f, + 1.0f, 1.0f, 1.0f, 1.0f}; + vector X_shape = {2, 2, 2, 1}; + vector W = {1.0f}; + vector W_shape = {1, 1, 1, 1}; + vector B = {0.0f}; + vector B_shape = {1}; + vector Z = {0.0f, 0.0f, 0.0f, 0.0f, + -2.0f, -3.0f, -4.0f, -5.0f}; + vector Z_shape = {2, 2, 2, 1}; + vector Y_shape = {2, 2, 2, 1}; + auto expected_vals = {1.0f, 2.0f, 3.0f, 4.0f, + 0.0f, 0.0f, 0.0f, 0.0f}; + + TestNhwcFusedConvFloatOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape); + TestNhwcFusedConvFloatOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, true); +} + +TEST(FusedConvTest, Cpu_NhwcConv2D_AutoPadSameUpper) { + ConvOpAndTestAttributes attrs = { + "SAME_UPPER", // auto_pad + vector{1, 1}, // dilations + 1, // group + vector{3, 3}, // kernel_shape + {}, // pads + vector{1, 1}, // strides + "Relu" // activation + }; + + vector X(25, 1.0f); + vector X_shape = {1, 5, 5, 1}; + vector W = {0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f}; + vector W_shape = {1, 1, 3, 3}; + vector Y_shape = {1, 5, 5, 1}; + auto expected_vals = {24.0f, 33.0f, 33.0f, 33.0f, 20.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 27.0f, 36.0f, 36.0f, 36.0f, 21.0f, + 12.0f, 15.0f, 15.0f, 15.0f, 8.0f}; + TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); + TestNhwcFusedConvFloatOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); +} +#endif + #endif } // namespace test diff --git a/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc b/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc index 86ffef6c49dc9..32720e42988f6 100644 --- a/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc +++ b/onnxruntime/test/framework/kernel_type_str_resolver_utils_test.cc @@ -9,18 +9,29 @@ #include "gtest/gtest.h" #include "core/flatbuffers/schema/ort.fbs.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" #include "core/graph/schema_registry.h" +#include "core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "test/test_environment.h" #include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include +#include namespace onnxruntime::test { static Status LoadLayoutTransformationRequiredOpsFromOpSchemas(KernelTypeStrResolver& kernel_type_str_resolver) { - const auto required_op_ids = kernel_type_str_resolver_utils::GetLayoutTransformationRequiredOpIdentifiers(); const auto schema_registry = SchemaRegistryManager{}; - for (const auto& op_id : required_op_ids) { + for (const auto& op_id : kLayoutTransformationPotentiallyAddedOps) { const auto* op_schema = schema_registry.GetSchema(std::string{op_id.op_type}, op_id.since_version, std::string{op_id.domain}); - ORT_RETURN_IF(op_schema == nullptr, "Failed to get op schema."); + ORT_RETURN_IF(op_schema == nullptr, + "Failed to get op schema for domain='", op_id.domain, + "', op_type='", op_id.op_type, + "', since_version=", op_id.since_version, "."); ORT_RETURN_IF_ERROR(kernel_type_str_resolver.RegisterOpSchema(*op_schema)); } return Status::OK(); @@ -49,6 +60,34 @@ TEST(KernelTypeStrResolverUtilsTest, VerifyLayoutTransformationRequiredOpsResolv #endif // !defined(DISABLE_CONTRIB_OPS) } +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) +TEST(KernelTypeStrResolverUtilsTest, ResolveNhwcFusedConvFromLayoutTransformationRequiredOps) { + KernelTypeStrResolver resolver; + ASSERT_STATUS_OK(kernel_type_str_resolver_utils::AddLayoutTransformationRequiredOpsToKernelTypeStrResolver(resolver)); + + Model model("nhwc_fused_conv_layout_transform_resolver_test", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto float_tensor; + auto* tensor_type = float_tensor.mutable_tensor_type(); + tensor_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_type->mutable_shape()->add_dim()->set_dim_value(1); + + auto& x = graph.GetOrCreateNodeArg("x", &float_tensor); + auto& w = graph.GetOrCreateNodeArg("w", &float_tensor); + auto& y = graph.GetOrCreateNodeArg("y", &float_tensor); + + auto& nhwc_fused_conv = graph.AddNode( + "nhwc_fused_conv", "NhwcFusedConv", "test node", {&x, &w}, {&y}, nullptr, kMSDomain); + nhwc_fused_conv.SetSinceVersion(1); + + gsl::span resolved_args; + ASSERT_STATUS_OK(resolver.ResolveKernelTypeStr(nhwc_fused_conv, "T", resolved_args)); + ASSERT_FALSE(resolved_args.empty()); +} + +#endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + // run this test manually to output a hard-coded byte array. // update AddLayoutTransformationRequiredOpsToKernelTypeStrResolver in // onnxruntime/core/framework/kernel_type_str_resolver_utils.cc diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index ec4f8967fd2a3..d711e4c47498d 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -17,6 +17,7 @@ #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" +#include #include "flatbuffers/idl.h" #include "flatbuffers/util.h" @@ -27,6 +28,7 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { + struct OrtModelTestInfo { std::basic_string model_filename; std::string logid; @@ -65,17 +67,20 @@ static void RunOrtModel(const OrtModelTestInfo& test_info) { std::vector model_data; InferenceSessionWrapper session_object{so, GetEnvironment()}; + std::filesystem::path model_path{test_info.model_filename}; + + const auto& model_path_str = model_path.native(); if (test_info.run_use_buffer) { // Load the file into a buffer and use the buffer to create inference session size_t num_bytes = 0; - ASSERT_STATUS_OK(Env::Default().GetFileLength(test_info.model_filename.c_str(), num_bytes)); + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path_str.c_str(), num_bytes)); model_data.resize(num_bytes); - std::ifstream bytes_stream(test_info.model_filename, std::ifstream::in | std::ifstream::binary); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); bytes_stream.read(model_data.data(), num_bytes); bytes_stream.close(); ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(num_bytes))); } else { - ASSERT_STATUS_OK(session_object.Load(test_info.model_filename)); // infer type from filename + ASSERT_STATUS_OK(session_object.Load(model_path_str)); // infer type from filename } ASSERT_STATUS_OK(session_object.Initialize()); @@ -151,7 +156,7 @@ static void CompareGraphAndSessionState(const InferenceSessionWrapper& session_o for (const auto& pair : i1) { auto iter = i2.find(pair.first); - ASSERT_NE(iter, i2.cend()); + ASSERT_NE(iter, i2.cend()) << "Missing initializer " << pair.first; const OrtValue& left = pair.second; const OrtValue& right = iter->second; @@ -218,10 +223,17 @@ static void CompareSessionMetadata(const InferenceSessionWrapper& session_object static void SaveAndCompareModels(const PathString& orig_file, const PathString& ort_file, - TransformerLevel optimization_level = TransformerLevel::Level3) { + TransformerLevel optimization_level = TransformerLevel::Level3, + bool compare_saved_model = true) { + std::filesystem::path orig_path{orig_file}; + std::filesystem::path ort_path{ort_file}; + if (ort_path.has_parent_path()) { + std::filesystem::create_directories(ort_path.parent_path()); + } + SessionOptions so; so.session_logid = "SerializeToOrtFormat"; - so.optimized_model_filepath = ort_file; + so.optimized_model_filepath = ort_path.native(); so.graph_optimization_level = optimization_level; // not strictly necessary - type should be inferred from the filename @@ -229,7 +241,7 @@ static void SaveAndCompareModels(const PathString& orig_file, InferenceSessionWrapper session_object{so, GetEnvironment()}; // create .ort file during Initialize due to values in SessionOptions - ASSERT_STATUS_OK(session_object.Load(orig_file)); + ASSERT_STATUS_OK(session_object.Load(orig_path.native())); ASSERT_STATUS_OK(session_object.Initialize()); SessionOptions so2; @@ -240,9 +252,13 @@ static void SaveAndCompareModels(const PathString& orig_file, // load serialized version InferenceSessionWrapper session_object2{so2, GetEnvironment()}; - ASSERT_STATUS_OK(session_object2.Load(ort_file)); + ASSERT_STATUS_OK(session_object2.Load(ort_path.native())); ASSERT_STATUS_OK(session_object2.Initialize()); + if (!compare_saved_model) { + return; + } + CompareSessionMetadata(session_object, session_object2); CompareGraphAndSessionState(session_object, session_object2); } @@ -368,7 +384,9 @@ void TestOrtModelUpdate(const PathString& onnx_file, // ort_file_v4 is ORT format model using v4 where we used kernel hashes instead of constraints // update v4 model and save as v5. do not run optimizations in order to preserve the model as-is. - SaveAndCompareModels(ort_file_v4, generated_ort_file_v5, TransformerLevel::Default); + // Loading a v4 ORT model updates it as part of deserialization, so the in-memory graph/session state is not + // expected to match a separately reloaded v5 model exactly. Just validate that we can save and reload it. + SaveAndCompareModels(ort_file_v4, generated_ort_file_v5, TransformerLevel::Default, false); // run the original, v4 and v5 models and check the output is the same OrtModelTestInfo test_info; diff --git a/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc b/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc index 74a812062875a..5850afd2f84e8 100644 --- a/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc +++ b/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc @@ -22,6 +22,7 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +#include using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; @@ -36,12 +37,36 @@ using namespace onnxruntime::internal_testing_ep; #define ORT_MODEL_FOLDER ORT_TSTR("testdata/") +namespace { +std::filesystem::path ResolveInternalTestPath(const std::filesystem::path& path) { + if (path.is_absolute() || path.empty()) { + return path; + } + + std::filesystem::path candidate = std::filesystem::current_path() / path; + std::error_code ec; + if (std::filesystem::exists(candidate, ec)) { + return candidate; + } + + static const std::filesystem::path kSourceTestRoot = + std::filesystem::path{ORT_TSTR_ON_MACRO(__FILE__)}.parent_path().parent_path(); + return kSourceTestRoot / path; +} + +std::basic_string ResolveInternalTestPathString(const ORTCHAR_T* path) { + return ResolveInternalTestPath(std::filesystem::path{path}).native(); +} +} // namespace + static Status CreateSession(const SessionOptions& so, std::unique_ptr& session, const ORTCHAR_T* model_path = ORT_MODEL_FOLDER "mnist.onnx", // arbitrary test model bool enable_custom_ep = true, const std::unordered_set* override_supported_ops = nullptr) { session = std::make_unique(so, GetEnvironment()); + std::filesystem::path resolved_model_path = ResolveInternalTestPath(std::filesystem::path{model_path}); + // set supported ops to ops that are ideally found consecutively in the model. // we can say the EP potentially handles them all, but can also test removing handling of one or more ops // at runtime to simulate a lower spec device where not all ops can be handled. this allows us to test @@ -55,7 +80,7 @@ static Status CreateSession(const SessionOptions& so, std::unique_ptr(*supported_ops))); } - ORT_RETURN_IF_ERROR(session->Load(model_path)); + ORT_RETURN_IF_ERROR(session->Load(resolved_model_path.c_str())); ORT_RETURN_IF_ERROR(session->Initialize()); return Status::OK(); } @@ -98,7 +123,9 @@ static void ExecuteMnist(InferenceSessionWrapper& session, bool custom_ep_enable #if !defined(ORT_MINIMAL_BUILD) TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.test_output.ort"; + const auto ort_model_dir = ResolveInternalTestPath(std::filesystem::path{ORT_MODEL_FOLDER}); + const std::basic_string ort_model_path = + (ort_model_dir / ORT_TSTR("mnist.internal_testing_ep.test_output.ort")).native(); // // First load the onnx format model and save as an ORT model. @@ -121,10 +148,10 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { so.optimized_model_filepath.clear(); bool enable_custom_ep = false; - ASSERT_STATUS_OK(CreateSession(so, session2, ort_model_path, enable_custom_ep)); + ASSERT_STATUS_OK(CreateSession(so, session2, ort_model_path.c_str(), enable_custom_ep)); const auto& graph1 = session2->GetGraph(); - // model should have all the original nodes and we should be able to execute with the fallback to CPU EP - ASSERT_EQ(graph1.NumberOfNodes(), num_nodes); + // ensure we can execute with the fallback to CPU EP even if additional nodes are introduced during loading + ASSERT_GE(graph1.NumberOfNodes(), num_nodes); ExecuteMnist(*session2, enable_custom_ep); session2 = nullptr; @@ -133,7 +160,7 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { // for the ORT format model. // enable_custom_ep = true; - ASSERT_STATUS_OK(CreateSession(so, session2, ort_model_path, enable_custom_ep)); + ASSERT_STATUS_OK(CreateSession(so, session2, ort_model_path.c_str(), enable_custom_ep)); const auto& graph2 = session2->GetGraph(); // model should be able to be loaded, and we should compile using custom ep. that will result in one node for the // custom EP (with Conv/Add/Relu/MaxPool), one for a reshape, and one for the fused MatMul+Add. @@ -142,7 +169,7 @@ TEST(InternalTestingEP, TestSaveAndLoadOrtModel) { } TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; + const auto ort_model_path = ResolveInternalTestPathString(ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"); // make sure we can't save a model with compiled ops. input/output model format doesn't matter SessionOptions so; @@ -154,7 +181,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { ASSERT_STATUS_OK(session->RegisterExecutionProvider( std::make_unique(supported_ops))); - ASSERT_STATUS_OK(session->Load(ort_model_path)); + ASSERT_STATUS_OK(session->Load(ort_model_path.c_str())); ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session->Initialize(), "Unable to serialize model as it contains compiled nodes"); } @@ -163,7 +190,7 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { // version of the ONNX operator when matching a static kernel, those are required. #if !defined(DISABLE_CONTRIB_OPS) TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "transform/fusion/conv_relu_opset12.onnx"; + const auto ort_model_path = ResolveInternalTestPathString(ORT_MODEL_FOLDER "transform/fusion/conv_relu_opset12.onnx"); SessionOptions so; InferenceSessionWrapper session(so, GetEnvironment()); @@ -175,7 +202,7 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { ep->EnableStaticKernels(); ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(ep))); - ASSERT_STATUS_OK(session.Load(ort_model_path)); + ASSERT_STATUS_OK(session.Load(ort_model_path.c_str())); ASSERT_STATUS_OK(session.Initialize()); TensorShape input_shape_x{1, 1, 7, 7}; @@ -204,7 +231,8 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { auto run_test = [&](const ORTCHAR_T* model_path) { - SCOPED_TRACE("model path: " + ToUTF8String(model_path)); + auto resolved_model_path = ResolveInternalTestPathString(model_path); + SCOPED_TRACE("model path: " + ToUTF8String(resolved_model_path.c_str())); SessionOptions so; // set this if you want to manually inspect the optimized model @@ -218,7 +246,7 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { ep->EnableStaticKernels(); ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(ep))); - ASSERT_STATUS_OK(session.Load(model_path)); + ASSERT_STATUS_OK(session.Load(resolved_model_path.c_str())); ASSERT_STATUS_OK(session.Initialize()); const auto& graph = session.GetGraph(); @@ -249,13 +277,11 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { }; // the internal NHWC domain supports opset 11 and later - const ORTCHAR_T* onnx_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.onnx"; - run_test(onnx_model_path); + run_test(ORT_MODEL_FOLDER "squeezenet/model_opset11.onnx"); // Note: Using ORT format model with runtime optimizations so that the Conv nodes are preserved in the graph, // not converted into FusedConv nodes. The InternalTestingExecutionProvider handles Conv nodes. - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model_opset11.with_runtime_opt.ort"; - run_test(ort_model_path); + run_test(ORT_MODEL_FOLDER "squeezenet/model_opset11.with_runtime_opt.ort"); } // make sure allocators returned by SessionState::GetAllocator are valid when IExecutionProvider::ReplaceAllocator @@ -283,8 +309,8 @@ TEST(InternalTestingEP, TestReplaceAllocatorDoesntBreakDueToLocalAllocatorStorag ASSERT_STATUS_OK(session.RegisterExecutionProvider(ep)); } - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "squeezenet/model.onnx"; - ASSERT_STATUS_OK(session.Load(ort_model_path)); + const auto ort_model_path = ResolveInternalTestPathString(ORT_MODEL_FOLDER "squeezenet/model.onnx"); + ASSERT_STATUS_OK(session.Load(ort_model_path.c_str())); ASSERT_STATUS_OK(session.Initialize()); // Need to undo the wrapping that happens in Environment::RegisterAllocator to be able to compare the pointers @@ -301,25 +327,25 @@ TEST(InternalTestingEP, TestReplaceAllocatorDoesntBreakDueToLocalAllocatorStorag // test to validate a minimal build TEST(InternalTestingEP, TestLoadOrtModel) { - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; + const auto ort_model_path = ResolveInternalTestPathString(ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"); std::unique_ptr session; bool enable_custom_ep = true; - ASSERT_STATUS_OK(CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep)); + ASSERT_STATUS_OK(CreateSession(SessionOptions{}, session, ort_model_path.c_str(), enable_custom_ep)); ExecuteMnist(*session, enable_custom_ep); } // test that if the custom EP cannot take all nodes due to device limitations // that we fallback to the CPU implementations and can execute the model TEST(InternalTestingEP, TestLoadOrtModelWithReducedOpCoverage) { - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; + const auto ort_model_path = ResolveInternalTestPathString(ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"); const std::unordered_set supported_ops{"Conv", "Add", "Relu" /*, "MaxPool"*/}; std::unique_ptr session; bool enable_custom_ep = true; - ASSERT_STATUS_OK(CreateSession(SessionOptions{}, session, ort_model_path, enable_custom_ep, &supported_ops)); + ASSERT_STATUS_OK(CreateSession(SessionOptions{}, session, ort_model_path.c_str(), enable_custom_ep, &supported_ops)); const auto& graph = session->GetGraph(); // Conv+Add gets fused by level 1 optimizer into single node. The 'Conv'/'Add'/'Relu' nodes should be compiled and @@ -454,7 +480,7 @@ TEST(InternalTestingEP, TestOrtModelWithCompileFailure) { // the layout transformation for this EP is already done at this stage and reverting // can result in more failures. // This is to test the model initialization fails if compile fails. - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"; + const auto ort_model_path = ResolveInternalTestPathString(ORT_MODEL_FOLDER "mnist.internal_testing_ep.ort"); const std::unordered_set& supported_ops{"Conv", "Gemm"}; const std::unordered_set& compile_failure_ops{"Gemm"}; diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index 4b6128f7e19da..9df09728ffa17 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -110,6 +110,7 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) { static_cast(output_channels_per_group), &activation, &WorkingBufferSize, + false, 0.0f, nullptr); @@ -217,6 +218,7 @@ void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) { static_cast(output_channels_per_group), &activation, &WorkingBufferSize, + false, 0.0f, tp); diff --git a/onnxruntime/test/mlas/bench/bench_transcendental.cpp b/onnxruntime/test/mlas/bench/bench_transcendental.cpp index f7e461c29843a..3d42c0f84e6cc 100644 --- a/onnxruntime/test/mlas/bench/bench_transcendental.cpp +++ b/onnxruntime/test/mlas/bench/bench_transcendental.cpp @@ -19,7 +19,9 @@ constexpr float kSiluMaxValue = 20.0f; constexpr float kGeluMinValue = -10.0f; constexpr float kGeluMaxValue = 10.0f; constexpr float kInvSqrt2 = 0.7071067811865475244f; +#if defined(MLAS_TARGET_AMD64) constexpr int64_t kFusedBytesPerElement = 2; +#endif constexpr int64_t kSiluUnfusedBytesPerElement = 5; constexpr int64_t kGeluUnfusedBytesPerElement = 7; diff --git a/onnxruntime/test/mlas/unittest/test_conv2d.h b/onnxruntime/test/mlas/unittest/test_conv2d.h index 37a844fdb4b02..6ac47c69ae0b8 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d.h +++ b/onnxruntime/test/mlas/unittest/test_conv2d.h @@ -4,7 +4,6 @@ #pragma once #include -#include #include "test_util.h" @@ -12,10 +11,6 @@ #include "core/mlas/lib/mlasi.h" #endif -#if defined(USE_KLEIDIAI) && defined(MLAS_ENABLE_TEST_HOOKS) -#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" -#endif - template class MlasConv2DTest : public MlasTestBase { protected: @@ -75,6 +70,7 @@ class MlasConv2DTest : public MlasTestBase { FilterCount, &Activation, &WorkingBufferSize, + false, Beta, threadpool_); @@ -308,168 +304,6 @@ class MlasConv2DTest : public MlasTestBase { MlasConv2DTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} -#if defined(USE_KLEIDIAI) && defined(MLAS_ENABLE_TEST_HOOKS) - void TestKleidiAILhsCacheIgnoresInputContent() { - if (!ArmKleidiAI::UseSME) { - return; - } - - struct ConvGeometry { - const char* name; - size_t input_channels; - size_t input_height; - size_t input_width; - size_t filter_count; - size_t kernel_height; - size_t kernel_width; - size_t padding; - size_t dilation_height; - size_t dilation_width; - size_t stride_height; - size_t stride_width; - }; - - constexpr size_t BatchCount = 1; - constexpr size_t GroupCount = 1; - constexpr ConvGeometry geometries[] = { - {"padded_3x3", 16, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1}, - {"strided_3x3", 24, 13, 9, 16, 3, 3, 1, 1, 1, 2, 2}, - {"dilated_3x3", 8, 15, 13, 20, 3, 3, 2, 2, 2, 1, 1}, - {"wide_kernel", 12, 10, 12, 24, 3, 5, 2, 1, 1, 1, 2}, - }; - - const auto fill_input = [](std::vector& input, size_t seed) { - for (size_t i = 0; i < input.size(); ++i) { - input[i] = static_cast((static_cast((i * (seed + 3)) % 31) - 15) * 0.075f + - static_cast(seed) * 0.125f); - } - }; - - for (const auto& geometry : geometries) { - SCOPED_TRACE(geometry.name); - - const int64_t output_height64 = - ((int64_t(geometry.input_height) + 2 * int64_t(geometry.padding)) - - (int64_t(geometry.dilation_height) * (int64_t(geometry.kernel_height) - 1) + 1)) / - int64_t(geometry.stride_height) + - 1; - const int64_t output_width64 = - ((int64_t(geometry.input_width) + 2 * int64_t(geometry.padding)) - - (int64_t(geometry.dilation_width) * (int64_t(geometry.kernel_width) - 1) + 1)) / - int64_t(geometry.stride_width) + - 1; - - ASSERT_GT(output_height64, 0); - ASSERT_GT(output_width64, 0); - - const size_t output_height = static_cast(output_height64); - const size_t output_width = static_cast(output_width64); - const size_t input_elements = - BatchCount * GroupCount * geometry.input_channels * geometry.input_height * geometry.input_width; - const size_t filter_elements = GroupCount * geometry.filter_count * geometry.input_channels * - geometry.kernel_height * geometry.kernel_width; - const size_t bias_elements = GroupCount * geometry.filter_count; - const size_t output_elements = BatchCount * GroupCount * geometry.filter_count * output_height * output_width; - - std::vector input_a(input_elements); - std::vector input_a_copy(input_elements); - std::vector input_b(input_elements); - std::vector filter(filter_elements); - std::vector bias(bias_elements); - std::vector output(output_elements); - std::vector output_reference(output_elements); - - fill_input(input_a, 1); - input_a_copy = input_a; - fill_input(input_b, 2); - - for (size_t i = 0; i < filter_elements; ++i) { - filter[i] = static_cast((static_cast((i * 5) % 23) - 11) * 0.05f); - } - - for (size_t i = 0; i < bias_elements; ++i) { - bias[i] = static_cast((static_cast(i % 7) - 3) * 0.02f); - } - - const auto verify_conv = [&](const std::vector& input, const char* label) { - SCOPED_TRACE(label); - - MlasConv2D(BatchCount, - GroupCount, - geometry.input_channels, - geometry.input_height, - geometry.input_width, - geometry.filter_count, - geometry.kernel_height, - geometry.kernel_width, - geometry.padding, - geometry.padding, - geometry.padding, - geometry.padding, - geometry.dilation_height, - geometry.dilation_width, - geometry.stride_height, - geometry.stride_width, - output_height, - output_width, - input.data(), - filter.data(), - bias.data(), - output.data()); - - ReferenceConv2D(BatchCount, - GroupCount, - geometry.input_channels, - geometry.input_height, - geometry.input_width, - geometry.filter_count, - geometry.kernel_height, - geometry.kernel_width, - geometry.padding, - geometry.padding, - geometry.dilation_height, - geometry.dilation_width, - geometry.stride_height, - geometry.stride_width, - output_height, - output_width, - input.data(), - filter.data(), - bias.data(), - output_reference.data()); - - for (size_t i = 0; i < output_elements; ++i) { - ASSERT_TRUE(CloseEnough(output[i], output_reference[i])) - << "Mismatch at output index " << i - << ": actual=" << output[i] - << ", expected=" << output_reference[i]; - } - }; - - ArmKleidiAI::MlasConvClearLhsCacheForTest(); - EXPECT_EQ(ArmKleidiAI::MlasConvLhsCacheEntryCountForTest(), size_t{0}); - - verify_conv(input_a, "initial_input"); - EXPECT_EQ(ArmKleidiAI::MlasConvLhsCacheEntryCountForTest(), size_t{1}); - - verify_conv(input_a_copy, "same_content_different_buffer"); - EXPECT_EQ(ArmKleidiAI::MlasConvLhsCacheEntryCountForTest(), size_t{1}) - << "same geometry with a different input buffer should reuse the LHS indirection cache"; - - verify_conv(input_b, "different_content_different_buffer"); - EXPECT_EQ(ArmKleidiAI::MlasConvLhsCacheEntryCountForTest(), size_t{1}) - << "same geometry with different input content should reuse the LHS indirection cache"; - - fill_input(input_a, 3); - verify_conv(input_a, "different_content_same_buffer"); - EXPECT_EQ(ArmKleidiAI::MlasConvLhsCacheEntryCountForTest(), size_t{1}) - << "same geometry with mutated input content should not add cache entries"; - } - - ArmKleidiAI::MlasConvClearLhsCacheForTest(); - } -#endif - #if defined(MLAS_TARGET_AMD64) void TestMobileClipAvx512DispatchSelection(size_t GroupCount, size_t InputHeight, @@ -533,6 +367,7 @@ class MlasConv2DTest : public MlasTestBase { FilterCount, &Activation, &WorkingBufferSize, + false, 0.0f, threadpool_); @@ -845,9 +680,6 @@ class MlasConv2DTest : public MlasTestBase { } void ExecuteShort(void) override { -#if defined(USE_KLEIDIAI) && defined(MLAS_ENABLE_TEST_HOOKS) - TestKleidiAILhsCacheIgnoresInputContent(); -#endif #if defined(MLAS_TARGET_AMD64) TestMobileClipAvx512DispatchSelection(64, 64, 64); TestMobileClipAvx512DispatchSelection(128, 32, 32); diff --git a/onnxruntime/test/optimizer/conv_add_act_test.cc b/onnxruntime/test/optimizer/conv_add_act_test.cc index bb409a2bbb82e..03ca950050d64 100644 --- a/onnxruntime/test/optimizer/conv_add_act_test.cc +++ b/onnxruntime/test/optimizer/conv_add_act_test.cc @@ -31,8 +31,9 @@ void TestConvPath(const std::vector& input_shape, const std::vector disabled_optimizers = {"NchwcTransformer"}; + InlinedHashSet disabled_optimizers = {"NchwcTransformer", "NhwcTransformer"}; TransformerTester(build_test_case, check_graph, TransformerLevel::Default, diff --git a/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc b/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc index de973679c8f80..b1997701e132f 100644 --- a/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc +++ b/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc @@ -363,6 +363,9 @@ TEST(TransformerTest, FuseFp16InitializersWithFp32Node_with_graph_optimizations_ // Create session and check graph before / after initiation InferenceSessionWrapper session{so, GetEnvironment()}; + // Keep this test focused on FuseInitializersTransformer/NCHWC behavior. NhwcTransformer is + // hardware/kernel dependent and can otherwise change the post-init node counts this test asserts on. + ASSERT_STATUS_OK(session.FilterEnabledOptimizers({"NhwcTransformer"})); ASSERT_STATUS_OK(session.Load(model_uri)); test_graph_structure_at_init(session.GetGraph()); ASSERT_STATUS_OK(session.Initialize()); @@ -402,6 +405,9 @@ TEST(TransformerTest, FuseFp16InitializersWithFp32Node_with_graph_optimizations_ // Create session and check graph before / after initiation InferenceSessionWrapper session{so, GetEnvironment()}; + // Keep this test focused on FuseInitializersTransformer/NCHWC behavior. NhwcTransformer is + // hardware/kernel dependent and can otherwise change the post-init node counts this test asserts on. + ASSERT_STATUS_OK(session.FilterEnabledOptimizers({"NhwcTransformer"})); ASSERT_STATUS_OK(session.Load(model_uri)); test_graph_structure_at_init(session.GetGraph()); ASSERT_STATUS_OK(session.Initialize()); @@ -443,6 +449,9 @@ TEST(TransformerTest, FuseFp16InitializersWithFp32Node_with_graph_optimizations_ // Create session and check graph before / after initiation InferenceSessionWrapper session{so, GetEnvironment()}; + // Keep this test focused on FuseInitializersTransformer/NCHWC behavior. NhwcTransformer is + // hardware/kernel dependent and can otherwise change the post-init node counts this test asserts on. + ASSERT_STATUS_OK(session.FilterEnabledOptimizers({"NhwcTransformer"})); ASSERT_STATUS_OK(session.Load(model_uri)); test_graph_structure_at_init(session.GetGraph()); ASSERT_STATUS_OK(session.Initialize()); diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index cd210f7bc70ba..a07f173950051 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -202,6 +202,7 @@ void NchwcOptimizerTester(const std::function& bu session_options.session_logid = "NchwcOptimizerTests"; InferenceSessionWrapper session{session_options, GetEnvironment()}; ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); + ASSERT_STATUS_OK(session.FilterEnabledOptimizers({"NhwcTransformer"})); ASSERT_STATUS_OK(session.Initialize()); RunOptions run_options; @@ -643,6 +644,36 @@ TEST(NchwcOptimizerTests, FusedConvAddFusion) { test_case(true, true, 1); } +TEST(NchwcOptimizerTests, PreExistingFusedConvWithNchwcSumInput) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 32, 28, 28}); + auto* sum_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + auto& sum_node = helper.AddConvNode(input_arg, sum_arg, {32, 32, 3, 3}); + sum_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + + auto* weights_arg = helper.MakeInitializer({32, 32, 3, 3}); + auto* bias_arg = helper.MakeInitializer({32}); + auto& fused_conv_node = + helper.AddNode("FusedConv", {input_arg, weights_arg, bias_arg, sum_arg}, {output_arg}, kMSDomain); + fused_conv_node.AddAttribute("activation", "Relu"); + fused_conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + fused_conv_node.AddAttribute("strides", std::vector{1, 1}); + fused_conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + }; + + auto check_nchwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count["com.microsoft.FusedConv"], 0); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); +} + TEST(NchwcOptimizerTests, ConvBinary) { auto test_case = [&](const std::string& op_type) { auto build_test_case = [&](NchwcTestHelper& helper) { diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index 21ea7af4e7389..b73929efab8a6 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -1,14 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include +#include #include #include "gtest/gtest.h" -#include "test/unittest_util/graph_transform_test_builder.h" +#include "core/framework/kernel_registry.h" #include "core/mlas/inc/mlas.h" +#include "core/providers/common.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) +#include "core/mlas/lib/mlasi.h" +#endif #include "core/graph/graph.h" #include "test/common/dnnl_op_test_utils.h" +#include "test/util/include/test_environment.h" namespace onnxruntime { namespace test { @@ -33,6 +42,180 @@ NodeArg* NhwcMakeInitializer(ModelTestBuilder& builder, const std::vector::max_value); } +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) +static bool HasFloatNhwcFusedConvKernel() { + auto* cpu_ep = TestCPUExecutionProvider(); + auto kernel_registry = cpu_ep->GetKernelRegistry(); + if (!kernel_registry) { + return false; + } + + KernelRegistry::TypeConstraintMap type_constraints{ + {"T", DataTypeImpl::GetTensorType()}, + }; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = kernel_registry->TryFindKernel( + kCpuExecutionProvider, + "NhwcFusedConv", + kMSDomain, + 1, + type_constraints, + DefaultLoggingManager().DefaultLogger(), + &kernel_create_info); + + return status.IsOK() && kernel_create_info != nullptr; +} +#endif + +static bool HasFloatNhwcNoTransposeSupport(const std::vector& input_shape, + const std::vector& weight_shape, + std::vector pads = {}, + std::vector strides = {}, + std::vector dilations = {}, + int64_t group = 1, + bool has_sum_input = false, + std::string_view auto_pad = "NOTSET") { +#if defined(USE_KLEIDIAI) && defined(MLAS_TARGET_ARM64) + if (!HasFloatNhwcFusedConvKernel() || !MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + return false; + } + + if (has_sum_input || group != 1 || input_shape.size() != 4 || weight_shape.size() != 4) { + return false; + } + + std::array input_spatial_shape{ + narrow(input_shape[2]), + narrow(input_shape[3]), + }; + std::array kernel_spatial_shape{ + narrow(weight_shape[2]), + narrow(weight_shape[3]), + }; + std::array strides_size_t{1, 1}; + std::array dilations_size_t{1, 1}; + std::array pads_size_t{}; + + if (!strides.empty()) { + if (strides.size() != strides_size_t.size()) { + return false; + } + + for (size_t i = 0; i < strides_size_t.size(); ++i) { + if (strides[i] < 0) { + return false; + } + + strides_size_t[i] = narrow(strides[i]); + } + } + + if (!dilations.empty()) { + if (dilations.size() != dilations_size_t.size()) { + return false; + } + + for (size_t i = 0; i < dilations_size_t.size(); ++i) { + if (dilations[i] < 0) { + return false; + } + + dilations_size_t[i] = narrow(dilations[i]); + } + } + + const AutoPadType auto_pad_type = StringToAutoPadType(std::string(auto_pad)); + if (auto_pad_type == AutoPadType::NOTSET) { + if (pads.empty()) { + pads_size_t.fill(0); + } else { + if (pads.size() != pads_size_t.size()) { + return false; + } + + for (size_t i = 0; i < pads_size_t.size(); ++i) { + if (pads[i] < 0) { + return false; + } + + pads_size_t[i] = narrow(pads[i]); + } + } + } else { + for (size_t i = 0; i < 2; ++i) { + int64_t pad_head = 0; + int64_t pad_tail = 0; + int64_t out_dim = 0; + const auto status = ComputePadAndOutputShape( + input_shape[2 + i], + narrow(strides_size_t[i]), + weight_shape[2 + i], + narrow(dilations_size_t[i]), + auto_pad_type, + pad_head, + pad_tail, + out_dim, + /*force_symmetric_auto_padding*/ false); + if (!status.IsOK() || pad_head < 0 || pad_tail < 0 || out_dim < 0) { + return false; + } + + pads_size_t[i] = narrow(pad_head); + pads_size_t[i + 2] = narrow(pad_tail); + } + } + + return MlasConvSupportsSymmetricChannelsLast2DFloatKernel( + /*Dimensions*/ 2, + narrow(input_shape[0]), + /*GroupCount*/ 1, + input_spatial_shape.data(), + kernel_spatial_shape.data(), + dilations_size_t.data(), + pads_size_t.data(), + strides_size_t.data(), + narrow(weight_shape[0]), + /*Beta*/ 0.0f); +#else + ORT_UNUSED_PARAMETER(input_shape); + ORT_UNUSED_PARAMETER(weight_shape); + ORT_UNUSED_PARAMETER(pads); + ORT_UNUSED_PARAMETER(strides); + ORT_UNUSED_PARAMETER(dilations); + ORT_UNUSED_PARAMETER(group); + ORT_UNUSED_PARAMETER(has_sum_input); + ORT_UNUSED_PARAMETER(auto_pad); + return false; +#endif +} + +#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED +static bool HasFp16NhwcFusedConvKernel() { + auto* cpu_ep = TestCPUExecutionProvider(); + auto kernel_registry = cpu_ep->GetKernelRegistry(); + if (!kernel_registry) { + return false; + } + + KernelRegistry::TypeConstraintMap type_constraints{ + {"T", DataTypeImpl::GetTensorType()}, + }; + + const KernelCreateInfo* kernel_create_info{}; + const auto status = kernel_registry->TryFindKernel( + kCpuExecutionProvider, + "NhwcFusedConv", + kMSDomain, + 1, + type_constraints, + DefaultLoggingManager().DefaultLogger(), + &kernel_create_info); + + return status.IsOK() && kernel_create_info != nullptr; +} +#endif + #ifndef DISABLE_CONTRIB_OPS TEST(NhwcTransformerTests, Conv) { @@ -224,6 +407,196 @@ TEST(NhwcTransformerTests, ConvGlobalAveragePool) { TransformerLevel::Level3); } +TEST(NhwcTransformerTests, ConvDepthwiseFloat_SkipNhwc) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({8, 1, 3, 3}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& conv_node = builder.AddConvNode(input_arg, weight_arg, output_arg); + conv_node.AddAttribute("group", static_cast(8)); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport({1, 8, 7, 7}, {8, 1, 3, 3}, {}, {}, {}, 8); + const int expected_nhwc_fused_conv = expect_nhwc ? 1 : 0; + const int expected_transposes = expect_nhwc ? 2 : 0; + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], expected_nhwc_fused_conv); + EXPECT_EQ(op_to_count["Transpose"], expected_transposes); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + +TEST(NhwcTransformerTests, ConvFloat_UsesNhwcOnlyWithKleidi) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({16, 8, 3, 3}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& conv_node = builder.AddConvNode(input_arg, weight_arg, output_arg); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport({1, 8, 7, 7}, {16, 8, 3, 3}, {1, 1, 1, 1}); + const int expected_nhwc_fused_conv = expect_nhwc ? 1 : 0; + const int expected_transposes = expect_nhwc ? 2 : 0; + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], expected_nhwc_fused_conv); + EXPECT_EQ(op_to_count["Transpose"], expected_transposes); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + +TEST(NhwcTransformerTests, ConvFloat_RespectsKleidiDisableConfig) { + if (!HasFloatNhwcNoTransposeSupport({1, 8, 7, 7}, {16, 8, 3, 3}, {1, 1, 1, 1})) { + GTEST_SKIP() << "Float NHWC KleidiAI path is not available on this configuration."; + } + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({16, 8, 3, 3}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& conv_node = builder.AddConvNode(input_arg, weight_arg, output_arg); + conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], 0); + EXPECT_EQ(op_to_count["Transpose"], 0); + }; + + auto add_session_options = [](SessionOptions& session_options) { + const auto status = session_options.config_options.AddConfigEntry(kOrtSessionOptionsMlasDisableKleidiAi, "1"); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6, + nullptr, + add_session_options); +} + +TEST(NhwcTransformerTests, FusedConvFloat_UsesNhwcOnlyWithKleidi) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({16, 8, 3, 3}, -1.0f, 1.0f); + auto* bias_arg = builder.MakeInitializer({16}, -0.5f, 0.5f); + auto* output_arg = builder.MakeOutput(); + + Node& fused_conv_node = builder.AddNode("FusedConv", {input_arg, weight_arg, bias_arg}, {output_arg}, kMSDomain); + fused_conv_node.AddAttribute("activation", "Relu"); + fused_conv_node.AddAttribute("pads", std::vector{1, 1, 1, 1}); + fused_conv_node.AddAttribute("strides", std::vector{1, 1}); + fused_conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport( + {1, 8, 7, 7}, {16, 8, 3, 3}, {1, 1, 1, 1}, {1, 1}); + const int expected_nhwc_fused_conv = expect_nhwc ? 1 : 0; + const int expected_transposes = expect_nhwc ? 2 : 0; + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], expected_nhwc_fused_conv); + EXPECT_EQ(op_to_count["Transpose"], expected_transposes); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + +TEST(NhwcTransformerTests, FusedConvWithSumFloat_SkipNhwc) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 7, 7}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({16, 8, 3, 3}, -1.0f, 1.0f); + auto* bias_arg = builder.MakeInitializer({16}, -0.5f, 0.5f); + auto* sum_arg = builder.MakeInput({1, 16, 5, 5}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& fused_conv_node = builder.AddNode("FusedConv", {input_arg, weight_arg, bias_arg, sum_arg}, {output_arg}, kMSDomain); + fused_conv_node.AddAttribute("activation", "Relu"); + fused_conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + fused_conv_node.AddAttribute("strides", std::vector{1, 1}); + fused_conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["com.microsoft.nchwc.Conv"], 0); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport( + {1, 8, 7, 7}, {16, 8, 3, 3}, {0, 0, 0, 0}, {1, 1}, {}, 1, true); + const int expected_nhwc_fused_conv = expect_nhwc ? 1 : 0; + const int expected_transposes = expect_nhwc ? 3 : 0; + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], expected_nhwc_fused_conv); + EXPECT_EQ(op_to_count["Transpose"], expected_transposes); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + +TEST(NhwcTransformerTests, ConvAutoPadFloat_SkipNhwc) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput({1, 8, 6, 6}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({16, 8, 3, 3}, -1.0f, 1.0f); + auto* output_arg = builder.MakeOutput(); + + Node& conv_node = builder.AddConvNode(input_arg, weight_arg, output_arg); + conv_node.AddAttribute("auto_pad", "SAME_UPPER"); + conv_node.AddAttribute("strides", std::vector{2, 2}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const bool expect_nhwc = HasFloatNhwcNoTransposeSupport( + {1, 8, 6, 6}, {16, 8, 3, 3}, {}, {2, 2}, {}, 1, false, "SAME_UPPER"); + const int expected_nhwc_fused_conv = expect_nhwc ? 1 : 0; + const int expected_transposes = expect_nhwc ? 2 : 0; + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], expected_nhwc_fused_conv); + EXPECT_EQ(op_to_count["Transpose"], expected_transposes); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3, + /*opset_version*/ 12, + /*per_sample_tolerance*/ 1e-6, + /*relative_per_sample_tolerance*/ 1e-6); +} + TEST(NhwcTransformerTests, ConvAveragePool) { DNNL_GTEST_SKIP(); @@ -598,6 +971,35 @@ TEST_F(NhwcTransformerTestsFp16, ConvFp16) { test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}); } +TEST_F(NhwcTransformerTestsFp16, FusedConvWithSumFp16) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = MakeInputARangeFP16(builder, {1, 8, 7, 7}, MLFloat16(-1.0f), MLFloat16(1.0f)); + auto* weight_arg = MakeInitializerARangeFP16(builder, {16, 8, 3, 3}, MLFloat16(-1.0f), MLFloat16(1.0f)); + auto* bias_arg = MakeInitializerARangeFP16(builder, {16}, MLFloat16(-0.5f), MLFloat16(0.5f)); + auto* sum_arg = MakeInputARangeFP16(builder, {1, 16, 5, 5}, MLFloat16(-1.0f), MLFloat16(1.0f)); + auto* output_arg = builder.MakeOutput(); + + Node& fused_conv_node = builder.AddNode("FusedConv", {input_arg, weight_arg, bias_arg, sum_arg}, {output_arg}, kMSDomain); + fused_conv_node.AddAttribute("activation", "Relu"); + fused_conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + fused_conv_node.AddAttribute("strides", std::vector{1, 1}); + fused_conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + }; + + auto check_nhwc_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const int expected_nhwc_fused_conv = HasFp16NhwcFusedConvKernel() ? 1 : 0; + const int expected_transposes = HasFp16NhwcFusedConvKernel() ? 3 : 0; + EXPECT_EQ(op_to_count["com.microsoft.NhwcFusedConv"], expected_nhwc_fused_conv); + EXPECT_EQ(op_to_count["Transpose"], expected_transposes); + }; + + TransformerTester(build_test_case, + check_nhwc_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); +} + TEST_F(NhwcTransformerTestsFp16, ConvMaxPoolFp16) { auto test_case = [&](const std::vector& input_shape, const std::vector& weights_shape) { auto build_test_case = [&](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index b57118bbb4ba3..080c382db5d93 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4604,6 +4604,66 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { } } +// Verifies that layout transformation preserves an existing NHWC-native +// NhwcFusedConv as-is instead of retargeting it or inserting Transpose nodes. +TEST(TransposeOptimizerTests, LayoutTransformDoesNotRetargetNhwcFusedConv) { + std::unordered_map domain_to_version{{kOnnxDomain, 13}, {kMSDomain, 1}}; + Model model("LayoutTransformDoesNotRetargetNhwcFusedConv", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + + auto* input_arg = builder.MakeInput({1, 7, 7, 8}, -1.0f, 1.0f); + auto* weight_arg = builder.MakeInitializer({16, 8, 3, 3}, -1.0f, 1.0f); + auto* bias_arg = builder.MakeInitializer({16}, -0.5f, 0.5f); + auto* output_arg = builder.MakeOutput(); + + auto& nhwc_fused_conv = builder.AddNode("NhwcFusedConv", {input_arg, weight_arg, bias_arg}, {output_arg}, kMSDomain); + nhwc_fused_conv.AddAttribute("activation", "Relu"); + nhwc_fused_conv.AddAttribute("pads", std::vector{1, 1, 1, 1}); + nhwc_fused_conv.AddAttribute("strides", std::vector{1, 1}); + nhwc_fused_conv.AddAttribute("kernel_shape", std::vector{3, 3}); + + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + SessionOptions so; + using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider; + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& optimized_graph = session.GetGraph(); + const auto op_to_count = CountOpsInGraph(optimized_graph); + const auto get_op_count = [&op_to_count](std::string_view op_type) { + const auto it = op_to_count.find(std::string{op_type}); + return it == op_to_count.end() ? 0 : it->second; + }; + + EXPECT_EQ(get_op_count("com.microsoft.NhwcFusedConv"), 1); + EXPECT_EQ(get_op_count("Transpose"), 0); + + int nhwc_fused_conv_count = 0; + for (const auto& node : optimized_graph.Nodes()) { + if (node.OpType() == "NhwcFusedConv") { + ++nhwc_fused_conv_count; + EXPECT_EQ(node.Domain(), kMSDomain); + EXPECT_EQ(node.GetExecutionProviderType(), internal_testing_ep::kInternalTestingExecutionProvider); + } + } + + EXPECT_EQ(nhwc_fused_conv_count, 1); +} + TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.qdq.onnx"); From b36d8dbd503ed1b1e56eb8819080b29821e18fe1 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 14 May 2026 15:25:55 -0700 Subject: [PATCH 04/11] enable dynamic max_k_step in FA for nvidia (#28511) apply https://github.com/microsoft/onnxruntime/pull/27780 for nvidia. I see 10-15% performance improvement for prefill on rtx5060ti --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 6458bbea0d2e4..27fa56e333874 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -92,11 +92,12 @@ class FlashAttentionProgram final : public Program { q_BNSH_(q_BNSH), use_seqlen_k_(use_seqlen_k), has_head_sink_(has_head_sink) { - if (is_apple) { - // On Apple GPUs, use an optimized loop-based path with dynamic max_k_step. - // Compute max_k_step from shared memory budget: k_tile + v_tile = 2 * element_size * head_size * max_k_step + if (is_apple || is_nvidia) { + // On Apple and NVIDIA, use an optimized loop-based path with dynamic max_k_step. + // Compute max_k_step from workgroup shared memory budget: k_tile + v_tile = 2 * element_size * head_size * max_k_step const int element_size = is_fp16 ? 2 : 4; - int max_k_from_shm = 16384 / (2 * element_size * qkv_head_size); + constexpr int kMinWorkgroupStorageBudgetBytes = 16384; + int max_k_from_shm = kMinWorkgroupStorageBudgetBytes / (2 * element_size * qkv_head_size); if (max_k_from_shm >= 32) { max_k_step_ = 32; } else { From 18ae6da8a4e61be8a3d15079e32900acbb96f55d Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 May 2026 15:48:28 -0700 Subject: [PATCH 05/11] Validate per-column weight_scale/weight_zero_point shape in CPU QAttention; harden integer arithmetic in QAttention and AttentionBase (#28480) ### Description The CPU `QAttention` kernel did not validate the shape of per-column `weight_scale` and `weight_zero_point` inputs against the expected `3 * hidden_size`. A model could supply a per-column tensor smaller than expected, causing the GEMM dequantization loop to read past the end of the buffer (offsets up to `~3 * hidden_size - head_size`). This PR adds the missing shape validation and, while in the area, hardens integer arithmetic across `QAttention` and `AttentionBase` against malformed shape attributes / dimensions. ### Changes **`onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc`** - Validate per-column `weight_scale` and `weight_zero_point` are 1-D with size `3 * hidden_size`; reject otherwise. - Use `narrow` / `narrow` when converting `int64_t` shape dims, so out-of-range values throw rather than silently truncating. - Use `SafeInt` for multiplications whose operands are not provably bounded by upstream validation (`loop_len`, `input_offset`, `qkv_offset`, the gemm allocation, and `packed_weights_data_size` in `PrePack`). - Refactor the gemm allocation and Q/K/V pointer arithmetic to share a single `SafeInt`-validated `batch_size * sequence_length * hidden_size` value. - Drop a few redundant `static_cast`s in the per-iteration index math. - Remove the `hidden_size_x3 % 3 == 0` and `hidden_size % num_heads_ == 0` checks here; they are now enforced uniformly in `AttentionBase::CheckInputs` with clearer error messages. **`onnxruntime/contrib_ops/cpu/bert/attention_base.h`** - Replace `static_cast` with `narrow` for `num_heads_`, `rotary_embedding_`, the `parameters` struct outputs, and `GetPresent`'s `past_sequence_length`. Without this, any `int64_t` value outside the `int` range (e.g., a `num_heads` attribute of `2^31`, or a `past` sequence length of `2^31`) silently truncates to an unrelated `int` value that is then propagated to downstream kernels and used in arithmetic, enabling division by zero, sign flips, or out-of-bounds indexing. - Drop the `static_cast` from the `past_dims[2]` / `past_dims[4]` shape comparisons so the equality check uses the full `int64_t` value; previously a `past` tensor whose dim's low 32 bits happened to match `num_heads_` (or `k_hidden_size / num_heads_`) would pass validation despite having the wrong physical shape. - In `CheckInputs`, when `require_same_hidden_size_` is true, reject `bias_dims[0]` not a multiple of 3 with a clear error (Q, K, V are packed and share a hidden size). - In `CheckInputs`, when `qkv_hidden_sizes` is not set, also reject `q_hidden_size % num_heads_ != 0` (mirrors the existing check on the `qkv_hidden_sizes` path). **`onnxruntime/test/contrib_ops/quantize_attention_op_test.cc`** - 4 regression tests for the per-column shape validation: - `InvalidWeightScalePerColumnShape` - `InvalidWeightScalePerColumnRank` - `InvalidWeightZeroPointPerColumnShape` - `InvalidWeightZeroPointPerColumnRank` - 3 regression tests for the divisibility / narrowing checks (sharing a `RunQAttentionExpectFailure` helper): - `InvalidBiasDimNotMultipleOfThree` - `InvalidHiddenSizeNotDivisibleByNumHeads` - `InvalidNumHeadsOverflowsInt` (`num_heads = INT_MAX + 1` triggers `gsl::narrowing_error`) ### Testing All `QAttention*` / `AttentionTest*` / `MultiHeadAttention*` tests (97/97) pass locally on CPU Release build. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../contrib_ops/cpu/bert/attention_base.h | 47 ++++-- .../cpu/quantization/attention_quant.cc | 75 ++++++--- .../contrib_ops/quantize_attention_op_test.cc | 159 ++++++++++++++++++ 3 files changed, 239 insertions(+), 42 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index fad8d9275c555..b7590a2d6a547 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -6,6 +6,7 @@ #include #include #include "core/common/common.h" +#include "core/common/narrow.h" #include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h" #ifndef SHARED_PROVIDER #include "core/framework/op_kernel.h" @@ -57,11 +58,11 @@ class AttentionBase { AttentionBase(const KernelInfoType& info, bool require_same_hidden_size) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - num_heads_ = static_cast(num_heads); + num_heads_ = narrow(num_heads); is_unidirectional_ = info.template GetAttrOrDefault("unidirectional", 0) == 1; do_rotary_ = info.template GetAttrOrDefault("do_rotary", 0) == 1; - rotary_embedding_ = static_cast(info.template GetAttrOrDefault("rotary_embedding_dim", 0)); + rotary_embedding_ = narrow(info.template GetAttrOrDefault("rotary_embedding_dim", 0)); mask_filter_value_ = info.template GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.template GetAttrOrDefault("scale", 0.0f); if (!info.template GetAttrs("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) { @@ -222,6 +223,14 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape, "Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'"); } + // Q, K, V are packed along bias_dims[0]. When their hidden sizes are required to be equal, + // bias_dims[0] == 3 * hidden_size must be a multiple of 3. + if (require_same_hidden_size_ && bias_dims[0] % 3 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'bias' dimension 0 (", bias_dims[0], + ") must be a multiple of 3 (Q, K, V are packed and have equal hidden sizes)."); + } + int64_t q_hidden_size = bias_dims[0] / static_cast(3); int64_t k_hidden_size = q_hidden_size; int64_t v_hidden_size = k_hidden_size; @@ -241,6 +250,10 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape, q_hidden_size = qkv_hidden_sizes_[0]; k_hidden_size = qkv_hidden_sizes_[1]; v_hidden_size = qkv_hidden_sizes_[2]; + } else if (q_hidden_size % num_heads_ != 0) { + // Match the error message produced by the qkv_hidden_sizes path above. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "hidden_size should be divisible by num_heads:", q_hidden_size); } int64_t kv_sequence_length = sequence_length; @@ -282,14 +295,14 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape, "Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0"); } - if (static_cast(past_dims[2]) != num_heads_) { + if (past_dims[2] != num_heads_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of num_heads", num_heads_); } - if (static_cast(past_dims[4]) != k_hidden_size / num_heads_) { + if (past_dims[4] != k_hidden_size / num_heads_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Inputs 'past' dimension 2 shall have length of ", k_hidden_size / num_heads_); + "Inputs 'past' dimension 4 shall have length of ", k_hidden_size / num_heads_); } if (!past_present_share_buffer_) { @@ -348,17 +361,17 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape, if (parameters != nullptr) { AttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = static_cast(batch_size); - output_parameters->sequence_length = static_cast(sequence_length); - output_parameters->past_sequence_length = static_cast(past_sequence_length); - output_parameters->kv_sequence_length = static_cast(kv_sequence_length); - output_parameters->total_sequence_length = static_cast(total_sequence_length); - output_parameters->max_sequence_length = static_cast(max_sequence_length); - output_parameters->input_hidden_size = static_cast(input_hidden_size); - output_parameters->hidden_size = static_cast(q_hidden_size); - output_parameters->v_hidden_size = static_cast(v_hidden_size); - output_parameters->head_size = static_cast(q_hidden_size) / num_heads_; - output_parameters->v_head_size = static_cast(v_hidden_size) / num_heads_; + output_parameters->batch_size = narrow(batch_size); + output_parameters->sequence_length = narrow(sequence_length); + output_parameters->past_sequence_length = narrow(past_sequence_length); + output_parameters->kv_sequence_length = narrow(kv_sequence_length); + output_parameters->total_sequence_length = narrow(total_sequence_length); + output_parameters->max_sequence_length = narrow(max_sequence_length); + output_parameters->input_hidden_size = narrow(input_hidden_size); + output_parameters->hidden_size = narrow(q_hidden_size); + output_parameters->v_hidden_size = narrow(v_hidden_size); + output_parameters->head_size = narrow(q_hidden_size) / num_heads_; + output_parameters->v_head_size = narrow(v_hidden_size) / num_heads_; output_parameters->num_heads = num_heads_; output_parameters->is_unidirectional = is_unidirectional_; output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr); @@ -398,7 +411,7 @@ inline Tensor* AttentionBase::GetPresent(TOpKernelContext* context, int head_size, int kv_sequence_length, int& past_sequence_length) const { - past_sequence_length = (nullptr != past) ? static_cast(past->Shape().GetDims()[3]) : 0; + past_sequence_length = (nullptr != past) ? narrow(past->Shape().GetDims()[3]) : 0; std::array present_dims{2, batch_size, num_heads_, static_cast(kv_sequence_length) + past_sequence_length, head_size}; diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc index 931677582d469..c07acbfcb0e47 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc @@ -7,6 +7,7 @@ #include "core/util/math.h" #include "core/util/qmath.h" #include "core/util/math_cpuonly.h" +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" #include "core/mlas/inc/mlas.h" @@ -71,8 +72,8 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr return Status::OK(); } - const size_t input_hidden_size = static_cast(weights_dims[0]); - const size_t hidden_size_x3 = static_cast(weights_dims[1]); + const size_t input_hidden_size = narrow(weights_dims[0]); + const size_t hidden_size_x3 = narrow(weights_dims[1]); const size_t hidden_size = hidden_size_x3 / 3; const size_t head_size = hidden_size / num_heads_; @@ -89,8 +90,8 @@ Status QAttention::PrePack(const Tensor& weights, int input_idx, AllocatorPtr return Status::OK(); } - const size_t loop_len = 3 * static_cast(num_heads_); - size_t packed_weights_data_size = packed_weights_size_ * loop_len; + const size_t loop_len = SafeInt(3) * num_heads_; + size_t packed_weights_data_size = SafeInt(packed_weights_size_) * loop_len; packed_weights_ = IAllocator::MakeUniquePtr(alloc, packed_weights_data_size, true); std::byte* packed_weights_data = static_cast(packed_weights_.get()); @@ -171,12 +172,6 @@ Status QAttention::Compute(OpKernelContext* context) const { T input_scale = *(input_scale_tensor->Data()); bool is_weight_scale_per_column = !IsScalarOr1ElementVector(weight_scale_tensor); - const T* weight_scale_data = weight_scale_tensor->Data(); - - std::vector dequant_scales(weight_scale_data, weight_scale_data + weight_scale_tensor->Shape().Size()); - std::for_each(dequant_scales.begin(), dequant_scales.end(), [&input_scale](float& dequant_scale) { - return dequant_scale *= input_scale; - }); uint8_t input_zero_point = 0; if (i_zp_tensor != nullptr) { @@ -194,14 +189,42 @@ Status QAttention::Compute(OpKernelContext* context) const { } const auto& shape = input->Shape(); - const int batch_size = static_cast(shape[0]); - const int sequence_length = static_cast(shape[1]); - const int input_hidden_size = static_cast(shape[2]); - - const auto hidden_size_x3 = weights_shape.GetDims()[1]; - const int hidden_size = static_cast(hidden_size_x3) / 3; + const int batch_size = narrow(shape[0]); + const int sequence_length = narrow(shape[1]); + const int input_hidden_size = narrow(shape[2]); + + const int hidden_size_x3 = narrow(weights_shape.GetDims()[1]); + // AttentionBase::CheckInputs verifies that weights_dims[1] (== bias_dims[0]) is a multiple of 3 + // and that hidden_size = bias_dims[0] / 3 is divisible by num_heads_. + const int hidden_size = hidden_size_x3 / 3; const int head_size = hidden_size / num_heads_; + // Validate per-column 'weight_scale' / 'weight_zero_point' shapes against the expected + // 3 * hidden_size. Without this check, a malicious or malformed model can supply a + // smaller per-column tensor and cause an out-of-bounds read in the GEMM loop below + // (which indexes scales/zero-points using offsets up to ~3 * hidden_size - head_size). + if (is_weight_scale_per_column) { + ORT_RETURN_IF_NOT(weight_scale_tensor->Shape().NumDimensions() == 1 && + weight_scale_tensor->Shape().Size() == hidden_size_x3, + "Input 'weight_scale' must be a scalar or a 1D tensor of size 3 * hidden_size (= ", + hidden_size_x3, "), got shape ", weight_scale_tensor->Shape().ToString()); + } + + if (is_weight_zp_per_column) { + ORT_RETURN_IF_NOT(w_zp_tensor->Shape().NumDimensions() == 1 && + w_zp_tensor->Shape().Size() == hidden_size_x3, + "Input 'weight_zero_point' must be a scalar or a 1D tensor of size 3 * hidden_size (= ", + hidden_size_x3, "), got shape ", w_zp_tensor->Shape().ToString()); + } + + // Build the dequantization scales after shape validation so that malformed + // inputs are rejected before any allocation/copy work. + const T* weight_scale_data = weight_scale_tensor->Data(); + std::vector dequant_scales(weight_scale_data, weight_scale_data + weight_scale_tensor->Shape().Size()); + std::for_each(dequant_scales.begin(), dequant_scales.end(), [&input_scale](T& dequant_scale) { + return dequant_scale *= input_scale; + }); + std::vector output_shape(3); output_shape[0] = shape[0]; output_shape[1] = shape[1]; @@ -216,16 +239,17 @@ Status QAttention::Compute(OpKernelContext* context) const { auto* tp = context->GetOperatorThreadPool(); // STEP.1: gemm_data(BS, 3NH) = Scale(input(BS, D) x weights(D, 3NH)) + bias(3NH) // D is hidden dimension of input, where input_hidden_size (D) could be larger than hidden_size (NH) when model is pruned. - auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); + const auto batch_size_x_sequence_length_x_hidden_size = SafeInt(batch_size) * sequence_length * hidden_size; + void* gemm_data = allocator->Alloc(batch_size_x_sequence_length_x_hidden_size * 3 * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator))); auto Q = reinterpret_cast(gemm_data); - auto K = Q + static_cast(batch_size) * sequence_length * hidden_size; - auto V = K + static_cast(batch_size) * sequence_length * hidden_size; + auto K = Q + static_cast(batch_size_x_sequence_length_x_hidden_size); + auto V = K + static_cast(batch_size_x_sequence_length_x_hidden_size); T* QKV[3] = {Q, K, V}; { - const int loop_len = 3 * batch_size * num_heads_; + const int loop_len = SafeInt(3) * batch_size * num_heads_; const auto* input_data = input->Data(); const auto* bias_data = bias->Data(); @@ -243,16 +267,17 @@ Status QAttention::Compute(OpKernelContext* context) const { scale_bias_procs.reserve(loop_len); for (int i = 0; i < loop_len; i++) { - const int batch_index = static_cast((i / 3) / num_heads_); - const int head_index = static_cast((i / 3) % num_heads_); - const int qkv_index = static_cast(i % 3); + const int batch_index = (i / 3) / num_heads_; + const int head_index = (i / 3) % num_heads_; + const int qkv_index = i % 3; - int input_offset = batch_index * sequence_length * input_hidden_size; + int input_offset = SafeInt(batch_index) * sequence_length * input_hidden_size; int weights_offset = qkv_index * hidden_size + head_index * head_size; int weights_scale_offset = is_weight_scale_per_column ? weights_offset : 0; int weights_zp_offset = is_weight_zp_per_column ? weights_offset : 0; float* qkv_dest = QKV[qkv_index]; - int qkv_offset = (batch_index * num_heads_ + head_index) * (sequence_length * head_size); + int qkv_offset = (SafeInt(batch_index) * num_heads_ + head_index) * + (SafeInt(sequence_length) * head_size); // original transposed iteration // A: input (BxSxD) (B.)S x D S x D diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index 907cfd7fa0833..0e1b4b6bde720 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include "gtest/gtest.h" @@ -1162,5 +1163,163 @@ TEST(QAttentionTest, SharedPrepackedWeights) { } #endif +namespace { + +// Build a small valid QAttention OpTester and let the caller mutate the +// per-column weight_scale / weight_zero_point shapes to invalid sizes to +// exercise the input-validation path that protects against OOB reads in +// QAttention::Compute. +void RunQAttentionInvalidPerColumnShapes(const std::vector& weight_scale_dims, + const std::vector& weight_scale_data, + const std::vector& weight_zp_dims, + const std::vector& weight_zp_data, + const std::string& expected_error_substring) { + constexpr int batch_size = 1; + constexpr int sequence_length = 2; + constexpr int hidden_size = 4; + constexpr int number_of_heads = 2; + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + + std::vector input_data(static_cast(batch_size * sequence_length * hidden_size), 128); + std::vector weight_data(static_cast(hidden_size * 3 * hidden_size), 128); + std::vector bias_data(static_cast(3 * hidden_size), 0.0f); + + OpTester tester("QAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddInput("input", input_dims, input_data); + tester.AddInput("weight", weights_dims, weight_data); + tester.AddInput("bias", bias_dims, bias_data); + tester.AddInput("input_scale", {1}, {0.1f}); + tester.AddInput("weight_scale", weight_scale_dims, weight_scale_data); + tester.AddOptionalInputEdge(); // mask_index + tester.AddInput("input_zero_point", {1}, {128}); + tester.AddInput("weight_zero_point", weight_zp_dims, weight_zp_data); + + // Output shape is required by OpTester even though we expect failure before + // the kernel writes anything meaningful. + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + std::vector dummy_output(static_cast(batch_size * sequence_length * hidden_size), 0.0f); + tester.AddOutput("output", output_dims, dummy_output); + + // CPU EP only — CUDA QAttention does not support per-column scale/zp. + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, expected_error_substring, + {}, nullptr, &execution_providers); +} + +} // namespace + +// Regression tests: a malformed model that supplies a "per-column" +// weight_scale or weight_zero_point of the wrong size used to trigger +// an out-of-bounds read in QAttention::Compute. The kernel must +// reject such inputs with a descriptive error. +TEST(QAttentionTest, InvalidWeightScalePerColumnShape) { + // hidden_size=4, expected per-column size = 3 * hidden_size = 12. + // Supply 2 elements (smaller than expected) to trigger the validation. + RunQAttentionInvalidPerColumnShapes( + /*weight_scale_dims=*/{2}, /*weight_scale_data=*/{0.1f, 0.1f}, + /*weight_zp_dims=*/{1}, /*weight_zp_data=*/{128}, + /*expected_error_substring=*/"'weight_scale' must be a scalar or a 1D tensor of size 3 * hidden_size"); +} + +TEST(QAttentionTest, InvalidWeightScalePerColumnRank) { + // 2-D weight_scale with size 12 should still be rejected (rank must be 1). + RunQAttentionInvalidPerColumnShapes( + /*weight_scale_dims=*/{3, 4}, + /*weight_scale_data=*/std::vector(12, 0.1f), + /*weight_zp_dims=*/{1}, /*weight_zp_data=*/{128}, + /*expected_error_substring=*/"'weight_scale' must be a scalar or a 1D tensor of size 3 * hidden_size"); +} + +TEST(QAttentionTest, InvalidWeightZeroPointPerColumnShape) { + // hidden_size=4, expected per-column size = 12. Supply 2 elements. + RunQAttentionInvalidPerColumnShapes( + /*weight_scale_dims=*/{1}, /*weight_scale_data=*/{0.1f}, + /*weight_zp_dims=*/{2}, /*weight_zp_data=*/{128, 128}, + /*expected_error_substring=*/"'weight_zero_point' must be a scalar or a 1D tensor of size 3 * hidden_size"); +} + +TEST(QAttentionTest, InvalidWeightZeroPointPerColumnRank) { + // 2-D weight_zero_point with size 12 should still be rejected (rank must be 1). + RunQAttentionInvalidPerColumnShapes( + /*weight_scale_dims=*/{1}, /*weight_scale_data=*/{0.1f}, + /*weight_zp_dims=*/{3, 4}, + /*weight_zp_data=*/std::vector(12, 128), + /*expected_error_substring=*/"'weight_zero_point' must be a scalar or a 1D tensor of size 3 * hidden_size"); +} + +namespace { + +// Builds a QAttention OpTester with hidden_size = hidden_size_x3 / 3 and the given num_heads +// attribute, and asserts that Run() fails with an error matching expected_error_substring. The +// inputs are otherwise valid (per-tensor scales / zero points, no past, no mask). +void RunQAttentionExpectFailure(int64_t hidden_size_x3, + int64_t num_heads_attr, + const std::string& expected_error_substring) { + constexpr int batch_size = 1; + constexpr int sequence_length = 2; + // input_hidden_size is the model's input embedding dimension; weights project it to 3 * hidden_size. + // It is independent of hidden_size, so use a distinct value here to make that clear. + constexpr int64_t input_hidden_size = 8; + const int64_t hidden_size = hidden_size_x3 / 3; + + std::vector input_dims = {batch_size, sequence_length, input_hidden_size}; + std::vector weights_dims = {input_hidden_size, hidden_size_x3}; + std::vector bias_dims = {hidden_size_x3}; + + std::vector input_data(static_cast(batch_size * sequence_length * input_hidden_size), 128); + std::vector weight_data(static_cast(input_hidden_size * hidden_size_x3), 128); + std::vector bias_data(static_cast(hidden_size_x3), 0.0f); + + OpTester tester("QAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", num_heads_attr); + tester.AddInput("input", input_dims, input_data); + tester.AddInput("weight", weights_dims, weight_data); + tester.AddInput("bias", bias_dims, bias_data); + tester.AddInput("input_scale", {1}, {0.1f}); + tester.AddInput("weight_scale", {1}, {0.1f}); + tester.AddOptionalInputEdge(); // mask_index + tester.AddInput("input_zero_point", {1}, {128}); + tester.AddInput("weight_zero_point", {1}, {128}); + + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + std::vector dummy_output(static_cast(batch_size * sequence_length * hidden_size), 0.0f); + tester.AddOutput("output", output_dims, dummy_output); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectFailure, expected_error_substring, + {}, nullptr, &execution_providers); +} + +} // namespace + +// Regression test: QAttention requires bias_dims[0] (== weights_dims[1]) to be a multiple of 3 +// since Q, K, V are packed along that dimension with equal hidden sizes. +TEST(QAttentionTest, InvalidBiasDimNotMultipleOfThree) { + RunQAttentionExpectFailure(/*hidden_size_x3=*/13, /*num_heads_attr=*/1, "must be a multiple of 3"); +} + +// Regression test: hidden_size must be divisible by num_heads. Otherwise head_size silently +// truncates and the per-head GEMM loop leaves part of the hidden dimension uncovered. +TEST(QAttentionTest, InvalidHiddenSizeNotDivisibleByNumHeads) { + // hidden_size_x3 = 12 -> hidden_size = 4; 4 % 3 != 0. + RunQAttentionExpectFailure(/*hidden_size_x3=*/12, /*num_heads_attr=*/3, + "hidden_size should be divisible by num_heads"); +} + +// Regression test: num_heads attribute exceeding INT_MAX must be rejected by the narrow +// conversion in AttentionBase's constructor rather than silently truncating. gsl::narrowing_error +// is thrown during session init; its what() returns the literal "narrowing_error". +TEST(QAttentionTest, InvalidNumHeadsOverflowsInt) { + RunQAttentionExpectFailure(/*hidden_size_x3=*/12, + /*num_heads_attr=*/static_cast(std::numeric_limits::max()) + 1, + "narrowing_error"); +} + } // namespace test } // namespace onnxruntime From c0b3212f31f2ad9c95259307ff6c5207f5c66a41 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 14 May 2026 16:51:31 -0700 Subject: [PATCH 06/11] Fill LSTM CUDA operator opset gap: extend coverage from opset 14 to opset 22 (#27737) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Extends LSTM CUDA kernel registration from opset 14 to opset 22. - **`lstm.cc`**: Cap existing opset 14 kernel to versioned 14–21, add new non-versioned kernel at opset 22 - **`cuda_execution_provider.cc`**: Update forward declarations and `BuildKernelCreateInfo` entries accordingly (versioned 14–21 + non-versioned 22) for all three types (`float`, `double`, `MLFloat16`) - **`deep_cpu_lstm_op_test.cc`**: Add `ONNXRuntime_TestLSTMForward_OpSet22_CUDA` test targeting the new registration - **`docs/OperatorKernels.md`**: Update CUDA LSTM entry from `14+` to `[14, 21]` and `22+` No spec-level behavior changes between opsets 14 and 22 for LSTM — this is purely a registration gap fill so the CUDA EP correctly claims nodes exported at newer opset versions. ### Motivation and Context LSTM CUDA kernel was registered only up to opset 14 while the ONNX spec defines LSTM through opset 22. Models exported at opset ≥15 would fall back to CPU. Follows the same pattern established by other opset gap PRs (ConvTranspose, MaxPool, Pad, etc.) referenced in #27729. --- 📍 Connect Copilot coding agent with [Jira](https://gh.io/cca-jira-docs), [Azure Boards](https://gh.io/cca-azure-boards-docs) or [Linear](https://gh.io/cca-linear-docs) to delegate work to Copilot in one click without leaving your project management tool. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com> Co-authored-by: Tianlei Wu --- docs/OperatorKernels.md | 3 +- .../providers/cuda/cuda_execution_provider.cc | 18 +++-- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 80 ++++++++++++------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 11 ++- onnxruntime/core/providers/cuda/rnn/lstm.cc | 20 ++++- onnxruntime/core/providers/cuda/rnn/lstm.h | 7 ++ .../cpu/rnn/deep_cpu_lstm_op_test.cc | 64 +++++++++++++++ 7 files changed, 158 insertions(+), 45 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 5cf92c6c53e4b..d163301189875 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -810,7 +810,8 @@ The **OpSet Version** column uses the following notation: |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| +|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| |||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| |LayerNormalization|*in* X:**T**
*in* Scale:**T**
*in* B:**T**
*out* Y:**T**
*out* Mean:**U**
*out* InvStdDev:**U**

or

*in* X:**T**
*in* Scale:**V**
*in* B:**V**
*out* Y:**V**
*out* Mean:**U**
*out* InvStdDev:**U**|17+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(float)| |||[1, 16]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 9241e35616ba2..d9b5760848678 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1409,9 +1409,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, GRU); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, GRU); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Identity); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, LSTM); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, LSTM); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, LSTM); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 18, Reshape); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, RNN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, RNN); @@ -1702,6 +1702,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, LSTM); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, LSTM); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, LSTM); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, RandomNormal); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, RandomNormalLike); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, RandomUniform); @@ -2681,9 +2684,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2975,6 +2978,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index e7c8f52950141..62ec59c0ef798 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -4,6 +4,7 @@ #include "core/providers/shared_library/provider_api.h" #include "cudnn_rnn_base.h" #include "rnn_impl.h" +#include "core/common/safeint.h" namespace onnxruntime { namespace cuda { @@ -16,7 +17,7 @@ Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, const void* reorganized_w_data, const int lin_layer_id, const T* pos, - int& offset, + size_t& offset, bool is_matrix, cudaStream_t cuda_stream) const { int numDims; @@ -37,12 +38,12 @@ Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data())); mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias; - int count = matDims[0] * matDims[1] * matDims[2]; + size_t count = SafeInt(matDims[0]) * matDims[1] * matDims[2]; - if (strideA[0] != count) { + if (static_cast(strideA[0]) != count) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed"); } - CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream)); offset += count; @@ -57,9 +58,9 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const T* R_data, const T* B_data, cudaStream_t cuda_stream) const { - int w_offset = 0; - int r_offset = 0; - int bias_offset = 0; + size_t w_offset = 0; + size_t r_offset = 0; + size_t bias_offset = 0; for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { ORT_RETURN_IF_ERROR(SetWeightBias( @@ -92,6 +93,12 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons void* alloc_stream, cudaStream_t cuda_stream, cudnnHandle_t cudnn_handle) const { typedef typename ToCudaType::MappedType CudaT; + ORT_RETURN_IF(W->Shape().NumDimensions() != 3, + "Weight W must be 3-D [num_directions, hidden_size, input_size], got rank ", + W->Shape().NumDimensions()); + ORT_RETURN_IF(R->Shape().NumDimensions() != 3, + "Recurrence R must be 3-D [num_directions, hidden_size, hidden_size], got rank ", + R->Shape().NumDimensions()); int64_t input_size = W->Shape()[2]; // RNN W[num_directions_, hidden_size_, input_size] // RNN R[num_directions_, hidden_size_, hidden_size_] @@ -103,12 +110,12 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons // LSTM R[num_directions_, 4*hidden_size_, hidden_size_] // LSTM B[num_directions_, 8*hidden_size_] size_t number = W_lin_layer_id_.size(); - int64_t w_size = num_directions_ * (number * hidden_size_ * (input_size + hidden_size_ + 2)); + int64_t w_size = SafeInt(num_directions_) * number * hidden_size_ * (input_size + hidden_size_ + 2); TensorShapeVector dims_w({w_size, 1, 1}); ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType())); // Prepare the weight data - reorganized_w_data_size_in_bytes = w_size * sizeof(T); + reorganized_w_data_size_in_bytes = SafeInt(w_size) * sizeof(T); reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, alloc_stream); // In many cases, this allocation is bigger than needed, leaving part of @@ -142,6 +149,8 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool has_bias = B != nullptr; if (get_W && get_R) { + ORT_RETURN_IF(W->Shape().NumDimensions() != 3, + "Constant W must be 3-D, got rank ", W->Shape().NumDimensions()); CudnnRNN tmp_rnn_desc; auto proj_size = hidden_size_; ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size @@ -163,7 +172,7 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr, nullptr, DefaultCudnnHandle())); } - cudaStreamSynchronize(nullptr); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); weight_cached_ = true; } @@ -177,6 +186,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // inputs const Tensor* X = ctx->Input(RNN_Input_Index::X); // inputs. [seq_length, batch_size, input_size] ORT_ENFORCE(nullptr != X); + ORT_RETURN_IF(X->Shape().NumDimensions() != 3, + "Input X must be 3-D [seq_length, batch_size, input_size], got rank ", + X->Shape().NumDimensions()); // optional inputs // [batch_size] @@ -187,6 +199,10 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (rnn_mode_ == CUDNN_LSTM) { // initial cell. [num_directions_, batch_size, hidden_size_] initial_c = ctx->Input(RNN_Input_Index::initial_c); + // cuDNN LSTM does not support peephole weights (ONNX input P at index 7) + const Tensor* P = ctx->Input(7); + ORT_RETURN_IF(P != nullptr, + "CUDA LSTM does not support peephole weights (input P). Use CPU EP instead."); } size_t proj_size = hidden_size_; @@ -197,7 +213,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]? std::vector sequence_lengths_temp; if (!sequence_lens) { - sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length)); + sequence_lengths_temp.resize(batch_size, gsl::narrow(seq_length)); } const int32_t* sequence_lens_data = (sequence_lens == nullptr) @@ -214,10 +230,10 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // 0-len sequences are not supported by cuDNN. // Replace them by sequences of len 1 and mask them out with SetZeroSequences - for (int i = 0; i < batch_size; ++i) { + for (int64_t i = 0; i < batch_size; ++i) { if (0 == sequence_lens_data[i]) { seq_len_array[i] = 1; - zero_seq_index_cache[zero_seq_count] = i; + zero_seq_index_cache[zero_seq_count] = gsl::narrow(i); ++zero_seq_count; } else { seq_len_array[i] = sequence_lens_data[i]; @@ -257,15 +273,15 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const T* x_data = X->Data(); if (reverse_) { // reverse input data - x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, GetComputeStream(ctx)); + x_reversed_data = GetScratchBuffer(SafeInt(seq_length) * batch_size * input_size, GetComputeStream(ctx)); ReverseBySequence(Stream(ctx), - gsl::narrow_cast(seq_length), + gsl::narrow(seq_length), sequence_lens_buffer.GpuPtr(), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(input_size), + gsl::narrow(batch_size), + gsl::narrow(input_size), reinterpret_cast(x_data), reinterpret_cast(x_reversed_data.get()), - seq_length * batch_size * input_size); + SafeInt(seq_length) * batch_size * input_size); } const T* x_data_input = reverse_ ? x_reversed_data.get() : x_data; @@ -274,7 +290,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const T* cx_data = (initial_c == nullptr) ? nullptr : initial_c->Data(); T* y_h_data = (Y_h == nullptr) ? nullptr : Y_h->MutableData(); T* y_c_data = (Y_c == nullptr) ? nullptr : Y_c->MutableData(); - int64_t output_size = seq_length * num_directions_ * batch_size * hidden_size_; + int64_t output_size = SafeInt(seq_length) * num_directions_ * batch_size * hidden_size_; T* y_data = nullptr; IAllocatorUniquePtr y_alloc_data; if (Y != nullptr) { @@ -357,7 +373,8 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); + SetZeroSequences(gsl::span(zero_seq_index_cache.data(), zero_seq_count), + y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); } return Status::OK(); } @@ -369,18 +386,18 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (reverse_) { // reverse output data ReverseBySequence(Stream(ctx), - gsl::narrow_cast(seq_length), + gsl::narrow(seq_length), sequence_lens_buffer.GpuPtr(), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(hidden_size_), + gsl::narrow(batch_size), + gsl::narrow(hidden_size_), reinterpret_cast(y_data), reinterpret_cast(y_reorganized_data.get()), output_size); } else { ReorderBidirectionalDataInSequence(Stream(ctx), - gsl::narrow_cast(seq_length), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(hidden_size_), + gsl::narrow(seq_length), + gsl::narrow(batch_size), + gsl::narrow(hidden_size_), reinterpret_cast(y_data), reinterpret_cast(y_reorganized_data.get()), output_size); @@ -388,7 +405,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (Y != nullptr) { // User specified this optional output, so need to copy the reversed data to original place - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), SafeInt(output_size) * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); } else { y_data = y_reorganized_data.get(); @@ -397,26 +414,27 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); + SetZeroSequences(gsl::span(zero_seq_index_cache.data(), zero_seq_count), + y_data, y_h_data, y_c_data, GetComputeStream(ctx), Stream(ctx)); } return Status::OK(); } template -void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, - const std::vector zero_seq_index_cache, +void CudnnRnnBase::SetZeroSequences(gsl::span zero_seq_index_cache, T* y_data, T* y_h_data, T* y_c_data, void* alloc_stream, cudaStream_t cuda_stream) const { typedef typename ToCudaType::MappedType CudaT; + const int64_t zero_seq_index_cache_size = static_cast(zero_seq_index_cache.size()); CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(alloc_stream)); MaskZeroSequences(cuda_stream, - gsl::narrow_cast(hidden_size_), + gsl::narrow(hidden_size_), reinterpret_cast(y_data), reinterpret_cast(y_h_data), reinterpret_cast(y_c_data), diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index b7a3d67b45e93..e7bccae2aad1c 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -58,9 +58,9 @@ class CudnnRNN { dataType, dataType, mathType, - gsl::narrow_cast(input_size), - gsl::narrow_cast(hidden_size), - gsl::narrow_cast(proj_size), // projected size + gsl::narrow(input_size), + gsl::narrow(hidden_size), + gsl::narrow(proj_size), // projected size num_layers, cudnn_dropout_desc, // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences @@ -148,12 +148,11 @@ class CudnnRnnBase : public CudaKernel { const void* w_data, const int lin_layer_id, const T* pos, - int& offset, + size_t& offset, bool is_matrix, cudaStream_t cuda_stream) const; - void SetZeroSequences(const int64_t zero_seq_index_cache_size, - const std::vector zero_seq_index_cache, + void SetZeroSequences(gsl::span zero_seq_index_cache, T* y_data, T* y_h_data, T* y_c_data, diff --git a/onnxruntime/core/providers/cuda/rnn/lstm.cc b/onnxruntime/core/providers/cuda/rnn/lstm.cc index 890d15cef6501..ce625371b3e5f 100644 --- a/onnxruntime/core/providers/cuda/rnn/lstm.cc +++ b/onnxruntime/core/providers/cuda/rnn/lstm.cc @@ -21,11 +21,25 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, RNN_Input_Index::sequence_lens), \ LSTM); +#define REGISTER_KERNEL_VERSIONED_TYPED_14(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + LSTM, \ + kOnnxDomain, \ + 14, \ + 21, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, RNN_Input_Index::sequence_lens), \ + LSTM); + #define REGISTER_KERNEL_TYPED(T) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ LSTM, \ kOnnxDomain, \ - 14, \ + 22, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ @@ -38,6 +52,10 @@ REGISTER_KERNEL_VERSIONED_TYPED(float); REGISTER_KERNEL_VERSIONED_TYPED(double); REGISTER_KERNEL_VERSIONED_TYPED(MLFloat16); +REGISTER_KERNEL_VERSIONED_TYPED_14(float); +REGISTER_KERNEL_VERSIONED_TYPED_14(double); +REGISTER_KERNEL_VERSIONED_TYPED_14(MLFloat16); + REGISTER_KERNEL_TYPED(float); REGISTER_KERNEL_TYPED(double); REGISTER_KERNEL_TYPED(MLFloat16); diff --git a/onnxruntime/core/providers/cuda/rnn/lstm.h b/onnxruntime/core/providers/cuda/rnn/lstm.h index b82061088a2c8..fb1fdc7165725 100644 --- a/onnxruntime/core/providers/cuda/rnn/lstm.h +++ b/onnxruntime/core/providers/cuda/rnn/lstm.h @@ -14,6 +14,13 @@ class LSTM final : public CudnnRnnBase { LSTM(const OpKernelInfo& info) : CudnnRnnBase(info) { CudnnRnnBase::SetRNNMode(CUDNN_LSTM); + // cuDNN LSTM does not support input_forget coupling (attribute added in opset 14) + int64_t input_forget = 0; + if (info.GetAttr("input_forget", &input_forget).IsOK()) { + ORT_ENFORCE(input_forget == 0, + "CUDA LSTM does not support input_forget=1. Use CPU EP instead."); + } + // ONNX W layout is W[iofc], WB[iofc], mapping to RNNLinLayerMatrixParams the linLayerID is 0, 3, 1, 2 CudnnRnnBase::W_lin_layer_id_.assign({0, 3, 1, 2}); // ONNX R layout is R[iofc], RB[iofc], mapping to RNNLinLayerMatrixParams the linLayerID is 4, 7, 5, 6 diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc index 3b7e93b8f7668..c23e843b6f143 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc @@ -1351,6 +1351,70 @@ TEST(LSTMTest, ONNXRuntime_TestLSTMZeroSeqInMiddle) { #ifndef ENABLE_TRAINING // Prepacking is disabled in full training build so no need to test the feature in a training build. +TEST(LSTMTest, ONNXRuntime_TestLSTMForward_OpSet22_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + return; + } + + constexpr int seq_len = 2, batch_size = 1; + constexpr int64_t input_size = 1, hidden_size = 2; + constexpr int num_directions = 1; + + OpTester test("LSTM", 22); + + test.AddAttribute>("activations", {"sigmoid", "tanh", "tanh"}); + test.AddAttribute("direction", "forward"); + test.AddAttribute("hidden_size", hidden_size); + + // X shape: [seq_len, batch_size, input_size] = [2, 1, 1] + std::vector X_data = {-0.455351f, -0.185934f}; + std::vector X_dims = {seq_len, batch_size, input_size}; + + // W shape: [num_directions, 4*hidden_size, input_size] = [1, 8, 1] + std::vector W_data = {-0.494659f, 0.0453352f, + -0.487793f, 0.417264f, + -0.0175329f, 0.489074f, + -0.446013f, 0.414029f}; + std::vector W_dims = {num_directions, 4 * hidden_size, input_size}; + + // R shape: [num_directions, 4*hidden_size, hidden_size] = [1, 8, 2] + std::vector R_data = {0.146304f, -0.0243403f, + -0.487793f, 0.417264f, + -0.0175329f, 0.489074f, + -0.446013f, 0.414029f, + 0.146304f, -0.0243403f, + -0.487793f, 0.417264f, + -0.0175329f, 0.489074f, + -0.446013f, 0.414029f}; + std::vector R_dims = {num_directions, 4 * hidden_size, hidden_size}; + + test.AddInput("X", X_dims, X_data); + test.AddInput("W", W_dims, W_data, true); + test.AddInput("R", R_dims, R_data, true); + + // B, sequence_lens, initial_h, initial_c, P - all optional, not provided + test.AddOptionalInputEdge(); // B + test.AddOptionalInputEdge(); // sequence_lens + test.AddOptionalInputEdge(); // initial_h + test.AddOptionalInputEdge(); // initial_c + test.AddOptionalInputEdge(); // P + + // Expected values computed via reference LSTM implementation with the same weights. + // Y (full sequence output) shape: [seq_len, num_directions, batch_size, hidden_size] + std::vector Y_dims = {seq_len, num_directions, batch_size, hidden_size}; + test.AddOutput("Y", Y_dims, {0.0616098f, -0.0416164f, 0.0455843f, -0.0476148f}, false, 1e-4f, 1e-4f); + // Y_h (final hidden state) + std::vector Y_h_dims = {num_directions, batch_size, hidden_size}; + test.AddOutput("Y_h", Y_h_dims, {0.0455843f, -0.0476148f}, false, 1e-4f, 1e-4f); + // Y_c + test.AddOptionalOutputEdge(); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(LSTMTest, SharedPrepackedWeights) { int64_t seq_length = 2; int batch_size = 2; From 127faf0ccc40b98e9fb1b4e1ab978ebd594d0928 Mon Sep 17 00:00:00 2001 From: KV2773 Date: Fri, 15 May 2026 11:22:04 +0530 Subject: [PATCH 07/11] Build issues on AIX for POWER10 and POWER11 (#26704) While building onnxruntime from source for AIX I ran into macro pre-defined errors for the POWER10 and POWER11 machines. This patch resolves the issue. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cmake/onnxruntime_mlas.cmake | 6 +++++- onnxruntime/core/mlas/lib/platform.cpp | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index ae6253cecdc03..caaee84d32acb 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -636,9 +636,13 @@ else() enable_language(ASM) check_cxx_source_compiles(" #ifdef _AIX + #include + #if !defined(POWER_10) #define POWER_10 0x40000 + #endif + #if !defined(POWER_10_ANDUP) #define POWER_10_ANDUP (POWER_10) - #include + #endif #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) int main() { bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 6bd116a3a417d..de521de0c3785 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -40,9 +40,13 @@ Module Name: #if defined(__linux__) #include #elif defined(_AIX) +#include +#if !defined(POWER_10) #define POWER_10 0x40000 +#endif +#if !defined(POWER_10_ANDUP) #define POWER_10_ANDUP (POWER_10) -#include +#endif #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) #elif defined(__FreeBSD__) #include From ddd1eae299fb5dda0ae0146845b2194e9dbe5b83 Mon Sep 17 00:00:00 2001 From: Brenden Sosnader Date: Fri, 15 May 2026 09:40:33 -0700 Subject: [PATCH 08/11] Fix WeaklyCanonicalPath ERROR_ACCESS_DENIED in Windows AppContainers (#28509) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Adds a Windows-only fallback to `WeaklyCanonicalPath` (in `onnxruntime/core/framework/tensorprotoutils.cc`) for use inside Windows AppContainers, where `std::filesystem::weakly_canonical` always fails with `ERROR_ACCESS_DENIED` because the underlying `GetFinalPathNameByHandleW(VOLUME_NAME_DOS)` call goes through the Volume Mount Manager, which AppContainer tokens cannot query regardless of file ACL grants. On `ERROR_ACCESS_DENIED`, fall back to a manual canonicalization that uses `GetFinalPathNameByHandleW(FILE_NAME_NORMALIZED | VOLUME_NAME_NT)` and prefixes the result with `\\?\GLOBALROOT` so it remains a valid Win32 path. All other error paths, non-Windows builds, and non-AppContainer Windows runs are unchanged. `VOLUME_NAME_NT` (not `VOLUME_NAME_NONE`) is required: it preserves volume identity, so the cross-volume escape rejection in `ValidateExternalDataPath` introduced by #26776 continues to hold. 8 new unit tests cover the fallback helper directly (existing dir/file, non-existent leaf, multi-component miss, all-non-existent → false, equivalence with `weakly_canonical`, symlink resolution, `..` collapse). The AppContainer trigger itself cannot be reproduced in a unit test environment. ### Motivation and Context Fixes #28508. Regression introduced in v1.24.1 by #26776 (`ValidateExternalDataPath`); current `WeaklyCanonicalPath` wrapper from #27539 in v1.25.0. Loading any ONNX model with external data fails inside a Windows AppContainer with: ``` Failed to get the weakly canonical path: "" - Access is denied. ``` Affected callers have no in-process workaround. Downstream report: microsoft/Foundry-Local#709. CC @yuslepukhin (#26776), @adrianlizarraga (#27539), @tianleiwu (#27374). --------- Co-authored-by: Brenden Sosnader Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/framework/tensorprotoutils.cc | 8 +- onnxruntime/core/platform/env.h | 6 + onnxruntime/core/platform/posix/env.cc | 13 ++ onnxruntime/core/platform/windows/env.cc | 131 ++++++++++++++++++ onnxruntime/core/platform/windows/env.h | 8 ++ .../test/framework/tensorutils_test.cc | 120 ++++++++++++++++ 6 files changed, 282 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index ed130686f302a..6f73456742160 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -380,11 +380,11 @@ Status TensorProtoWithExternalDataToTensorProto( return Status::OK(); } -// Wraps std::filesystem::weakly_canonical with error_code handling. +// Wraps Env::GetWeaklyCanonicalPath for std::filesystem::path. static Status WeaklyCanonicalPath(const std::filesystem::path& path, std::filesystem::path& result) { - std::error_code ec; - result = std::filesystem::weakly_canonical(path, ec); - ORT_RETURN_IF(ec, "Failed to get the weakly canonical path: ", path, " - ", ec.message()); + PathString canonical_str; + ORT_RETURN_IF_ERROR(Env::Default().GetWeaklyCanonicalPath(path.native(), canonical_str)); + result = std::filesystem::path(std::move(canonical_str)); return Status::OK(); } diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 8412f8c9b4f6f..f45f6c088d2a5 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -231,6 +231,12 @@ class Env { const PathString& path, PathString& canonical_path) const = 0; + /** Like GetCanonicalPath, but the path is not required to exist. Mirrors + * std::filesystem::weakly_canonical. */ + virtual common::Status GetWeaklyCanonicalPath( + const PathString& path, + PathString& canonical_path) const = 0; + // This functions is always successful. It can't fail. virtual PIDType GetSelfPid() const = 0; diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index ba7af3b383b54..cf28d7a0cce2a 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -543,6 +543,19 @@ class PosixEnv : public Env { return Status::OK(); } + common::Status GetWeaklyCanonicalPath( + const PathString& path, + PathString& canonical_path) const override { + std::error_code ec; + auto canonical = std::filesystem::weakly_canonical(std::filesystem::path{path}, ec); + if (ec) { + return common::Status(common::ONNXRUNTIME, common::FAIL, + "Failed to get the weakly canonical path: " + path + " - " + ec.message()); + } + canonical_path.assign(canonical.native()); + return Status::OK(); + } + common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const override { dlerror(); // clear any old error_str *handle = dlopen(library_filename.c_str(), RTLD_NOW | (global_symbols ? RTLD_GLOBAL : RTLD_LOCAL)); diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 4d80b5afff4b8..07d5dfc9c0b22 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -685,6 +686,136 @@ common::Status WindowsEnv::GetCanonicalPath( return Status::OK(); } +namespace { + +constexpr std::wstring_view kGlobalRootPrefix{L"\\\\?\\GLOBALROOT"}; + +wil::unique_hfile OpenHandleForFinalPath(const std::filesystem::path& path) { + CREATEFILE2_EXTENDED_PARAMETERS params{}; + params.dwSize = sizeof(params); + params.dwFileFlags = FILE_FLAG_BACKUP_SEMANTICS; + return wil::unique_hfile{::CreateFile2(path.c_str(), + FILE_READ_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + OPEN_EXISTING, + ¶ms)}; +} + +// Final-path query using VOLUME_NAME_NT, prefixed with "\\?\GLOBALROOT" to stay a valid Win32 path. +bool TryGetFinalPathNt(const std::filesystem::path& path, std::filesystem::path& result) { + wil::unique_hfile handle = OpenHandleForFinalPath(path); + if (handle.get() == INVALID_HANDLE_VALUE) { + return false; + } + + std::wstring buffer(MAX_PATH, L'\0'); + constexpr DWORD kFlags = FILE_NAME_NORMALIZED | VOLUME_NAME_NT; + DWORD needed = ::GetFinalPathNameByHandleW(handle.get(), buffer.data(), + static_cast(buffer.size()), kFlags); + if (needed != 0 && needed >= buffer.size()) { + buffer.resize(needed); + needed = ::GetFinalPathNameByHandleW(handle.get(), buffer.data(), + static_cast(buffer.size()), kFlags); + } + + if (needed == 0 || needed >= buffer.size()) { + return false; + } + buffer.resize(needed); + + std::wstring prefixed; + prefixed.reserve(kGlobalRootPrefix.size() + buffer.size()); + prefixed.append(kGlobalRootPrefix); + prefixed.append(buffer); + result = std::filesystem::path(std::move(prefixed)); + return true; +} + +// weakly_canonical analogue using TryGetFinalPathNt for the existing prefix. +bool TryWeaklyCanonicalPathNtVolume(const std::filesystem::path& input, + std::filesystem::path& result) { + std::filesystem::path head = input; + std::filesystem::path tail; + std::filesystem::path canonical_head; + bool found_existing_prefix = false; + + while (true) { + std::error_code ec; + const bool exists = std::filesystem::exists(head, ec); + if (ec) { + return false; + } + if (exists) { + if (!TryGetFinalPathNt(head, canonical_head)) { + return false; + } + found_existing_prefix = true; + break; + } + if (head.empty()) { + break; + } + const auto parent = head.parent_path(); + if (parent == head) { + break; + } + const auto leaf = head.filename(); + if (!leaf.empty()) { + // path / empty would insert a trailing separator. + tail = tail.empty() ? leaf : (leaf / tail); + } + head = parent; + } + + if (!found_existing_prefix) { + return false; + } + + if (tail.empty()) { + result = std::move(canonical_head); + } else { + result = (canonical_head / tail).lexically_normal(); + } + return true; +} + +} // namespace + +// On AppContainer, std::filesystem::weakly_canonical fails with ERROR_ACCESS_DENIED +// because VOLUME_NAME_DOS goes through the Volume Mount Manager. Fall back to +// VOLUME_NAME_NT, which preserves volume identity (cross-volume escape rejection in +// ValidateExternalDataPath relies on this — do NOT use VOLUME_NAME_NONE). +common::Status WindowsEnv::GetWeaklyCanonicalPath( + const PathString& path, + PathString& canonical_path) const { + std::filesystem::path fs_path{path}; + std::error_code ec; + std::filesystem::path canonical = std::filesystem::weakly_canonical(fs_path, ec); + if (!ec) { + canonical_path = canonical.native(); + return Status::OK(); + } + + if (ec.value() == ERROR_ACCESS_DENIED) { + std::filesystem::path fallback; + if (TryWeaklyCanonicalPathNtVolume(fs_path, fallback)) { + canonical_path = fallback.native(); + return Status::OK(); + } + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to get the weakly canonical path: ", + ToUTF8String(path), " - ", ec.message()); +} + +namespace internal { +bool WeaklyCanonicalPathNtVolumeFallbackForTesting(const std::filesystem::path& input, + std::filesystem::path& result) { + return TryWeaklyCanonicalPathNtVolume(input, result); +} +} // namespace internal + // Return the path of the executable/shared library for the current running code. This is to make it // possible to load other shared libraries installed next to our core runtime code. PathString WindowsEnv::GetRuntimePath() const { diff --git a/onnxruntime/core/platform/windows/env.h b/onnxruntime/core/platform/windows/env.h index f118e05d42ead..df8a3e10d512a 100644 --- a/onnxruntime/core/platform/windows/env.h +++ b/onnxruntime/core/platform/windows/env.h @@ -18,6 +18,7 @@ limitations under the License. #include "core/platform/windows/telemetry.h" #include "core/common/inlined_containers.h" #include +#include namespace onnxruntime { @@ -79,6 +80,7 @@ class WindowsEnv : public Env { common::Status FileOpenWr(const std::string& path, /*out*/ int& fd) const override; common::Status FileClose(int fd) const override; common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override; + common::Status GetWeaklyCanonicalPath(const PathString& path, PathString& canonical_path) const override; PathString GetRuntimePath() const override; Status LoadDynamicLibrary(const PathString& library_filename, bool /*global_symbols*/, void** handle) const override; Status UnloadDynamicLibrary(void* handle) const override; @@ -141,4 +143,10 @@ class WindowsEnv : public Env { WindowsTelemetry telemetry_provider_; }; +namespace internal { +// Test-only: exposes the AppContainer fallback used by WindowsEnv::GetWeaklyCanonicalPath. +bool WeaklyCanonicalPathNtVolumeFallbackForTesting(const std::filesystem::path& input, + std::filesystem::path& result); +} // namespace internal + } // namespace onnxruntime diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 880208960a63c..71ac5b49e9718 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -21,6 +21,7 @@ #ifdef _WIN32 #include +#include "core/platform/windows/env.h" #endif using namespace ::onnxruntime::utils; @@ -715,6 +716,125 @@ TEST_F(PathValidationTest, ValidateExternalDataPathEmptyModelPathWithSymlinkOuts EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("escapes working directory")); } +#if defined(_WIN32) +// Direct tests for the Windows AppContainer fallback used by +// WindowsEnv::GetWeaklyCanonicalPath. The AppContainer trigger itself can't be +// reproduced in a unit test environment; see microsoft/onnxruntime#28508. +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_ExistingDirectory) { + std::filesystem::path canonical; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(base_dir_, canonical)); + + EXPECT_THAT(canonical.wstring(), testing::StartsWith(L"\\\\?\\GLOBALROOT\\Device\\")); + + std::error_code ec; + EXPECT_TRUE(std::filesystem::exists(canonical, ec)) << "ec=" << ec.message(); +} + +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_ExistingFile) { + CreateEmptyFile(base_dir_ / "data.bin"); + + std::filesystem::path canonical; + ASSERT_TRUE( + onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(base_dir_ / "data.bin", canonical)); + + EXPECT_THAT(canonical.wstring(), testing::StartsWith(L"\\\\?\\GLOBALROOT\\Device\\")); + + std::error_code ec; + EXPECT_TRUE(std::filesystem::exists(canonical, ec)) << "ec=" << ec.message(); + EXPECT_TRUE(std::filesystem::is_regular_file(canonical, ec)) << "ec=" << ec.message(); +} + +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_NonExistentLeafLexicallyAppended) { + const std::filesystem::path leaf{L"does_not_exist.bin"}; + std::filesystem::path canonical; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(base_dir_ / leaf, canonical)); + + EXPECT_THAT(canonical.wstring(), testing::StartsWith(L"\\\\?\\GLOBALROOT\\Device\\")); + EXPECT_EQ(canonical.filename(), leaf); + + // The canonicalized parent must be a path-component prefix of the result so that the + // containment check in ValidateExternalDataPath still works. + std::filesystem::path parent_canonical; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(base_dir_, parent_canonical)); + auto [parent_end, full_it] = std::mismatch(parent_canonical.begin(), parent_canonical.end(), + canonical.begin(), canonical.end()); + EXPECT_EQ(parent_end, parent_canonical.end()) + << "parent: " << parent_canonical << " full: " << canonical; +} + +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_NonExistentMiddleAndLeaf) { + std::filesystem::path canonical; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting( + base_dir_ / L"missing_dir" / L"data.bin", canonical)); + + EXPECT_THAT(canonical.wstring(), testing::StartsWith(L"\\\\?\\GLOBALROOT\\Device\\")); + EXPECT_EQ(canonical.filename(), std::filesystem::path{L"data.bin"}); + EXPECT_EQ(canonical.parent_path().filename(), std::filesystem::path{L"missing_dir"}); +} + +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_AllNonExistentReturnsFalse) { + // Synthetic absolute path on a non-existent volume. The fallback must return false so + // the caller surfaces the original weakly_canonical error rather than substituting an + // unverified path. + const std::filesystem::path bogus{L"\\\\?\\Volume{00000000-0000-0000-0000-000000000000}\\nope\\data.bin"}; + std::filesystem::path canonical; + EXPECT_FALSE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(bogus, canonical)); +} + +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_MatchesWeaklyCanonicalAtFile) { + // Compare via std::filesystem::equivalent: the fallback returns the NT form + // (\\?\GLOBALROOT\Device\HarddiskVolumeN\...) while weakly_canonical returns the DOS + // form (C:\...), but both must point at the same file. + CreateEmptyFile(base_dir_ / "compare.bin"); + const auto target = base_dir_ / "compare.bin"; + + std::error_code ec; + const auto reference = std::filesystem::weakly_canonical(target, ec); + ASSERT_FALSE(ec) << ec.message(); + + std::filesystem::path fallback; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(target, fallback)); + + EXPECT_TRUE(std::filesystem::equivalent(reference, fallback, ec)) + << "reference=" << reference << " fallback=" << fallback << " ec=" << ec.message(); +} + +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_ResolvesSymlinks) { + const auto target = base_dir_ / "symlink_target.bin"; + const auto link = base_dir_ / "symlink_link.bin"; + try { + std::ofstream{target}; + std::filesystem::create_symlink(target, link); + } catch (const std::exception& e) { + GTEST_SKIP() << "Symlink creation not supported in this environment: " << e.what(); + } + + std::filesystem::path link_canonical; + std::filesystem::path target_canonical; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(link, link_canonical)); + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(target, target_canonical)); + + std::error_code ec; + EXPECT_TRUE(std::filesystem::equivalent(link_canonical, target_canonical, ec)) + << "link=" << link_canonical << " target=" << target_canonical << " ec=" << ec.message(); +} + +TEST_F(PathValidationTest, WeaklyCanonicalPathNtVolumeFallback_ResolvesDotDot) { + CreateDirectories(base_dir_ / "sub_for_dotdot"); + + std::filesystem::path canonical; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting( + base_dir_ / "sub_for_dotdot" / "..", canonical)); + + std::filesystem::path base_canonical; + ASSERT_TRUE(onnxruntime::internal::WeaklyCanonicalPathNtVolumeFallbackForTesting(base_dir_, base_canonical)); + + std::error_code ec; + EXPECT_TRUE(std::filesystem::equivalent(canonical, base_canonical, ec)) + << "canonical=" << canonical << " base=" << base_canonical << " ec=" << ec.message(); +} +#endif // defined(_WIN32) + TEST(TensorProtoUtilsTest, GetNodeProtoLayeringAnnotation) { // Case 1: Annotation exists { From 00e9ef3d67647002441b9f8d591e3e1cd8e9f765 Mon Sep 17 00:00:00 2001 From: vraspar Date: Fri, 15 May 2026 09:41:38 -0700 Subject: [PATCH 09/11] Fix input validation and null-pointer dereference in STFTDecomposition graph transformer (#28465) ### Description Validate that `frame_step` and `dft_size` (derived from `frame_length`) are positive before they are used in buffer sizing and loop arithmetic. Use SafeInt for the weight buffer allocation to guard against overflow. Also fix an unconditional dereference of `window_recipient` which is nullptr when the STFT node has no window input. ### Motivation and Context A model with non-positive initializer values for frame_length or frame_step causes signed-to-unsigned wrapping in size computations, leading to out-of-bounds writes during graph optimization. The nullptr dereference is a crash on any windowless STFT node. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/optimizer/stft_decomposition.cc | 32 +++++++-- .../test/optimizer/graph_transform_test.cc | 61 ++++++++++++++++++ .../transform/stft_negative_frame_length.onnx | Bin 0 -> 204 bytes .../transform/stft_negative_frame_length.py | 46 +++++++++++++ .../transform/stft_negative_frame_step.onnx | Bin 0 -> 204 bytes .../testdata/transform/stft_no_window.onnx | Bin 0 -> 195 bytes 6 files changed, 132 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/stft_negative_frame_length.onnx create mode 100644 onnxruntime/test/testdata/transform/stft_negative_frame_length.py create mode 100644 onnxruntime/test/testdata/transform/stft_negative_frame_step.onnx create mode 100644 onnxruntime/test/testdata/transform/stft_no_window.onnx diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc index 49fe180b9470c..713cbcc6193d3 100644 --- a/onnxruntime/core/optimizer/stft_decomposition.cc +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -9,6 +9,7 @@ #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" #include "core/optimizer/utils.h" +#include "core/common/safeint.h" #include "core/framework/op_kernel.h" #include "core/framework/tensorprotoutils.h" #include @@ -210,6 +211,14 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve dft_size = window_length_dim.dim_value(); } + // Validate model-provided scalar values before using them in size calculations. + // These come from untrusted model initializers/shapes and must be positive. + if (dft_size <= 0 || frame_step_value <= 0) { + LOGS(logger, WARNING) << "STFT decomposition skipped: invalid dft_size (" << dft_size + << ") or frame_step_value (" << frame_step_value << ")"; + continue; + } + bool is_onesided = true; auto& attrs = stft.GetAttributes(); if (attrs.find("onesided") != attrs.end()) { @@ -227,14 +236,22 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve if (is_real) { auto output_num_frames = stft.MutableOutputDefs()[0]->Shape()->dim(1).dim_value(); auto output_frame_length = stft.MutableOutputDefs()[0]->Shape()->dim(2).dim_value(); - auto weight_size = static_cast(dft_unique_bins * dft_size); + + size_t dft_size_sz, dft_unique_bins_sz, weight_size; + if (!SafeCast(dft_unique_bins, dft_unique_bins_sz) || + !SafeCast(dft_size, dft_size_sz) || + !SafeMultiply(dft_unique_bins_sz, dft_size_sz, weight_size)) { + LOGS(logger, WARNING) << "STFT decomposition skipped: weight size overflow"; + continue; + } + auto real_weights_data = std::vector(weight_size); auto imag_weights_data = std::vector(weight_size); // Populate weights - for (size_t k = 0; k < static_cast(dft_unique_bins); k++) { - for (size_t n = 0; n < static_cast(dft_size); n++) { - auto index = static_cast(k * dft_size + n); + for (size_t k = 0; k < dft_unique_bins_sz; k++) { + for (size_t n = 0; n < dft_size_sz; n++) { + auto index = k * dft_size_sz + n; auto theta = -2 * std::numbers::pi_v * k * n / static_cast(dft_size); real_weights_data[index] = static_cast(cos(theta)); imag_weights_data[index] = static_cast(sin(theta)); @@ -356,7 +373,6 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve // Copy inputs auto signal_target_idx = signal_recipient->Index(); - auto window_target_idx = window_recipient->Index(); for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) { const graph_utils::GraphEdge& edge = *cur; NodeIndex target_idx = 0; @@ -367,8 +383,10 @@ Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_leve recipient = signal_recipient; break; case 2: - target_idx = window_target_idx; - recipient = window_recipient; + if (window_recipient) { + target_idx = window_recipient->Index(); + recipient = window_recipient; + } break; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 591740be0263d..0e4ab5c2d3b73 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -76,6 +76,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/slice_concat_to_space_to_depth_fusion.h" #include "core/optimizer/slice_elimination.h" +#include "core/optimizer/stft_decomposition.h" #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/utils.h" #include "core/platform/env.h" @@ -10425,5 +10426,65 @@ TEST_F(GraphTransformationTests, DivMulFusion_MultiElementInitializer) { // `dropout_elimination.cc` remains as pure defense-in-depth against future // internal callers that may bypass shape inference. +// These tests verify that STFTDecomposition skips malformed models +// instead of crashing with OOB writes from negative initializer values. +TEST_F(GraphTransformationTests, STFTDecomposition_NegativeFrameLength) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "stft_negative_frame_length.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["STFT"], 1); + + const InlinedHashSet empty_ep = {}; + auto stft_transformer = std::make_unique(empty_ep); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(stft_transformer), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // STFT node should NOT be decomposed — transformer skips invalid models + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["STFT"], 1); +} + +TEST_F(GraphTransformationTests, STFTDecomposition_NegativeFrameStep) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "stft_negative_frame_step.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["STFT"], 1); + + const InlinedHashSet empty_ep = {}; + auto stft_transformer = std::make_unique(empty_ep); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(stft_transformer), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // STFT node should NOT be decomposed — transformer skips invalid models + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["STFT"], 1); +} + +TEST_F(GraphTransformationTests, STFTDecomposition_NoWindowInput) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "stft_no_window.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["STFT"], 1); + + const InlinedHashSet empty_ep = {}; + auto stft_transformer = std::make_unique(empty_ep); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(stft_transformer), TransformerLevel::Level1)); + // Should not crash (previously dereferenced nullptr window_recipient) + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + // Valid windowless STFT should be successfully decomposed + op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["STFT"], 0); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/transform/stft_negative_frame_length.onnx b/onnxruntime/test/testdata/transform/stft_negative_frame_length.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6cfb53f2fd431fe2c28bf1793a4d6d7713899003 GIT binary patch literal 204 zcmdKUM1ZnjQJ@S95HkZYlM@RUgMc6aNboGE literal 0 HcmV?d00001 From ad6d34c65ed4a5c1bde519f29bc0b95780520926 Mon Sep 17 00:00:00 2001 From: Ashwath Shankarnarayan Date: Fri, 15 May 2026 10:23:55 -0700 Subject: [PATCH 10/11] [QNN EP] ETW log level rule change (#27593) ### Description - ETW log level takes precedence over the QNN log level - This is now changed to only using ETW log level when it provides higher fidelity ### Motivation and Context - ETW log level takes precedence over the QNN log level - If ETW log level = Baisc ; Qnn log level = detailed, QNN EP still picks basic logging over detailed. --- .../core/providers/qnn/builder/qnn_backend_manager.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 4a6692778da0b..5758ff3ad2847 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -664,8 +664,9 @@ Status QnnBackendManager::ReleaseDevice() { Status QnnBackendManager::InitializeProfiling() { profiling_level_merge_ = profiling_level_; - // use profiling level from ETW if ETW is enabled - if (profiling_level_etw_ != ProfilingLevel::INVALID) { + // Only use ETW level if it provides higher fidelity + if (profiling_level_etw_ != ProfilingLevel::INVALID && + profiling_level_etw_ > profiling_level_) { profiling_level_merge_ = profiling_level_etw_; } From 45c0663223c7dde21c848e0e4d1c9f624189f2d0 Mon Sep 17 00:00:00 2001 From: Ryan VanderMeulen Date: Fri, 15 May 2026 15:36:27 -0400 Subject: [PATCH 11/11] Fix mlasi_sve.h preprocessor guards to allow clang compilation (#28507) ### Description The outer `#ifndef __clang__` in `mlasi_sve.h` (line 20 to line 679) was intended to wrap the GCC-specific `#pragma GCC` directives, but it also ends up hiding every SVE kernel declaration and the typedefs from clang. The `#ifdef __clang__` block that defines `MLAS_SVE_TARGET` for clang's per-function `__attribute__((target("...")))` syntax is unreachable for the same reason. This moves the closing `#endif` up to right after the GCC pragmas so only the pragmas are GCC-only, and the rest of the header (typedefs, kernel declarations, MLAS_SVE_TARGET) is visible to both compilers. ### Motivation and Context Without this, building MLAS with clang for aarch64 fails at platform.cpp - the `MLAS_USE_SVE` runtime-dispatch block references `MlasSveErfKernel`, `MlasSveLogisticKernel`, and friends, all undeclared from clang's perspective. Confirmed working with clang 20.1.8. --- onnxruntime/core/mlas/lib/sve/mlasi_sve.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h index 922945c702119..3660bc108165c 100644 --- a/onnxruntime/core/mlas/lib/sve/mlasi_sve.h +++ b/onnxruntime/core/mlas/lib/sve/mlasi_sve.h @@ -20,6 +20,7 @@ Module Name: #ifndef __clang__ #pragma GCC push_options #pragma GCC target("arch=armv8.2-a+sve") +#endif // Use Clang-specific per-function attribute #ifdef __clang__ @@ -676,5 +677,4 @@ MlasSveCompareGreaterThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B) #pragma GCC pop_options #endif -#endif