Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ else()
check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE)
check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS)
check_cxx_compiler_flag(-Wcharacter-conversion HAS_CHARACTER_CONVERSION)
check_cxx_compiler_flag(-Wno-error=character-conversion HAS_NO_ERROR_CHARACTER_CONVERSION)
check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE)
check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION)
check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS)
Expand Down
6 changes: 5 additions & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,13 @@ else()
enable_language(ASM)
check_cxx_source_compiles("
#ifdef _AIX
#include <sys/systemcfg.h>
#if !defined(POWER_10)
#define POWER_10 0x40000
#endif
#if !defined(POWER_10_ANDUP)
#define POWER_10_ANDUP (POWER_10)
#include <sys/systemcfg.h>
#endif
#define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP)
int main() {
bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31);
Expand Down
15 changes: 9 additions & 6 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ function(filter_test_srcs test_srcs_var)
endfunction()

set(disabled_warnings)

function(onnxruntime_disable_gtest_character_conversion_as_error target_name)
if (HAS_NO_ERROR_CHARACTER_CONVERSION)
target_compile_options(${target_name} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=character-conversion>")
endif()
endfunction()

function(AddTest)
cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS;TEST_ARGS" ${ARGN})
list(REMOVE_DUPLICATES _UT_SOURCES)
Expand Down Expand Up @@ -170,9 +177,7 @@ function(AddTest)
if (${HAS_NOERROR})
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=uninitialized>")
endif()
if (${HAS_CHARACTER_CONVERSION})
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=character-conversion>")
endif()
onnxruntime_disable_gtest_character_conversion_as_error(${_UT_TARGET})
endif()

set(TEST_ARGS ${_UT_TEST_ARGS})
Expand Down Expand Up @@ -847,9 +852,7 @@ if(MSVC)
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd6326>")
else()
target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
if (HAS_CHARACTER_CONVERSION)
target_compile_options(onnxruntime_test_utils PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=character-conversion>")
endif()
onnxruntime_disable_gtest_character_conversion_as_error(onnxruntime_test_utils)
endif()
if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS})
Expand Down
13 changes: 6 additions & 7 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3569,7 +3569,6 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.NhwcFusedConv"></a><a name="com.microsoft.nhwcfusedconv">**com.microsoft.NhwcFusedConv**</a>

NhwcFusedConv is a Conv operator with optional activation and add operators fused in.
Only has fp16 implementation as of 2023/04/15.

#### Version

Expand Down Expand Up @@ -3600,26 +3599,26 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>X</tt> : T</dt>
<dd></dd>
<dd>Input activation tensor in channels-last layout. For 2D convolution this is [N, H, W, C], where N is batch size, H/W are spatial dimensions, and C is the number of input channels.</dd>
<dt><tt>W</tt> : T</dt>
<dd></dd>
<dd>Convolution weight tensor in the standard ONNX Conv filter layout [M, C/group, kH, kW], where M is the number of output channels.</dd>
<dt><tt>B</tt> (optional) : T</dt>
<dd></dd>
<dd>Optional 1D bias tensor of shape [M].</dd>
<dt><tt>Z</tt> (optional) : T</dt>
<dd>Tensor to be added to the output, must be the same shape and format as the output tensor.</dd>
<dd>Optional residual/add tensor in the same channels-last layout and shape as the output tensor Y. For 2D convolution this is [N, out_H, out_W, M].</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> : T</dt>
<dd></dd>
<dd>Output tensor in channels-last layout. For 2D convolution this is [N, out_H, out_W, M], where M is the number of output channels.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain input and output types to float tensors</dd>
</dl>

Expand Down
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,8 @@ The **OpSet Version** column uses the following notation:
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
|LayerNormalization|*in* X:**T**<br> *in* Scale:**T**<br> *in* B:**T**<br> *out* Y:**T**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**<br><br>or<br><br>*in* X:**T**<br> *in* Scale:**V**<br> *in* B:**V**<br> *out* Y:**V**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**|17+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(float)|
|||[1, 16]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimiz
// Default is an empty string which means no optimizers are disabled.
static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";

// Maximum total output size in bytes that the constant folding optimizer is allowed to produce per node.
// Prevents malicious models from causing excessive memory allocation during optimization.
// If the estimated or actual output size of a constant-foldable node exceeds this limit, the node will
// not be constant folded and will instead be executed at runtime.
//
// Option values:
// - A positive integer (as string): Maximum allowed output size in bytes per constant-folded node.
// Default is "1073741824" (1 GB).
// - "0": Disable the size limit (not recommended for untrusted models).
static const char* const kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes =
"optimization.constant_folding_max_output_size_in_bytes";

// It controls whether to run graph optimizations in loop or not.
//
// "0": disable. Graph Optimization Loop is disabled.
Expand Down
47 changes: 30 additions & 17 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <array>
#include <vector>
#include "core/common/common.h"
#include "core/common/narrow.h"
#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h"
#ifndef SHARED_PROVIDER
#include "core/framework/op_kernel.h"
Expand Down Expand Up @@ -57,11 +58,11 @@ class AttentionBase {
AttentionBase(const KernelInfoType& info, bool require_same_hidden_size) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
num_heads_ = narrow<int>(num_heads);

is_unidirectional_ = info.template GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
do_rotary_ = info.template GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_embedding_ = static_cast<int>(info.template GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
rotary_embedding_ = narrow<int>(info.template GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
mask_filter_value_ = info.template GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
scale_ = info.template GetAttrOrDefault<float>("scale", 0.0f);
if (!info.template GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
Expand Down Expand Up @@ -222,6 +223,14 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
"Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'");
}

// Q, K, V are packed along bias_dims[0]. When their hidden sizes are required to be equal,
// bias_dims[0] == 3 * hidden_size must be a multiple of 3.
if (require_same_hidden_size_ && bias_dims[0] % 3 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'bias' dimension 0 (", bias_dims[0],
") must be a multiple of 3 (Q, K, V are packed and have equal hidden sizes).");
}

int64_t q_hidden_size = bias_dims[0] / static_cast<int64_t>(3);
int64_t k_hidden_size = q_hidden_size;
int64_t v_hidden_size = k_hidden_size;
Expand All @@ -241,6 +250,10 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
q_hidden_size = qkv_hidden_sizes_[0];
k_hidden_size = qkv_hidden_sizes_[1];
v_hidden_size = qkv_hidden_sizes_[2];
} else if (q_hidden_size % num_heads_ != 0) {
// Match the error message produced by the qkv_hidden_sizes path above.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"hidden_size should be divisible by num_heads:", q_hidden_size);
}

int64_t kv_sequence_length = sequence_length;
Expand Down Expand Up @@ -282,14 +295,14 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
"Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
}

if (static_cast<int>(past_dims[2]) != num_heads_) {
if (past_dims[2] != num_heads_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
}

if (static_cast<int>(past_dims[4]) != k_hidden_size / num_heads_) {
if (past_dims[4] != k_hidden_size / num_heads_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'past' dimension 2 shall have length of ", k_hidden_size / num_heads_);
"Inputs 'past' dimension 4 shall have length of ", k_hidden_size / num_heads_);
}

if (!past_present_share_buffer_) {
Expand Down Expand Up @@ -348,17 +361,17 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,

if (parameters != nullptr) {
AttentionParameters* output_parameters = reinterpret_cast<AttentionParameters*>(parameters);
output_parameters->batch_size = static_cast<int>(batch_size);
output_parameters->sequence_length = static_cast<int>(sequence_length);
output_parameters->past_sequence_length = static_cast<int>(past_sequence_length);
output_parameters->kv_sequence_length = static_cast<int>(kv_sequence_length);
output_parameters->total_sequence_length = static_cast<int>(total_sequence_length);
output_parameters->max_sequence_length = static_cast<int>(max_sequence_length);
output_parameters->input_hidden_size = static_cast<int>(input_hidden_size);
output_parameters->hidden_size = static_cast<int>(q_hidden_size);
output_parameters->v_hidden_size = static_cast<int>(v_hidden_size);
output_parameters->head_size = static_cast<int>(q_hidden_size) / num_heads_;
output_parameters->v_head_size = static_cast<int>(v_hidden_size) / num_heads_;
output_parameters->batch_size = narrow<int>(batch_size);
output_parameters->sequence_length = narrow<int>(sequence_length);
output_parameters->past_sequence_length = narrow<int>(past_sequence_length);
output_parameters->kv_sequence_length = narrow<int>(kv_sequence_length);
output_parameters->total_sequence_length = narrow<int>(total_sequence_length);
output_parameters->max_sequence_length = narrow<int>(max_sequence_length);
output_parameters->input_hidden_size = narrow<int>(input_hidden_size);
output_parameters->hidden_size = narrow<int>(q_hidden_size);
output_parameters->v_hidden_size = narrow<int>(v_hidden_size);
output_parameters->head_size = narrow<int>(q_hidden_size) / num_heads_;
output_parameters->v_head_size = narrow<int>(v_hidden_size) / num_heads_;
output_parameters->num_heads = num_heads_;
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
Expand Down Expand Up @@ -398,7 +411,7 @@ inline Tensor* AttentionBase::GetPresent(TOpKernelContext* context,
int head_size,
int kv_sequence_length,
int& past_sequence_length) const {
past_sequence_length = (nullptr != past) ? static_cast<int>(past->Shape().GetDims()[3]) : 0;
past_sequence_length = (nullptr != past) ? narrow<int>(past->Shape().GetDims()[3]) : 0;
std::array<int64_t, 5> present_dims{2, batch_size, num_heads_,
static_cast<int64_t>(kv_sequence_length) + past_sequence_length, head_size};

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
#ifdef USE_KLEIDIAI
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NhwcFusedConv);
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
#if !defined(DISABLE_GENERATION_OPS)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
Expand Down Expand Up @@ -313,6 +316,9 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,
#ifdef USE_KLEIDIAI
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NhwcFusedConv)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
#if !defined(DISABLE_GENERATION_OPS)
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/contrib_ops/cpu/fused_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,18 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConvFloat);

#ifdef USE_KLEIDIAI
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
NhwcFusedConv,
1,
float,
KernelDefBuilder()
// Allow the optional "sum" input (index 3) to be reused as the output buffer (index 0),
// consistent with the FusedConv kernel registration.
.MayInplace(3, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConvFloat);
#endif

} // namespace contrib
} // namespace onnxruntime
Loading
Loading