diff --git a/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp b/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp index e205d33..5b97377 100644 --- a/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp +++ b/src/ml_flashpoint/replication/transfer_service/connection_pool.cpp @@ -51,7 +51,7 @@ ScopedConnection::~ScopedConnection() { Release(); } void ScopedConnection::Release() { if (pool_ != nullptr && sockfd_ >= 0) { - pool_->ReleaseConnection(sockfd_, true); + pool_->ReleaseConnection(sockfd_, reuse_); } else if (sockfd_ >= 0) { close(sockfd_); } @@ -81,7 +81,7 @@ ConnectionPool::~ConnectionPool() { cv_.notify_all(); std::unique_lock lock(mtx_); while (!available_connections_.empty()) { - close(available_connections_.front()); + close(available_connections_.top()); available_connections_.pop(); } } @@ -96,7 +96,7 @@ bool ConnectionPool::Initialize() { int fd = CreateConnection(); if (fd < 0) { while (!available_connections_.empty()) { - close(available_connections_.front()); + close(available_connections_.top()); available_connections_.pop(); } return false; @@ -154,38 +154,109 @@ int ConnectionPool::CreateConnection() { return -1; } +bool ConnectionPool::IsConnectionAlive(int sockfd) { + if (sockfd < 0) return false; + + char buf; + // We use MSG_PEEK | MSG_DONTWAIT to check the status of the TCP connection + // without actually consuming any data from the socket's receive buffer. + // This is a fast, zero-copy way to ask the kernel if the connection has + // been closed or has encountered an error while it was idle in the pool. + ssize_t r = recv(sockfd, &buf, 1, MSG_PEEK | MSG_DONTWAIT); + + if (r == 0) { + // If recv returns 0, it means the remote peer has performed an orderly + // shutdown (sent a FIN packet). The connection is no longer usable for + // sending new requests. + return false; + } else if (r < 0) { + // If recv returns -1, we check errno to distinguish between "no data" + // and a real network error. + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // EAGAIN/EWOULDBLOCK means the connection is still alive and healthy, + // but there is currently no data waiting to be read. This is the + // expected state for an idle connection in the pool. + return true; + } + // Any other error (like ECONNRESET or EPIPE) indicates that the + // connection has been broken or timed out. + return false; + } + // If r > 0, there is actual data waiting in the buffer. While unusual for + // an idle pool connection, it indicates the connection is definitely alive. + return true; +} + std::optional ConnectionPool::GetConnection(int timeout_ms) { CHECK_GT(timeout_ms, 0) << "timeout_ms must be positive"; - std::unique_lock lock(mtx_); - if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout_ms), [this] { - return !available_connections_.empty() || stopping_; - })) { - LOG(WARNING) << "ConnectionPool::GetConnection: timeout"; - return std::nullopt; - } - if (stopping_) { - LOG(WARNING) << "ConnectionPool::GetConnection: stopping"; - return std::nullopt; - } - if (available_connections_.empty()) { - // TODO: Handle the case when we run out of connections - LOG(WARNING) << "ConnectionPool::GetConnection: no available connections"; - return std::nullopt; + + // Calculate the absolute deadline to ensure we respect the user-provided + // timeout even if we have to loop through several dead connections. + auto start_time = std::chrono::steady_clock::now(); + auto end_time = start_time + std::chrono::milliseconds(timeout_ms); + + while (true) { + std::unique_lock lock(mtx_); + + // Re-calculate the remaining wait time for each iteration of the loop. + auto remaining = std::chrono::duration_cast( + end_time - std::chrono::steady_clock::now()); + + if (remaining.count() <= 0 || !cv_.wait_for(lock, remaining, [this] { + return !available_connections_.empty() || stopping_; + })) { + LOG(WARNING) << "ConnectionPool::GetConnection: timeout reached while " + "searching for a healthy connection"; + return std::nullopt; + } + + if (stopping_) { + LOG(WARNING) << "ConnectionPool::GetConnection: pool is shutting down"; + return std::nullopt; + } + + // Pop the most recently used connection from the LIFO stack. + // LIFO (Last-In, First-Out) is preferred for connection pools as it + // increases the likelihood of reusing "hot" connections that still have + // active TCP state (e.g., large congestion windows) and are still cached + // in the kernel/CPU. + int fd = available_connections_.top(); + available_connections_.pop(); + + // Verify the connection's health before handing it to the caller. + // This protects against "stale" connections that were closed by the + // peer or a firewall while sitting idle in the pool. + if (IsConnectionAlive(fd)) { + return ScopedConnection(fd, this); + } + + // The connection is dead. We close it and attempt to retrieve another + // one from the stack. + LOG(INFO) << "ConnectionPool::GetConnection: discarded dead connection; " + "retrying with next available connection"; + close(fd); + + // To maintain the desired pool size, we immediately attempt to open a + // replacement connection. This ensures the pool doesn't slowly drain + // if many connections go stale at once. + int new_fd = CreateConnection(); + if (new_fd >= 0) { + available_connections_.push(new_fd); + // The loop will continue and pick up this or another connection. + } } - int fd = available_connections_.front(); - available_connections_.pop(); - return ScopedConnection(fd, this); } // Returns a connection to the pool, allowing it to be reused. // // If `reuse` is true and the pool is not full, the connection is added back to -// the queue of available connections. Otherwise, the connection is closed. +// the stack of available connections. Otherwise, the connection is closed. void ConnectionPool::ReleaseConnection(int sockfd, bool reuse) { if (sockfd < 0) { LOG(WARNING) << "ConnectionPool::ReleaseConnection: invalid sockfd"; return; } + std::unique_lock lock(mtx_); if (stopping_) { LOG(WARNING) @@ -193,11 +264,12 @@ void ConnectionPool::ReleaseConnection(int sockfd, bool reuse) { close(sockfd); return; } + if (reuse) { if (available_connections_.size() < max_size_) { LOG(INFO) << "ConnectionPool::ReleaseConnection: reuse connection"; - // TODO: Check if we need cleanup for the connection before return it to - // the pool + // We push to the stack to ensure this connection is the first to be + // reused by the next caller (LIFO). available_connections_.push(sockfd); cv_.notify_one(); } else { @@ -207,8 +279,22 @@ void ConnectionPool::ReleaseConnection(int sockfd, bool reuse) { } } else { LOG(INFO) - << "ConnectionPool::ReleaseConnection: do not reuse, close connection"; + << "ConnectionPool::ReleaseConnection: connection marked as unusable, " + "closing and replenishing pool"; close(sockfd); + + // Since we are discarding a connection that was previously part of the + // pool's "active" set, we create a new one to maintain the fixed pool size. + // This prevents the pool from permanently shrinking when network errors + // occur. + int new_fd = CreateConnection(); + if (new_fd >= 0) { + available_connections_.push(new_fd); + cv_.notify_one(); + } else { + LOG(ERROR) << "ConnectionPool::ReleaseConnection: failed to replenish " + "pool after discarding unusable connection"; + } } } } // namespace ml_flashpoint::replication::transfer_service diff --git a/src/ml_flashpoint/replication/transfer_service/connection_pool.h b/src/ml_flashpoint/replication/transfer_service/connection_pool.h index 229cac5..715a62a 100644 --- a/src/ml_flashpoint/replication/transfer_service/connection_pool.h +++ b/src/ml_flashpoint/replication/transfer_service/connection_pool.h @@ -35,7 +35,7 @@ #include #include #include -#include +#include #include namespace ml_flashpoint::replication::transfer_service { @@ -57,11 +57,17 @@ class ScopedConnection { int fd() const { return sockfd_; } bool IsValid() const { return sockfd_ >= 0; } + + // Marks the connection as unusable (e.g., after a socket error). + // This prevents it from being returned to the pool for reuse. + void SetUnusable() { reuse_ = false; } + void Release(); private: int sockfd_; ConnectionPool* pool_; + bool reuse_ = true; }; // Manages a thread-safe pool of TCP connections to a single peer. @@ -108,6 +114,9 @@ class ConnectionPool { // Releases a connection back to the pool or closes it. void ReleaseConnection(int sockfd, bool reuse = true); + // Checks if a connection is still alive by performing a non-blocking peek. + bool IsConnectionAlive(int sockfd); + // Creates a new connection to the peer. // Returns the socket file descriptor on a successful connection, or -1 on // failure. @@ -116,7 +125,7 @@ class ConnectionPool { std::string peer_host_; int peer_port_; size_t max_size_; - std::queue available_connections_; // Guarded by mtx_. + std::stack available_connections_; // Guarded by mtx_. std::mutex mtx_; // Protects available_connections_ and stopping_. std::condition_variable cv_; // Signaled when a connection is released or the pool is stopping. diff --git a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp index ea7f735..e848274 100644 --- a/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp +++ b/src/ml_flashpoint/replication/transfer_service/transfer_service.cpp @@ -635,59 +635,87 @@ void TransferService::ExecutePutTask(PutTask* task) { auto& metric_container = task->GetMetricContainer(); metric_container.start_execution_time = absl::Now(); LOG(INFO) << "Executing PutTask for task_id=" << task->GetTaskId(); - auto conn_opt = GetConnectionFromPool(task->GetDestAddr()); - if (!conn_opt) { - ReportResult(task->GetTaskId(), false, "Failed to get connection"); - return; - } - ScopedConnection conn = std::move(conn_opt.value()); - metric_container.connection_acquired_time = absl::Now(); - - ObjInfoHeader header; - std::memset(&header, 0, sizeof(ObjInfoHeader)); - snprintf(header.dest_obj_id, sizeof(header.dest_obj_id), "%s", - task->GetDestObjId().c_str()); - header.type = MessageType::kPutObj; - header.obj_size = task->GetDataSize(); - - int sockfd = conn.fd(); - if (!SendAll(sockfd, &header, kHeaderSize).ok()) { - LOG(ERROR) << "perform_send_obj: Failed sending header/filename for " - << task->GetDestAddr(); - ReportResult(task->GetTaskId(), false, "Failed to send header"); - return; - } - metric_container.header_sent_time = absl::Now(); - - if (!SendAll(sockfd, task->GetDataPtr(), task->GetDataSize()).ok()) { - LOG(ERROR) << "perform_send_obj: Failed sending data for " - << task->GetDestAddr(); - ReportResult(task->GetTaskId(), false, "Failed to send data"); - return; - } - metric_container.data_sent_time = absl::Now(); + // We use a retry loop to handle cases where a connection from the pool + // appears healthy but fails during the initial handshake (e.g., due to + // a race condition where the remote peer closes the socket just as we + // retrieve it). + constexpr int kMaxRetries = 2; + for (int attempt = 0; attempt < kMaxRetries; ++attempt) { + auto conn_opt = GetConnectionFromPool(task->GetDestAddr()); + if (!conn_opt) { + ReportResult(task->GetTaskId(), false, "Failed to get connection"); + return; + } + ScopedConnection conn = std::move(conn_opt.value()); + metric_container.connection_acquired_time = absl::Now(); + + ObjInfoHeader header; + std::memset(&header, 0, sizeof(ObjInfoHeader)); + snprintf(header.dest_obj_id, sizeof(header.dest_obj_id), "%s", + task->GetDestObjId().c_str()); + header.type = MessageType::kPutObj; + header.obj_size = task->GetDataSize(); + + int sockfd = conn.fd(); + + // Step 1: Send the Object Metadata Header. + if (!SendAll(sockfd, &header, kHeaderSize).ok()) { + LOG(WARNING) << "ExecutePutTask: Failed sending header on attempt " + << attempt; + // Socket error occurred. Mark connection as unusable so it is closed + // instead of returned to the pool, preventing "pool poisoning." + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; // Try with a fresh connection. + ReportResult(task->GetTaskId(), false, "Failed to send header"); + return; + } + metric_container.header_sent_time = absl::Now(); + + // Step 2: Send the actual object data. + if (!SendAll(sockfd, task->GetDataPtr(), task->GetDataSize()).ok()) { + LOG(WARNING) << "ExecutePutTask: Failed sending data on attempt " + << attempt; + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; + ReportResult(task->GetTaskId(), false, "Failed to send data"); + return; + } + metric_container.data_sent_time = absl::Now(); + + // Step 3: Wait for Acknowledgement from the destination. + ObjInfoHeader ack_header; + if (!RecvHeader(sockfd, ack_header).ok()) { + LOG(WARNING) << "ExecutePutTask: Failed to receive ACK on attempt " + << attempt; + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; + ReportResult(task->GetTaskId(), false, "Failed to receive ACK"); + return; + } - ObjInfoHeader ack_header; - if (!RecvHeader(sockfd, ack_header).ok()) { - ReportResult(task->GetTaskId(), false, "Failed to receive ACK"); - return; - } - switch (ack_header.type) { - case MessageType::kAck: - LOG(INFO) << "Buffer data sent and ACK received successfully !!!"; - metric_container.finish_time = absl::Now(); - ReportResult(task->GetTaskId(), true, - "Buffer data sent and ACK received"); - break; - case MessageType::kError: - ReportResult(task->GetTaskId(), false, "Received error from destination"); - break; - case MessageType::kPutObj: - case MessageType::kGetObj: - case MessageType::kRespondToGetObj: - ReportResult(task->GetTaskId(), false, "Received unexpected ACK"); - break; + // Exhaustive handling of message types to ensure no unexpected states. + switch (ack_header.type) { + case MessageType::kAck: + LOG(INFO) << "Buffer data sent and ACK received successfully !!!"; + metric_container.finish_time = absl::Now(); + ReportResult(task->GetTaskId(), true, + "Buffer data sent and ACK received"); + return; + case MessageType::kError: + // Remote peer explicitly reported a failure. Retrying is unlikely + // to help if the error is logic-related (e.g., disk full). + ReportResult(task->GetTaskId(), false, + "Received error from destination"); + return; + case MessageType::kPutObj: + case MessageType::kGetObj: + case MessageType::kRespondToGetObj: + // Receiving a request type instead of an ACK is a protocol violation. + ReportResult(task->GetTaskId(), false, + "Received unexpected message type"); + return; + } } } @@ -786,55 +814,75 @@ void TransferService::ExecuteGetTask(GetTask* task) { metric_container.start_execution_time = absl::Now(); LOG(INFO) << "Executing GetTask for source address " << task->GetSourceAddr() << ", source obj id: " << task->GetSourceObjId(); - auto conn_opt = GetConnectionFromPool(task->GetSourceAddr()); - if (!conn_opt) { - ReportResult(task->GetTaskId(), false, "Failed to get connection"); - return; - } - ScopedConnection conn = std::move(conn_opt.value()); - metric_container.connection_acquired_time = absl::Now(); - ObjInfoHeader header; - header.type = MessageType::kGetObj; - snprintf(header.task_id, sizeof(header.task_id), "%s", - task->GetTaskId().c_str()); - snprintf(header.source_obj_id, sizeof(header.source_obj_id), "%s", - task->GetSourceObjId().c_str()); - snprintf(header.dest_obj_id, sizeof(header.dest_obj_id), "%s", - task->GetDestObjId().c_str()); - snprintf(header.source_address, sizeof(header.source_address), "%s", - task->GetSourceAddr().c_str()); - snprintf(header.dest_address, sizeof(header.dest_address), "%s", - task->GetDestAddr().c_str()); - header.obj_size = 0; // Not used for request - - if (!SendAll(conn.fd(), &header, kHeaderSize).ok()) { - LOG(ERROR) << "Failed to send GET_OBJ header"; - ReportResult(task->GetTaskId(), false, "Failed to send GET_OBJ header"); - return; - } - metric_container.header_sent_time = absl::Now(); - - // Wait for the immediate response (ACK or ERROR) - ObjInfoHeader resp_header; - if (!RecvHeader(conn.fd(), resp_header).ok()) { - ReportResult(task->GetTaskId(), false, - "Failed to receive response for GET request"); - return; - } - switch (resp_header.type) { - case MessageType::kAck: - LOG(INFO) << "Received ACK for GET request. Waiting for data transfer."; - break; - case MessageType::kError: - ReportResult(task->GetTaskId(), false, "Received error message"); - break; - case MessageType::kPutObj: - case MessageType::kGetObj: - case MessageType::kRespondToGetObj: + // Retry loop to handle transient connection drops during the Get request. + constexpr int kMaxRetries = 2; + for (int attempt = 0; attempt < kMaxRetries; ++attempt) { + auto conn_opt = GetConnectionFromPool(task->GetSourceAddr()); + if (!conn_opt) { + ReportResult(task->GetTaskId(), false, "Failed to get connection"); + return; + } + ScopedConnection conn = std::move(conn_opt.value()); + metric_container.connection_acquired_time = absl::Now(); + + ObjInfoHeader header; + header.type = MessageType::kGetObj; + snprintf(header.task_id, sizeof(header.task_id), "%s", + task->GetTaskId().c_str()); + snprintf(header.source_obj_id, sizeof(header.source_obj_id), "%s", + task->GetSourceObjId().c_str()); + snprintf(header.dest_obj_id, sizeof(header.dest_obj_id), "%s", + task->GetDestObjId().c_str()); + snprintf(header.source_address, sizeof(header.source_address), "%s", + task->GetSourceAddr().c_str()); + snprintf(header.dest_address, sizeof(header.dest_address), "%s", + task->GetDestAddr().c_str()); + header.obj_size = 0; // Not used for request + + // Send the GET request header. + if (!SendAll(conn.fd(), &header, kHeaderSize).ok()) { + LOG(WARNING) + << "ExecuteGetTask: Failed to send GET_OBJ header on attempt " + << attempt; + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; + ReportResult(task->GetTaskId(), false, "Failed to send GET_OBJ header"); + return; + } + metric_container.header_sent_time = absl::Now(); + + // Wait for the remote peer to confirm they have the object and are + // starting the transfer. + ObjInfoHeader resp_header; + if (!RecvHeader(conn.fd(), resp_header).ok()) { + LOG(WARNING) << "ExecuteGetTask: Failed to receive response on attempt " + << attempt; + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; ReportResult(task->GetTaskId(), false, - "Received unexpected response for GET request"); - break; + "Failed to receive response for GET request"); + return; + } + + switch (resp_header.type) { + case MessageType::kAck: + // Request accepted. The remote peer will now initiate a separate + // connection back to us to send the data. + LOG(INFO) << "Received ACK for GET request. Waiting for data transfer."; + return; + case MessageType::kError: + // Remote peer could not find the object or encountered an error. + ReportResult(task->GetTaskId(), false, "Received error message"); + return; + case MessageType::kPutObj: + case MessageType::kGetObj: + case MessageType::kRespondToGetObj: + // Unexpected message type in this context. + ReportResult(task->GetTaskId(), false, + "Received unexpected response type for GET request"); + return; + } } } @@ -847,75 +895,98 @@ void TransferService::ExecuteRespondToGetTask(RespondToGetTask* task) { << ", source_addr=" << task->GetSourceAddr() << ", dest_addr=" << task->GetDestAddr(); - auto conn_opt = GetConnectionFromPool(task->GetDestAddr()); - if (!conn_opt) { - LOG(ERROR) << "Failed to get connection!"; - ReportResult(task->GetTaskId(), false, "Failed to get connection"); - return; - } - ScopedConnection conn = std::move(conn_opt.value()); - metric_container.connection_acquired_time = absl::Now(); - - // Open file as buffer object - BufferObject buffer_obj(task->GetSourceObjId()); - void* buffer_data_ptr = buffer_obj.get_data_ptr(); - size_t size = buffer_obj.get_capacity(); - - if (buffer_data_ptr == nullptr) { - LOG(ERROR) << "RespondToGetTask failed: Could not open buffer object for '" - << task->GetSourceObjId() << "'"; - ReportResult(task->GetTaskId(), false, "Failed to create buffer object"); - return; - } - - ObjInfoHeader header; - - std::memset(header.task_id, 0, sizeof(header.task_id)); - snprintf(header.task_id, sizeof(header.task_id), "%s", - task->GetTaskId().c_str()); - header.task_id[sizeof(header.task_id) - 1] = '\0'; - - std::memset(header.source_obj_id, 0, sizeof(header.source_obj_id)); - snprintf(header.source_obj_id, sizeof(header.source_obj_id), "%s", - task->GetSourceObjId().c_str()); - - std::memset(header.dest_obj_id, 0, sizeof(header.dest_obj_id)); - snprintf(header.dest_obj_id, sizeof(header.dest_obj_id), "%s", - task->GetDestObjId().c_str()); - - header.type = MessageType::kRespondToGetObj; - header.obj_size = size; - int sockfd = conn.fd(); - - if (!SendAll(sockfd, &header, kHeaderSize).ok()) { - LOG(ERROR) << "Failed to send kRespondToGetObj header"; - ReportResult(task->GetTaskId(), false, "Failed to send header"); - return; - } - metric_container.header_sent_time = absl::Now(); - - if (!SendAll(sockfd, buffer_data_ptr, size).ok()) { - LOG(ERROR) << "Failed to send buffer data"; - ReportResult(task->GetTaskId(), false, "Failed to send data"); - return; - } - metric_container.data_sent_time = absl::Now(); + // Retry loop to handle connection failures when trying to push the + // requested data back to the original requester. + constexpr int kMaxRetries = 2; + for (int attempt = 0; attempt < kMaxRetries; ++attempt) { + auto conn_opt = GetConnectionFromPool(task->GetDestAddr()); + if (!conn_opt) { + LOG(ERROR) << "Failed to get connection!"; + ReportResult(task->GetTaskId(), false, "Failed to get connection"); + return; + } + ScopedConnection conn = std::move(conn_opt.value()); + metric_container.connection_acquired_time = absl::Now(); + + // Prepare the local data buffer. + BufferObject buffer_obj(task->GetSourceObjId()); + void* buffer_data_ptr = buffer_obj.get_data_ptr(); + size_t size = buffer_obj.get_capacity(); + + if (buffer_data_ptr == nullptr) { + LOG(ERROR) + << "RespondToGetTask failed: Could not open buffer object for '" + << task->GetSourceObjId() << "'"; + ReportResult(task->GetTaskId(), false, "Failed to create buffer object"); + return; + } - ObjInfoHeader ack_header; - if (!RecvHeader(sockfd, ack_header).ok()) { - LOG(ERROR) << "Failed to receive ACK"; - ReportResult(task->GetTaskId(), false, "Failed to receive ACK"); - return; - } + ObjInfoHeader header; + std::memset(&header, 0, sizeof(header)); + snprintf(header.task_id, sizeof(header.task_id), "%s", + task->GetTaskId().c_str()); + snprintf(header.source_obj_id, sizeof(header.source_obj_id), "%s", + task->GetSourceObjId().c_str()); + snprintf(header.dest_obj_id, sizeof(header.dest_obj_id), "%s", + task->GetDestObjId().c_str()); + + header.type = MessageType::kRespondToGetObj; + header.obj_size = size; + int sockfd = conn.fd(); + + // Send the response header. + if (!SendAll(sockfd, &header, kHeaderSize).ok()) { + LOG(WARNING) + << "ExecuteRespondToGetTask: Failed to send header on attempt " + << attempt; + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; + ReportResult(task->GetTaskId(), false, "Failed to send header"); + return; + } + metric_container.header_sent_time = absl::Now(); + + // Stream the data back to the requester. + if (!SendAll(sockfd, buffer_data_ptr, size).ok()) { + LOG(WARNING) << "ExecuteRespondToGetTask: Failed to send data on attempt " + << attempt; + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; + ReportResult(task->GetTaskId(), false, "Failed to send data"); + return; + } + metric_container.data_sent_time = absl::Now(); + + // Wait for the final ACK to confirm the data was successfully received. + ObjInfoHeader ack_header; + if (!RecvHeader(sockfd, ack_header).ok()) { + LOG(WARNING) + << "ExecuteRespondToGetTask: Failed to receive ACK on attempt " + << attempt; + conn.SetUnusable(); + if (attempt < kMaxRetries - 1) continue; + ReportResult(task->GetTaskId(), false, "Failed to receive ACK"); + return; + } - if (ack_header.type != MessageType::kAck) { - LOG(ERROR) << "Failed to receive ACK for RespondToGetObj"; - ReportResult(task->GetTaskId(), false, "Received unexpected ACK"); - return; + switch (ack_header.type) { + case MessageType::kAck: + metric_container.finish_time = absl::Now(); + ReportResult(task->GetTaskId(), true, + "RespondToGetTask completed successfully"); + return; + case MessageType::kError: + ReportResult(task->GetTaskId(), false, + "Received error from destination"); + return; + case MessageType::kPutObj: + case MessageType::kGetObj: + case MessageType::kRespondToGetObj: + ReportResult(task->GetTaskId(), false, + "Received unexpected message type"); + return; + } } - metric_container.finish_time = absl::Now(); - ReportResult(task->GetTaskId(), true, - "RespondToGetTask completed successfully"); } std::optional TransferService::GetConnectionFromPool( diff --git a/tests/replication/transfer_service/connection_pool_test.cpp b/tests/replication/transfer_service/connection_pool_test.cpp index 3227ecb..22bc750 100644 --- a/tests/replication/transfer_service/connection_pool_test.cpp +++ b/tests/replication/transfer_service/connection_pool_test.cpp @@ -47,7 +47,8 @@ class ConnectionPoolTest : public ::testing::Test { if (fd == -1) { break; } - close(fd); + std::lock_guard lock(accepted_fds_mutex_); + accepted_fds_.push_back(fd); } }); } @@ -56,33 +57,139 @@ class ConnectionPoolTest : public ::testing::Test { shutdown(listen_fd_, SHUT_RDWR); close(listen_fd_); accept_thread_.join(); + std::lock_guard lock(accepted_fds_mutex_); + for (int fd : accepted_fds_) { + close(fd); + } + accepted_fds_.clear(); } int port_ = 0; int listen_fd_ = -1; std::thread accept_thread_; + std::vector accepted_fds_; + std::mutex accepted_fds_mutex_; }; +TEST_F(ConnectionPoolTest, DiscardDeadConnection) { + // Given + ConnectionPool pool("127.0.0.1", port_, 1); + EXPECT_TRUE(pool.Initialize()); + + // Wait for the connection to be accepted by the server. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // When + // Close the connection on the server side to make it stale in the pool. + { + std::lock_guard lock(accepted_fds_mutex_); + ASSERT_EQ(accepted_fds_.size(), 1); + close(accepted_fds_[0]); + accepted_fds_.clear(); + } + + // Then + // GetConnection should detect it's dead, discard it, and create a new one. + auto conn = pool.GetConnection(); + EXPECT_TRUE(conn.has_value()); + EXPECT_TRUE(conn->IsValid()); + + // Verify that a new connection was indeed established. + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::lock_guard lock(accepted_fds_mutex_); + EXPECT_EQ(accepted_fds_.size(), 1); +} + +TEST_F(ConnectionPoolTest, SetUnusablePreventsReuse) { + // Given + ConnectionPool pool("127.0.0.1", port_, 1); + EXPECT_TRUE(pool.Initialize()); + + // When + { + auto conn = pool.GetConnection(); + EXPECT_TRUE(conn.has_value()); + conn->SetUnusable(); + // Destructor will close fd and not return it to pool. + } + + // Then + // Next GetConnection must result in a new connection being created. + auto conn2 = pool.GetConnection(); + EXPECT_TRUE(conn2.has_value()); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::lock_guard lock(accepted_fds_mutex_); + // We should have 2 fds in accepted_fds_: the first one (now closed) and the new one. + EXPECT_EQ(accepted_fds_.size(), 2); +} + +TEST_F(ConnectionPoolTest, VerifyLIFOBehavior) { + // Given + // Use a pool of size 2 to verify stack behavior. + ConnectionPool pool("127.0.0.1", port_, 2); + EXPECT_TRUE(pool.Initialize()); + + // Acquire both connections. + auto connA = pool.GetConnection(); + auto connB = pool.GetConnection(); + ASSERT_TRUE(connA.has_value()); + ASSERT_TRUE(connB.has_value()); + + int fdA = connA->fd(); + int fdB = connB->fd(); + + // When + // Return them in a specific order: first A, then B. + // In a Stack (LIFO), the last one returned (B) should be the first one + // retrieved next. + connA.reset(); // Release A + connB.reset(); // Release B + + // Then + auto connNext1 = pool.GetConnection(); + ASSERT_TRUE(connNext1.has_value()); + // Should be B because it was the most recently returned. + EXPECT_EQ(connNext1->fd(), fdB); + + auto connNext2 = pool.GetConnection(); + ASSERT_TRUE(connNext2.has_value()); + // Should be A. + EXPECT_EQ(connNext2->fd(), fdA); +} + TEST_F(ConnectionPoolTest, Initialize) { + // Given ConnectionPool pool("127.0.0.1", port_, 5); + + // When/Then EXPECT_TRUE(pool.Initialize()); } TEST_F(ConnectionPoolTest, InitializeFailure) { + // Given // Use a port that is very unlikely to have a listener. ConnectionPool pool("127.0.0.1", port_ + 1, 1); + + // When/Then EXPECT_FALSE(pool.Initialize()); } TEST_F(ConnectionPoolTest, GetConnection) { + // Given ConnectionPool pool("127.0.0.1", port_, 1); EXPECT_TRUE(pool.Initialize()); + + // When auto conn = pool.GetConnection(); + + // Then EXPECT_TRUE(conn.has_value()); EXPECT_TRUE(conn->IsValid()); } TEST_F(ConnectionPoolTest, MultipleThreads) { + // Given const int pool_size = 3; const int num_threads = 10; const int iterations_per_thread = 5; @@ -90,6 +197,7 @@ TEST_F(ConnectionPoolTest, MultipleThreads) { ConnectionPool pool("127.0.0.1", port_, pool_size); EXPECT_TRUE(pool.Initialize()); + // When std::vector threads; for (int i = 0; i < num_threads; ++i) { threads.emplace_back([&pool, iterations_per_thread]() { @@ -109,6 +217,7 @@ TEST_F(ConnectionPoolTest, MultipleThreads) { thread.join(); } + // Then // After all threads are done, the pool should have all connections back. // We can try to get pool_size connections to verify. std::vector conns; @@ -124,29 +233,41 @@ TEST_F(ConnectionPoolTest, MultipleThreads) { } TEST_F(ConnectionPoolTest, ReleaseConnection) { + // Given ConnectionPool pool("127.0.0.1", port_, 1); EXPECT_TRUE(pool.Initialize()); { auto conn = pool.GetConnection(); EXPECT_TRUE(conn.has_value()); } + + // When auto conn = pool.GetConnection(); + + // Then EXPECT_TRUE(conn.has_value()); } TEST_F(ConnectionPoolTest, PoolExhaustion) { + // Given ConnectionPool pool("127.0.0.1", port_, 1); EXPECT_TRUE(pool.Initialize()); auto conn1 = pool.GetConnection(); EXPECT_TRUE(conn1.has_value()); + + // When auto conn2 = pool.GetConnection(100); + + // Then EXPECT_FALSE(conn2.has_value()); } TEST_F(ConnectionPoolTest, GetConnectionInvalidTimeout) { + // Given ConnectionPool pool("127.0.0.1", port_, 1); EXPECT_TRUE(pool.Initialize()); + // When/Then // Try to get a connection with a negative timeout, expecting the program to // terminate. EXPECT_DEATH(pool.GetConnection(-100), "timeout_ms must be positive"); @@ -157,10 +278,11 @@ TEST_F(ConnectionPoolTest, GetConnectionInvalidTimeout) { } TEST_F(ConnectionPoolTest, ScopedConnectionMoveSemantics) { + // Given ConnectionPool pool("127.0.0.1", port_, 2); EXPECT_TRUE(pool.Initialize()); - // Test move constructor + // When/Then (Test move constructor) { auto conn1 = pool.GetConnection(); EXPECT_TRUE(conn1.has_value()); @@ -172,7 +294,7 @@ TEST_F(ConnectionPoolTest, ScopedConnectionMoveSemantics) { EXPECT_FALSE(conn1->IsValid()); } - // Test move assignment + // When/Then (Test move assignment) { auto conn3 = pool.GetConnection(); EXPECT_TRUE(conn3.has_value()); @@ -189,15 +311,19 @@ TEST_F(ConnectionPoolTest, ScopedConnectionMoveSemantics) { } TEST_F(ConnectionPoolTest, ScopedConnectionReleaseInvalidFdNoPool) { - // Given an invalid file descriptor and no associated pool + // Given + // An invalid file descriptor and no associated pool int invalid_fd = -1; - // When a ScopedConnection is created with an invalid fd and no pool, + // When + // A ScopedConnection is created with an invalid fd and no pool, // and Release() is called (implicitly by destructor or explicitly) { ScopedConnection conn(invalid_fd, nullptr); EXPECT_FALSE(conn.IsValid()); EXPECT_EQ(conn.fd(), invalid_fd); + + // Then EXPECT_NO_THROW( conn.Release()); // Should be safe to call Release on invalid state } // Destructor calls Release() again, should also be safe