diff --git a/xllm/core/framework/parallel_state/ilu_process_group.h b/xllm/core/framework/parallel_state/ilu_process_group.h index ed51e3db6..eed3a682e 100644 --- a/xllm/core/framework/parallel_state/ilu_process_group.h +++ b/xllm/core/framework/parallel_state/ilu_process_group.h @@ -50,6 +50,37 @@ class ProcessGroupImpl : public ProcessGroup { pg_ = std::make_unique( store, rank, rank_size, pg_options); } + + ProcessGroupImpl(int32_t global_rank, + int32_t local_rank, + const std::vector& group_ranks, + int32_t world_size, + int32_t rank_size, + int32_t port, + const std::string& host, + const std::string& group_name, + const torch::Device& device) + : ProcessGroup(global_rank, world_size, device) { + c10::intrusive_ptr pg_options = + c10d::ProcessGroupNCCL::Options::create(); +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) + pg_options->group_name = group_name; +#endif + if (world_size != rank_size) { + std::vector uint64_ranks; + uint64_ranks.reserve(group_ranks.size()); + for (const int32_t rank : group_ranks) { + uint64_ranks.emplace_back(static_cast(rank)); + } + } + pg_options->global_ranks_in_group = std::move(uint64_ranks); + } + + auto store = create_tcp_store(host, port, local_rank); + pg_ = std::make_unique( + store, local_rank, rank_size, pg_options); + } }; } // namespace xllm