diff --git a/src/ml_flashpoint/adapter/nemo/wrapper_util.py b/src/ml_flashpoint/adapter/nemo/wrapper_util.py index f2284d8..390a476 100644 --- a/src/ml_flashpoint/adapter/nemo/wrapper_util.py +++ b/src/ml_flashpoint/adapter/nemo/wrapper_util.py @@ -89,14 +89,21 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint( flashpoint_base_container = CheckpointContainerId(flashpoint_base_container) - pool_config = BufferPoolConfig( - pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool"), + local_pool_config = BufferPoolConfig( + pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool", "local"), rank=trainer.global_rank, num_buffers=write_thread_count * NUM_OF_BUFFERS_PER_OBJECT, buffer_size=initial_write_buffer_size_bytes or DEFAULT_INITIAL_BUFFER_SIZE_BYTES, ) - ckpt_obj_manager = CheckpointObjectManager(pool_config=pool_config) + repl_pool_config = BufferPoolConfig( + pool_dir_path=os.path.join(str(flashpoint_base_container), "buffer_pool", "repl"), + rank=trainer.global_rank, + num_buffers=write_thread_count * NUM_OF_BUFFERS_PER_OBJECT, + buffer_size=initial_write_buffer_size_bytes or DEFAULT_INITIAL_BUFFER_SIZE_BYTES, + ) + + ckpt_obj_manager = CheckpointObjectManager(local_pool_config=local_pool_config, repl_pool_config=repl_pool_config) replication_manager = ReplicationManager() replication_manager.initialize(checkpoint_object_manager=ckpt_obj_manager) diff --git a/src/ml_flashpoint/checkpoint_object_manager/buffer_object/CMakeLists.txt b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/CMakeLists.txt index 985e7c2..e9076ce 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/buffer_object/CMakeLists.txt +++ b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/CMakeLists.txt @@ -21,6 +21,7 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) add_library(buffer_object_lib STATIC buffer_object.cpp buffer_helper.cpp + buffer_pool.cpp ) target_link_libraries(buffer_object_lib PUBLIC diff --git a/src/ml_flashpoint/checkpoint_object_manager/buffer_object/bindings.cpp b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/bindings.cpp index 133f628..e137d1e 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/buffer_object/bindings.cpp +++ b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/bindings.cpp @@ -20,8 +20,10 @@ #include #include "buffer_object.h" +#include "buffer_pool.h" namespace py = pybind11; +using ml_flashpoint::checkpoint_object_manager::buffer_object::BufferPool; // Module entry point PYBIND11_MODULE(buffer_object_ext, m) { @@ -98,4 +100,20 @@ PYBIND11_MODULE(buffer_object_ext, m) { b.is_readonly() // Readonly flag ); }); + + py::class_ buffer_pool_class(m, "BufferPool"); + + buffer_pool_class + .def(py::init(), + py::arg("shm_name"), py::arg("pool_dir") = "", py::arg("rank") = 0, + py::arg("num_buffers") = 0, py::arg("buffer_size") = 0, + py::call_guard()) + .def("acquire", &BufferPool::Acquire, "Acquires a buffer from the pool.", + py::arg("associated_symlink") = "", + py::call_guard()) + .def("release", &BufferPool::Release, "Releases a buffer back to the pool.", + py::arg("object_id"), + py::call_guard()) + .def("gc", &BufferPool::GC, "Performs garbage collection.", + py::call_guard()); } \ No newline at end of file diff --git a/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_pool.cpp b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_pool.cpp new file mode 100644 index 0000000..b9c2c24 --- /dev/null +++ b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_pool.cpp @@ -0,0 +1,207 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "buffer_pool.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "absl/log/log.h" +#include "buffer_object.h" // Needed for pre-allocation + +namespace ml_flashpoint::checkpoint_object_manager::buffer_object { + +BufferPool::BufferPool(const std::string& shm_name, const std::string& pool_dir, + int rank, size_t num_buffers, + size_t buffer_size) + : shm_name_(shm_name) { + + // Try to create exclusively + shm_fd_ = shm_open(shm_name_.c_str(), O_CREAT | O_EXCL | O_RDWR, 0666); + bool is_creator = true; + + if (shm_fd_ == -1) { + if (errno == EEXIST) { + // Already exists, try to open read-write + shm_fd_ = shm_open(shm_name_.c_str(), O_RDWR, 0666); + is_creator = false; + } + + if (shm_fd_ == -1) { + throw std::runtime_error("shm_open failed: " + std::string(strerror(errno))); + } + } + + initialized_ = is_creator; // We use initialized_ to know if we should unlink in destructor + + size_t shm_size = sizeof(SharedBufferPoolState); + if (is_creator) { + if (ftruncate(shm_fd_, shm_size) == -1) { + close(shm_fd_); + shm_unlink(shm_name_.c_str()); + throw std::runtime_error("ftruncate failed: " + std::string(strerror(errno))); + } + } + + void* ptr = mmap(NULL, shm_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd_, 0); + if (ptr == MAP_FAILED) { + close(shm_fd_); + if (is_creator) { + shm_unlink(shm_name_.c_str()); + } + throw std::runtime_error("mmap failed: " + std::string(strerror(errno))); + } + + state_ = static_cast(ptr); + + if (is_creator) { + // Initialize mutex + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_setpshared(&attr, PTHREAD_PROCESS_SHARED); + pthread_mutex_init(&state_->mutex, &attr); + pthread_mutexattr_destroy(&attr); + + state_->num_buffers = num_buffers; + state_->buffer_size = buffer_size; + + for (size_t i = 0; i < kMaxBuffers; ++i) { + state_->buffers[i].state = BufferState::kFree; + state_->buffers[i].object_id[0] = '\0'; + state_->buffers[i].associated_symlink[0] = '\0'; + state_->buffers[i].capacity = 0; + } + + for (size_t i = 0; i < num_buffers; ++i) { + std::string buffer_name = "buffer_" + std::to_string(rank) + "_" + std::to_string(i) + ".dist"; + std::string buffer_path = (std::filesystem::path(pool_dir) / buffer_name).string(); + snprintf(state_->buffers[i].object_id, kMaxPathLen, "%s", buffer_path.c_str()); + state_->buffers[i].capacity = buffer_size; + + // Pre-allocate file + try { + BufferObject buf(buffer_path, buffer_size, true); + LOG(INFO) << "Pre-allocated buffer file: " << buffer_path; + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to pre-allocate buffer " << buffer_path << ": " << e.what(); + munmap(state_, sizeof(SharedBufferPoolState)); + close(shm_fd_); + shm_unlink(shm_name_.c_str()); + throw; + } + } + + LOG(INFO) << "BufferPool initialized in shared memory. Num buffers: " << num_buffers; + } else { + LOG(INFO) << "Attached to existing BufferPool in shared memory."; + } +} + +BufferPool::~BufferPool() { + munmap(state_, sizeof(SharedBufferPoolState)); + close(shm_fd_); + if (initialized_) { + shm_unlink(shm_name_.c_str()); + } +} + +void BufferPool::Lock() { + pthread_mutex_lock(&state_->mutex); +} + +void BufferPool::Unlock() { + pthread_mutex_unlock(&state_->mutex); +} + +std::string BufferPool::Acquire(const std::string& associated_symlink) { + Lock(); + + GC(); // Clean up broken symlinks + + for (size_t i = 0; i < state_->num_buffers; ++i) { + if (state_->buffers[i].state == BufferState::kFree) { + state_->buffers[i].state = BufferState::kAcquired; + + if (state_->buffers[i].object_id[0] == '\0') { + Unlock(); + throw std::runtime_error("BufferPool: object_id is empty for free buffer!"); + } + + if (!associated_symlink.empty()) { + snprintf(state_->buffers[i].associated_symlink, kMaxPathLen, "%s", associated_symlink.c_str()); + // Create symlink + std::filesystem::path link_path(associated_symlink); + std::filesystem::path target_path(state_->buffers[i].object_id); + + std::error_code ec; + if (std::filesystem::exists(link_path, ec)) { + std::filesystem::remove(link_path, ec); + } + std::filesystem::create_symlink(target_path, link_path, ec); + if (ec) { + state_->buffers[i].state = BufferState::kFree; + state_->buffers[i].associated_symlink[0] = '\0'; + Unlock(); + throw std::runtime_error("Failed to create symlink: " + ec.message()); + } + } + + std::string result = state_->buffers[i].object_id; + Unlock(); + return result; + } + } + + Unlock(); + throw std::runtime_error("BufferPool exhausted"); +} + +void BufferPool::Release(const std::string& object_id) { + Lock(); + for (size_t i = 0; i < state_->num_buffers; ++i) { + if (object_id == state_->buffers[i].object_id) { + state_->buffers[i].state = BufferState::kFree; + state_->buffers[i].associated_symlink[0] = '\0'; + Unlock(); + return; + } + } + Unlock(); + LOG(WARNING) << "Attempted to release unknown buffer: " << object_id; +} + +void BufferPool::GC() { + // MUST BE CALLED WITH LOCK HELD! + for (size_t i = 0; i < state_->num_buffers; ++i) { + if (state_->buffers[i].state == BufferState::kAcquired) { + std::string symlink = state_->buffers[i].associated_symlink; + if (!symlink.empty() && !std::filesystem::exists(symlink)) { + LOG(INFO) << "GC: Releasing buffer " << state_->buffers[i].object_id << " because symlink " << symlink << " is gone."; + state_->buffers[i].state = BufferState::kFree; + state_->buffers[i].associated_symlink[0] = '\0'; + } + } + } +} + +} // namespace ml_flashpoint::checkpoint_object_manager::buffer_object diff --git a/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_pool.h b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_pool.h new file mode 100644 index 0000000..e7b954d --- /dev/null +++ b/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_pool.h @@ -0,0 +1,84 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef BUFFER_POOL_H_ +#define BUFFER_POOL_H_ + +#include +#include +#include + +namespace ml_flashpoint::checkpoint_object_manager::buffer_object { + +constexpr size_t kMaxBuffers = 64; +constexpr size_t kMaxPathLen = 256; + +enum class BufferState : int { + kFree = 0, + kAcquired = 1, +}; + +struct SharedBufferInfo { + char object_id[kMaxPathLen]; + size_t capacity; + BufferState state; + char associated_symlink[kMaxPathLen]; +}; + +struct SharedBufferPoolState { + pthread_mutex_t mutex; + size_t num_buffers; + size_t buffer_size; + SharedBufferInfo buffers[kMaxBuffers]; +}; + +class BufferPool { + public: + // Constructor: Initializes the pool. + // One process must call it with initialize=true to create the shared memory. + // Others call it with initialize=false to just attach to it. + explicit BufferPool(const std::string& shm_name, const std::string& pool_dir = "", + int rank = 0, size_t num_buffers = 0, + size_t buffer_size = 0); + ~BufferPool(); + + // Non-copyable + BufferPool(const BufferPool&) = delete; + BufferPool& operator=(const BufferPool&) = delete; + + // Acquires a buffer from the pool. + // Returns the object_id (path) of the allocated buffer. + std::string Acquire(const std::string& associated_symlink = ""); + + // Releases a buffer back to the pool. + void Release(const std::string& object_id); + + // Performs garbage collection. + void GC(); + + private: + std::string shm_name_; + int shm_fd_; + SharedBufferPoolState* state_; + bool initialized_; + + void Lock(); + void Unlock(); +}; + +} // namespace ml_flashpoint::checkpoint_object_manager::buffer_object + +#endif // BUFFER_POOL_H_ diff --git a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py index 6f6008e..2f26304 100644 --- a/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py +++ b/src/ml_flashpoint/checkpoint_object_manager/checkpoint_object_manager.py @@ -15,12 +15,12 @@ import os import shutil import threading -from typing import ClassVar, Optional +from typing import ClassVar, Dict, Optional from ml_flashpoint.checkpoint_object_manager.buffer_io import BufferIO from ml_flashpoint.checkpoint_object_manager.buffer_metadata import METADATA_SIZE -from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject -from ml_flashpoint.core.buffer_pool import BufferPool, BufferPoolConfig +from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject, BufferPool +from ml_flashpoint.core.buffer_pool import BufferPoolConfig from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId from ml_flashpoint.core.mlf_logging import get_logger @@ -34,68 +34,73 @@ class CheckpointObjectManager: awareness of buffers (within a rank). """ - # Class-level registry for BufferPool in the worker process. - _worker_pool: ClassVar[Optional[BufferPool]] = None + # Class-level registry for BufferPools in the worker process. + # Maps pool type (string) to BufferPool instance. + _worker_pools: ClassVar[Dict[str, BufferPool]] = {} _worker_pool_lock: ClassVar[threading.Lock] = threading.Lock() - def __init__(self, pool_config: Optional[BufferPoolConfig] = None): + def __init__(self, local_pool_config: Optional[BufferPoolConfig] = None, repl_pool_config: Optional[BufferPoolConfig] = None): """Initializes the CheckpointObjectManager. Args: - pool_config: Optional configuration for the BufferPool. + local_pool_config: Optional configuration for the local BufferPool. + repl_pool_config: Optional configuration for the replication BufferPool. """ - self._pool_config = pool_config + self._local_pool_config = local_pool_config + self._repl_pool_config = repl_pool_config - def _get_or_create_buffer_pool(self) -> Optional[BufferPool]: - """Lazily initializes and returns the BufferPool instance. + def _get_or_create_buffer_pool(self, pool_type: str = "local") -> Optional[BufferPool]: + """Lazily initializes and returns the specified BufferPool instance. Check the class-level registry first to reuse existing pools in this process. """ # 1. Fast path: check class var directly - if CheckpointObjectManager._worker_pool: - return CheckpointObjectManager._worker_pool + if pool_type in CheckpointObjectManager._worker_pools: + return CheckpointObjectManager._worker_pools[pool_type] - if not self._pool_config: + config = self._local_pool_config if pool_type == "local" else self._repl_pool_config + if not config: return None # 2. Registry lookup / Creation with CheckpointObjectManager._worker_pool_lock: - if CheckpointObjectManager._worker_pool: - _LOGGER.debug("Reusing existing BufferPool from worker registry.") - return CheckpointObjectManager._worker_pool + if pool_type in CheckpointObjectManager._worker_pools: + _LOGGER.debug("Reusing existing %s BufferPool from worker registry.", pool_type) + return CheckpointObjectManager._worker_pools[pool_type] # Create new try: - _LOGGER.info("Initializing BufferPool with config: %s", self._pool_config) + _LOGGER.info("Initializing C++ BufferPool for %s with config: %s", pool_type, config) + shm_suffix = "local" if pool_type == "local" else "repl" pool = BufferPool( - pool_dir_path=self._pool_config.pool_dir_path, - rank=self._pool_config.rank, - num_buffers=self._pool_config.num_buffers, - buffer_size=self._pool_config.buffer_size, + shm_name=f"/mlf_buffer_pool_rank_{config.rank}_{shm_suffix}", + pool_dir=config.pool_dir_path, + rank=config.rank, + num_buffers=config.num_buffers, + buffer_size=config.buffer_size, ) - CheckpointObjectManager._worker_pool = pool + CheckpointObjectManager._worker_pools[pool_type] = pool except Exception: - _LOGGER.exception("Failed to initialize BufferPool") + _LOGGER.exception("Failed to initialize BufferPool for %s", pool_type) pass - return CheckpointObjectManager._worker_pool + return CheckpointObjectManager._worker_pools.get(pool_type) def teardown_pool(self): - """Teardown the BufferPool if it exists and remove from registry.""" - pool_to_teardown = None - + """Teardown the BufferPools if they exist and remove from registry.""" with CheckpointObjectManager._worker_pool_lock: - if CheckpointObjectManager._worker_pool: - pool_to_teardown = CheckpointObjectManager._worker_pool - CheckpointObjectManager._worker_pool = None - - if pool_to_teardown: - try: - pool_to_teardown.teardown() - except Exception as e: - _LOGGER.debug("Failed to teardown BufferPool: %s", e) - - def acquire_buffer(self, object_id: CheckpointObjectId, buffer_size: int, overwrite: bool = True) -> "BufferIO": + if CheckpointObjectManager._worker_pools: + _LOGGER.debug("Clearing BufferPools from registry.") + CheckpointObjectManager._worker_pools.clear() + + @property + def replication_pool_shm_name(self) -> str: + """Returns the shared memory name for the replication pool.""" + if not self._repl_pool_config: + return "" + return f"/mlf_buffer_pool_rank_{self._repl_pool_config.rank}_repl" + + def acquire_buffer(self, object_id: CheckpointObjectId, buffer_size: int, overwrite: bool = True, use_replication_pool: bool = False) -> "BufferIO": """Acquires a buffer, preferring the BufferPool if available. This method attempts to acquire a buffer from the rank level BufferPool. If the pool @@ -111,6 +116,7 @@ def acquire_buffer(self, object_id: CheckpointObjectId, buffer_size: int, overwr object_id: A unique identifier (logical path) for the new buffer object. buffer_size: The desired size of the buffer in bytes. overwrite: If True, allows overwriting an existing object. Defaults to True. + use_replication_pool: If True, uses the replication pool instead of the local pool. Returns: A BufferIO instance. @@ -144,14 +150,15 @@ def acquire_buffer(self, object_id: CheckpointObjectId, buffer_size: int, overwr else: raise FileExistsError(f"File {object_id} already exists and overwrite=False") - pool = self._get_or_create_buffer_pool() + pool_type = "repl" if use_replication_pool else "local" + pool = self._get_or_create_buffer_pool(pool_type) if pool: # Pool manages the physical creation/resizing AND the logical link (symlink) creation. - buffer_io = pool.acquire(associated_symlink=str(object_id)) - - _LOGGER.debug("Acquired buffer for '%s'", object_id) - - return buffer_io + buffer_path = pool.acquire(associated_symlink=str(object_id)) + _LOGGER.debug("Acquired pool buffer '%s' for '%s'", buffer_path, object_id) + # Create a writable BufferObject pointing to the physical path in the pool + buffer_obj = BufferObject(buffer_path, buffer_size, overwrite=True) + return BufferIO(buffer_obj) else: _LOGGER.debug( "BufferPool not configured or validation failed. Falling back to standalone buffer creation." diff --git a/src/ml_flashpoint/core/buffer_pool.py b/src/ml_flashpoint/core/buffer_pool.py index a822866..ef37d00 100644 --- a/src/ml_flashpoint/core/buffer_pool.py +++ b/src/ml_flashpoint/core/buffer_pool.py @@ -13,121 +13,7 @@ # limitations under the License. -import logging -import os -import threading from dataclasses import dataclass -from typing import Dict, List, Optional - -from ml_flashpoint.checkpoint_object_manager.buffer_io import BufferIO -from ml_flashpoint.checkpoint_object_manager.buffer_metadata import METADATA_SIZE -from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject -from ml_flashpoint.core.mlf_logging import get_logger -from ml_flashpoint.core.utils import log_execution_time - -_LOGGER = get_logger(__name__) - -# Constants for buffer resizing -PADDING_SIZE = 1024 * 1024 -RESIZE_FACTOR = 1.1 - - -class PooledBufferIO: - """Proxies a BufferIO object to prevent it from being closed by the client. - - This allows the BufferPool to reuse the underlying BufferIO object even if - the client (e.g. CheckpointSaver) calls close() on it. - """ - - def __init__(self, buffer_io: BufferIO): - self._buffer_io = buffer_io - self._closed = False - - def __getattr__(self, name): - # Delegate attribute access to the underlying BufferIO object - if self._closed: - _LOGGER.warning("PooledBufferIO: Accessing closed buffer") - return getattr(self._buffer_io, name) - - @log_execution_time(logger=_LOGGER, name="close", level=logging.INFO) - def close(self, truncate: bool = True): - """Marks the proxy as closed, releasing the buffer back to the pool. - - This method does NOT close the underlying BufferIO object, allowing it to be - reused by the BufferPool. Ideally, it truncates the buffer to the written size - to save space. - - Args: - truncate: If True, truncates the underlying buffer to the size of the - written data (plus metadata) before releasing it. - """ - if self._closed: - return - - self._closed = True - - _LOGGER.debug("Closing PooledBufferIO object...") - if truncate and not self._buffer_io.is_readonly: - try: - final_data_len = self._buffer_io._metadata.len_written_data - truncate_size = METADATA_SIZE + final_data_len - - current_size = len(self._buffer_io._mv) - if truncate_size != current_size: - _LOGGER.debug( - "PooledBufferIO: Truncating reusable buffer from %d to %d bytes", current_size, truncate_size - ) - self._buffer_io.resize(truncate_size) - except Exception: - _LOGGER.warning("PooledBufferIO: Failed to truncate buffer during close.", exc_info=True) - - @property - def closed(self) -> bool: - return self._closed or self._buffer_io.closed - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() - - def _auto_resize(self, required_bytes: int): - """Resizes the buffer to accommodate an additional required_bytes from the current position.""" - current_size = self._buffer_io.buffer_obj.get_capacity() - current_pos = self._buffer_io.tell() - required_size = METADATA_SIZE + current_pos + required_bytes - - # Use 1MB padding or 10% growth (whichever is larger) to amortize resize costs - new_size = int(max(current_size * RESIZE_FACTOR, required_size + PADDING_SIZE)) - - _LOGGER.debug( - "PooledBufferIO: Auto-resizing from %d to %d bytes (required: %d)", - current_size, - new_size, - required_size, - ) - self._buffer_io.resize(new_size) - - def write(self, data: bytes) -> int: - data_len = len(data) - current_pos = self._buffer_io.tell() - current_capacity = self._buffer_io.buffer_obj.get_capacity() - required_capacity = METADATA_SIZE + current_pos + data_len - - if required_capacity > current_capacity: - self._auto_resize(data_len) - - return self._buffer_io.write(data) - - def next_buffer_slice(self, size: int) -> memoryview: - current_pos = self._buffer_io.tell() - current_capacity = self._buffer_io.buffer_obj.get_capacity() - required_capacity = METADATA_SIZE + current_pos + size - - if required_capacity > current_capacity: - self._auto_resize(size) - - return self._buffer_io.next_buffer_slice(size) @dataclass @@ -148,195 +34,3 @@ def __post_init__(self): raise ValueError("num_buffers must be a non-negative integer in BufferPoolConfig") if not isinstance(self.buffer_size, int) or self.buffer_size < 0: raise ValueError("buffer_size must be a non-negative integer in BufferPoolConfig") - - -class BufferPool: - """Singleton class to manage a pool of persistent BufferIO objects. - - This class maintains a pool of buffer file paths in a dedicated directory. - When a buffer is requested via `acquire`, it returns a BufferIO object pointing - to a free buffer file (reusing it if available) or creates a new one. - - Buffers are strictly managed by file paths. The underlying file descriptors - are closed when BufferIO.close() is called by the user, but the file remains - on disk for reuse. Garbage collection reclaims these files when their - associated symlinks (checkpoints) are deleted. - """ - - def __init__( - self, - pool_dir_path: str, - rank: int = 0, - num_buffers: int = 3, - buffer_size: int = 0, - ): - """Initializes the BufferPool. - - Args: - pool_dir_path: The directory path where buffer files will be stored. - rank: The rank of the process using this pool (used for naming). - num_buffers: The fixed number of buffers to allocate in the pool. - buffer_size: The initial size of each buffer in bytes. - """ - if num_buffers <= 0: - raise ValueError(f"Number of buffers must be positive. Got {num_buffers}.") - if buffer_size <= 0: - raise ValueError(f"Buffer size must be positive. Got {buffer_size}.") - if not pool_dir_path: - raise ValueError("Pool directory path must be provided.") - self.pool_dir = pool_dir_path - self.rank = rank - self.num_buffers = num_buffers - self.buffer_size = buffer_size - self.free_buffers: List[BufferIO] = [] - # active_buffers maps buffer_path (str) -> (BufferIO, associated_symlink_path (str)) - self.active_buffers: Dict[str, tuple[BufferIO, str]] = {} - self._lock = threading.Lock() - - try: - os.makedirs(self.pool_dir, exist_ok=True) - _LOGGER.debug("BufferPool initialized with directory: %s", self.pool_dir) - self._preallocate_buffers() - except OSError: - _LOGGER.exception("Failed to create/populate BufferPool.") - raise - - @log_execution_time(logger=_LOGGER, name="acquire", level=logging.INFO) - def acquire(self, associated_symlink: Optional[str] = None) -> PooledBufferIO: - """Acquires a buffer from the pool. - - Args: - associated_symlink: The path to the symlink that will point to this buffer. - If provided, GC will check this symlink. If None, the buffer is protected - from GC until the pool is torn down. - - Returns: - A BufferIO object (wrapped in a PooledBufferIO). - - Raises: - RuntimeError: If the pool is exhausted and no buffers can be reclaimed. - """ - with self._lock: - # 0. Opportunistic GC - self._gc() - - # 1. Try to find a free buffer - if self.free_buffers: - _LOGGER.debug("Number of free buffers: %d", len(self.free_buffers)) - free_buffer = self.free_buffers.pop() - _LOGGER.debug( - "Acquired buffer from pool for %s, the buffer name is %s", - associated_symlink, - free_buffer.buffer_obj.get_id(), - ) - try: - buf_io = self._reuse_buffer(free_buffer, associated_symlink) - return PooledBufferIO(buf_io) - except Exception: - # If reuse fails (e.g. symlink creation), we must put the buffer back! - _LOGGER.exception("Failed to reuse buffer for '%s'. Releasing buffer.", associated_symlink) - # _reuse_buffer might have added it to active_buffers, so ensure it's removed. - buffer_path = free_buffer.buffer_obj.get_id() - if buffer_path in self.active_buffers: - del self.active_buffers[buffer_path] - self.free_buffers.append(free_buffer) - raise - - # 2. If no free buffer, raise RuntimeError (CheckpointObjectManager will catch and fallback) - _LOGGER.debug("BufferPool exhausted. All %d buffers are in use.", self.num_buffers) - raise RuntimeError(f"BufferPool exhausted. All {self.num_buffers} buffers are in use.") - - def _gc(self) -> None: - """Releases buffers whose associated symlinks no longer exist.""" - # Check all active buffers - to_release = [] - for buffer_path, (buf_io, symlink) in self.active_buffers.items(): - # We use os.path.exists (not lexists) here to detect broken symlinks. - # If the symlink exists but points to a non-existent file, exists() returns False, - # correctly identifying it as a candidate for GC. - if symlink and not os.path.exists(symlink): - to_release.append(buffer_path) - - if to_release: - _LOGGER.debug("Garbage collecting %d buffers whose symlinks are gone.", len(to_release)) - for buffer_path in to_release: - _LOGGER.debug("Garbage collecting buffer %s", buffer_path) - buf_io, _ = self.active_buffers.pop(buffer_path) - self.free_buffers.append(buf_io) - - def teardown(self) -> None: - """Closes all buffers and clears the pool. - - The files remain on disk (persistent). - If we wanted to adhere to strict cleanup of tests, we might delete them? - But persistent pool implies persistence. - Use cases using temp dir will cleanup the dir. - """ - with self._lock: - _LOGGER.debug( - "Tearing down BufferPool. Closing %d free and %d active buffers.", - len(self.free_buffers), - len(self.active_buffers), - ) - - for buf in self.free_buffers: - try: - buf.close(truncate=True) - except Exception: - _LOGGER.warning("PooledBufferIO: Failed to close buffer during teardown.", exc_info=True) - self.free_buffers.clear() - - for buffer_path, (buf, _) in self.active_buffers.items(): - try: - buf.close(truncate=True) - except Exception: - _LOGGER.warning("PooledBufferIO: Failed to close buffer during teardown.", exc_info=True) - self.active_buffers.clear() - - @log_execution_time(logger=_LOGGER, name="_reuse_buffer", level=logging.INFO) - def _reuse_buffer(self, buffer_io: BufferIO, symlink: Optional[str]) -> BufferIO: - """Reuses an existing buffer object, resizing if necessary.""" - try: - current_capacity = buffer_io.buffer_obj.get_capacity() - _LOGGER.debug("Reusing pool buffer (capacity %d).", current_capacity) - - # Reset the buffer for reuse: start at the beginning and clear written length. - buffer_io.seek(0) - buffer_io._metadata.len_written_data = 0 - - buffer_path = buffer_io.buffer_obj.get_id() - self.active_buffers[buffer_path] = (buffer_io, symlink) - - if symlink: - try: - # Create symlink pointing to the physical buffer path - os.symlink(buffer_path, symlink) - _LOGGER.debug("Created symlink '%s' -> '%s'", symlink, buffer_path) - except OSError: - # If symlink creation fails, propagate the error. - # The caller (acquire) is responsible for cleanup. - raise - - return buffer_io - except Exception: - raise - - def _preallocate_buffers(self) -> None: - """Pre-allocates fixed number of buffers.""" - for idx in range(self.num_buffers): - buffer_name = f"buffer_{self.rank}_{idx}.dist" - buffer_path = os.path.join(self.pool_dir, buffer_name) - - # Create/Reset buffer file - try: - _LOGGER.debug("Pre-allocating buffer: %s with size %d", buffer_path, self.buffer_size) - - # Always overwrite to ensure a clean state - buffer_obj = BufferObject(buffer_path, self.buffer_size, overwrite=True) - - buffer_io = BufferIO(buffer_obj) - self.free_buffers.append(buffer_io) - - except Exception: - _LOGGER.exception("Failed to pre-allocate buffer %s", buffer_path) - raise diff --git a/src/ml_flashpoint/replication/replication_manager.py b/src/ml_flashpoint/replication/replication_manager.py index ca2b5df..a4c0268 100644 --- a/src/ml_flashpoint/replication/replication_manager.py +++ b/src/ml_flashpoint/replication/replication_manager.py @@ -219,9 +219,11 @@ def initialize( if self._transfer_service is None: _LOGGER.info("No TransferService provided, initializing a new one...") self._transfer_service = transfer_service_ext.TransferService() + repl_shm_name = self._checkpoint_object_manager.replication_pool_shm_name bound_listen_port = self._transfer_service.initialize( listen_port, global_rank=dist.get_rank(), + repl_shm_name=repl_shm_name, ) if bound_listen_port <= 0: _LOGGER.error("Failed to initialize C++ TransferService") @@ -473,8 +475,19 @@ def sync_bulk_retrieve( futures = [] effective_retrieved_ids = retrieved_object_ids if retrieved_object_ids else object_ids_to_retrieve + for i, obj_id in enumerate(object_ids_to_retrieve): - futures.append(self._async_retrieve(source_address, obj_id, effective_retrieved_ids[i])) + target_id = effective_retrieved_ids[i] + try: + # Use a default size of 1MB, it will be resized by TransferService if needed + buf = self._checkpoint_object_manager.acquire_buffer(target_id, buffer_size=1024*1024, overwrite=True, use_replication_pool=True) + # Close it immediately in Python to avoid conflicts with TransferService + buf.close(truncate=False) + _LOGGER.info("Acquired and closed buffer from pool for retrieval to %s", target_id) + except Exception as e: + _LOGGER.warning("Failed to acquire buffer from pool for %s, falling back to direct retrieval. Error: %s", target_id, e) + + futures.append(self._async_retrieve(source_address, obj_id, target_id)) # TODO: Handle container_ids_to_retrieve diff --git a/src/ml_flashpoint/replication/transfer_service/CMakeLists.txt b/src/ml_flashpoint/replication/transfer_service/CMakeLists.txt index 99ac46a..69fb345 100644 --- a/src/ml_flashpoint/replication/transfer_service/CMakeLists.txt +++ b/src/ml_flashpoint/replication/transfer_service/CMakeLists.txt @@ -28,6 +28,7 @@ add_library(transfer_service_lib STATIC mlf_log_sink.cpp ${CMAKE_SOURCE_DIR}/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_object.cpp ${CMAKE_SOURCE_DIR}/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_helper.cpp + ${CMAKE_SOURCE_DIR}/src/ml_flashpoint/checkpoint_object_manager/buffer_object/buffer_pool.cpp ) target_link_libraries(transfer_service_lib PUBLIC absl::log absl::check absl::status absl::statusor absl::strings absl::base absl::time absl::log_initialize absl::log_globals) diff --git a/src/ml_flashpoint/replication/transfer_service/bindings.cpp b/src/ml_flashpoint/replication/transfer_service/bindings.cpp index 258db89..cad6253 100644 --- a/src/ml_flashpoint/replication/transfer_service/bindings.cpp +++ b/src/ml_flashpoint/replication/transfer_service/bindings.cpp @@ -37,7 +37,7 @@ PYBIND11_MODULE(transfer_service_ext, m) { // Synchronous methods remain unchanged (still blocking) .def("initialize", &TransferService::Initialize, py::arg("listen_port"), py::arg("threads") = 16, py::arg("conn_pool_per_peer") = 16, - py::arg("global_rank") = -1, + py::arg("global_rank") = -1, py::arg("repl_shm_name") = "", "Initializes and starts the C++ transfer service.") .def("shutdown", &TransferService::Shutdown, diff --git a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp index ea7f735..76c7e6c 100644 --- a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp +++ b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp @@ -80,7 +80,8 @@ TransferService::~TransferService() { } int TransferService::Initialize(int listen_port, int threads, - int conn_pool_size_per_peer, int global_rank) { + int conn_pool_size_per_peer, int global_rank, + const std::string& repl_shm_name) { std::lock_guard lock(init_mutex_); if (running_.load()) { LOG(INFO) << "Transfer service is already running."; @@ -93,6 +94,19 @@ int TransferService::Initialize(int listen_port, int threads, threads_ = threads; conn_pool_size_per_peer_ = conn_pool_size_per_peer; global_rank_ = global_rank; + repl_shm_name_ = repl_shm_name; + + if (!repl_shm_name_.empty()) { + try { + LOG(INFO) << "Attaching to replication BufferPool with SHM name: " << repl_shm_name_; + repl_pool_ = std::make_unique(repl_shm_name_); + } catch (const std::exception& e) { + LOG(ERROR) << "Failed to attach to replication BufferPool: " << e.what(); + // Fallback to standalone buffers if pool fails to initialize + repl_pool_.reset(); + } + } + absl::RemoveLogSink(mlf_log_sink_.get()); mlf_log_sink_ = std::make_unique(global_rank); absl::AddLogSink(mlf_log_sink_.get()); @@ -722,11 +736,23 @@ void TransferService::HandleDataReceive(int client_fd, } }); } - std::string tmp_obj_id = - std::string(header.dest_obj_id) + std::string(kTempFileSuffix); - BufferObject buffer_obj(tmp_obj_id, header.obj_size, - /*overwrite=*/true); - LOG(INFO) << "Successfully created buffer object"; + std::string buffer_path = header.dest_obj_id; + bool is_standalone = true; + + if (repl_pool_) { + try { + LOG(INFO) << "Acquiring buffer from pool for " << header.dest_obj_id; + buffer_path = repl_pool_->Acquire(header.dest_obj_id); + is_standalone = false; + LOG(INFO) << "Acquired pooled buffer: " << buffer_path; + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to acquire buffer from pool: " << e.what() + << ". Falling back to standalone buffer."; + } + } + + BufferObject buffer_obj(buffer_path, header.obj_size, is_standalone); + LOG(INFO) << "Successfully created buffer object at " << buffer_path; void* receiver_data_ptr = buffer_obj.get_data_ptr(); if (!RecvAll(client_fd, receiver_data_ptr, header.obj_size).ok()) { @@ -739,19 +765,8 @@ void TransferService::HandleDataReceive(int client_fd, } // Close the buffer object to ensure data is flushed and the file descriptor - // is released before renaming. + // is released. buffer_obj.close(); - - // Rename the temporary file to the final destination. - if (rename(tmp_obj_id.c_str(), header.dest_obj_id) != 0) { - PLOG(ERROR) << "Failed to rename temporary file " << tmp_obj_id << " to " - << header.dest_obj_id; - if (is_respond_get_task) { - ReportResult(header.task_id, false, "Failed to rename temporary file"); - } - SendErrorResponse(client_fd, header.task_id, header.dest_obj_id); - return; - } if (is_respond_get_task) { UpdateTaskMetrics(header.task_id, [](TaskMetricContainer& ts) { if (auto* get_ts = dynamic_cast(&ts)) { diff --git a/src/ml_flashpoint/replication/transfer_service/transfer_service.h b/src/ml_flashpoint/replication/transfer_service/transfer_service.h index ddcdd10..c9b1d59 100644 --- a/src/ml_flashpoint/replication/transfer_service/transfer_service.h +++ b/src/ml_flashpoint/replication/transfer_service/transfer_service.h @@ -21,6 +21,8 @@ #include #include +#include "../../checkpoint_object_manager/buffer_object/buffer_pool.h" + // For socket programming #include #include @@ -73,7 +75,8 @@ class TransferService final { // Returns: // The port the service is listening on, or -1 on failure. int Initialize(int listen_port = 0, int threads = 16, - int conn_pool_per_peer = 16, int global_rank = -1); + int conn_pool_per_peer = 16, int global_rank = -1, + const std::string& repl_shm_name = ""); // Shuts down the transfer service gracefully. void Shutdown(); @@ -179,6 +182,9 @@ class TransferService final { std::unordered_map pending_tasks_; std::mutex pending_tasks_mutex_; // Guard pending_tasks_ + std::string repl_shm_name_; + std::unique_ptr repl_pool_; + std::unique_ptr mlf_log_sink_; void UpdateTaskMetrics(const std::string& task_id, diff --git a/tests/adapter/nemo/test_wrapper_util.py b/tests/adapter/nemo/test_wrapper_util.py index 2ad6794..8558d33 100644 --- a/tests/adapter/nemo/test_wrapper_util.py +++ b/tests/adapter/nemo/test_wrapper_util.py @@ -109,10 +109,14 @@ def test_successful_wrap_and_resume_creation(self, mocker, mock_ckpt_obj_manager mock_ckpt_obj_manager_cls.assert_called_once() call_args = mock_ckpt_obj_manager_cls.call_args - assert "pool_config" in call_args.kwargs - pool_config = call_args.kwargs["pool_config"] - assert isinstance(pool_config, BufferPoolConfig) - assert pool_config.pool_dir_path == f"{flashpoint_base_container}/buffer_pool" + assert "local_pool_config" in call_args.kwargs + assert "repl_pool_config" in call_args.kwargs + local_config = call_args.kwargs["local_pool_config"] + repl_config = call_args.kwargs["repl_pool_config"] + assert isinstance(local_config, BufferPoolConfig) + assert isinstance(repl_config, BufferPoolConfig) + assert local_config.pool_dir_path == f"{flashpoint_base_container}/buffer_pool/local" + assert repl_config.pool_dir_path == f"{flashpoint_base_container}/buffer_pool/repl" # Capture the ckpt_obj_manager passed to initialize _, kwargs_init = mock_replication_manager_instance.initialize.call_args diff --git a/tests/checkpoint_object_manager/test_checkpoint_object_manager.py b/tests/checkpoint_object_manager/test_checkpoint_object_manager.py index 3d8508e..9882025 100644 --- a/tests/checkpoint_object_manager/test_checkpoint_object_manager.py +++ b/tests/checkpoint_object_manager/test_checkpoint_object_manager.py @@ -21,7 +21,8 @@ from ml_flashpoint.checkpoint_object_manager.buffer_io import METADATA_SIZE, BufferIO from ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager import CheckpointObjectManager -from ml_flashpoint.core.buffer_pool import BufferPool, BufferPoolConfig +from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferPool +from ml_flashpoint.core.buffer_pool import BufferPoolConfig from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId, CheckpointObjectId @@ -160,7 +161,7 @@ def test_acquire_buffer_falls_back_on_pool_exhaustion(self, manager_setup, mocke # Setup: Pool exists but raises RuntimeError on acquire mock_pool = mocker.MagicMock(spec=BufferPool) mock_pool.acquire.side_effect = RuntimeError("Pool exhausted") - CheckpointObjectManager._worker_pool = mock_pool + CheckpointObjectManager._worker_pools["local"] = mock_pool # Setup: Standalone creation succeeds mock_instance = mocker.MagicMock() @@ -877,17 +878,17 @@ def test_lazy_initialization(self, mocker): """Tests that BufferPool is lazily initialized on acquire_buffer.""" # Ensure pool_config has pool_dir_path for registry check manager = CheckpointObjectManager( - pool_config=BufferPoolConfig(pool_dir_path="/tmp/lazy", num_buffers=1, rank=0, buffer_size=0) + local_pool_config=BufferPoolConfig(pool_dir_path="/tmp/lazy", num_buffers=1, rank=0, buffer_size=0) ) mock_buffer_pool_cls = mocker.patch( "ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager.BufferPool" ) - # Access internal property to verify it's None initially - assert CheckpointObjectManager._worker_pool is None - # Clear registry to ensure no interference - CheckpointObjectManager._worker_pool = None + CheckpointObjectManager._worker_pools.clear() + + # Access internal property to verify it's empty initially + assert "local" not in CheckpointObjectManager._worker_pools # Call acquire_buffer should trigger init # We need to mock os.makedirs and exists to pass the early checks in acquire_buffer @@ -896,34 +897,40 @@ def test_lazy_initialization(self, mocker): # It will try to acquire, so mock the pool instance mock_pool_instance = mock_buffer_pool_cls.return_value - mock_pool_instance.acquire.return_value = mocker.Mock(spec=BufferIO) - + mock_pool_instance.acquire.return_value = "/tmp/fake_pool_buffer.dist" + manager.acquire_buffer(CheckpointObjectId("/tmp/lazy/foo"), 100) - mock_buffer_pool_cls.assert_called_once_with(pool_dir_path="/tmp/lazy", num_buffers=1, rank=0, buffer_size=0) - assert CheckpointObjectManager._worker_pool == mock_pool_instance + mock_buffer_pool_cls.assert_called_once_with( + shm_name="/mlf_buffer_pool_rank_0_local", + pool_dir="/tmp/lazy", + num_buffers=1, + rank=0, + buffer_size=0 + ) + assert CheckpointObjectManager._worker_pools["local"] == mock_pool_instance def test_get_pool_returns_none_when_config_is_none(self): """Tests that _get_or_create_buffer_pool returns None when pool_config is None.""" - CheckpointObjectManager._worker_pool = None - manager = CheckpointObjectManager(pool_config=None) + CheckpointObjectManager._worker_pools.clear() + manager = CheckpointObjectManager() - assert manager._get_or_create_buffer_pool() is None + assert manager._get_or_create_buffer_pool("local") is None def test_get_pool_returns_none_on_init_failure(self, mocker): """Tests that _get_or_create_buffer_pool returns None and logs error if BufferPool init fails.""" - CheckpointObjectManager._worker_pool = None + CheckpointObjectManager._worker_pools.clear() config = BufferPoolConfig(pool_dir_path="/tmp/fail_init", num_buffers=1, rank=0, buffer_size=0) - manager = CheckpointObjectManager(pool_config=config) + manager = CheckpointObjectManager(local_pool_config=config) mock_logger = mocker.patch("ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager._LOGGER") mock_pool_cls = mocker.patch("ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager.BufferPool") mock_pool_cls.side_effect = Exception("Init failed") - pool = manager._get_or_create_buffer_pool() + pool = manager._get_or_create_buffer_pool("local") assert pool is None - assert CheckpointObjectManager._worker_pool is None + assert "local" not in CheckpointObjectManager._worker_pools # Verify it logged the error mock_logger.exception.assert_called_once() assert "Failed to initialize BufferPool" in mock_logger.exception.call_args[0][0] @@ -934,35 +941,35 @@ def test_pickling_resets_pool(self): # Use a real config object but with minimal values config = BufferPoolConfig(pool_dir_path="/tmp/pickle_test", num_buffers=1, rank=0, buffer_size=0) - manager = CheckpointObjectManager(pool_config=config) + manager = CheckpointObjectManager(local_pool_config=config) # Simulate initialized pool - CheckpointObjectManager._worker_pool = "fake_pool" + CheckpointObjectManager._worker_pools["local"] = "fake_pool" pickled = pickle.dumps(manager) unpickled = pickle.loads(pickled) - assert unpickled._pool_config == config + assert unpickled._local_pool_config == config # Verify that pickling/unpickling works fine and doesn't crash on the class var # Note: Class vars are not pickled with instance, so unpickled instance sees whatever is on the class. - assert CheckpointObjectManager._worker_pool == "fake_pool" + assert CheckpointObjectManager._worker_pools["local"] == "fake_pool" # Cleanup - CheckpointObjectManager._worker_pool = None + CheckpointObjectManager._worker_pools.clear() def test_init_with_config(self): """Tests initialization with a pool configuration.""" config = BufferPoolConfig(pool_dir_path="/tmp", num_buffers=1, rank=0, buffer_size=0) - manager = CheckpointObjectManager(pool_config=config) - assert manager._pool_config == config + manager = CheckpointObjectManager(local_pool_config=config) + assert manager._local_pool_config == config def test_worker_side_pool_reuse(self, mocker): """Tests that multiple managers in the same process reuse the same BufferPool.""" # Clear registry first - CheckpointObjectManager._worker_pool = None + CheckpointObjectManager._worker_pools.clear() config = BufferPoolConfig(pool_dir_path="/tmp/reuse_test", num_buffers=1, rank=0, buffer_size=0) - manager1 = CheckpointObjectManager(pool_config=config) - manager2 = CheckpointObjectManager(pool_config=config) + manager1 = CheckpointObjectManager(local_pool_config=config) + manager2 = CheckpointObjectManager(local_pool_config=config) # Mock BufferPool to avoid actual creation mock_pool_cls = mocker.patch("ml_flashpoint.checkpoint_object_manager.checkpoint_object_manager.BufferPool") @@ -971,34 +978,32 @@ def test_worker_side_pool_reuse(self, mocker): # 1. First manager acquires -> creates pool mock_pool_instance = mock_pool_cls.return_value - mock_pool_instance.acquire.return_value = mocker.Mock(spec=BufferIO) - + mock_pool_instance.acquire.return_value = "/tmp/fake_pool_buffer.dist" + manager1.acquire_buffer(CheckpointObjectId("/tmp/reuse_test/1"), 100) manager1.acquire_buffer(CheckpointObjectId("/tmp/reuse_test/1"), 100) - assert CheckpointObjectManager._worker_pool == mock_pool_instance + assert CheckpointObjectManager._worker_pools["local"] == mock_pool_instance mock_pool_cls.assert_called_once() # 2. Second manager acquires -> should reuse pool manager2.acquire_buffer(CheckpointObjectId("/tmp/reuse_test/2"), 100) - assert CheckpointObjectManager._worker_pool == mock_pool_instance + assert CheckpointObjectManager._worker_pools["local"] == mock_pool_instance # Should NOT have called constructor again mock_pool_cls.assert_called_once() def test_teardown_clears_worker_registry(self, mocker): """Tests that teardown_pool removes the pool from the registry.""" - CheckpointObjectManager._worker_pool = None + CheckpointObjectManager._worker_pools.clear() config = BufferPoolConfig(pool_dir_path="/tmp/teardown_test", num_buffers=1, rank=0, buffer_size=0) - manager = CheckpointObjectManager(pool_config=config) + manager = CheckpointObjectManager(local_pool_config=config) # Manually populate registry mock_pool = mocker.Mock(spec=BufferPool) - CheckpointObjectManager._worker_pool = mock_pool + CheckpointObjectManager._worker_pools["local"] = mock_pool manager.teardown_pool() - - assert CheckpointObjectManager._worker_pool is None - mock_pool.teardown.assert_called_once() + assert not CheckpointObjectManager._worker_pools diff --git a/tests/core/test_buffer_pool.py b/tests/core/test_buffer_pool.py deleted file mode 100644 index eba656a..0000000 --- a/tests/core/test_buffer_pool.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Generator - -import pytest - -from ml_flashpoint.checkpoint_object_manager.buffer_io import BufferIO -from ml_flashpoint.checkpoint_object_manager.buffer_metadata import METADATA_SIZE -from ml_flashpoint.core.buffer_pool import BufferPool, BufferPoolConfig, PooledBufferIO - - -class TestBufferPool: - @pytest.fixture - def buffer_pool_config(self, tmp_path) -> dict: - pool_dir = tmp_path / ".buffer_pool" - return { - "pool_dir_path": str(pool_dir), - "rank": 0, - "num_buffers": 3, - "buffer_size": METADATA_SIZE + 1024, - } - - @pytest.fixture - def buffer_pool(self, buffer_pool_config) -> Generator[BufferPool, None, None]: - pool = BufferPool(**buffer_pool_config) - yield pool - - pool.teardown() - - def test_acquire_preallocated_buffer(self, buffer_pool, buffer_pool_config): - """Verifies that acquire reuses pre-allocated buffers.""" - buffer_io = buffer_pool.acquire() - assert buffer_io is not None - assert isinstance(buffer_io, PooledBufferIO) - - # Check that the buffer path follows the naming convention - buffer_id = buffer_io.buffer_obj.get_id() - assert "buffer_0_" in buffer_id - assert os.path.exists(buffer_id) - - # Verify it was removed from free_buffers - assert len(buffer_pool.free_buffers) == buffer_pool_config["num_buffers"] - 1 - assert len(buffer_pool.active_buffers) == 1 - assert buffer_id in buffer_pool.active_buffers - - # Cleanup - buffer_io.close() - assert buffer_io.closed - assert not buffer_io._buffer_io.closed - - def test_gc_releases_orphaned_buffers(self, buffer_pool, tmp_path): - """Verifies that GC correctly releases buffers with deleted symlinks.""" - symlink_path = str(tmp_path / "my_symlink") - - # Acquire with symlink=None - buf_io = buffer_pool.acquire(associated_symlink=None) - buffer_id = buf_io.buffer_obj.get_id() - - # Verify IT IS ACTIVE - assert buffer_id in buffer_pool.active_buffers - assert buffer_pool.active_buffers[buffer_id][1] is None - - # Trigger GC - should NOT release (pending registration) - buffer_pool._gc() - assert buffer_id in buffer_pool.active_buffers - - # Re-acquire with symlink - buf_io.close() - - # Create valid symlink path - buf_io = buffer_pool.acquire(associated_symlink=symlink_path) - buffer_id = buf_io.buffer_obj.get_id() - - assert os.path.islink(symlink_path) - assert buffer_id in buffer_pool.active_buffers - assert buffer_pool.active_buffers[buffer_id][1] == symlink_path - - # Delete symlink - os.remove(symlink_path) - - # Trigger GC - buffer_pool._gc() - - assert buffer_id not in buffer_pool.active_buffers - assert buf_io._buffer_io in buffer_pool.free_buffers - - def test_symlink_creation_failure(self, buffer_pool, tmp_path): - """Verifies that acquire fails and releases buffer if symlink creation fails.""" - symlink_path = str(tmp_path / "symlink_fail") - # Create a directory at symlink path to cause OSError - os.makedirs(symlink_path) - - with pytest.raises(OSError): - buffer_pool.acquire(associated_symlink=symlink_path) - - # Verify buffer was returned to free pool - assert len(buffer_pool.free_buffers) == 3 - # Since we don't know exactly which buffer was picked, checks size - assert len(buffer_pool.active_buffers) == 0 - - def test_init_invalid_args(self): - """Verifies that initialization raises ValueError for invalid arguments.""" - # Test num_buffers <= 0 - with pytest.raises(ValueError, match="Number of buffers must be positive"): - BufferPool(pool_dir_path="/tmp", num_buffers=0, buffer_size=1024) - - # Test buffer_size <= 0 - with pytest.raises(ValueError, match="Buffer size must be positive"): - BufferPool(pool_dir_path="/tmp", num_buffers=3, buffer_size=0) - - # Test missing pool_dir_path (empty string) - with pytest.raises(ValueError, match="Pool directory path must be provided"): - BufferPool(pool_dir_path="", num_buffers=3, buffer_size=1024) - - -class TestBufferIOProxy: - @pytest.fixture - def mock_buffer_io(self, mocker): - """Creates a mock BufferIO object.""" - buffer_io = mocker.Mock(spec=BufferIO) - buffer_io.closed = False - # Setup buffer_obj mock for capacity checks - buffer_io.buffer_obj = mocker.Mock() - buffer_io.buffer_obj.get_capacity.return_value = 1000 - buffer_io.tell.return_value = 0 - return buffer_io - - @pytest.fixture - def proxy(self, mock_buffer_io): - """Creates a BufferIOProxy wrapping the mock.""" - return PooledBufferIO(mock_buffer_io) - - def test_delegation_basic(self, proxy, mock_buffer_io): - """Verifies that basic methods are delegated to the underlying BufferIO.""" - # read - proxy.read(10) - mock_buffer_io.read.assert_called_with(10) - - # seek - proxy.seek(5, 1) - mock_buffer_io.seek.assert_called_with(5, 1) - - # tell - proxy.tell() - mock_buffer_io.tell.assert_called_once() - - # flush - proxy.flush() - mock_buffer_io.flush.assert_called_once() - - def test_properties(self, proxy, mock_buffer_io): - """Verifies property delegation.""" - # buffer_obj - assert proxy.buffer_obj is mock_buffer_io.buffer_obj - - # closed - assert not proxy.closed - mock_buffer_io.closed = True - assert proxy.closed - - def test_write_delegation_success(self, proxy, mock_buffer_io): - """Verifies write delegates correctly when no resize is needed.""" - data = b"test" - proxy.write(data) - mock_buffer_io.write.assert_called_with(data) - - def test_write_auto_resize(self, proxy, mock_buffer_io): - """Verifies write triggers auto-resize when capacity is insufficient.""" - # Setup mock to succeed immediately (resize happens proactively) - mock_buffer_io.write.return_value = 1500 - - # Current capacity 1000, current pos 0 - mock_buffer_io.buffer_obj.get_capacity.return_value = 1000 - mock_buffer_io.tell.return_value = 0 - - data = b"x" * 1500 # Needs more than 1000 - - # Call write - proxy.write(data) - - # Verify resize was called - # Calculation: max(1000 * 1.1, METADATA_SIZE + 0 + 1500 + PADDING_SIZE) - assert mock_buffer_io.resize.called - # Check that it called write exactly once (after resize) - assert mock_buffer_io.write.call_count == 1 - mock_buffer_io.write.assert_called_with(data) - - def test_next_buffer_slice_delegation_success(self, proxy, mock_buffer_io): - """Verifies next_buffer_slice delegates correctly.""" - proxy.next_buffer_slice(100) - mock_buffer_io.next_buffer_slice.assert_called_with(100) - - def test_next_buffer_slice_auto_resize(self, proxy, mock_buffer_io, mocker): - """Verifies next_buffer_slice triggers resize.""" - mock_buffer_io.next_buffer_slice.return_value = mocker.Mock() - - mock_buffer_io.buffer_obj.get_capacity.return_value = 1000 - mock_buffer_io.tell.return_value = 900 - - # Request 200 bytes (total 1100 > 1000) - proxy.next_buffer_slice(200) - - assert mock_buffer_io.resize.called - assert mock_buffer_io.next_buffer_slice.call_count == 1 - - def test_close_truncate(self, proxy, mock_buffer_io, mocker): - """Verifies close calls buffer_obj.resize if truncate is True.""" - mock_buffer_io.is_readonly = False - mock_buffer_io._metadata = mocker.Mock() - mock_buffer_io._metadata.len_written_data = 500 - mock_buffer_io._mv = range(1000) # Mock len() - - proxy.close(truncate=True) - - target = METADATA_SIZE + 500 - mock_buffer_io.resize.assert_called_with(target) - - assert proxy.closed - - def test_close_no_truncate(self, proxy, mock_buffer_io): - """Verifies close does not resize if truncate is False.""" - proxy.close(truncate=False) - mock_buffer_io.resize.assert_not_called() - assert proxy.closed - - -class TestBufferPoolConfig: - def test_valid_config(self): - """Tests that a valid configuration does not raise any exceptions.""" - config = BufferPoolConfig(pool_dir_path="/tmp/pool", rank=0, num_buffers=3, buffer_size=1024) - assert config.pool_dir_path == "/tmp/pool" - assert config.rank == 0 - assert config.num_buffers == 3 - assert config.buffer_size == 1024 - - def test_invalid_pool_dir_path(self): - """Tests that missing pool_dir_path raises ValueError.""" - with pytest.raises(ValueError, match="pool_dir_path must be provided"): - BufferPoolConfig(pool_dir_path="", rank=0, num_buffers=3, buffer_size=1024) - - def test_invalid_rank(self): - """Tests that invalid rank raises ValueError.""" - with pytest.raises(ValueError, match="rank must be a non-negative integer"): - BufferPoolConfig(pool_dir_path="/tmp/pool", rank=-1, num_buffers=3, buffer_size=1024) - with pytest.raises(ValueError, match="rank must be a non-negative integer"): - BufferPoolConfig( - pool_dir_path="/tmp/pool", - rank="0", # type: ignore - num_buffers=3, - buffer_size=1024, - ) - - def test_invalid_num_buffers(self): - """Tests that invalid num_buffers raises ValueError.""" - with pytest.raises(ValueError, match="num_buffers must be a non-negative integer"): - BufferPoolConfig(pool_dir_path="/tmp/pool", rank=0, num_buffers=-1, buffer_size=1024) - with pytest.raises(ValueError, match="num_buffers must be a non-negative integer"): - BufferPoolConfig( - pool_dir_path="/tmp/pool", - rank=0, - num_buffers="3", # type: ignore - buffer_size=1024, - ) - - def test_invalid_buffer_size(self): - """Tests that invalid buffer_size raises ValueError.""" - with pytest.raises(ValueError, match="buffer_size must be a non-negative integer"): - BufferPoolConfig(pool_dir_path="/tmp/pool", rank=0, num_buffers=3, buffer_size=-100) - with pytest.raises(ValueError, match="buffer_size must be a non-negative integer"): - BufferPoolConfig( - pool_dir_path="/tmp/pool", - rank=0, - num_buffers=3, - buffer_size="1024", # type: ignore - ) diff --git a/tests/core/test_checkpoint_saver.py b/tests/core/test_checkpoint_saver.py index 1659b2d..d76e739 100644 --- a/tests/core/test_checkpoint_saver.py +++ b/tests/core/test_checkpoint_saver.py @@ -109,7 +109,7 @@ def chkpt_object_manager(self, temp_dir_path): CheckpointObjectManager._worker_pool = None config = BufferPoolConfig(pool_dir_path=pool_dir, rank=0, num_buffers=3, buffer_size=1024 * 1024) - manager = CheckpointObjectManager(pool_config=config) + manager = CheckpointObjectManager(local_pool_config=config) yield manager # Teardown diff --git a/tests/core/test_cpp_buffer_pool.py b/tests/core/test_cpp_buffer_pool.py new file mode 100644 index 0000000..8c0aef0 --- /dev/null +++ b/tests/core/test_cpp_buffer_pool.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import multiprocessing as mp +import pytest + +from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferPool + +def worker_target(shm_name, pool_dir, rank, num_buffers, buffer_size, result_queue): + """Worker process that tries to acquire a buffer.""" + try: + # Attach to pool + pool = BufferPool(shm_name=shm_name, pool_dir=pool_dir, rank=rank, + num_buffers=num_buffers, buffer_size=buffer_size) + + try: + buf_path = pool.acquire() + result_queue.put(("SUCCESS", buf_path)) + except RuntimeError as e: + result_queue.put(("ERROR", str(e))) + + except Exception as e: + result_queue.put(("EXCEPTION", str(e))) + +class TestCppBufferPoolMultiprocess: + def test_shared_pool_exhaustion(self, tmp_path): + """Verifies that pool is shared across processes and exhaustion is respected.""" + # Use a unique shm name to avoid conflicts with running tests + shm_name = f"/test_shm_pool_{int(time.time())}" + pool_dir = str(tmp_path / "pool_dir") + os.makedirs(pool_dir) + rank = 0 + num_buffers = 2 + buffer_size = 1024 + + # 1. Initialize pool in main process + pool = BufferPool(shm_name=shm_name, pool_dir=pool_dir, rank=rank, + num_buffers=num_buffers, buffer_size=buffer_size) + + # 2. Acquire all buffers in main process + path1 = pool.acquire() + path2 = pool.acquire() + + # 3. Start worker process to try to acquire another buffer + result_queue = mp.Queue() + p = mp.Process(target=worker_target, args=(shm_name, pool_dir, rank, num_buffers, buffer_size, result_queue)) + p.start() + + p.join(timeout=5) + if p.is_alive(): + p.terminate() + pytest.fail("Worker process timed out") + + status, message = result_queue.get() + assert status == "ERROR" + assert "BufferPool exhausted" in message + + # 4. Release one buffer in main process + pool.release(path1) + + # 5. Start another worker process + p2 = mp.Process(target=worker_target, args=(shm_name, pool_dir, rank, num_buffers, buffer_size, result_queue)) + p2.start() + p2.join(timeout=5) + + status, message = result_queue.get() + assert status == "SUCCESS" + assert message == path1 # Should get the released buffer + + def test_buffer_resize_via_symlink(self, tmp_path): + """Verifies that BufferObject resizes the pooled buffer when opened via symlink with overwrite=True.""" + shm_name = f"/test_shm_pool_resize_{int(time.time())}" + pool_dir = str(tmp_path / "pool_dir") + os.makedirs(pool_dir) + rank = 0 + num_buffers = 1 + buffer_size = 1024 + + pool = BufferPool(shm_name=shm_name, pool_dir=pool_dir, rank=rank, + num_buffers=num_buffers, buffer_size=buffer_size) + + symlink_path = str(tmp_path / "my_symlink") + path1 = pool.acquire(symlink_path) + + assert os.path.islink(symlink_path) + + target_path = os.readlink(symlink_path) + assert os.path.getsize(target_path) == buffer_size + + new_size = 2048 + from ml_flashpoint.checkpoint_object_manager.buffer_object.buffer_object_ext import BufferObject + + buffer_obj = BufferObject(symlink_path, new_size, True) + + assert os.path.getsize(target_path) == new_size + + buffer_obj.close() + pool.release(symlink_path) diff --git a/tests/replication/test_replication_manager.py b/tests/replication/test_replication_manager.py index 7820144..13aabc4 100644 --- a/tests/replication/test_replication_manager.py +++ b/tests/replication/test_replication_manager.py @@ -59,7 +59,9 @@ def mock_gather(): # Then assert replication_manager._listen_port == 12345 - mock_transfer_service_instance.initialize.assert_called_once_with(0, global_rank=0) + mock_transfer_service_instance.initialize.assert_called_once_with( + 0, global_rank=0, repl_shm_name=mock_checkpoint_manager.replication_pool_shm_name + ) def test_initialize_binds_port_failure(replication_manager, mocker): diff --git a/tests/replication/test_replication_manager_e2e.py b/tests/replication/test_replication_manager_e2e.py index 002f109..b7a4b55 100644 --- a/tests/replication/test_replication_manager_e2e.py +++ b/tests/replication/test_replication_manager_e2e.py @@ -25,9 +25,9 @@ from ml_flashpoint.replication.transfer_service import transfer_service_ext -def run_service(service, port_container, rank): +def run_service(service, port_container, rank, repl_shm_name=""): """Target function for running a TransferService in a thread.""" - port = service.initialize(0, global_rank=rank) + port = service.initialize(0, global_rank=rank, repl_shm_name=repl_shm_name) port_container.append(port) @@ -191,3 +191,104 @@ def test_sync_bulk_retrieve_end_to_end(tmp_path, services, mocker): retrieved_data = retrieved_buffer_io.read() assert retrieved_data == original_data_map[obj_id] ckpt_obj_manager.close_buffer(retrieved_buffer_io) + + +@pytest.mark.e2e +def test_sync_bulk_retrieve_with_buffer_pool_and_resize(tmp_path, services, mocker): + """ + An end-to-end test for sync_bulk_retrieve using BufferPool and triggering resize. + """ + # Given + # We ignore receiver_service from fixture because we need to create one with pool attached! + sender_service, _, sender_addr, _ = services + + # Mock torch.distributed + mocker.patch("torch.distributed.get_rank", return_value=1) + mocker.patch("torch.distributed.get_world_size", return_value=2) + mocker.patch("torch.cuda.device_count", return_value=1) + mocker.patch("ml_flashpoint.core.utils.get_num_of_nodes", return_value=2) + + # METADATA_SIZE is 4096. We need buffer larger than that. + metadata_size = 4096 + + # Create a BufferObject on the sender side with size metadata_size + 2048 + obj_id = str(tmp_path / "test_buffer_large") + capacity = metadata_size + 2048 + sender_bo = buffer_object_ext.BufferObject(obj_id, capacity, overwrite=True) + original_data = os.urandom(2048) + + buffer_io = BufferIO(sender_bo) + buffer_io.write(original_data) + buffer_io.close() + + # Setup CheckpointObjectManager with BufferPoolConfig on receiver + # Set default buffer_size small (e.g. metadata_size + 1024) to trigger resize! + pool_dir = tmp_path / "pool_dir" + os.makedirs(pool_dir) + from ml_flashpoint.core.buffer_pool import BufferPoolConfig + pool_config = BufferPoolConfig( + pool_dir_path=str(pool_dir), + rank=1, + num_buffers=2, + buffer_size=metadata_size + 1024, # Smaller than metadata_size + 2048! + ) + ckpt_obj_manager = CheckpointObjectManager(repl_pool_config=pool_config) + + # Create and start a NEW receiver TransferService with pool attached! + new_receiver_service = transfer_service_ext.TransferService() + repl_shm_name = ckpt_obj_manager.replication_pool_shm_name + receiver_port_container = [] + + import threading + import time + + receiver_thread = threading.Thread( + target=run_service, + args=(new_receiver_service, receiver_port_container, 1, repl_shm_name) + ) + receiver_thread.daemon = True + receiver_thread.start() + + start_time = time.time() + while not receiver_port_container and time.time() - start_time < 5: + time.sleep(0.1) + assert receiver_port_container, "New receiver service failed to start" + + manager = ReplicationManager() + strategy = PairwiseReplicationStrategy( + replication_service_addresses=[sender_addr, "127.0.0.1:0"], processes_per_node=1 + ) + manager.initialize( + checkpoint_object_manager=ckpt_obj_manager, + replication_transfer_service=new_receiver_service, + repl_strategy=strategy, + ) + + retrieved_obj_id = obj_id + "_retrieved" + + try: + # When + success = manager.sync_bulk_retrieve( + source_global_rank=0, + object_ids_to_retrieve=[obj_id], + container_ids_to_retrieve=[], + retrieved_object_ids=[retrieved_obj_id], + retrieved_container_ids=[], + ) + + # Then + assert success + + # Verify that retrieved_obj_id is a symlink! + assert os.path.islink(retrieved_obj_id) + + # Verify the retrieved data + retrieved_bo = buffer_object_ext.BufferObject(retrieved_obj_id) + retrieved_buffer_io = BufferIO(retrieved_bo) + retrieved_data = retrieved_buffer_io.read() + assert retrieved_data == original_data + ckpt_obj_manager.close_buffer(retrieved_buffer_io) + + finally: + new_receiver_service.shutdown() + receiver_thread.join(timeout=5) diff --git a/tests/replication/test_transer_service.py b/tests/replication/test_transer_service.py index 9cb3890..9a2a2b0 100644 --- a/tests/replication/test_transer_service.py +++ b/tests/replication/test_transer_service.py @@ -351,3 +351,55 @@ def test_async_get_non_existent_object( # Then with pytest.raises(RuntimeError, match="Received error message"): get_future.result(timeout=10) + + +def test_transfer_to_symlink( + transfer_services: tuple[ + transfer_service_ext.TransferService, + transfer_service_ext.TransferService, + str, + str, + ], +) -> None: + """Verifies that async_put follows symlinks when writing to destination.""" + service1, _, _, addr2 = transfer_services + + target_id = "test_target.object" + symlink_id = "test_symlink.object" + data_to_send = np.arange(20, dtype=np.int64) + expected_content = data_to_send.tobytes() + + # Create the target file + with open(target_id, "wb") as f: + f.write(b"initial data") + + # Create the symlink + pathlib.Path(symlink_id).symlink_to(target_id) + + try: + # Call the async_put method to transfer data from service1 to service2. + put_future = service1.async_put( + data_to_send.ctypes.data, # data_ptr + data_to_send.nbytes, # data_size + addr2, # dest_address + symlink_id, # dest_object_id (the symlink!) + ) + + # Block until the C++ future completes and get the result. + result = put_future.result(timeout=10) + + assert result.success is True, "The 'success' flag should be True." + + # Verify that the symlink is still a symlink + assert pathlib.Path(symlink_id).is_symlink(), "The symlink should still be a symlink." + + # Verify the content of the target file + with open(target_id, "rb") as f: + received_content = f.read() + assert received_content == expected_content, "Received content in target file does not match." + + finally: + # Clean up + safe_remove_file(symlink_id) + safe_remove_file(target_id) +