Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions paddle/fluid/distributed/collective/ProcessGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ class ProcessGroup {
"ProcessGroup%s does not support broadcast", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const BroadcastOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument(
Expand Down Expand Up @@ -160,14 +170,14 @@ class ProcessGroup {
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor&, // NOLINT
int,
int,
int) {
int64_t,
int64_t) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
phi::DenseTensor&, int, int, int, bool) { // NOLINT
phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send_partial with sync_op flag",
GetBackendName()));
Expand All @@ -176,14 +186,14 @@ class ProcessGroup {
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor&, // NOLINT
int,
int,
int) {
int64_t,
int64_t) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor&, int, int, int, bool) { // NOLINT
phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support recv_partial with sync_op flag",
GetBackendName()));
Expand All @@ -208,18 +218,18 @@ class ProcessGroup {
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int offset,
int length) { // NOLINT
int64_t offset,
int64_t length) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
int offset,
int length,
bool) { // NOLINT
int64_t offset,
int64_t length,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
}
Expand All @@ -231,6 +241,14 @@ class ProcessGroup {
"ProcessGroup%s does not support AllToAll", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support alltoall", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
Expand All @@ -240,26 +258,66 @@ class ProcessGroup {
"ProcessGroup%s does not support AllToAll_Single", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<int64_t>&,
std::vector<int64_t>&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support alltoall_single", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ReduceOptions& opts) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Reduce", GetBackendName()));
"ProcessGroup%s does not support reduce", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const ReduceOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&) { // NOLINT
const ScatterOptions&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support Scatter", GetBackendName()));
"ProcessGroup%s does not support scatter", GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support scatter with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
const ReduceScatterOptions&,
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support reduce_scatter with sync_op flag",
GetBackendName()));
}

virtual std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) { // NOLINT
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support ReduceScatter", GetBackendName()));
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/collective/ProcessGroupCustom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ void* XcclGetPointerByOffset(void* raw_pointer,
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) {
int64_t offset,
int64_t length) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/distributed/collective/ProcessGroupCustom.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class ProcessGroupCustom : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
int offset,
int length) override;
int64_t offset,
int64_t length) override;

std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
Expand Down Expand Up @@ -117,8 +117,8 @@ class ProcessGroupCustom : public ProcessGroup {
std::set<int> used_place_ids_;

private:
void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids,
int root, // NOLINT
void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids, // NOLINT
int root,
int server_fd);

void BroadcastUniqueCustomID(
Expand Down
Loading