Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/ml_flashpoint/adapter/nemo/wrapper_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,21 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(

flashpoint_base_container = CheckpointContainerId(flashpoint_base_container)

pool_config = BufferPoolConfig(
pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool"),
local_pool_config = BufferPoolConfig(
pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool", "local"),
rank=trainer.global_rank,
num_buffers=write_thread_count * NUM_OF_BUFFERS_PER_OBJECT,
buffer_size=initial_write_buffer_size_bytes or DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
)

ckpt_obj_manager = CheckpointObjectManager(pool_config=pool_config)
repl_pool_config = BufferPoolConfig(
pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool", "repl"),
rank=trainer.global_rank,
num_buffers=write_thread_count * NUM_OF_BUFFERS_PER_OBJECT,
buffer_size=initial_write_buffer_size_bytes or DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
)

ckpt_obj_manager = CheckpointObjectManager(local_pool_config=local_pool_config, repl_pool_config=repl_pool_config)
replication_manager = ReplicationManager()
replication_manager.initialize(checkpoint_object_manager=ckpt_obj_manager)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_library(buffer_object_lib STATIC
buffer_object.cpp
buffer_helper.cpp
buffer_pool.cpp
)

target_link_libraries(buffer_object_lib PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
#include <string>

#include "buffer_object.h"
#include "buffer_pool.h"

namespace py = pybind11;
using ml_flashpoint::checkpoint_object_manager::buffer_object::BufferPool;

// Module entry point
PYBIND11_MODULE(buffer_object_ext, m) {
Expand Down Expand Up @@ -98,4 +100,20 @@ PYBIND11_MODULE(buffer_object_ext, m) {
b.is_readonly() // Readonly flag
);
});

py::class_<BufferPool> buffer_pool_class(m, "BufferPool");

buffer_pool_class
.def(py::init<const std::string&, const std::string&, int, size_t, size_t>(),
py::arg("shm_name"), py::arg("pool_dir") = "", py::arg("rank") = 0,
py::arg("num_buffers") = 0, py::arg("buffer_size") = 0,
py::call_guard<py::gil_scoped_release>())
.def("acquire", &BufferPool::Acquire, "Acquires a buffer from the pool.",
py::arg("associated_symlink") = "",
py::call_guard<py::gil_scoped_release>())
.def("release", &BufferPool::Release, "Releases a buffer back to the pool.",
py::arg("object_id"),
py::call_guard<py::gil_scoped_release>())
.def("gc", &BufferPool::GC, "Performs garbage collection.",
py::call_guard<py::gil_scoped_release>());
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "buffer_pool.h"

#include <fcntl.h>
#include <pthread.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

#include <cstring>
#include <stdexcept>
#include <filesystem>

#include "absl/log/log.h"
#include "buffer_object.h" // Needed for pre-allocation

namespace ml_flashpoint::checkpoint_object_manager::buffer_object {

BufferPool::BufferPool(const std::string& shm_name, const std::string& pool_dir,
int rank, size_t num_buffers,
size_t buffer_size)
: shm_name_(shm_name) {

// Try to create exclusively
shm_fd_ = shm_open(shm_name_.c_str(), O_CREAT | O_EXCL | O_RDWR, 0666);
bool is_creator = true;

if (shm_fd_ == -1) {
if (errno == EEXIST) {
// Already exists, try to open read-write
shm_fd_ = shm_open(shm_name_.c_str(), O_RDWR, 0666);
is_creator = false;
}

if (shm_fd_ == -1) {
throw std::runtime_error("shm_open failed: " + std::string(strerror(errno)));
}
}

initialized_ = is_creator; // We use initialized_ to know if we should unlink in destructor

size_t shm_size = sizeof(SharedBufferPoolState);
if (is_creator) {
if (ftruncate(shm_fd_, shm_size) == -1) {
close(shm_fd_);
shm_unlink(shm_name_.c_str());
throw std::runtime_error("ftruncate failed: " + std::string(strerror(errno)));
}
}

void* ptr = mmap(NULL, shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd_, 0);
if (ptr == MAP_FAILED) {
close(shm_fd_);
if (is_creator) {
shm_unlink(shm_name_.c_str());
}
throw std::runtime_error("mmap failed: " + std::string(strerror(errno)));
}

state_ = static_cast<SharedBufferPoolState*>(ptr);

if (is_creator) {
// Initialize mutex
pthread_mutexattr_t attr;
pthread_mutexattr_init(&attr);
pthread_mutexattr_setpshared(&attr, PTHREAD_PROCESS_SHARED);
pthread_mutex_init(&state_->mutex, &attr);
pthread_mutexattr_destroy(&attr);

state_->num_buffers = num_buffers;
state_->buffer_size = buffer_size;

for (size_t i = 0; i < kMaxBuffers; ++i) {
state_->buffers[i].state = BufferState::kFree;
state_->buffers[i].object_id[0] = '\0';
state_->buffers[i].associated_symlink[0] = '\0';
state_->buffers[i].capacity = 0;
}

for (size_t i = 0; i < num_buffers; ++i) {
std::string buffer_name = "buffer_" + std::to_string(rank) + "_" + std::to_string(i) + ".dist";
std::string buffer_path = (std::filesystem::path(pool_dir) / buffer_name).string();
snprintf(state_->buffers[i].object_id, kMaxPathLen, "%s", buffer_path.c_str());
state_->buffers[i].capacity = buffer_size;

// Pre-allocate file
try {
BufferObject buf(buffer_path, buffer_size, true);
LOG(INFO) << "Pre-allocated buffer file: " << buffer_path;
} catch (const std::exception& e) {
LOG(ERROR) << "Failed to pre-allocate buffer " << buffer_path << ": " << e.what();
munmap(state_, sizeof(SharedBufferPoolState));
close(shm_fd_);
shm_unlink(shm_name_.c_str());
throw;
}
}

LOG(INFO) << "BufferPool initialized in shared memory. Num buffers: " << num_buffers;
} else {
LOG(INFO) << "Attached to existing BufferPool in shared memory.";
}
}

BufferPool::~BufferPool() {
munmap(state_, sizeof(SharedBufferPoolState));
close(shm_fd_);
if (initialized_) {
shm_unlink(shm_name_.c_str());
}
}

void BufferPool::Lock() {
pthread_mutex_lock(&state_->mutex);
}

void BufferPool::Unlock() {
pthread_mutex_unlock(&state_->mutex);
}

std::string BufferPool::Acquire(const std::string& associated_symlink) {
Lock();

GC(); // Clean up broken symlinks

for (size_t i = 0; i < state_->num_buffers; ++i) {
if (state_->buffers[i].state == BufferState::kFree) {
state_->buffers[i].state = BufferState::kAcquired;

if (state_->buffers[i].object_id[0] == '\0') {
Unlock();
throw std::runtime_error("BufferPool: object_id is empty for free buffer!");
}

if (!associated_symlink.empty()) {
snprintf(state_->buffers[i].associated_symlink, kMaxPathLen, "%s", associated_symlink.c_str());
// Create symlink
std::filesystem::path link_path(associated_symlink);
std::filesystem::path target_path(state_->buffers[i].object_id);

std::error_code ec;
if (std::filesystem::exists(link_path, ec)) {
std::filesystem::remove(link_path, ec);
}
std::filesystem::create_symlink(target_path, link_path, ec);
if (ec) {
state_->buffers[i].state = BufferState::kFree;
state_->buffers[i].associated_symlink[0] = '\0';
Unlock();
throw std::runtime_error("Failed to create symlink: " + ec.message());
}
}

std::string result = state_->buffers[i].object_id;
Unlock();
return result;
}
}

Unlock();
throw std::runtime_error("BufferPool exhausted");
}

void BufferPool::Release(const std::string& object_id) {
Lock();
for (size_t i = 0; i < state_->num_buffers; ++i) {
if (object_id == state_->buffers[i].object_id) {
state_->buffers[i].state = BufferState::kFree;
state_->buffers[i].associated_symlink[0] = '\0';
Unlock();
return;
}
}
Unlock();
LOG(WARNING) << "Attempted to release unknown buffer: " << object_id;
}

void BufferPool::GC() {
// MUST BE CALLED WITH LOCK HELD!
for (size_t i = 0; i < state_->num_buffers; ++i) {
if (state_->buffers[i].state == BufferState::kAcquired) {
std::string symlink = state_->buffers[i].associated_symlink;
if (!symlink.empty() && !std::filesystem::exists(symlink)) {
LOG(INFO) << "GC: Releasing buffer " << state_->buffers[i].object_id << " because symlink " << symlink << " is gone.";
state_->buffers[i].state = BufferState::kFree;
state_->buffers[i].associated_symlink[0] = '\0';
}
}
}
}

} // namespace ml_flashpoint::checkpoint_object_manager::buffer_object
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef BUFFER_POOL_H_
#define BUFFER_POOL_H_

#include <cstddef>
#include <string>
#include <pthread.h>

namespace ml_flashpoint::checkpoint_object_manager::buffer_object {

constexpr size_t kMaxBuffers = 64;
constexpr size_t kMaxPathLen = 256;

enum class BufferState : int {
kFree = 0,
kAcquired = 1,
};

struct SharedBufferInfo {
char object_id[kMaxPathLen];
size_t capacity;
BufferState state;
char associated_symlink[kMaxPathLen];
};

struct SharedBufferPoolState {
pthread_mutex_t mutex;
size_t num_buffers;
size_t buffer_size;
SharedBufferInfo buffers[kMaxBuffers];
};

class BufferPool {
public:
// Constructor: Initializes the pool.
// One process must call it with initialize=true to create the shared memory.
// Others call it with initialize=false to just attach to it.
explicit BufferPool(const std::string& shm_name, const std::string& pool_dir = "",
int rank = 0, size_t num_buffers = 0,
size_t buffer_size = 0);
~BufferPool();

// Non-copyable
BufferPool(const BufferPool&) = delete;
BufferPool& operator=(const BufferPool&) = delete;

// Acquires a buffer from the pool.
// Returns the object_id (path) of the allocated buffer.
std::string Acquire(const std::string& associated_symlink = "");

// Releases a buffer back to the pool.
void Release(const std::string& object_id);

// Performs garbage collection.
void GC();

private:
std::string shm_name_;
int shm_fd_;
SharedBufferPoolState* state_;
bool initialized_;

void Lock();
void Unlock();
};

} // namespace ml_flashpoint::checkpoint_object_manager::buffer_object

#endif // BUFFER_POOL_H_
Loading
Loading