Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ target_link_libraries(
triton-core-backendapi # from repo-core
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
${CMAKE_DL_LIBS} # dlopen/dlsym/dlclose for the MODEL_INIT_LIBRARY hook

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMAKE_ prefix is a reserved keyword, afraid it's usage may create some confusion. See examples below.
http://cmake.org/cmake/help/latest/manual/cmake-variables.7.html

Also change description explains following, which means those libraries are vital to preserve the functionality, isn't?

An HSTU AOTI package (model.pt2) calls nve_ops::embedding_lookup(keys, layer_id),
which looks the embedding table up by layer_id in a process-global
NVELayerRegistry. The embedding weights do not live inside model.pt2 —
they are separate files next to it (<model_dir>/metadata.json +
<model_dir>/weights/*.nve). So loading the .pt2 does not load them, and the
first inference request fails

Could you please share the missed libraries origin ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@z52527 please follow up with @mc-nv's requests. This is blocking merge.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. PR description was stale; it still described the earlier "link NVE into the backend" version. I've updated it.
In the current revision the backend links no NVE or vendor libraries. There are no missing libraries to source. The only CMake change is target_link_libraries(... ${CMAKE_DL_LIBS}). CMAKE_DL_LIBS is a built-in CMake variable for the platform's dynamic-loading library (libdl), needed for the dlopen/dlsym/dlclose the hook uses.

${TRITON_PYTORCH_LDFLAGS}
${TRITON_PYTORCH_LIBS}
$<$<BOOL:${TRITON_PYTORCH_OPTREE}>:optree>
Expand Down
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,56 @@ output: [
> Support for batch sizes greater than 1 and for sequence batching for AOT Inductor compiled models has not be completed.
> These Triton Server features are currently unavailable for PyTorch models compiled using AOT Inductor and packaged as a PT2 model archive.

### Model Initialization Hook (`MODEL_INIT_LIBRARY`)

Some models need to run one-time **native** initialization at load time — before
the model package is loaded — that cannot be expressed inside `model.pt2`. A
typical case is a model whose custom operators resolve large weights from a
process-global registry at execute time, where those weights live outside the
package (e.g. in sidecar files next to `model.pt2`) and must be read into memory
and registered first.

The PT2 path supports an optional, per-model hook for this. When a model's
`config.pbtxt` sets the `MODEL_INIT_LIBRARY` parameter, the backend `dlopen()`s
that shared library at model load and calls its initialization entry point. The
backend does **not** link against the library and knows nothing about what it
does; it only loads it, calls a fixed entry point, and releases it on unload.

```
parameters {
key: "MODEL_INIT_LIBRARY"
value: { string_value: "/abs/path/to/libmy_model_init.so" }
}
```

The library must export the following C entry point, and may optionally export a
matching finalizer:

```c
// Called once at model load, before the model package is loaded.
// model_dir : the versioned model directory (the one containing model.pt2)
// device_index : the GPU ordinal for the instance, or -1 for CPU
// Returns an opaque handle that is passed back to the finalizer on unload, or
// NULL if there is nothing to keep.
extern "C" void* triton_pytorch_model_init(const char* model_dir, int device_index);

// Optional. Called once on model unload with the handle returned above.
extern "C" void triton_pytorch_model_fini(void* state);
```

If `MODEL_INIT_LIBRARY` is unset (the default), no library is loaded and this is
a complete no-op. If it is set but the library cannot be `dlopen()`ed or does not
export `triton_pytorch_model_init`, model load fails with an error.

> [!WARNING]
> `MODEL_INIT_LIBRARY` makes the backend load and execute native code from the
> path given in the model configuration, in the server process and with its
> privileges. Only point it at libraries you trust. This is the same trust level
> already required for the model repository itself — a PT2 package contains
> compiled code that runs when the model is loaded — so it does not widen the
> trust boundary, but the model repository must remain a trusted,
> operator-controlled source.

### PyTorch 2.0 Models

PyTorch 2.0 features are available.
Expand Down
96 changes: 95 additions & 1 deletion src/pt2/model_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

#include "model_state.hh"

#include <dlfcn.h>

#include <mutex>

#include "../libtorch.hh"
Expand Down Expand Up @@ -349,6 +351,13 @@ ModelState::LoadModel(
<< local_file_path << "\" is unreachable or inaccessible.");
}

// Optional per-model native init hook (MODEL_INIT_LIBRARY parameter). Lets a
// model run one-time native setup before its package is loaded -- e.g.
// loading external embedding weights into a process-global registry --
// without the backend linking against or knowing anything about that library.
// No-op when the parameter is unset.
MaybeRunModelInitHook(repository_path, repository_version, device);

std::pair<bool, int> device_pair{false, 0};
if (weight_sharing_enabled_) {
device_pair = std::make_pair(!device.is_cpu(), device.index());
Expand Down Expand Up @@ -555,7 +564,7 @@ void
ModelState::ParseParameters()
{
TritonJsonValue parameters;
if (!ModelConfig().Find("parameters", &parameters)) {
if (ModelConfig().Find("parameters", &parameters)) {
bool disable_optimized_execution{false};
if (auto err = ParseParameter(
parameters, "DISABLE_OPTIMIZED_EXECUTION",
Expand Down Expand Up @@ -696,6 +705,29 @@ ModelState::ParseParameters()
}
TRITONSERVER_ErrorDelete(err);
}
{
TritonJsonValue init_lib_param;
if (parameters.Find("MODEL_INIT_LIBRARY", &init_lib_param)) {
Comment thread
whoisj marked this conversation as resolved.
if (auto err = init_lib_param.MemberAsString(
"string_value", &model_init_library_)) {
DEBUG_TRACE_ERROR(
"{ model: \"" << Name() << "\", error: \""
<< TRITONSERVER_ErrorMessage(err) << "\" }");
THROW_TRITON_EXCEPTION(
TRITONSERVER_ErrorCode(err),
"Failed to parse 'MODEL_INIT_LIBRARY' parameter for model \""
<< Name() << "\": " << TRITONSERVER_ErrorMessage(err) << ".");
}
}
}

if (!model_init_library_.empty()) {
TRITON_LOG_INFO(
"Model-init library is \"" << model_init_library_
<< "\" for model instance \"" << Name()
<< "\".");
}

DEBUG_TRACE_INFO(
"{ disable_optimized_execution: "
<< (disable_optimized_execution ? "true" : "false")
Expand All @@ -720,6 +752,68 @@ ModelState::ParseParameters()
}
}

void
ModelState::MaybeRunModelInitHook(
const std::string& repository_path, const std::string& repository_version,
const torch::Device& device)
{
// No hook configured, or it already ran for this model.
if (model_init_library_.empty() || model_init_dl_handle_ != nullptr) {
return;
}

const std::string package_dir =
triton::backend::JoinPath({repository_path, repository_version});
// GPU ordinal; -1 for CPU. The hook decides how to interpret it.
const int device_index = device.index();

void* handle = dlopen(model_init_library_.c_str(), RTLD_NOW | RTLD_LOCAL);
if (handle == nullptr) {
const char* dl_err = dlerror();
THROW_TRITON_EXCEPTION(
TRITONSERVER_ERROR_INVALID_ARG,
"Failed to dlopen MODEL_INIT_LIBRARY \""
<< model_init_library_ << "\" for model \"" << Name()
<< "\": " << (dl_err != nullptr ? dl_err : "unknown error") << ".");
}

using ModelInitFn = void* (*)(const char*, int);
dlerror(); // Clear any stale error.
auto* init_fn =
reinterpret_cast<ModelInitFn>(dlsym(handle, "triton_pytorch_model_init"));
const char* sym_err = dlerror();
if (init_fn == nullptr || sym_err != nullptr) {
dlclose(handle);
THROW_TRITON_EXCEPTION(
TRITONSERVER_ERROR_INVALID_ARG,
"MODEL_INIT_LIBRARY \""
<< model_init_library_
<< "\" does not export 'triton_pytorch_model_init' for model \""
<< Name() << "\": " << (sym_err != nullptr ? sym_err : "not found")
<< ".");
}

model_init_state_ = init_fn(package_dir.c_str(), device_index);
model_init_dl_handle_ = handle;
TRITON_LOG_INFO(
"Ran model-init hook \"" << model_init_library_ << "\" for model \""
<< Name() << "\" (device_index=" << device_index
<< ").");
}

ModelState::~ModelState()
{
if (model_init_dl_handle_ != nullptr) {
using ModelFiniFn = void (*)(void*);
auto* fini_fn = reinterpret_cast<ModelFiniFn>(
dlsym(model_init_dl_handle_, "triton_pytorch_model_fini"));
Comment thread
whoisj marked this conversation as resolved.
if (fini_fn != nullptr) {
fini_fn(model_init_state_);
}
dlclose(model_init_dl_handle_);
}
}

const std::string&
ModelState::RepositoryPath() const
{
Expand Down
16 changes: 15 additions & 1 deletion src/pt2/model_state.hh
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,17 @@ class ModelState : public triton::backend::BackendModel {
bool optimized_execution_enabled_{true};
bool weight_sharing_enabled_{false};

// Optional per-model native init hook (MODEL_INIT_LIBRARY parameter). The
// backend dlopen()s this library at model load and calls its
// triton_pytorch_model_init() entry point; it never links against it.
std::string model_init_library_;
void* model_init_dl_handle_{nullptr};
void* model_init_state_{nullptr};

public:
ModelState() = delete;

virtual ~ModelState() = default;
virtual ~ModelState();

[[nodiscard]] bool CacheCleaningEnabled() const;

Expand Down Expand Up @@ -164,5 +171,12 @@ class ModelState : public triton::backend::BackendModel {
void AutoCompleteConfig();

void ParseParameters();

// Run the optional MODEL_INIT_LIBRARY hook (no-op if unset). dlopen()s the
// library and calls its triton_pytorch_model_init(); throws on failure so a
// misconfigured hook fails model load loudly. The backend does not link it.
void MaybeRunModelInitHook(
const std::string& repository_path, const std::string& repository_version,
const torch::Device& device);
};
} // namespace triton::backend::pytorch::pt2
Loading