Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/backend_model.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -825,10 +825,19 @@ TritonModel::SetConfiguredScheduler(
for (const auto& input : config_.input()) {
if (input.is_shape_tensor()) {
enforce_equal_shape_tensors.insert({input.name(), true});
} else if (
!input.allow_ragged_batch() &&
(triton::common::GetElementCount(input) == -1)) {
enforce_equal_shape_tensors.insert({input.name(), false});
} else {
auto element_count = triton::common::GetElementCount(input);
if (element_count == triton::common::OVERFLOW_SIZE) {
Comment thread
whoisj marked this conversation as resolved.
Outdated
return Status(
Status::Code::INVALID_ARG,
"input '" + input.name() +
"' causes total element count to exceed maximum size of " +
std::to_string(INT64_MAX));
}
if (!input.allow_ragged_batch() &&
(element_count == triton::common::WILDCARD_SIZE)) {
enforce_equal_shape_tensors.insert({input.name(), false});
}
}
}

Expand Down
27 changes: 20 additions & 7 deletions src/backend_model_instance.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -374,14 +374,19 @@ TritonModelInstance::GenerateWarmupData()
for (const auto& input_meta : warmup_setting.inputs()) {
auto element_count =
triton::common::GetElementCount(input_meta.second.dims());
if (element_count == -1) {
if (element_count == triton::common::WILDCARD_SIZE) {
return Status(
Status::Code::INVALID_ARG,
"warmup setting expects all variable-size dimensions are specified "
"for input '" +
input_meta.first + "'");
} else if (element_count == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG,
"warmup setting for input '" + input_meta.first +
"' causes total element count to exceed maximum size of " +
std::to_string(INT64_MAX));
}

int64_t batch_byte_size =
element_count *
triton::common::GetDataTypeByteSize(input_meta.second.data_type());
Comment thread
yinggeh marked this conversation as resolved.
Outdated
Expand Down Expand Up @@ -445,12 +450,20 @@ TritonModelInstance::GenerateWarmupData()
for (const auto& input_meta : warmup_setting.inputs()) {
auto batch1_element_count =
triton::common::GetElementCount(input_meta.second.dims());
auto batch_byte_size =
batch1_element_count *
auto dtype_byte_size =
triton::common::GetDataTypeByteSize(input_meta.second.data_type());
if (batch_byte_size == 0) {
batch_byte_size = batch1_element_count * sizeof(int32_t);
dtype_byte_size =
dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size;
if (batch1_element_count == triton::common::OVERFLOW_SIZE ||
(batch1_element_count >
INT64_MAX / static_cast<int64_t>(dtype_byte_size))) {
return Status(
Status::Code::INVALID_ARG,
"warmup setting for input '" + input_meta.first +
"' causes total element count to exceed maximum size of " +
std::to_string(INT64_MAX));
Comment thread
yinggeh marked this conversation as resolved.
Outdated
}
auto batch_byte_size = batch1_element_count * dtype_byte_size;
Comment thread
yinggeh marked this conversation as resolved.
Outdated

const char* allocated_ptr;
switch (input_meta.second.input_data_type_case()) {
Expand Down
21 changes: 18 additions & 3 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -1173,14 +1173,14 @@ InferenceRequest::Normalize()
if (input_config->has_reshape()) {
std::deque<int64_t> variable_size_values;
for (int64_t idx = 0; idx < input_config->dims_size(); idx++) {
if (input_config->dims(idx) == -1) {
if (input_config->dims(idx) == triton::common::WILDCARD_DIM) {
variable_size_values.push_back((*shape)[idx]);
}
}

shape->clear();
for (const auto& dim : input_config->reshape().shape()) {
if (dim == -1) {
if (dim == triton::common::WILDCARD_DIM) {
shape->push_back(variable_size_values.front());
variable_size_values.pop_front();
} else {
Expand Down Expand Up @@ -1222,6 +1222,13 @@ InferenceRequest::Normalize()
int64_t expected_byte_size =
triton::common::GetByteSize(data_type, input_dims);
const size_t& byte_size = input.Data()->TotalByteSize();
if (expected_byte_size == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input_name +
"' causes total byte size to exceed maximum size of " +
Comment thread
yinggeh marked this conversation as resolved.
Outdated
std::to_string(INT64_MAX));
}
if ((byte_size > LLONG_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Expand Down Expand Up @@ -1322,6 +1329,14 @@ InferenceRequest::ValidateBytesInputs(
size_t remaining_buffer_size = 0;
int64_t buffer_memory_id;

if (element_count == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input_name +
"' causes total element count to exceed maximum size of " +
Comment thread
yinggeh marked this conversation as resolved.
Outdated
std::to_string(INT64_MAX));
}

// Validate elements until all buffers have been fully processed.
while (remaining_buffer_size || buffer_next_idx < buffer_count) {
// Get the next buffer if not currently processing one.
Expand Down
17 changes: 13 additions & 4 deletions src/model_config_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -357,6 +357,15 @@ ValidateIOShape(
const int64_t reshape_size =
triton::common::GetElementCount(io.reshape().shape());

if (dims_size == triton::common::OVERFLOW_SIZE ||
reshape_size == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG,
message_prefix_with_name +
"causes total element count to exceed maximum size of " +
std::to_string(INT64_MAX));
}

// dims and reshape must both have same element count
// or both have variable-size dimension.
// Special case for empty reshape... expect dims to have element
Expand All @@ -372,12 +381,12 @@ ValidateIOShape(
// each pair of the trunks separated by variable-size dimension has
// the same element count. For instance, from [2, 4, -1, 6] to [8, -1, 1, 6]
// is valid reshape as 2 * 4 = 8 and 6 = 1 * 6.
if (dims_size == -1) {
if (dims_size == triton::common::WILDCARD_DIM) {
Comment thread
yinggeh marked this conversation as resolved.
Outdated
std::vector<int64_t> dim_element_cnts;
std::vector<int64_t> reshape_element_cnts;
int64_t current_cnt = 1;
for (const auto& dim : io.dims()) {
if (dim != -1) {
if (dim != triton::common::WILDCARD_DIM) {
current_cnt *= dim;
} else {
dim_element_cnts.push_back(current_cnt);
Expand All @@ -388,7 +397,7 @@ ValidateIOShape(

current_cnt = 1;
for (const auto& dim : io.reshape().shape()) {
if (dim != -1) {
if (dim != triton::common::WILDCARD_DIM) {
current_cnt *= dim;
} else {
reshape_element_cnts.push_back(current_cnt);
Expand Down
19 changes: 16 additions & 3 deletions src/sequence_batch_scheduler/sequence_batch_scheduler.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -384,13 +384,14 @@ SequenceBatchScheduler::GenerateInitialStateData(
auto state_dim = state.dims().begin();
for (; initial_state_dim != initial_state.dims().end();
initial_state_dim++, state_dim++) {
if (*initial_state_dim == -1) {
if (*initial_state_dim == triton::common::WILDCARD_DIM) {
return Status(
Status::Code::INVALID_ARG,
std::string("'initial_state' field for state input name '") +
state.input_name() + "' contains variable dimensions.");
} else {
if (*state_dim != -1 && *initial_state_dim != *state_dim) {
if (*state_dim != triton::common::WILDCARD_DIM &&
*initial_state_dim != *state_dim) {
return Status(
Status::Code::INVALID_ARG,
std::string("'initial_state' dim for input name '") +
Expand All @@ -409,6 +410,18 @@ SequenceBatchScheduler::GenerateInitialStateData(
triton::common::GetDataTypeByteSize(initial_state.data_type());
size_t total_byte_size = element_count * dtype_byte_size;

if (element_count == triton::common::OVERFLOW_SIZE ||
(dtype_byte_size != 0 &&
(element_count > INT64_MAX / (static_cast<int64_t>(dtype_byte_size)))) ||
(total_byte_size > INT64_MAX / sizeof(int32_t))) {
return Status(
Comment thread
yinggeh marked this conversation as resolved.
Outdated
Status::Code::INVALID_ARG,
std::string("'initial_state' field for state input name '") +
state.input_name() +
"' causes total element count to exceed maximum size of " +
std::to_string(INT64_MAX));
}

// Custom handling for TYPE_BYTES
if (dtype_byte_size == 0) {
total_byte_size = sizeof(int32_t) * element_count;
Expand Down
19 changes: 17 additions & 2 deletions src/sequence_state.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -159,7 +159,7 @@ SequenceStates::Initialize(

// Convert the variable dimensions to 1 for the first request.
for (auto& dim : state_config.dims()) {
if (dim == -1) {
if (dim == triton::common::WILDCARD_DIM) {
dims.push_back(1);
} else {
dims.push_back(dim);
Expand Down Expand Up @@ -212,12 +212,27 @@ SequenceStates::Initialize(
size_t state_size;
if (state.second.data_type() == inference::DataType::TYPE_STRING) {
auto element_count = triton::common::GetElementCount(dims);
if (element_count == triton::common::OVERFLOW_SIZE ||
(element_count > INT64_MAX / 4)) {
Comment thread
yinggeh marked this conversation as resolved.
Outdated
Comment thread
whoisj marked this conversation as resolved.
Outdated
return Status(
Status::Code::INVALID_ARG,
"state '" + state_config.input_name() +
"' causes total element count to exceed maximum size of " +
Comment thread
yinggeh marked this conversation as resolved.
Outdated
std::to_string(INT64_MAX));
}
// Total number of bytes required is equal to the element count
// multiplied by 4.
state_size = 4 * element_count;
Comment thread
yinggeh marked this conversation as resolved.
Outdated
} else {
state_size =
triton::common::GetByteSize(state.second.data_type(), dims);
if (state_size == static_cast<size_t>(triton::common::OVERFLOW_SIZE)) {
return Status(
Status::Code::INVALID_ARG,
"state '" + state_config.input_name() +
Comment thread
yinggeh marked this conversation as resolved.
Outdated
"' causes total byte size to exceed maximum size of " +
std::to_string(INT64_MAX));
}
}
if (use_growable_memory) {
std::unique_ptr<GrowableMemory> growable_memory;
Expand Down
Loading