From b6d00f897cd8f52165596debbcf274c21f4aa3b4 Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Thu, 9 Apr 2026 15:49:29 -0700 Subject: [PATCH 01/14] Add support for replicate op in distributed training - Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op --- .../include/realm-execution/sum_reduction.h | 99 ++++ .../realm-execution/tasks/realm_reduction.h | 96 ++++ .../src/realm-execution/test_op_replicate.cc | 450 ++++++++++++++++++ 3 files changed, 645 insertions(+) create mode 100644 lib/realm-execution/include/realm-execution/sum_reduction.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_reduction.h create mode 100644 lib/realm-execution/test/src/realm-execution/test_op_replicate.cc diff --git a/lib/realm-execution/include/realm-execution/sum_reduction.h b/lib/realm-execution/include/realm-execution/sum_reduction.h new file mode 100644 index 0000000000..b845b5b7f2 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/sum_reduction.h @@ -0,0 +1,99 @@ +#pragma once +#include +#include "op-attrs/datatype.dtg.h" + +namespace FlexFlow { + +// Sum reduction for float +struct SumReductionFloat { + using LHS = float; + using RHS = float; + static const RHS identity; + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + // atomic add for non-exclusive + __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); + // proper float atomic add — use union trick + union { float f; int i; } old_val, new_val; + do { + old_val.f = lhs; + new_val.f = old_val.f + rhs; + } while (!__sync_bool_compare_and_swap( + (int*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { float f; int i; } old_val, new_val; + do { + old_val.f = rhs1; + new_val.f = old_val.f + rhs2; + } while (!__sync_bool_compare_and_swap( + (int*)&rhs1, old_val.i, new_val.i)); + } + } +}; + +const SumReductionFloat::RHS SumReductionFloat::identity = 0.0f; + +// Sum reduction for double +struct SumReductionDouble { + using LHS = double; + using RHS = double; + static const RHS identity; + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = lhs; + new_val.d = old_val.d + rhs; + } while (!__sync_bool_compare_and_swap( + (long long*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = rhs1; + new_val.d = old_val.d + rhs2; + } while (!__sync_bool_compare_and_swap( + (long long*)&rhs1, old_val.i, new_val.i)); + } + } +}; + +const SumReductionDouble::RHS SumReductionDouble::identity = 0.0; + +// Reduction op IDs — must not conflict with other registered redops +enum SumReductionOpIDs { + REDOP_SUM_FLOAT = 1, + REDOP_SUM_DOUBLE = 2, +}; + +inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { + switch (dtype) { + case DataType::FLOAT: return REDOP_SUM_FLOAT; + case DataType::DOUBLE: return REDOP_SUM_DOUBLE; + default: + PANIC("no sum reduction registered for datatype {}", dtype); + } +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h new file mode 100644 index 0000000000..d1d6e1d880 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h @@ -0,0 +1,96 @@ +#pragma once +#include +#include "op-attrs/datatype.dtg.h" + +namespace FlexFlow { + +// Sum reduction for float +struct SumReductionFloat { + using LHS = float; + using RHS = float; + static constexpr RHS identity = 0.0f; // ← inside struct, constexpr + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + // atomic add for non-exclusive + __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); + // proper float atomic add — use union trick + union { float f; int i; } old_val, new_val; + do { + old_val.f = lhs; + new_val.f = old_val.f + rhs; + } while (!__sync_bool_compare_and_swap( + (int*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { float f; int i; } old_val, new_val; + do { + old_val.f = rhs1; + new_val.f = old_val.f + rhs2; + } while (!__sync_bool_compare_and_swap( + (int*)&rhs1, old_val.i, new_val.i)); + } + } +}; + + +// Sum reduction for double +struct SumReductionDouble { + using LHS = double; + using RHS = double; + static constexpr RHS identity = 0.0; // ← inside struct, constexpr + + template + static void apply(LHS &lhs, RHS rhs) { + if (EXCLUSIVE) { + lhs += rhs; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = lhs; + new_val.d = old_val.d + rhs; + } while (!__sync_bool_compare_and_swap( + (long long*)&lhs, old_val.i, new_val.i)); + } + } + + template + static void fold(RHS &rhs1, RHS rhs2) { + if (EXCLUSIVE) { + rhs1 += rhs2; + } else { + union { double d; long long i; } old_val, new_val; + do { + old_val.d = rhs1; + new_val.d = old_val.d + rhs2; + } while (!__sync_bool_compare_and_swap( + (long long*)&rhs1, old_val.i, new_val.i)); + } + } +}; + +// Reduction op IDs — must not conflict with other registered redops +enum SumReductionOpIDs { + REDOP_SUM_FLOAT = 1, + REDOP_SUM_DOUBLE = 2, +}; + +inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { + switch (dtype) { + case DataType::FLOAT: return REDOP_SUM_FLOAT; + case DataType::DOUBLE: return REDOP_SUM_DOUBLE; + default: + PANIC("no sum reduction registered for datatype {}", dtype); + } +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc new file mode 100644 index 0000000000..d1fc941007 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -0,0 +1,450 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch, + Allocator &allocator) { + return tensor_accessor_all( + compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); +} + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = manager.start_controller([](RealmContext + &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /* sum_component */ 0_n, /* discard_copy_component */ 0_n, + /*shard_component*/ FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /* sum_component */ 0_n, /* discard_copy_component */ 1_n, + /*shard_component*/ FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + {{inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, + tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, + tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}}, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{cpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = + create_distributed_ff_handle(ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}}}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + }, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{gpu0, OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} +} // namespace test From 34056217cbb4a8067e582a792fa8af726c8d712e Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Thu, 9 Apr 2026 15:52:21 -0700 Subject: [PATCH 02/14] Add support for replicate op in distributed training - Add perform_pass_expansion_for_replicate for fwd/bwd pass expansion - Add perform_shard_expansion_for_replicate and _bwd for shard expansion - Add build_replicate_invocation in make_dynamic_open_dataflow_graph - Add is_replicate_attrs helper and guard replicate in copy_insertion - Add ReplicateAttrs to TrainingOperationAttrs - Add SumReductionFloat/Double for backward replicate reduce operation - Add issue_replicate_bwd in spawn_dynamic_node_invocation - Fix per_device_op_state init race condition with direct write - Fix .value() calls on optional per_device_op_state across op impls - Update issue_copy to support optional reduction op - Add testcase for replicate op --- .../src/op-attrs/ops/element_unary.cc | 1 - .../test/src/op-attrs/ops/element_unary.cc | 8 - .../include/realm-execution/realm_context.h | 19 +- .../include/realm-execution/sum_reduction.h | 99 ---- .../realm-execution/tasks/realm_reduction.h | 49 +- ...uted_per_device_op_state_initialization.cc | 6 +- .../src/realm-execution/pcg_instance.cc | 54 +++ .../src/realm-execution/realm_context.cc | 9 +- .../impl/per_device_op_state_init_task.cc | 16 +- .../tasks/realm_task_registry.cc | 10 + .../src/realm-execution/test_op_replicate.cc | 444 +++++++++--------- .../training_operation_attrs.dtg.toml | 4 + .../task-spec/dynamic_graph/copy_insertion.cc | 47 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 127 +++++ .../task-spec/dynamic_graph/pass_expansion.cc | 43 ++ .../dynamic_graph/shard_expansion.cc | 125 ++++- .../src/task-spec/ops/impl/element_binary.cc | 8 +- .../src/task-spec/ops/impl/element_unary.cc | 8 +- 18 files changed, 713 insertions(+), 364 deletions(-) delete mode 100644 lib/realm-execution/include/realm-execution/sum_reduction.h diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 9d02923689..ca7e417814 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees( ElementUnaryAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { ASSERT(input_degrees.sum_degree.value == 1); - ASSERT(input_degrees.discard_copy_degree.value == 1); return input_degrees; } diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 672b160cbd..43b4be06d8 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -62,13 +62,5 @@ TEST_SUITE(FF_TEST_SUITE) { SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } - SUBCASE("discard copy degree > 1") { - positive_int degree = 2_p; - - CHECK_THROWS(get_output_shape( - attrs, - make_input( - SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p))); - } } } diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index ab89e916c0..eab42d0d79 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -63,15 +63,18 @@ struct RealmContext { int priority = 0); ///\} - /** \name Data movement */ + /** \name Data movement and reduction */ ///\{ - Realm::Event issue_copy(ParallelTensorShape const &src_shape, - Realm::RegionInstance src_inst, - ParallelTensorShape const &dst_shape, - Realm::RegionInstance dst_inst, - Realm::ProfilingRequestSet const &requests, - Realm::Event wait_on = Realm::Event::NO_EVENT, - int priority = 0); + Realm::Event + issue_copy(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0, + std::optional redop_id = std::nullopt, + bool exclusive = false); ///\} /** \name Instance management */ diff --git a/lib/realm-execution/include/realm-execution/sum_reduction.h b/lib/realm-execution/include/realm-execution/sum_reduction.h deleted file mode 100644 index b845b5b7f2..0000000000 --- a/lib/realm-execution/include/realm-execution/sum_reduction.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once -#include -#include "op-attrs/datatype.dtg.h" - -namespace FlexFlow { - -// Sum reduction for float -struct SumReductionFloat { - using LHS = float; - using RHS = float; - static const RHS identity; - - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - // atomic add for non-exclusive - __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); - // proper float atomic add — use union trick - union { float f; int i; } old_val, new_val; - do { - old_val.f = lhs; - new_val.f = old_val.f + rhs; - } while (!__sync_bool_compare_and_swap( - (int*)&lhs, old_val.i, new_val.i)); - } - } - - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - union { float f; int i; } old_val, new_val; - do { - old_val.f = rhs1; - new_val.f = old_val.f + rhs2; - } while (!__sync_bool_compare_and_swap( - (int*)&rhs1, old_val.i, new_val.i)); - } - } -}; - -const SumReductionFloat::RHS SumReductionFloat::identity = 0.0f; - -// Sum reduction for double -struct SumReductionDouble { - using LHS = double; - using RHS = double; - static const RHS identity; - - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - union { double d; long long i; } old_val, new_val; - do { - old_val.d = lhs; - new_val.d = old_val.d + rhs; - } while (!__sync_bool_compare_and_swap( - (long long*)&lhs, old_val.i, new_val.i)); - } - } - - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - union { double d; long long i; } old_val, new_val; - do { - old_val.d = rhs1; - new_val.d = old_val.d + rhs2; - } while (!__sync_bool_compare_and_swap( - (long long*)&rhs1, old_val.i, new_val.i)); - } - } -}; - -const SumReductionDouble::RHS SumReductionDouble::identity = 0.0; - -// Reduction op IDs — must not conflict with other registered redops -enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, - REDOP_SUM_DOUBLE = 2, -}; - -inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { - switch (dtype) { - case DataType::FLOAT: return REDOP_SUM_FLOAT; - case DataType::DOUBLE: return REDOP_SUM_DOUBLE; - default: - PANIC("no sum reduction registered for datatype {}", dtype); - } -} - -} // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h index d1d6e1d880..d9cf00441b 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h @@ -1,6 +1,6 @@ #pragma once -#include #include "op-attrs/datatype.dtg.h" +#include namespace FlexFlow { @@ -8,7 +8,7 @@ namespace FlexFlow { struct SumReductionFloat { using LHS = float; using RHS = float; - static constexpr RHS identity = 0.0f; // ← inside struct, constexpr + static constexpr RHS identity = 0.0f; // ← inside struct, constexpr template static void apply(LHS &lhs, RHS rhs) { @@ -16,14 +16,17 @@ struct SumReductionFloat { lhs += rhs; } else { // atomic add for non-exclusive - __sync_fetch_and_add((int*)&lhs, *(int*)&rhs); + __sync_fetch_and_add((int *)&lhs, *(int *)&rhs); // proper float atomic add — use union trick - union { float f; int i; } old_val, new_val; + union { + float f; + int i; + } old_val, new_val; do { old_val.f = lhs; new_val.f = old_val.f + rhs; - } while (!__sync_bool_compare_and_swap( - (int*)&lhs, old_val.i, new_val.i)); + } while ( + !__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i)); } } @@ -32,34 +35,39 @@ struct SumReductionFloat { if (EXCLUSIVE) { rhs1 += rhs2; } else { - union { float f; int i; } old_val, new_val; + union { + float f; + int i; + } old_val, new_val; do { old_val.f = rhs1; new_val.f = old_val.f + rhs2; - } while (!__sync_bool_compare_and_swap( - (int*)&rhs1, old_val.i, new_val.i)); + } while ( + !__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i)); } } }; - // Sum reduction for double struct SumReductionDouble { using LHS = double; using RHS = double; - static constexpr RHS identity = 0.0; // ← inside struct, constexpr + static constexpr RHS identity = 0.0; // ← inside struct, constexpr template static void apply(LHS &lhs, RHS rhs) { if (EXCLUSIVE) { lhs += rhs; } else { - union { double d; long long i; } old_val, new_val; + union { + double d; + long long i; + } old_val, new_val; do { old_val.d = lhs; new_val.d = old_val.d + rhs; } while (!__sync_bool_compare_and_swap( - (long long*)&lhs, old_val.i, new_val.i)); + (long long *)&lhs, old_val.i, new_val.i)); } } @@ -68,26 +76,31 @@ struct SumReductionDouble { if (EXCLUSIVE) { rhs1 += rhs2; } else { - union { double d; long long i; } old_val, new_val; + union { + double d; + long long i; + } old_val, new_val; do { old_val.d = rhs1; new_val.d = old_val.d + rhs2; } while (!__sync_bool_compare_and_swap( - (long long*)&rhs1, old_val.i, new_val.i)); + (long long *)&rhs1, old_val.i, new_val.i)); } } }; // Reduction op IDs — must not conflict with other registered redops enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, + REDOP_SUM_FLOAT = 1, REDOP_SUM_DOUBLE = 2, }; inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { switch (dtype) { - case DataType::FLOAT: return REDOP_SUM_FLOAT; - case DataType::DOUBLE: return REDOP_SUM_DOUBLE; + case DataType::FLOAT: + return REDOP_SUM_FLOAT; + case DataType::DOUBLE: + return REDOP_SUM_DOUBLE; default: PANIC("no sum reduction registered for datatype {}", dtype); } diff --git a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc index 1d517a8fe4..e7d8647b12 100644 --- a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc @@ -31,6 +31,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( std::unordered_map *> device_state_map; + std::vector completion_events; for (DynamicNodeInvocation const &invocation : dg.invocations) { Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); @@ -56,6 +57,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( precondition); if (completion_event.has_value()) { + completion_events.push_back(completion_event.value()); device_state_map.insert(std::pair{invocation, device_state_ptr}); } else { // Task doesn't require initialization, clean up and don't store result @@ -63,7 +65,9 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( } } - ctx.get_outstanding_events().wait(); + // wait for all init tasks — direct write to *result_ptr happens + // before each init task event fires so result is ready after this + Realm::Event::merge_events(completion_events).wait(); auto deref = [](DeviceSpecificPtr *const &p) { return *p; }; std::unordered_map> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 0ecd02143e..a0653c3c37 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -6,6 +6,7 @@ #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" #include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/copy_insertion.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" @@ -215,6 +216,46 @@ static Realm::Event spawn_dynamic_node_invocation( precondition); }; + // issue_replicate_bwd lambda + auto issue_replicate_bwd = [&]() { + std::optional output_grad_opt; + for (auto const &[slot, value] : invocation.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_opt = value; + } + } + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs input_grad = get_only(invocation.outputs).second; + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(input_grad).first; + + Realm::ReductionOpID redop_id = get_sum_reduction_op_id( + assert_unwrap(output_grad.parallel_tensor_shape).data_type); + + // chain reductions sequentially to avoid write races on dst + Realm::Event e = precondition; + for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) { + DynamicValueAttrs replica_key = output_grad; + replica_key.mapping = + bidict{{p, m}}; + replica_key.shard_coord = p; + + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(replica_key).first; + + e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape), + src_inst, + assert_unwrap(input_grad.parallel_tensor_shape), + dst_inst, + Realm::ProfilingRequestSet{}, + e, + 0, + redop_id, + false); + } + return e; + }; + TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); return op_attrs.visit(overload{ @@ -222,11 +263,24 @@ static Realm::Event spawn_dynamic_node_invocation( return pcg_op_attrs.visit(overload{ [&](InputAttrs const &) { return Realm::Event::NO_EVENT; }, [&](WeightAttrs const &) { return Realm::Event::NO_EVENT; }, + [&](ReplicateAttrs const &) { + // this should never be reached since replicate + // goes through TrainingOperationAttrs::ReplicateAttrs + PANIC("unexpected replicate in PCGOperatorAttrs path"); + return Realm::Event::NO_EVENT; + }, [&](auto const &) { return spawn_task(); }, }); }, [&](LossAttrs const &) { return spawn_task(); }, [&](CopyAttrs const &) { return issue_copy(); }, + [&](ReplicateAttrs const &) { + if (invocation.node_attrs.task_type.has_value() && + invocation.node_attrs.task_type.value() == DynamicTaskType::BWD) { + return issue_replicate_bwd(); + } + return issue_copy(); + }, }); } diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 790c1bd613..a4669bf43e 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -161,7 +161,9 @@ Realm::Event Realm::RegionInstance dst_inst, Realm::ProfilingRequestSet const &requests, Realm::Event wait_on, - int priority) { + int priority, + std::optional redop_id, + bool exclusive) { TensorShape src_piece_shape = get_piece_shape(src_shape); TensorShape dst_piece_shape = get_piece_shape(dst_shape); ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match @@ -183,6 +185,11 @@ Realm::Event size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), /*subfield_offset=*/0); + // set reduction op on dst field if provided + if (redop_id.has_value()) { + dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive); + } + Realm::Event result; switch (src_piece_shape.dims.ff_ordered.num_dims()) { #if REALM_MAX_DIM >= 1 diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc index 753fccf74b..0ea51810e4 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc @@ -66,11 +66,17 @@ void per_device_op_state_init_task_body(void const *args, result_state, ctx.get_current_device_idx())}; DeviceSpecificPtr result_device_specific{ ctx.get_current_device_idx(), result_state_ptr}; - spawn_per_device_op_state_init_return_task(ctx, - task_args.origin_proc, - result_device_specific, - task_args.origin_result_ptr, - Realm::Event::NO_EVENT); + + // replace spawn_per_device_op_state_init_return_task with: + // NOTE: SM/TODO: direct write assumes single-node shared address space + // For multi-node, replace with UserEvent trigger pattern + *task_args.origin_result_ptr = result_device_specific; + + // spawn_per_device_op_state_init_return_task(ctx, + // task_args.origin_proc, + // result_device_specific, + // task_args.origin_result_ptr, + // Realm::Event::NO_EVENT); } std::optional spawn_per_device_op_state_init_task( diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..acafdf59fd 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -5,6 +5,7 @@ #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_task.h" +#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/exception.h" @@ -30,9 +31,18 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::ProfilingRequestSet()); } +static void register_reductions() { + // register sum reduction ops + Realm::Runtime rt = Realm::Runtime::get_runtime(); + rt.register_reduction(REDOP_SUM_FLOAT); + rt.register_reduction(REDOP_SUM_DOUBLE); + // register_reduction is synchronous — no event returned +} + Realm::Event register_all_tasks() { std::vector pending_registrations; + register_reductions(); std::vector init_task_ids = { // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index d1fc941007..632f08d239 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -56,194 +56,207 @@ TEST_SUITE(FF_TEST_SUITE) { char **fake_argv = fake_args.data(); RealmManager manager = RealmManager{&fake_argc, &fake_argv}; - ControllerTaskResult result = manager.start_controller([](RealmContext - &ctx) { - Allocator allocator = ctx.get_current_device_allocator(); - - positive_int batch_size = 10_p; - positive_int data_dim = 16_p; - positive_int hidden_dim = 32_p; - positive_int output_dim = 1_p; - - // 10,2 - TensorShape output_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - // 10,2 - TensorShape label_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - GenericTensorAccessorW label_tensor = - allocator.allocate_tensor(label_tensor_shape); - - // construct computation graph - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - // input tensor - // 10, 16 - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; - - // parallel layer -> input tensor - ParallelLayerAddedResult inputs_layer = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input = - require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> input tensor 2 - ParallelLayerAddedResult inputs_layer_2 = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input_2 = - require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); - - // binary ADD attribute - ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - false, - false, - }; - - // parallel layer -> perform add - ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(add_attrs), - { - { - TensorSlotName::LHS_INPUT, - t_input, - }, + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + // 10,2 + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // 10,2 + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // input tensor + // 10, 16 + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // parallel layer -> input tensor + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> input tensor 2 + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // binary ADD attribute + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + // parallel layer -> perform add + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), { - TensorSlotName::RHS_INPUT, - t_input_2, + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, }, - }, - {/* weight */}); - - parallel_tensor_guid_t t_add_1 = - require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform replicate - const positive_int replicate_degree = 2_p; - ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); - ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(repl_attrs), - { + {/* weight */}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform replicate + const positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), { - TensorSlotName::INPUT, - t_add_1, + { + TensorSlotName::INPUT, + t_add_1, + }, }, - }, - /*weight=*/{}); - // output of replicate layer - parallel_tensor_guid_t t_repl_1 = - require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform RelU - ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), - /*inputs=*/ - { + /*weight=*/{}); + // output of replicate layer + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // parallel layer -> perform RelU + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ { - TensorSlotName::INPUT, - t_repl_1, + { + TensorSlotName::INPUT, + t_repl_1, + }, }, - }, - /*weights=*/{}); - // output of relu layer - parallel_tensor_guid_t t_relu_1 = - require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); - - // machine - MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; - MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; - - ParallelTensorSpaceCoordinate tensor_coord0{ - /* sum_component */ 0_n, /* discard_copy_component */ 0_n, - /*shard_component*/ FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{ - /* sum_component */ 0_n, /* discard_copy_component */ 1_n, - /*shard_component*/ FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, - {{inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, - tensor_coord0}}}}}}}, - {inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, OperatorAtomicTaskShardBinding{{{TensorSlotName::OUTPUT, - tensor_coord0}}}}}}}, - {add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}, - {relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}}, - }; - - MappedOperatorTaskGroup loss_mapping{ - {{cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; - - // instantiate computation graph - LossAttrs loss_attrs = LossAttrs{ - NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; - OptimizerAttrs optimizer_attrs = - OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, - /*momentum=*/0.9, - /*nesterov=*/false, - /*weight_decay=*/0.001}}; - - std::unordered_map - input_tensors; - - DistributedFfHandle device_handle = - create_distributed_ff_handle(ctx, - /*workSpaceSize=*/1024 * 1024, - /*allowTensorOpMathConversion=*/true); - PCGInstance pcg_instance = create_pcg_instance( - /*ctx=*/ctx, - /*mpcg=*/mpcg, - /*optimizer=*/optimizer_attrs, - /*loss=*/std::nullopt, - /*input_tensors=*/input_tensors, - /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); - - // begin training loop - int num_epochs = 1; - for (int i = 0; i < num_epochs; i++) { - perform_all_passes_for_pcg_instance( - /*instance=*/pcg_instance, - /*profiling_settings=*/ProfilingSettings{0, 0}, - /*device_handle=*/device_handle, - /*iteration_config=*/FFIterationConfig{1_p}); - } - }); + /*weights=*/{}); + // output of relu layer + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // machine + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /* sum_component */ 0_n, + /* discard_copy_component */ 0_n, + /*shard_component*/ FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /* sum_component */ 0_n, + /* discard_copy_component */ 1_n, + /*shard_component*/ FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + {{inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}, + {relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + {cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, + }}}}, + }; + + MappedOperatorTaskGroup loss_mapping{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); result.wait(); } } @@ -307,7 +320,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // parallel layer -> perform add ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(add_attrs), + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), { { TensorSlotName::LHS_INPUT, @@ -327,7 +341,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { const positive_int replicate_degree = 2_p; ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(repl_attrs), + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), { { TensorSlotName::INPUT, @@ -341,7 +356,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // parallel layer -> perform RelU ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, make_layer_attrs(make_relu_attrs()), + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), /*inputs=*/ { { @@ -357,8 +373,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // machine MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; - ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; MappedParallelComputationGraph mpcg{ pcg, { @@ -374,38 +390,44 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, {add_operator_1.parallel_layer, MappedOperatorTaskGroup{ - {{gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}}}}, + MappedOperatorTaskGroup{ + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}}}}, {relu_operator_1.parallel_layer, MappedOperatorTaskGroup{{ - {gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, + {gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}, + {gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}}, }}}, }, }; MappedOperatorTaskGroup loss_mapping{ - {{gpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; + {{gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}}}}; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 8f8f6467c8..2bd0714512 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -25,3 +25,7 @@ key = "loss" [[values]] type = "::FlexFlow::CopyAttrs" key = "copy" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate" diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 4c1b9d4609..7a28e254aa 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -25,15 +25,43 @@ bool node_is_copy(DynamicNodeAttrs const &n) { return n.op_attrs.has_value() && n.op_attrs.value().is_copy(); } +static bool is_replicate_invocation(DynamicNodeInvocation const &i) { + if (!i.node_attrs.op_attrs.has_value()) { + return false; + } + TrainingOperationAttrs const &op_attrs = i.node_attrs.op_attrs.value(); + if (op_attrs.is_replicate()) { + return true; + } + return false; +} + bool value_is_mapped(DynamicValueAttrs const &n) { return n.mapping.has_value(); } bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &g) { auto slot_is_mapped = [](DynamicTensorSlot const &) -> bool { return false; }; - - return no_part_of_dynamic_graph_satisfies( - g, node_is_copy, value_is_mapped, slot_is_mapped); + // check all non-replicate invocations + for (DynamicNodeInvocation const &i : g.invocations) { + if (is_replicate_invocation(i)) { + continue; // replicate tensors have mapping set by design + } + if (node_is_copy(i.node_attrs)) { + return false; + } + for (auto const &[slot, value] : i.inputs) { + if (value_is_mapped(value)) { + return false; + } + } + for (auto const &[slot, value] : i.outputs) { + if (value_is_mapped(value)) { + return false; + } + } + } + return true; } bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &g) { @@ -85,6 +113,11 @@ std::unordered_set perform_copy_insertion_for_invocation( std::unordered_map const &unmapped_value_to_mapped_source_value) { + // replicate nodes have no MappedOperatorTaskGroup — + // pass through unchanged, no copies needed + if (is_replicate_invocation(i)) { + return {i}; + } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); auto map_tensor = [&](DynamicTensorSlot const &slot, @@ -157,6 +190,14 @@ DynamicOpenDataflowGraph std::unordered_map unmapped_value_to_mapped_source_value; for (DynamicNodeInvocation const &i : g.invocations) { + // replicate nodes have no MappedOperatorTaskGroup — + // output mapping already fully set, maps to itself + if (is_replicate_invocation(i)) { + for (auto const &[slot, value] : i.outputs) { + unmapped_value_to_mapped_source_value.insert(std::pair{value, value}); + } + continue; + } for (auto const &[slot, value] : i.outputs) { unmapped_value_to_mapped_source_value.insert( std::pair{value, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 246f9a3242..3d48a0dc2b 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -7,11 +7,129 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" #include #include #include namespace FlexFlow { +static bidict + get_input_mapping_for_replicate( + MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &replicate_layer) { + + auto [input_slot_name, input_tensor_guid] = + get_only(get_incoming_tensors(mpcg.pcg, replicate_layer)); + + // find the layer that produces this tensor + for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { + for (auto const &[slot_name, t] : get_outgoing_tensors(mpcg.pcg, layer)) { + if (t == input_tensor_guid) { + MappedOperatorTaskGroup producer_mapping = mpcg.mapped_tasks.at(layer); + return get_tensor_bindings_for_slot_name(producer_mapping, slot_name); + } + } + } + + PANIC("could not find producer of replicate layer input tensor"); +} + +static std::unordered_map + get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &tensor) { + std::unordered_map result; + for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { + for (auto const &[slot_name, t] : get_incoming_tensors(mpcg.pcg, layer)) { + if (t == tensor) { + result.insert({layer, slot_name}); + } + } + } + return result; +} + +static bidict + build_replicated_output_mapping( + MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &replicate_layer) { + + auto [output_slot_name, output_tensor_guid] = + get_only(get_outgoing_tensors(mpcg.pcg, replicate_layer)); + + auto consumers = get_consumers_of_tensor(mpcg, output_tensor_guid); + ASSERT(!consumers.empty()); + + // union all consumer bindings — each consumer shard maps to a distinct + // (discard_copy, machine) pair since replicas are always on different machines + bidict result; + for (auto const &[consumer_layer, slot_name] : consumers) { + MappedOperatorTaskGroup consumer_mapping = + mpcg.mapped_tasks.at(consumer_layer); + bidict binding = + get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); + for (auto const &[p, m] : binding) { + result.equate(p, m); + } + } + return result; +} + +static DynamicNodeInvocation + build_replicate_invocation(parallel_layer_guid_t const &layer, + ParallelLayerAttrs const &attrs, + MappedParallelComputationGraph const &mpcg) { + auto [input_slot_name, input_tensor_guid] = + get_only(get_incoming_tensors(mpcg.pcg, layer)); + auto incoming = get_incoming_tensors(mpcg.pcg, layer); + ASSERT(!incoming.empty(), + "replicate layer has no incoming tensors — " + "check PCG edge construction in test"); + + ParallelTensorAttrs input_attrs = + get_parallel_tensor_attrs(mpcg.pcg, input_tensor_guid); + bidict input_mapping = + get_input_mapping_for_replicate(mpcg, layer); + + DynamicValueAttrs input_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, + /*parallel_tensor_shape=*/input_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/get_input_mapping_for_replicate(mpcg, layer), + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + + auto [output_slot_name, output_tensor_guid] = + get_only(get_outgoing_tensors(mpcg.pcg, layer)); + ParallelTensorAttrs output_attrs = + get_parallel_tensor_attrs(mpcg.pcg, output_tensor_guid); + + DynamicValueAttrs output_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, + /*parallel_tensor_shape=*/output_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/build_replicated_output_mapping(mpcg, layer), + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + DynamicNodeAttrs node_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs.get()}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + DynamicNodeInvocation invocation_node{ + /*inputs=*/{ + {DynamicTensorSlot{input_slot_name, std::nullopt}, input_value}}, + /*node_attrs=*/node_attrs, + /*outputs=*/ + {{DynamicTensorSlot{output_slot_name, std::nullopt}, output_value}}, + }; + return invocation_node; +} DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &mpcg) { @@ -19,6 +137,15 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { + + if (attrs.op_attrs.has()) { + // build replicate invocation + DynamicNodeInvocation repl_inv = + build_replicate_invocation(layer, attrs, mpcg); + result.invocations.emplace(repl_inv); + continue; + } + DynamicNodeAttrs result_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 0cee06368f..aed5f2c4c3 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -4,6 +4,7 @@ #include "utils/containers/are_all_same.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" +#include "utils/containers/get_only.h" namespace FlexFlow { @@ -109,6 +110,44 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( transform(invocation.inputs, to_grad), }; } +static std::unordered_set + perform_pass_expansion_for_replicate( + DynamicNodeInvocation const &invocation) { + + auto const &[input_slot, input] = get_only(invocation.inputs); + auto const &[output_slot, output] = get_only(invocation.outputs); + + // forward: INPUT/FWD → OUTPUT/FWD (copy to replicas) + DynamicNodeInvocation fwd{ + /*inputs=*/{{pass_expand_slot(input_slot, FwbTensorType::FORWARD), + pass_expand_value(input, FwbTensorType::FORWARD)}}, + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::FWD), + /*outputs=*/ + {{pass_expand_slot(output_slot, FwbTensorType::FORWARD), + pass_expand_value(output, FwbTensorType::FORWARD)}}, + }; + + // backward: OUTPUT/FWD + OUTPUT/GRAD → INPUT/GRAD (reduce gradients) + // The backward node needs the mapping from the output (replicated) + // so it knows which replicas to reduce from + DynamicNodeAttrs bwd_node_attrs = invocation.node_attrs; + bwd_node_attrs.task_type = DynamicTaskType::BWD; + + DynamicNodeInvocation bwd{ + /*inputs=*/{ + {pass_expand_slot(output_slot, FwbTensorType::FORWARD), + pass_expand_value(output, FwbTensorType::FORWARD)}, + {pass_expand_slot(output_slot, FwbTensorType::GRADIENT), + pass_expand_value(output, FwbTensorType::GRADIENT)}, + }, + /*node_attrs=*/bwd_node_attrs, + /*outputs=*/ + {{pass_expand_slot(input_slot, FwbTensorType::GRADIENT), + pass_expand_value(input, FwbTensorType::GRADIENT)}}, + }; + return {fwd, bwd}; +} DynamicOpenDataflowGraph perform_pass_expansion(DynamicOpenDataflowGraph const &g) { @@ -117,6 +156,10 @@ DynamicOpenDataflowGraph DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { + if (invocation.node_attrs.op_attrs.has_value() && + invocation.node_attrs.op_attrs.value().is_replicate()) { + return perform_pass_expansion_for_replicate(invocation); + } if (invocation.inputs.empty()) { return std::unordered_set{ perform_fwd_pass_expansion_for_invocation(invocation), diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index fb6efb96d0..f30a4d8470 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -39,7 +39,6 @@ bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &g) { value_is_shard_expanded, slot_is_shard_expanded); } - static bidict restrict_tensor_mapping_keys_to_coord( bidict const @@ -85,6 +84,114 @@ static DynamicNodeInvocation shard_invocation_for_binding( }; } +static std::unordered_set + perform_shard_expansion_for_replicate(DynamicNodeInvocation const &i) { + auto const &[input_slot, input] = get_only(i.inputs); + auto const &[output_slot, output] = get_only(i.outputs); + + bidict input_mapping = + assert_unwrap(input.mapping); + bidict output_mapping = + assert_unwrap(output.mapping); + + return transform(output_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) { + ParallelTensorSpaceCoordinate input_p{ + /*sum_component=*/p.sum_component, + /*discard_copy_component=*/nonnegative_int{0}, + /*shard_components=*/p.shard_components, + }; + return shard_invocation_for_binding( + i, + output_mapping.at_l(p), + OperatorAtomicTaskShardBinding{{ + {input_slot.slot_name, input_p}, + {output_slot.slot_name, p}, + }}); + }); +} + +static std::unordered_set + perform_shard_expansion_for_replicate_bwd(DynamicNodeInvocation const &i) { + + std::optional output_grad_opt; + std::optional output_fwd_opt; + std::optional output_grad_slot_opt; + std::optional output_fwd_slot_opt; + + for (auto const &[slot, value] : i.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_slot_opt = slot; + output_grad_opt = value; + } else { + output_fwd_slot_opt = slot; + output_fwd_opt = value; + } + } + + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs output_fwd = assert_unwrap(output_fwd_opt); + DynamicTensorSlot output_grad_slot = assert_unwrap(output_grad_slot_opt); + DynamicTensorSlot output_fwd_slot = assert_unwrap(output_fwd_slot_opt); + auto const &[input_grad_slot, input_grad] = get_only(i.outputs); + + bidict + output_grad_mapping = assert_unwrap(output_grad.mapping); + bidict + input_grad_mapping = assert_unwrap(input_grad.mapping); + + std::unordered_map, + std::unordered_set> + by_shard; + for (auto const &p : output_grad_mapping.left_values()) { + by_shard[p.shard_components].insert(p); + } + + std::unordered_set result; + for (auto const &[shard_components, replica_coords] : by_shard) { + ParallelTensorSpaceCoordinate src_p{ + nonnegative_int{0}, nonnegative_int{0}, shard_components}; + MachineSpaceCoordinate src_machine = input_grad_mapping.at_l(src_p); + + bidict + replica_mapping; + for (auto const &p : replica_coords) { + replica_mapping.equate(p, output_grad_mapping.at_l(p)); + } + + DynamicValueAttrs sharded_output_grad = output_grad; + sharded_output_grad.mapping = replica_mapping; + sharded_output_grad.shard_coord = src_p; + + DynamicValueAttrs sharded_output_fwd = output_fwd; + sharded_output_fwd.mapping = replica_mapping; + sharded_output_fwd.shard_coord = src_p; + + DynamicValueAttrs sharded_input_grad = input_grad; + sharded_input_grad.mapping = + bidict{ + {src_p, src_machine}}; + sharded_input_grad.shard_coord = src_p; + + DynamicNodeAttrs sharded_node = i.node_attrs; + sharded_node.device_coord = src_machine; + + result.insert(DynamicNodeInvocation{ + /*inputs=*/{ + {output_fwd_slot, sharded_output_fwd}, + {output_grad_slot, sharded_output_grad}, + }, + /*node_attrs=*/sharded_node, + /*outputs=*/ + { + {input_grad_slot, sharded_input_grad}, + }, + }); + } + return result; +} + + static std::unordered_set perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); @@ -121,6 +228,22 @@ std::unordered_set return perform_shard_expansion_for_copy(i); } + // forward replicate + if (i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().is_replicate() && + i.node_attrs.task_type.has_value() && + i.node_attrs.task_type.value() == DynamicTaskType::FWD) { + return perform_shard_expansion_for_replicate(i); + } + + // backward replicate + if (i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().is_replicate() && + i.node_attrs.task_type.has_value() && + i.node_attrs.task_type.value() == DynamicTaskType::BWD) { + return perform_shard_expansion_for_replicate_bwd(i); + } + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); std::unordered_set shard_machine_coords = diff --git a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc index 13465d7a5f..c8460af538 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc @@ -36,8 +36,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); @@ -62,8 +62,8 @@ static std::optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); diff --git a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc index d66ff9ab8d..9a092b90b8 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc @@ -35,8 +35,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(forward_kernel, profiling, @@ -62,8 +62,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(backward_kernel, profiling, From d033e22f77d08fc6b4d1151ef7d6bf7cc23281cb Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Tue, 14 Apr 2026 17:10:12 -0700 Subject: [PATCH 03/14] remove ReplicateAttr --- .../src/realm-execution/pcg_instance.cc | 17 ++++++----------- .../src/realm-execution/tasks/task_id_t.cc | 12 +++--------- .../training_operation_attrs.dtg.toml | 4 ---- .../task-spec/dynamic_graph/copy_insertion.cc | 13 +++++-------- ...namic_open_dataflow_graph_from_mapped_pcg.cc | 2 +- .../task-spec/dynamic_graph/pass_expansion.cc | 10 +++++++--- .../task-spec/dynamic_graph/shard_expansion.cc | 16 +++++++++------- 7 files changed, 31 insertions(+), 43 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index a0653c3c37..17c62fe70c 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -264,23 +264,18 @@ static Realm::Event spawn_dynamic_node_invocation( [&](InputAttrs const &) { return Realm::Event::NO_EVENT; }, [&](WeightAttrs const &) { return Realm::Event::NO_EVENT; }, [&](ReplicateAttrs const &) { - // this should never be reached since replicate - // goes through TrainingOperationAttrs::ReplicateAttrs - PANIC("unexpected replicate in PCGOperatorAttrs path"); - return Realm::Event::NO_EVENT; + if (invocation.node_attrs.task_type.has_value() && + invocation.node_attrs.task_type.value() == + DynamicTaskType::BWD) { + return issue_replicate_bwd(); + } + return issue_copy(); // forward }, [&](auto const &) { return spawn_task(); }, }); }, [&](LossAttrs const &) { return spawn_task(); }, [&](CopyAttrs const &) { return issue_copy(); }, - [&](ReplicateAttrs const &) { - if (invocation.node_attrs.task_type.has_value() && - invocation.node_attrs.task_type.value() == DynamicTaskType::BWD) { - return issue_replicate_bwd(); - } - return issue_copy(); - }, }); } diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index dd4b0a66ca..0bdc2ca6b5 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -64,9 +64,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_INIT_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_INIT_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return std::nullopt; }, [](ReverseAttrs const &) { return std::nullopt; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_INIT_TASK_ID; }, @@ -115,9 +113,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_FWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_FWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_FWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_FWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_FWD_TASK_ID; }, @@ -166,9 +162,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_BWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_BWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_BWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_BWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_BWD_TASK_ID; }, diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 2bd0714512..8f8f6467c8 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -25,7 +25,3 @@ key = "loss" [[values]] type = "::FlexFlow::CopyAttrs" key = "copy" - -[[values]] -type = "::FlexFlow::ReplicateAttrs" -key = "replicate" diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 7a28e254aa..ef41042a51 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -26,14 +26,11 @@ bool node_is_copy(DynamicNodeAttrs const &n) { } static bool is_replicate_invocation(DynamicNodeInvocation const &i) { - if (!i.node_attrs.op_attrs.has_value()) { - return false; - } - TrainingOperationAttrs const &op_attrs = i.node_attrs.op_attrs.value(); - if (op_attrs.is_replicate()) { - return true; - } - return false; + return i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().has() && + i.node_attrs.op_attrs.value() + .get() + .has(); } bool value_is_mapped(DynamicValueAttrs const &n) { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 3d48a0dc2b..a4ef156db9 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -116,7 +116,7 @@ static DynamicNodeInvocation /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs.get()}, + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index aed5f2c4c3..faa1e186c3 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -2,9 +2,9 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include "utils/containers/are_all_same.h" +#include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" -#include "utils/containers/get_only.h" namespace FlexFlow { @@ -30,6 +30,11 @@ bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &g) { g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); } +static bool is_replicate_attrs(DynamicNodeAttrs const &n) { + return n.op_attrs.has_value() && n.op_attrs.value().has() && + n.op_attrs.value().get().has(); +} + DynamicTensorSlot pass_expand_slot(DynamicTensorSlot const &s, FwbTensorType tensor_type) { ASSERT(!slot_is_pass_expanded(s)); @@ -156,8 +161,7 @@ DynamicOpenDataflowGraph DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { - if (invocation.node_attrs.op_attrs.has_value() && - invocation.node_attrs.op_attrs.value().is_replicate()) { + if (is_replicate_attrs(invocation.node_attrs)) { return perform_pass_expansion_for_replicate(invocation); } if (invocation.inputs.empty()) { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index f30a4d8470..d3365ae44c 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -191,7 +191,6 @@ static std::unordered_set return result; } - static std::unordered_set perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); @@ -228,18 +227,21 @@ std::unordered_set return perform_shard_expansion_for_copy(i); } + bool const is_replicate = + i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().has() && + i.node_attrs.op_attrs.value() + .get() + .has(); + // forward replicate - if (i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().is_replicate() && - i.node_attrs.task_type.has_value() && + if (is_replicate && i.node_attrs.task_type.has_value() && i.node_attrs.task_type.value() == DynamicTaskType::FWD) { return perform_shard_expansion_for_replicate(i); } // backward replicate - if (i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().is_replicate() && - i.node_attrs.task_type.has_value() && + if (is_replicate && i.node_attrs.task_type.has_value() && i.node_attrs.task_type.value() == DynamicTaskType::BWD) { return perform_shard_expansion_for_replicate_bwd(i); } From 6cd706091420f4e9c776d75dc3464bbf040f5385 Mon Sep 17 00:00:00 2001 From: Seema Mirchandaney Date: Wed, 15 Apr 2026 16:15:21 -0700 Subject: [PATCH 04/14] Add comments to realm reductions, Use existing graph methods --- .../realm-execution/tasks/realm_reduction.h | 69 +++++++++++++++---- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 44 ++++++------ 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h index d9cf00441b..512e344824 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h @@ -1,23 +1,33 @@ -#pragma once +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H #include "op-attrs/datatype.dtg.h" #include namespace FlexFlow { -// Sum reduction for float +/** + * \brief Realm Sum Reduction for Float + * \see https://legion.stanford.edu/tutorial/realm/reductions.html + */ struct SumReductionFloat { using LHS = float; using RHS = float; - static constexpr RHS identity = 0.0f; // ← inside struct, constexpr + /** \brief Identity element for addition (0.0) */ + static constexpr RHS identity = 0.0f; + + /** + * \brief Apply reduction: lhs += rhs + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param lhs Left-hand side accumulator (modified in place) + * \param rhs Value to add + */ template static void apply(LHS &lhs, RHS rhs) { if (EXCLUSIVE) { lhs += rhs; } else { - // atomic add for non-exclusive - __sync_fetch_and_add((int *)&lhs, *(int *)&rhs); - // proper float atomic add — use union trick + // Atomic float add via CAS loop union { float f; int i; @@ -30,11 +40,18 @@ struct SumReductionFloat { } } + /** + * \brief Fold two RHS values: rhs1 += rhs2 + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param rhs1 Accumulator (modified in place) + * \param rhs2 Value to fold in + */ template static void fold(RHS &rhs1, RHS rhs2) { if (EXCLUSIVE) { rhs1 += rhs2; } else { + // Atomic float add via CAS loop union { float f; int i; @@ -48,17 +65,29 @@ struct SumReductionFloat { } }; -// Sum reduction for double +/** + * \brief Realm Sum Reduction for Double + * \see https://legion.stanford.edu/tutorial/realm/reductions.html + */ struct SumReductionDouble { using LHS = double; using RHS = double; - static constexpr RHS identity = 0.0; // ← inside struct, constexpr + /** \brief Identity element for addition (0.0) */ + static constexpr RHS identity = 0.0; + + /** + * \brief Apply reduction: lhs += rhs + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param lhs Left-hand side accumulator (modified in place) + * \param rhs Value to add + */ template static void apply(LHS &lhs, RHS rhs) { if (EXCLUSIVE) { lhs += rhs; } else { + // Atomic double add via CAS loop using long long reinterpretation union { double d; long long i; @@ -71,11 +100,18 @@ struct SumReductionDouble { } } + /** + * \brief Fold two RHS values: rhs1 += rhs2 + * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop + * \param rhs1 Accumulator (modified in place) + * \param rhs2 Value to fold in + */ template static void fold(RHS &rhs1, RHS rhs2) { if (EXCLUSIVE) { rhs1 += rhs2; } else { + // Atomic double add via CAS loop using long long reinterpretation union { double d; long long i; @@ -89,12 +125,21 @@ struct SumReductionDouble { } }; -// Reduction op IDs — must not conflict with other registered redops +/** + * \brief Reduction op IDs for sum reductions + * \warning These IDs must not conflict with other registered reduction ops + */ enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, - REDOP_SUM_DOUBLE = 2, + REDOP_SUM_FLOAT = 1, ///< Sum reduction op ID for float + REDOP_SUM_DOUBLE = 2, ///< Sum reduction op ID for double }; +/** + * \brief Returns the Realm reduction op ID for a sum reduction over the given datatype + * \param dtype The datatype to look up + * \return The corresponding Realm::ReductionOpID + * \throws PANIC if no sum reduction is registered for the given datatype + */ inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { switch (dtype) { case DataType::FLOAT: @@ -105,5 +150,5 @@ inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { PANIC("no sum reduction registered for datatype {}", dtype); } } - } // namespace FlexFlow +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index a4ef156db9..9349341d4b 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -2,6 +2,7 @@ #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/pcg_operator_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" @@ -18,31 +19,30 @@ static bidict MappedParallelComputationGraph const &mpcg, parallel_layer_guid_t const &replicate_layer) { - auto [input_slot_name, input_tensor_guid] = - get_only(get_incoming_tensors(mpcg.pcg, replicate_layer)); - - // find the layer that produces this tensor - for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { - for (auto const &[slot_name, t] : get_outgoing_tensors(mpcg.pcg, layer)) { - if (t == input_tensor_guid) { - MappedOperatorTaskGroup producer_mapping = mpcg.mapped_tasks.at(layer); - return get_tensor_bindings_for_slot_name(producer_mapping, slot_name); - } - } - } + // get_incoming_edges returns map + // replicate has exactly one input + auto [input_slot_name, input_edge] = + get_only(get_incoming_edges(mpcg.pcg, replicate_layer)); - PANIC("could not find producer of replicate layer input tensor"); + parallel_layer_guid_t producer_layer = get_src_layer(input_edge); + TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); + + return get_tensor_bindings_for_slot_name(mpcg.mapped_tasks.at(producer_layer), + producer_slot); } static std::unordered_map get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, parallel_tensor_guid_t const &tensor) { + parallel_layer_guid_t producer_layer = get_source_layer(mpcg.pcg, tensor); + std::unordered_map result; - for (auto const &[layer, _] : get_parallel_layer_attrs_mapping(mpcg.pcg)) { - for (auto const &[slot_name, t] : get_incoming_tensors(mpcg.pcg, layer)) { - if (t == tensor) { - result.insert({layer, slot_name}); - } + // get_outgoing_edges returns unordered_set + for (ParallelComputationGraphEdge const &edge : + get_outgoing_edges(mpcg.pcg, producer_layer)) { + if (get_parallel_tensor(edge) == tensor) { + result.insert( + std::pair{get_dst_layer(edge), get_dst_layer_input_slot_name(edge)}); } } return result; @@ -76,7 +76,7 @@ static bidict static DynamicNodeInvocation build_replicate_invocation(parallel_layer_guid_t const &layer, - ParallelLayerAttrs const &attrs, + ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { auto [input_slot_name, input_tensor_guid] = get_only(get_incoming_tensors(mpcg.pcg, layer)); @@ -116,7 +116,7 @@ static DynamicNodeInvocation /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, + /*op_attrs=*/TrainingOperationAttrs{PCGOperatorAttrs{attrs}}, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; @@ -140,8 +140,8 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( if (attrs.op_attrs.has()) { // build replicate invocation - DynamicNodeInvocation repl_inv = - build_replicate_invocation(layer, attrs, mpcg); + DynamicNodeInvocation repl_inv = build_replicate_invocation( + layer, attrs.op_attrs.get(), mpcg); result.invocations.emplace(repl_inv); continue; } From c50f3846e4f59920cce36792daeef22b2a70d9e0 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 15 May 2026 17:50:24 -0700 Subject: [PATCH 05/14] Minor PR fixes --- .../mapped_parallel_computation_graph.h | 23 ++ .../parallel_computation_graph.h | 5 + .../mapped_parallel_computation_graph.cc | 43 +++ .../parallel_computation_graph.cc | 15 + .../src/realm-execution/pcg_instance.cc | 38 ++- .../src/realm-execution/test_op_replicate.cc | 298 +++++++++++------- .../sub_parallel_computation_graph.h | 2 +- .../apply_substitution/apply_substitution.cc | 2 +- .../sub_parallel_computation_graph.cc | 2 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 32 +- .../get_kwarg_dataflow_value_uses.h | 33 ++ .../include/utils/many_to_one/many_to_one.h | 5 + .../include/utils/one_to_many/one_to_many.h | 5 + .../get_kwarg_dataflow_value_uses.cc | 14 + 14 files changed, 373 insertions(+), 144 deletions(-) create mode 100644 lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h create mode 100644 lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 12c7921282..984a524c21 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -8,12 +8,35 @@ namespace FlexFlow { std::unordered_set mpcg_get_parallel_layers(MappedParallelComputationGraph const &); + MappedOperatorTaskGroup mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &, parallel_layer_guid_t); ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); +parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 0368be62bc..1b2d5a0b67 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -11,6 +11,7 @@ #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" namespace FlexFlow { @@ -53,6 +54,10 @@ std::unordered_map get_incoming_edges(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); + std::unordered_set get_initial_layers(ParallelComputationGraph const &); diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index f4fa946a66..571b89b6dd 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -8,6 +8,8 @@ #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" +#include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/many_to_one/many_to_one_from_map.h" namespace FlexFlow { @@ -46,6 +48,47 @@ ParallelComputationGraph }; } +parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) +{ + return get_source_layer(pcg_from_mpcg(mpcg), t); +} + +ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) +{ + return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); +} + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return get_incoming_edges(pcg_from_mpcg(mpcg), l); +} + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return get_outgoing_edges(pcg_from_mpcg(mpcg), l); +} + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return many_to_one_from_map(get_incoming_tensors(pcg_from_mpcg(mpcg), l)); +} + + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); +} + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index a548ceb65a..2c5197242d 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -36,6 +36,7 @@ #include "utils/graph/node/node.dtg.h" #include "utils/record_formatter.h" #include +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" namespace FlexFlow { @@ -206,6 +207,20 @@ std::unordered_map }); } +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) +{ + std::unordered_set> raw_uses = + get_kwarg_dataflow_value_uses(pcg.raw_graph, + t.raw_graph_output); + + return transform(raw_uses, [](KwargDataflowInput const &i) { + return parallel_tensor_use_t{i}; + }); +} + + std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 17c62fe70c..17a6a383e6 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -218,14 +218,17 @@ static Realm::Event spawn_dynamic_node_invocation( // issue_replicate_bwd lambda auto issue_replicate_bwd = [&]() { - std::optional output_grad_opt; - for (auto const &[slot, value] : invocation.inputs) { - if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { - output_grad_opt = value; - } - } - DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); - DynamicValueAttrs input_grad = get_only(invocation.outputs).second; + + DynamicValueAttrs output_grad = get_only( + values( + filter_keys( + invocation.inputs, + [](DynamicTensorSlot const &s) -> bool { + return s.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}; + }))); + + DynamicValueAttrs input_grad = get_only(values(invocation.outputs)); + Realm::RegionInstance dst_inst = tensor_instance_backing.backing.at(input_grad).first; @@ -243,15 +246,16 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance src_inst = tensor_instance_backing.backing.at(replica_key).first; - e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape), - src_inst, - assert_unwrap(input_grad.parallel_tensor_shape), - dst_inst, - Realm::ProfilingRequestSet{}, - e, - 0, - redop_id, - false); + e = ctx.issue_copy( + /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), + /*src_inst=*/src_inst, + /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), + /*dst_inst=*/dst_inst, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/e, + /*priority=*/0, + /*redop_id=*/redop_id, + /*exlusive=*/false); } return e; }; diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index 632f08d239..cae5ca1756 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -27,6 +27,7 @@ #include "test/utils/doctest/check_kv.h" #include "utils/containers/require_only_key.h" #include +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" namespace test { @@ -168,67 +169,116 @@ TEST_SUITE(FF_TEST_SUITE) { /* sum_component */ 0_n, /* discard_copy_component */ 1_n, /*shard_component*/ FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, - {{inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, + MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + { + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}, - {relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}}, - }; + }}, + }, + }, + }, + }, + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + }); MappedOperatorTaskGroup loss_mapping{ - {{cpu0, + { + { + cpu0, OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}, + }, + }, + }; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; OptimizerAttrs optimizer_attrs = - OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, - /*momentum=*/0.9, - /*nesterov=*/false, - /*weight_decay=*/0.001}}; + OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; std::unordered_map input_tensors; @@ -375,68 +425,102 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg{ - pcg, + MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ { - {inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, - {add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}}}}, - {repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}}}}, - {relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - {gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}}, - {gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}}, - }}}, + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}, + }, + }}, }, - }; - - MappedOperatorTaskGroup loss_mapping{ - {{gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}}}}; + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}, + }}, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }}, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }}, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }}, + }, + }); + + MappedOperatorTaskGroup loss_mapping{{ + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::LOGIT, tensor_coord0}, + }}, + }, + }}; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; OptimizerAttrs optimizer_attrs = - OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, - /*momentum=*/0.9, - /*nesterov=*/false, - /*weight_decay=*/0.001}}; + OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; std::unordered_map input_tensors; diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index cbfe3ab264..26c98e915c 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -48,7 +48,7 @@ std::unordered_set get_subgraph_outgoing_edges( std::unordered_set const &); std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &, + get_open_parallel_tensor_uses(SubParallelComputationGraph const &, open_parallel_tensor_guid_t const &); SubParallelComputationGraphData diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 6ed2ef563e..a56555550f 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -109,7 +109,7 @@ SubParallelComputationGraph apply_substitution_from_output_result( input_parallel_tensor_guid_t output_graph_input = output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( output_expr_input); - std::unordered_set uses = get_parallel_tensor_uses( + std::unordered_set uses = get_open_parallel_tensor_uses( substitution_output_graph, open_parallel_tensor_guid_from_input(output_graph_input)); for (parallel_tensor_use_t const &use : uses) { diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 34b8ae1e96..990975bff9 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -131,7 +131,7 @@ std::unordered_set get_subgraph_incoming_edges( } std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &spcg, + get_open_parallel_tensor_uses(SubParallelComputationGraph const &spcg, open_parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = get_open_kwarg_dataflow_value_uses(spcg.raw_graph, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 0aea7d2324..b23edc0411 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -23,24 +23,24 @@ static bidict // get_incoming_edges returns map // replicate has exactly one input auto [input_slot_name, input_edge] = - get_only(get_incoming_edges(mpcg.pcg, replicate_layer)); + get_only(mpcg_get_incoming_edges(mpcg, replicate_layer)); parallel_layer_guid_t producer_layer = get_src_layer(input_edge); TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); - return get_tensor_bindings_for_slot_name(mpcg.mapped_tasks.at(producer_layer), - producer_slot); + return get_tensor_bindings_for_slot_name( + /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), + /*slot_name=*/producer_slot); } static std::unordered_map get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, parallel_tensor_guid_t const &tensor) { - parallel_layer_guid_t producer_layer = get_source_layer(mpcg.pcg, tensor); + parallel_layer_guid_t producer_layer = mpcg_get_source_layer(mpcg, tensor); std::unordered_map result; // get_outgoing_edges returns unordered_set - for (ParallelComputationGraphEdge const &edge : - get_outgoing_edges(mpcg.pcg, producer_layer)) { + for (ParallelComputationGraphEdge const &edge : mpcg_get_outgoing_edges(mpcg, producer_layer)) { if (get_parallel_tensor(edge) == tensor) { result.insert( std::pair{get_dst_layer(edge), get_dst_layer_input_slot_name(edge)}); @@ -55,7 +55,7 @@ static bidict parallel_layer_guid_t const &replicate_layer) { auto [output_slot_name, output_tensor_guid] = - get_only(get_outgoing_tensors(mpcg.pcg, replicate_layer)); + get_only(mpcg_get_outgoing_tensors(mpcg, replicate_layer)); auto consumers = get_consumers_of_tensor(mpcg, output_tensor_guid); ASSERT(!consumers.empty()); @@ -64,8 +64,7 @@ static bidict // (discard_copy, machine) pair since replicas are always on different machines bidict result; for (auto const &[consumer_layer, slot_name] : consumers) { - MappedOperatorTaskGroup consumer_mapping = - mpcg.mapped_tasks.at(consumer_layer); + MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); bidict binding = get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); for (auto const &[p, m] : binding) { @@ -80,14 +79,13 @@ static DynamicNodeInvocation ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { auto [input_slot_name, input_tensor_guid] = - get_only(get_incoming_tensors(mpcg.pcg, layer)); - auto incoming = get_incoming_tensors(mpcg.pcg, layer); - ASSERT(!incoming.empty(), - "replicate layer has no incoming tensors — " - "check PCG edge construction in test"); + get_only(mpcg_get_incoming_tensors(mpcg, layer).l_to_r()); + + auto incoming = mpcg_get_incoming_tensors(mpcg, layer); + ASSERT(!incoming.empty(), "Replicate layer has no incoming tensors."); ParallelTensorAttrs input_attrs = - get_parallel_tensor_attrs(mpcg.pcg, input_tensor_guid); + mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); bidict input_mapping = get_input_mapping_for_replicate(mpcg, layer); @@ -101,9 +99,9 @@ static DynamicNodeInvocation }; auto [output_slot_name, output_tensor_guid] = - get_only(get_outgoing_tensors(mpcg.pcg, layer)); + get_only(mpcg_get_outgoing_tensors(mpcg, layer)); ParallelTensorAttrs output_attrs = - get_parallel_tensor_attrs(mpcg.pcg, output_tensor_guid); + mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); DynamicValueAttrs output_value{ /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h new file mode 100644 index 0000000000..b5557e9e49 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_kwarg_dataflow_value_uses( + KwargDataflowGraphView const &g, + KwargDataflowOutput const &v) { + + KwargDataflowEdgeQuery query = + KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::match_single_value(v.node), + /*src_slots=*/query_set::match_single_value(v.slot_name), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; + + std::unordered_set> edges = + g.query_edges(query); + + return transform( + edges, [&](KwargDataflowEdge const &e) { + return e.dst; + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index d2f727661c..c73f696172 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -19,6 +19,7 @@ #include #include #include +#include "utils/containers/require_same.h" namespace FlexFlow { @@ -106,6 +107,10 @@ struct ManyToOne { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: std::unordered_map m_l_to_r; std::unordered_map> m_r_to_l; diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 30d84d34c3..7b725fdec1 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -23,6 +23,7 @@ #include #include #include +#include "utils/containers/require_same.h" namespace FlexFlow { @@ -114,6 +115,10 @@ struct OneToMany { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: std::unordered_map> m_l_to_r; std::unordered_map m_r_to_l; diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc new file mode 100644 index 0000000000..2e42863e53 --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc @@ -0,0 +1,14 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template + std::unordered_set> + get_kwarg_dataflow_value_uses( + KwargDataflowGraphView const &, + KwargDataflowOutput const &); + +} // namespace FlexFlow From ac4fffcb307fe1116fde88b3c7aa85599c224a4e Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 15 May 2026 21:39:42 -0700 Subject: [PATCH 06/14] Clean up pass expansion code --- .../op-attrs/pcg_operator_attrs.dtg.toml | 22 +- .../test/src/op-attrs/ops/element_unary.cc | 13 + .../mapped_parallel_computation_graph.h | 7 + .../parallel_tensor_use_t.h | 14 + .../mapped_parallel_computation_graph.cc | 27 +- .../parallel_tensor_use_t.cc | 13 + .../src/realm-execution/pcg_instance.cc | 1 - .../src/realm-execution/test_op_replicate.cc | 587 ++++++------------ .../output_expr_to_result_sub_pcg_mapping.cc | 4 +- .../src/substitutions/pcg_pattern_match.cc | 4 +- .../dynamic_graph/training_operation_attrs.h | 13 + ...mic_open_dataflow_graph_from_mapped_pcg.cc | 220 +++---- .../task-spec/dynamic_graph/pass_expansion.cc | 85 +-- .../dynamic_graph/training_operation_attrs.cc | 21 + .../task-spec/dynamic_graph/pass_expansion.cc | 270 +++++--- .../binary_merge_disjoint_bidicts.h | 37 ++ .../algorithms/merge_disjoint_bidicts.h | 39 +- lib/utils/include/utils/bidict/bidict.h | 8 + .../utils/containers/transform_pairs.h | 46 ++ .../binary_merge_disjoint_bidicts.cc | 12 + .../algorithms/merge_disjoint_bidicts.cc | 10 + .../src/utils/containers/transform_pairs.cc | 17 + ...ts.cc => binary_merge_disjoint_bidicts.cc} | 12 +- 23 files changed, 792 insertions(+), 690 deletions(-) create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc create mode 100644 lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h create mode 100644 lib/utils/include/utils/containers/transform_pairs.h create mode 100644 lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc create mode 100644 lib/utils/src/utils/containers/transform_pairs.cc rename lib/utils/test/src/utils/bidict/algorithms/{merge_disjoint_bidicts.cc => binary_merge_disjoint_bidicts.cc} (72%) diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml index 88a65f75c5..f2dd7c9350 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml @@ -11,13 +11,13 @@ features = [ ] includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", "op-attrs/ops/dropout_attrs.dtg.h", "op-attrs/ops/element_binary_attrs.dtg.h", @@ -61,7 +61,7 @@ key = "cast" [[values]] type = "::FlexFlow::CombineAttrs" -key = "combine_distributed" +key = "parallel_combine" [[values]] type = "::FlexFlow::ConcatAttrs" @@ -125,15 +125,15 @@ key = "reduce" [[values]] type = "::FlexFlow::ReductionAttrs" -key = "reduce_distributed" +key = "parallel_reduce" [[values]] type = "::FlexFlow::RepartitionAttrs" -key = "partition_distributed" +key = "parallel_partition" [[values]] type = "::FlexFlow::ReplicateAttrs" -key = "replicate_distributed" +key = "parallel_replicate" [[values]] type = "::FlexFlow::ReverseAttrs" diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 43b4be06d8..8b2555610e 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -53,6 +53,19 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } + SUBCASE("discard copy degree > 1") { + positive_int degree = 2_p; + + ParallelTensorShape par_input = make_input( + SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = par_input; + + CHECK(result == correct); + } + SUBCASE("sum degree > 1") { positive_int degree = 2_p; diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 984a524c21..6c24d4c1e1 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -18,6 +18,9 @@ ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &, parallel_tensor_guid_t const &); +PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, parallel_tensor_guid_t const &); @@ -40,6 +43,10 @@ bidict std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h new file mode 100644 index 0000000000..88f1512149 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H + +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" + +namespace FlexFlow { + +parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &); +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index 571b89b6dd..3b996ccdab 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -54,22 +54,28 @@ parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const return get_source_layer(pcg_from_mpcg(mpcg), t); } +PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) +{ + return pcg_get_op_attrs(pcg_from_mpcg(mpcg), l); +} + ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) + parallel_tensor_guid_t const &t) { return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); } std::unordered_map mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) + parallel_layer_guid_t const &l) { return get_incoming_edges(pcg_from_mpcg(mpcg), l); } std::unordered_set mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) + parallel_layer_guid_t const &l) { return get_outgoing_edges(pcg_from_mpcg(mpcg), l); } @@ -84,11 +90,24 @@ ManyToOne bidict mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) + parallel_layer_guid_t const &l) { return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); } +std::unordered_set + mpcg_get_edges(MappedParallelComputationGraph const &mpcg) +{ + return get_edges(pcg_from_mpcg(mpcg)); +} + +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) +{ + return pcg_get_parallel_tensor_uses(pcg_from_mpcg(mpcg), t); +} + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc new file mode 100644 index 0000000000..e93341d312 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc @@ -0,0 +1,13 @@ +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" + +namespace FlexFlow { + +parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { + return parallel_layer_guid_t{u.raw_dataflow_input.node}; +} + +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &u) { + return u.raw_dataflow_input.slot_name; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 17a6a383e6..f2edac7f88 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -216,7 +216,6 @@ static Realm::Event spawn_dynamic_node_invocation( precondition); }; - // issue_replicate_bwd lambda auto issue_replicate_bwd = [&]() { DynamicValueAttrs output_grad = get_only( diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index cae5ca1756..2523cae798 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -49,6 +49,190 @@ static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); } +MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_type) { + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs{replicate_degree}; + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + MachineSpaceCoordinate cpu0{0_n, 0_n, device_type}; + MachineSpaceCoordinate cpu1{0_n, 1_n, device_type}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_component=*/FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /*sum_component=*/0_n, + /*discard_copy_component=*/1_n, + /*shard_component=*/FFOrdered{0_n}}; + + MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + { + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + }); + + return mpcg; +} + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmBackend e2e Training Replicate Op (CPU Model Parallelism)") { std::vector fake_args = @@ -61,215 +245,12 @@ TEST_SUITE(FF_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - positive_int batch_size = 10_p; - positive_int data_dim = 16_p; - positive_int hidden_dim = 32_p; - positive_int output_dim = 1_p; - - // 10,2 - TensorShape output_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - // 10,2 - TensorShape label_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - GenericTensorAccessorW label_tensor = - allocator.allocate_tensor(label_tensor_shape); - - // construct computation graph - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - // input tensor - // 10, 16 - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; - - // parallel layer -> input tensor - ParallelLayerAddedResult inputs_layer = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input = - require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> input tensor 2 - ParallelLayerAddedResult inputs_layer_2 = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input_2 = - require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); - - // binary ADD attribute - ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - false, - false, - }; - - // parallel layer -> perform add - ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(add_attrs), - { - { - TensorSlotName::LHS_INPUT, - t_input, - }, - { - TensorSlotName::RHS_INPUT, - t_input_2, - }, - }, - {/* weight */}); - - parallel_tensor_guid_t t_add_1 = - require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform replicate - const positive_int replicate_degree = 2_p; - ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); - ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(repl_attrs), - { - { - TensorSlotName::INPUT, - t_add_1, - }, - }, - /*weight=*/{}); - // output of replicate layer - parallel_tensor_guid_t t_repl_1 = - require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform RelU - ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(make_relu_attrs()), - /*inputs=*/ - { - { - TensorSlotName::INPUT, - t_repl_1, - }, - }, - /*weights=*/{}); - // output of relu layer - parallel_tensor_guid_t t_relu_1 = - require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); - - // machine - MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; - MachineSpaceCoordinate cpu1{0_n, 1_n, DeviceType::CPU}; - - ParallelTensorSpaceCoordinate tensor_coord0{ - /* sum_component */ 0_n, - /* discard_copy_component */ 0_n, - /*shard_component*/ FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{ - /* sum_component */ 0_n, - /* discard_copy_component */ 1_n, - /*shard_component*/ FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( - /*pcg=*/pcg, - /*mapped_op_task_groups=*/{ - { - inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }, - }, - }, - { - inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }, - }, - }, - { - add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }, - }, - }, - { - repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }, - }, - }, - { - relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }, - }, - }, - }); + MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::CPU); - MappedOperatorTaskGroup loss_mapping{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}, - }, - }, - }; - // instantiate computation graph - LossAttrs loss_attrs = LossAttrs{ - NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + std::unordered_map + input_tensors; + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ SGDOptimizerAttrs{ @@ -280,13 +261,11 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; - std::unordered_map - input_tensors; - DistributedFfHandle device_handle = create_distributed_ff_handle( ctx, /*workSpaceSize=*/1024 * 1024, /*allowTensorOpMathConversion=*/true); + PCGInstance pcg_instance = create_pcg_instance( /*ctx=*/ctx, /*mpcg=*/mpcg, @@ -324,194 +303,8 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - positive_int batch_size = 10_p; - positive_int data_dim = 16_p; - positive_int hidden_dim = 32_p; - positive_int output_dim = 1_p; - - // 10,2 - TensorShape output_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - // 10,2 - TensorShape label_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - - GenericTensorAccessorW label_tensor = - allocator.allocate_tensor(label_tensor_shape); - - // construct computation graph - ParallelComputationGraph pcg = empty_parallel_computation_graph(); - - // input tensor - // 10, 16 - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; - - // parallel layer -> input tensor - ParallelLayerAddedResult inputs_layer = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input = - require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> input tensor 2 - ParallelLayerAddedResult inputs_layer_2 = - pcg_add_input_layer(pcg, input_tensor_shape); - parallel_tensor_guid_t t_input_2 = - require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); - - // binary ADD attribute - ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ - OperatorType::EW_ADD, - DataType::FLOAT, - false, - false, - }; - - // parallel layer -> perform add - ParallelLayerAddedResult add_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(add_attrs), - { - { - TensorSlotName::LHS_INPUT, - t_input, - }, - { - TensorSlotName::RHS_INPUT, - t_input_2, - }, - }, - {/* weight */}); - - parallel_tensor_guid_t t_add_1 = - require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform replicate - const positive_int replicate_degree = 2_p; - ReplicateAttrs repl_attrs = ReplicateAttrs(replicate_degree); - ParallelLayerAddedResult repl_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(repl_attrs), - { - { - TensorSlotName::INPUT, - t_add_1, - }, - }, - /*weight=*/{}); - // output of replicate layer - parallel_tensor_guid_t t_repl_1 = - require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); - - // parallel layer -> perform RelU - ParallelLayerAddedResult relu_operator_1 = - add_parallel_layer(pcg, - make_layer_attrs(make_relu_attrs()), - /*inputs=*/ - { - { - TensorSlotName::INPUT, - t_repl_1, - }, - }, - /*weights=*/{}); - // output of relu layer - parallel_tensor_guid_t t_relu_1 = - require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); - - // machine - MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; - MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; - ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; - ParallelTensorSpaceCoordinate tensor_coord1{0_n, 1_n, FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( - /*pcg=*/pcg, - /*mapped_op_task_groups=*/{ - { - inputs_layer.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}, - }, - }}, - }, - { - inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{ - {{TensorSlotName::OUTPUT, tensor_coord0}}}, - }}, - }, - }, - { - add_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - }}, - }, - { - repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }}, - }, - { - relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - gpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }}, - }, - }); - - MappedOperatorTaskGroup loss_mapping{{ - { - gpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::LOGIT, tensor_coord0}, - }}, - }, - }}; + MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::GPU); - // instantiate computation graph - LossAttrs loss_attrs = LossAttrs{ - NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; OptimizerAttrs optimizer_attrs = OptimizerAttrs{ SGDOptimizerAttrs{ diff --git a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc index 2ad5b54a17..4374a951f8 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc @@ -2,7 +2,7 @@ #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_pairs.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/containers/values.h" #include "utils/containers/zip_values_strict.h" @@ -26,7 +26,7 @@ bidict mapping_for_layer = bidict_from_pairs(values( zip_values_strict(layer_outputs, output_graph_expr_outputs))); - result = merge_disjoint_bidicts(result, mapping_for_layer); + result = binary_merge_disjoint_bidicts(result, mapping_for_layer); } return result; diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index 498fd6c1bf..dbd968d476 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -5,7 +5,7 @@ #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/bidict/algorithms/exhaustive_relational_join.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_values.h" @@ -34,7 +34,7 @@ bidict exhaustive_relational_join(pattern_node_outputs.reversed(), matched_layer_output_tensors); - result = merge_disjoint_bidicts(result, mapping); + result = binary_merge_disjoint_bidicts(result, mapping); } return result; diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h new file mode 100644 index 0000000000..bb8ca4f840 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H + +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" +#include "op-attrs/operator_type.dtg.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, OperatorType); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index b23edc0411..664c615a90 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -13,15 +13,21 @@ #include #include #include +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" +#include "utils/containers/require_only_key.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/containers/map_keys_and_values.h" +#include "utils/containers/transform_pairs.h" namespace FlexFlow { + static bidict get_input_mapping_for_replicate( MappedParallelComputationGraph const &mpcg, parallel_layer_guid_t const &replicate_layer) { - // get_incoming_edges returns map - // replicate has exactly one input + ASSERT(mpcg_get_pcg_op_attrs(mpcg, replicate_layer).is_parallel_replicate()); + auto [input_slot_name, input_edge] = get_only(mpcg_get_incoming_edges(mpcg, replicate_layer)); @@ -33,44 +39,32 @@ static bidict /*slot_name=*/producer_slot); } -static std::unordered_map - get_consumers_of_tensor(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &tensor) { - parallel_layer_guid_t producer_layer = mpcg_get_source_layer(mpcg, tensor); - - std::unordered_map result; - // get_outgoing_edges returns unordered_set - for (ParallelComputationGraphEdge const &edge : mpcg_get_outgoing_edges(mpcg, producer_layer)) { - if (get_parallel_tensor(edge) == tensor) { - result.insert( - std::pair{get_dst_layer(edge), get_dst_layer_input_slot_name(edge)}); - } - } - return result; -} - static bidict build_replicated_output_mapping( MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &replicate_layer) { + parallel_tensor_guid_t const &output_tensor_guid) { - auto [output_slot_name, output_tensor_guid] = - get_only(mpcg_get_outgoing_tensors(mpcg, replicate_layer)); - - auto consumers = get_consumers_of_tensor(mpcg, output_tensor_guid); + std::unordered_set consumers = mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); ASSERT(!consumers.empty()); // union all consumer bindings — each consumer shard maps to a distinct // (discard_copy, machine) pair since replicas are always on different machines - bidict result; - for (auto const &[consumer_layer, slot_name] : consumers) { - MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); - bidict binding = - get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); - for (auto const &[p, m] : binding) { - result.equate(p, m); - } - } + bidict result = + merge_disjoint_bidicts( + transform(consumers, + [&](parallel_tensor_use_t const &use) + -> bidict + { + parallel_layer_guid_t consumer_layer = parallel_tensor_use_get_layer(use); + TensorSlotName slot_name = parallel_tensor_use_get_slot(use); + + MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); + bidict binding = + get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); + + return binding; + })); + return result; } @@ -78,14 +72,19 @@ static DynamicNodeInvocation build_replicate_invocation(parallel_layer_guid_t const &layer, ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { - auto [input_slot_name, input_tensor_guid] = - get_only(mpcg_get_incoming_tensors(mpcg, layer).l_to_r()); - - auto incoming = mpcg_get_incoming_tensors(mpcg, layer); - ASSERT(!incoming.empty(), "Replicate layer has no incoming tensors."); + ManyToOne incoming = mpcg_get_incoming_tensors(mpcg, layer); + TensorSlotName input_slot_name = TensorSlotName::INPUT; + parallel_tensor_guid_t input_tensor_guid = require_only_key(incoming.l_to_r(), input_slot_name); ParallelTensorAttrs input_attrs = mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); + + bidict outgoing = mpcg_get_outgoing_tensors(mpcg, layer); + TensorSlotName output_slot_name = TensorSlotName::OUTPUT; + parallel_tensor_guid_t output_tensor_guid = require_only_key(outgoing.l_to_r(), output_slot_name); + ParallelTensorAttrs output_attrs = + mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); + bidict input_mapping = get_input_mapping_for_replicate(mpcg, layer); @@ -93,24 +92,20 @@ static DynamicNodeInvocation /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, /*parallel_tensor_shape=*/input_attrs.shape, /*shard_coord=*/std::nullopt, - /*mapping=*/get_input_mapping_for_replicate(mpcg, layer), + /*mapping=*/input_mapping, /*accessor=*/std::nullopt, /*role=*/std::nullopt, }; - auto [output_slot_name, output_tensor_guid] = - get_only(mpcg_get_outgoing_tensors(mpcg, layer)); - ParallelTensorAttrs output_attrs = - mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); - DynamicValueAttrs output_value{ /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, /*parallel_tensor_shape=*/output_attrs.shape, /*shard_coord=*/std::nullopt, - /*mapping=*/build_replicated_output_mapping(mpcg, layer), + /*mapping=*/build_replicated_output_mapping(mpcg, output_tensor_guid), /*accessor=*/std::nullopt, /*role=*/std::nullopt, }; + DynamicNodeAttrs node_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, @@ -122,85 +117,92 @@ static DynamicNodeInvocation DynamicNodeInvocation invocation_node{ /*inputs=*/{ - {DynamicTensorSlot{input_slot_name, std::nullopt}, input_value}}, + { + DynamicTensorSlot{input_slot_name, std::nullopt}, + input_value, + }, + }, /*node_attrs=*/node_attrs, - /*outputs=*/ - {{DynamicTensorSlot{output_slot_name, std::nullopt}, output_value}}, + /*outputs=*/{ + { + DynamicTensorSlot{output_slot_name, std::nullopt}, + output_value, + }, + }, }; + return invocation_node; } DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &mpcg) { - DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); - for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { - if (attrs.op_attrs.has()) { + auto mk_invocation = [&](parallel_layer_guid_t layer, ParallelLayerAttrs const &attrs) + -> DynamicNodeInvocation + { + if (attrs.op_attrs.is_parallel_replicate()) { // build replicate invocation DynamicNodeInvocation repl_inv = build_replicate_invocation( - layer, attrs.op_attrs.get(), mpcg); - result.invocations.emplace(repl_inv); - continue; - } - - DynamicNodeAttrs result_attrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, - /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, - /*per_device_op_state=*/std::nullopt, + layer, attrs.op_attrs.require_parallel_replicate(), mpcg); + return repl_inv; + } else { + DynamicNodeAttrs result_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }; + }; + + auto mk_value_attrs = [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs + { + ParallelTensorAttrs attrs = + get_parallel_tensor_attrs(pcg, tensor); + + return DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + }; + + std::unordered_map result_inputs = + map_keys_and_values(get_incoming_tensors(pcg, layer), + mk_slot, + mk_value_attrs); + + std::unordered_map result_outputs = + map_keys_and_values(get_outgoing_tensors(pcg, layer), + mk_slot, + mk_value_attrs); + + DynamicNodeInvocation invocation = DynamicNodeInvocation{ + /*inputs=*/result_inputs, + /*node_attrs=*/result_attrs, + /*outputs=*/result_outputs, + }; + + return invocation; }; + }; - std::unordered_map result_inputs = - transform(get_incoming_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); - std::unordered_map result_outputs = - transform(get_outgoing_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); - - result.invocations.emplace(result_inputs, result_attrs, result_outputs); - } - - return result; + return dynamic_open_dataflow_graph_from_invocation_set( + transform_pairs( + unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), + mk_invocation)); } } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index faa1e186c3..25958b5cb7 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -5,6 +5,7 @@ #include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" namespace FlexFlow { @@ -88,6 +89,8 @@ DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation( DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation const &invocation) { + TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); + auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { return std::pair{ pass_expand_slot(k, FwbTensorType::FORWARD), @@ -102,56 +105,37 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( }; }; - return DynamicNodeInvocation{ - /*inputs=*/ - merge_disjoint_maps(std::vector{ - transform(invocation.inputs, to_fwd), - transform(invocation.outputs, to_fwd), - transform(invocation.outputs, to_grad), - }), - /*node_attrs=*/ - pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), - /*outputs=*/ - transform(invocation.inputs, to_grad), - }; -} -static std::unordered_set - perform_pass_expansion_for_replicate( - DynamicNodeInvocation const &invocation) { - - auto const &[input_slot, input] = get_only(invocation.inputs); - auto const &[output_slot, output] = get_only(invocation.outputs); - - // forward: INPUT/FWD → OUTPUT/FWD (copy to replicas) - DynamicNodeInvocation fwd{ - /*inputs=*/{{pass_expand_slot(input_slot, FwbTensorType::FORWARD), - pass_expand_value(input, FwbTensorType::FORWARD)}}, - /*node_attrs=*/ - pass_expand_node(invocation.node_attrs, DynamicTaskType::FWD), - /*outputs=*/ - {{pass_expand_slot(output_slot, FwbTensorType::FORWARD), - pass_expand_value(output, FwbTensorType::FORWARD)}}, - }; - - // backward: OUTPUT/FWD + OUTPUT/GRAD → INPUT/GRAD (reduce gradients) - // The backward node needs the mapping from the output (replicated) - // so it knows which replicas to reduce from - DynamicNodeAttrs bwd_node_attrs = invocation.node_attrs; - bwd_node_attrs.task_type = DynamicTaskType::BWD; - - DynamicNodeInvocation bwd{ - /*inputs=*/{ - {pass_expand_slot(output_slot, FwbTensorType::FORWARD), - pass_expand_value(output, FwbTensorType::FORWARD)}, - {pass_expand_slot(output_slot, FwbTensorType::GRADIENT), - pass_expand_value(output, FwbTensorType::GRADIENT)}, - }, - /*node_attrs=*/bwd_node_attrs, - /*outputs=*/ - {{pass_expand_slot(input_slot, FwbTensorType::GRADIENT), - pass_expand_value(input, FwbTensorType::GRADIENT)}}, + if (training_op_attrs_has_op_type(op_attrs, OperatorType::REPLICATE)) { + auto [input_slot, input] = get_only(invocation.inputs); + auto [output_slot, output] = get_only(invocation.outputs); + + DynamicNodeInvocation bwd{ + /*inputs=*/{ + to_fwd(output_slot, output), + to_grad(output_slot, output), + }, + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/{ + to_grad(input_slot, input), + }, + }; + + return bwd; + } else { + return DynamicNodeInvocation{ + /*inputs=*/ + merge_disjoint_maps(std::vector{ + transform(invocation.inputs, to_fwd), + transform(invocation.outputs, to_fwd), + transform(invocation.outputs, to_grad), + }), + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/ + transform(invocation.inputs, to_grad), + }; }; - return {fwd, bwd}; } DynamicOpenDataflowGraph @@ -161,9 +145,6 @@ DynamicOpenDataflowGraph DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { - if (is_replicate_attrs(invocation.node_attrs)) { - return perform_pass_expansion_for_replicate(invocation); - } if (invocation.inputs.empty()) { return std::unordered_set{ perform_fwd_pass_expansion_for_invocation(invocation), diff --git a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc new file mode 100644 index 0000000000..d1452242ca --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc @@ -0,0 +1,21 @@ +#include "task-spec/dynamic_graph/training_operation_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "utils/overload.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, OperatorType op_type) { + return op_attrs.visit(overload { + [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { + return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; + }, + [](LossAttrs const &) -> bool { + return false; + }, + [](CopyAttrs const &) -> bool { + return false; + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index fb087f5295..ed22a8cbde 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -2,6 +2,7 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include +#include "op-attrs/ops/element_unary.h" using namespace ::FlexFlow; @@ -36,6 +37,19 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; + TrainingOperationAttrs op_attrs = + TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, + }, + }; + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); @@ -46,14 +60,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -79,14 +92,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/DynamicTaskType::FWD, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -130,88 +142,163 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; - DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { - DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); - DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); - DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); - - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, - {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, - {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, - }, - }; - }(); - - DynamicNodeInvocation result = - perform_bwd_pass_expansion_for_invocation(invocation); - - DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { - DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; - DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; - - DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); - DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); - DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); - DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); - DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); - DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); - - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, - {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, - {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*pass_type=*/DynamicTaskType::BWD, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, + DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); + DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); + DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); + + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + + DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); + DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); + DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); + DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); + DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); + DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); + + SUBCASE("normal operator") { + TrainingOperationAttrs op_attrs = + TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, - {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, - {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, - {mk_slot(TensorSlotName::SCALE, grad_role), v1_grad}, + }; + + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, + {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, + {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } + + SUBCASE("replicate operator optimization") { + TrainingOperationAttrs op_attrs = + TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, - }; - }(); - - ASSERT(result == correct); + }; + + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v2}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v2_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } } TEST_CASE("perform_pass_expansion(DynamicOpenDataflowGraph)") { auto mk_node_attrs = [](size_t layer_id, + TrainingOperationAttrs const &op_attrs, std::optional const &pass_type) -> DynamicNodeAttrs { return DynamicNodeAttrs{ /*pass_type=*/pass_type, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/ dynamic_layer_guid_t{parallel_layer_guid_t{Node{layer_id}}}, /*per_device_op_state=*/std::nullopt, @@ -236,9 +323,32 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; + TrainingOperationAttrs input_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + InputAttrs{ + TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 8_p, + }, + }, + DataType::FLOAT, + }, + }, + }, + }; + + TrainingOperationAttrs relu_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), + }, + }; + + DynamicOpenDataflowGraph input = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1 = mk_node_attrs(10, std::nullopt); - DynamicNodeAttrs n2 = mk_node_attrs(11, std::nullopt); + DynamicNodeAttrs n1 = mk_node_attrs(10, input_op_attrs, std::nullopt); + DynamicNodeAttrs n2 = mk_node_attrs(11, relu_op_attrs, std::nullopt); DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); @@ -286,10 +396,10 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicOpenDataflowGraph result = perform_pass_expansion(input); DynamicOpenDataflowGraph correct = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1_fwd = mk_node_attrs(10, DynamicTaskType::FWD); - DynamicNodeAttrs n2_fwd = mk_node_attrs(11, DynamicTaskType::FWD); - DynamicNodeAttrs n1_bwd = mk_node_attrs(10, DynamicTaskType::BWD); - DynamicNodeAttrs n2_bwd = mk_node_attrs(11, DynamicTaskType::BWD); + DynamicNodeAttrs n1_fwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n2_fwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n1_bwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n2_bwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); DynamicValueAttrs v1_activation = mk_value_attrs(0, mk_dynamic_tensor_role_fwd()); diff --git a/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h new file mode 100644 index 0000000000..5b0bb45910 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H + +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict binary_merge_disjoint_bidicts(bidict const &lhs, + bidict const &rhs) { + if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); + } + if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); + } + + bidict result; + for (auto const &kv : lhs) { + result.equate_strict(kv.first, kv.second); + } + for (auto const &kv : rhs) { + result.equate_strict(kv.first, kv.second); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index 97e7334c26..f2104fd113 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,35 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H -#include "utils/bidict/algorithms/left_entries.h" -#include "utils/bidict/algorithms/right_entries.h" -#include "utils/bidict/bidict.h" -#include "utils/containers/are_disjoint.h" -#include "utils/exception.h" +#include "utils/containers/foldl.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" namespace FlexFlow { -template -bidict merge_disjoint_bidicts(bidict const &lhs, - bidict const &rhs) { - if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); - } - if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); - } - - bidict result; - for (auto const &kv : lhs) { - result.equate(kv.first, kv.second); - } - for (auto const &kv : rhs) { - result.equate(kv.first, kv.second); - } - - return result; +template +bidict merge_disjoint_bidicts(C const &c) { + bidict empty = {}; + return foldl(c, + /*init=*/empty, + [](bidict const &lhs, + bidict const &rhs) { + return binary_merge_disjoint_bidicts(lhs, rhs); + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 5dbd1c603d..2d8c5d23a8 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -213,6 +213,14 @@ struct bidict { return this->fwd_map; } + std::unordered_map const &l_to_r() const { + return this->fwd_map; + } + + std::unordered_map const &r_to_l() const { + return this->bwd_map; + } + bidict(std::unordered_map const &fwd_map, std::unordered_map const &bwd_map) : fwd_map(fwd_map), bwd_map(bwd_map) {} diff --git a/lib/utils/include/utils/containers/transform_pairs.h b/lib/utils/include/utils/containers/transform_pairs.h new file mode 100644 index 0000000000..c01b50554f --- /dev/null +++ b/lib/utils/include/utils/containers/transform_pairs.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H + +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template > +std::vector transform_pairs(std::vector> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::unordered_set transform_pairs(std::unordered_set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::set transform_pairs(std::set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..8650de44f6 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -0,0 +1,12 @@ +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template + bidict binary_merge_disjoint_bidicts(bidict const &, bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc index 754b8d2e90..2c27821d3b 100644 --- a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -1 +1,11 @@ #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template bidict merge_disjoint_bidicts(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform_pairs.cc b/lib/utils/src/utils/containers/transform_pairs.cc new file mode 100644 index 0000000000..241f1ad425 --- /dev/null +++ b/lib/utils/src/utils/containers/transform_pairs.cc @@ -0,0 +1,17 @@ +#include "utils/containers/transform_pairs.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; +using Out = value_type<2>; +using F = std::function; + +template + std::vector transform_pairs(std::vector> const &, F &&); + +template + std::unordered_set transform_pairs(std::unordered_set> const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc similarity index 72% rename from lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc rename to lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc index 0a1babd9f9..8a3371b8d8 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("merge_disjoint_bidicts") { + TEST_CASE("binary_merge_disjoint_bidicts") { SUBCASE("disjoint keys and values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "three"}, {4, "four"}}; - bidict result = merge_disjoint_bidicts(bd1, bd2); + bidict result = binary_merge_disjoint_bidicts(bd1, bd2); bidict correct = { {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; @@ -22,21 +22,21 @@ TEST_SUITE(FF_TEST_SUITE) { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "three"}, {3, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping key, same associated value") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "two"}, {3, "three"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "two"}, {4, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } } } From fdf4fe5e74d4ede4cc21da923ee0aaedf5771351 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 15 May 2026 21:42:09 -0700 Subject: [PATCH 07/14] Remove unnecessary is_replicate_attrs function --- lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 25958b5cb7..f4960fe67a 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -31,11 +31,6 @@ bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &g) { g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); } -static bool is_replicate_attrs(DynamicNodeAttrs const &n) { - return n.op_attrs.has_value() && n.op_attrs.value().has() && - n.op_attrs.value().get().has(); -} - DynamicTensorSlot pass_expand_slot(DynamicTensorSlot const &s, FwbTensorType tensor_type) { ASSERT(!slot_is_pass_expanded(s)); From dce15e23118a128c65d7c35f7e57da91011329f4 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:14:33 -0700 Subject: [PATCH 08/14] Format. --- .../test/src/op-attrs/ops/element_unary.cc | 5 +- .../mapped_parallel_computation_graph.h | 20 +-- .../parallel_computation_graph.h | 2 +- .../parallel_tensor_use_t.h | 5 +- .../mapped_parallel_computation_graph.cc | 43 +++--- .../parallel_computation_graph.cc | 9 +- .../parallel_tensor_use_t.cc | 3 +- .../sub_parallel_computation_graph.h | 2 +- .../apply_substitution/apply_substitution.cc | 7 +- .../src/substitutions/pcg_pattern_match.cc | 2 +- .../sub_parallel_computation_graph.cc | 2 +- .../dynamic_graph/training_operation_attrs.h | 5 +- ...mic_open_dataflow_graph_from_mapped_pcg.cc | 129 +++++++++--------- .../task-spec/dynamic_graph/pass_expansion.cc | 16 ++- .../dynamic_graph/training_operation_attrs.cc | 19 ++- .../task-spec/dynamic_graph/pass_expansion.cc | 95 ++++++------- .../algorithms/merge_disjoint_bidicts.h | 5 +- .../utils/containers/transform_pairs.h | 3 +- .../get_kwarg_dataflow_value_uses.h | 33 ++--- .../include/utils/many_to_one/many_to_one.h | 2 +- .../include/utils/one_to_many/one_to_many.h | 2 +- .../binary_merge_disjoint_bidicts.cc | 4 +- .../src/utils/containers/transform_pairs.cc | 8 +- .../get_kwarg_dataflow_value_uses.cc | 8 +- 24 files changed, 210 insertions(+), 219 deletions(-) diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 8b2555610e..09e49a123c 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -56,8 +56,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("discard copy degree > 1") { positive_int degree = 2_p; - ParallelTensorShape par_input = make_input( - SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); + ParallelTensorShape par_input = + make_input(SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); tl::expected result = get_output_shape(attrs, par_input); @@ -74,6 +74,5 @@ TEST_SUITE(FF_TEST_SUITE) { make_input( SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } - } } diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 6c24d4c1e1..a2afdb7914 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -15,22 +15,24 @@ MappedOperatorTaskGroup ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); -parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &, - parallel_tensor_guid_t const &); +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &, parallel_layer_guid_t const &); -ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, - parallel_tensor_guid_t const &); +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); std::unordered_map - mpcg_get_incoming_edges(MappedParallelComputationGraph const &, - parallel_layer_guid_t const &); + mpcg_get_incoming_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); std::unordered_set - mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, - parallel_layer_guid_t const &); + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); ManyToOne mpcg_get_incoming_tensors(MappedParallelComputationGraph const &, @@ -38,7 +40,7 @@ ManyToOne bidict mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &, - parallel_layer_guid_t const &); + parallel_layer_guid_t const &); std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 1b2d5a0b67..9764e40627 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -10,8 +10,8 @@ #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" -#include #include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h index 88f1512149..f5e5575632 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h @@ -1,12 +1,13 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H -#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" namespace FlexFlow { -parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &); +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &); TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &); } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index 3b996ccdab..fc1dff504b 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -2,13 +2,13 @@ #include "op-attrs/pcg_operator_attrs.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/bidict/algorithms/transform_keys.h" #include "utils/containers/transform.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" -#include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/many_to_one/many_to_one_from_map.h" namespace FlexFlow { @@ -48,63 +48,56 @@ ParallelComputationGraph }; } -parallel_layer_guid_t mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) -{ +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { return get_source_layer(pcg_from_mpcg(mpcg), t); } -PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ +PCGOperatorAttrs + mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { return pcg_get_op_attrs(pcg_from_mpcg(mpcg), l); } -ParallelTensorAttrs mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) -{ +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); } std::unordered_map - mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { return get_incoming_edges(pcg_from_mpcg(mpcg), l); } std::unordered_set - mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { return get_outgoing_edges(pcg_from_mpcg(mpcg), l); } ManyToOne mpcg_get_incoming_tensors(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + parallel_layer_guid_t const &l) { return many_to_one_from_map(get_incoming_tensors(pcg_from_mpcg(mpcg), l)); } - bidict mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, - parallel_layer_guid_t const &l) -{ + parallel_layer_guid_t const &l) { return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); } std::unordered_set - mpcg_get_edges(MappedParallelComputationGraph const &mpcg) -{ + mpcg_get_edges(MappedParallelComputationGraph const &mpcg) { return get_edges(pcg_from_mpcg(mpcg)); } std::unordered_set mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &mpcg, - parallel_tensor_guid_t const &t) -{ + parallel_tensor_guid_t const &t) { return pcg_get_parallel_tensor_uses(pcg_from_mpcg(mpcg), t); } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 2c5197242d..5098cadafe 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -28,6 +28,7 @@ #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" @@ -36,7 +37,6 @@ #include "utils/graph/node/node.dtg.h" #include "utils/record_formatter.h" #include -#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" namespace FlexFlow { @@ -209,18 +209,15 @@ std::unordered_map std::unordered_set pcg_get_parallel_tensor_uses(ParallelComputationGraph const &pcg, - parallel_tensor_guid_t const &t) -{ + parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = - get_kwarg_dataflow_value_uses(pcg.raw_graph, - t.raw_graph_output); + get_kwarg_dataflow_value_uses(pcg.raw_graph, t.raw_graph_output); return transform(raw_uses, [](KwargDataflowInput const &i) { return parallel_tensor_use_t{i}; }); } - std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc index e93341d312..71a9cadf1c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc @@ -2,7 +2,8 @@ namespace FlexFlow { -parallel_layer_guid_t parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { return parallel_layer_guid_t{u.raw_dataflow_input.node}; } diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 26c98e915c..2a3dc8bbb8 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -49,7 +49,7 @@ std::unordered_set get_subgraph_outgoing_edges( std::unordered_set get_open_parallel_tensor_uses(SubParallelComputationGraph const &, - open_parallel_tensor_guid_t const &); + open_parallel_tensor_guid_t const &); SubParallelComputationGraphData get_sub_pcg_data(SubParallelComputationGraph const &); diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index a56555550f..f2686f7cf7 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -109,9 +109,10 @@ SubParallelComputationGraph apply_substitution_from_output_result( input_parallel_tensor_guid_t output_graph_input = output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( output_expr_input); - std::unordered_set uses = get_open_parallel_tensor_uses( - substitution_output_graph, - open_parallel_tensor_guid_from_input(output_graph_input)); + std::unordered_set uses = + get_open_parallel_tensor_uses( + substitution_output_graph, + open_parallel_tensor_guid_from_input(output_graph_input)); for (parallel_tensor_use_t const &use : uses) { SubParallelComputationGraphEdge new_edge = subpcg_edge_from_tensor_and_use(base_graph_tensor, use); diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index dbd968d476..85a0493e33 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -4,8 +4,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/bidict_from_map.h" -#include "utils/bidict/algorithms/exhaustive_relational_join.h" #include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/exhaustive_relational_join.h" #include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_values.h" diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 990975bff9..c0c05ad5b1 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -132,7 +132,7 @@ std::unordered_set get_subgraph_incoming_edges( std::unordered_set get_open_parallel_tensor_uses(SubParallelComputationGraph const &spcg, - open_parallel_tensor_guid_t const &t) { + open_parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = get_open_kwarg_dataflow_value_uses(spcg.raw_graph, t.raw_open_dataflow_value); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h index bb8ca4f840..9caea8c341 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h @@ -1,12 +1,13 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H -#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "op-attrs/operator_type.dtg.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" namespace FlexFlow { -bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, OperatorType); +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, + OperatorType); } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 664c615a90..7a149787b9 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -5,19 +5,19 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" #include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" +#include "utils/containers/map_keys_and_values.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/transform_pairs.h" #include #include #include -#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" -#include "utils/containers/require_only_key.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" -#include "utils/containers/map_keys_and_values.h" -#include "utils/containers/transform_pairs.h" namespace FlexFlow { @@ -35,8 +35,8 @@ static bidict TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); return get_tensor_bindings_for_slot_name( - /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), - /*slot_name=*/producer_slot); + /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), + /*slot_name=*/producer_slot); } static bidict @@ -44,26 +44,29 @@ static bidict MappedParallelComputationGraph const &mpcg, parallel_tensor_guid_t const &output_tensor_guid) { - std::unordered_set consumers = mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); + std::unordered_set consumers = + mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); ASSERT(!consumers.empty()); // union all consumer bindings — each consumer shard maps to a distinct // (discard_copy, machine) pair since replicas are always on different machines bidict result = - merge_disjoint_bidicts( - transform(consumers, - [&](parallel_tensor_use_t const &use) - -> bidict - { - parallel_layer_guid_t consumer_layer = parallel_tensor_use_get_layer(use); - TensorSlotName slot_name = parallel_tensor_use_get_slot(use); - - MappedOperatorTaskGroup consumer_mapping = mpcg_get_mapping_for_layer(mpcg, consumer_layer); - bidict binding = - get_tensor_bindings_for_slot_name(consumer_mapping, slot_name); - - return binding; - })); + merge_disjoint_bidicts(transform( + consumers, + [&](parallel_tensor_use_t const &use) + -> bidict { + parallel_layer_guid_t consumer_layer = + parallel_tensor_use_get_layer(use); + TensorSlotName slot_name = parallel_tensor_use_get_slot(use); + + MappedOperatorTaskGroup consumer_mapping = + mpcg_get_mapping_for_layer(mpcg, consumer_layer); + bidict + binding = get_tensor_bindings_for_slot_name(consumer_mapping, + slot_name); + + return binding; + })); return result; } @@ -73,15 +76,19 @@ static DynamicNodeInvocation ReplicateAttrs const &attrs, MappedParallelComputationGraph const &mpcg) { - ManyToOne incoming = mpcg_get_incoming_tensors(mpcg, layer); + ManyToOne incoming = + mpcg_get_incoming_tensors(mpcg, layer); TensorSlotName input_slot_name = TensorSlotName::INPUT; - parallel_tensor_guid_t input_tensor_guid = require_only_key(incoming.l_to_r(), input_slot_name); + parallel_tensor_guid_t input_tensor_guid = + require_only_key(incoming.l_to_r(), input_slot_name); ParallelTensorAttrs input_attrs = mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); - bidict outgoing = mpcg_get_outgoing_tensors(mpcg, layer); + bidict outgoing = + mpcg_get_outgoing_tensors(mpcg, layer); TensorSlotName output_slot_name = TensorSlotName::OUTPUT; - parallel_tensor_guid_t output_tensor_guid = require_only_key(outgoing.l_to_r(), output_slot_name); + parallel_tensor_guid_t output_tensor_guid = + require_only_key(outgoing.l_to_r(), output_slot_name); ParallelTensorAttrs output_attrs = mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); @@ -117,17 +124,18 @@ static DynamicNodeInvocation DynamicNodeInvocation invocation_node{ /*inputs=*/{ - { - DynamicTensorSlot{input_slot_name, std::nullopt}, - input_value, - }, + { + DynamicTensorSlot{input_slot_name, std::nullopt}, + input_value, + }, }, /*node_attrs=*/node_attrs, - /*outputs=*/{ - { - DynamicTensorSlot{output_slot_name, std::nullopt}, - output_value, - }, + /*outputs=*/ + { + { + DynamicTensorSlot{output_slot_name, std::nullopt}, + output_value, + }, }, }; @@ -139,9 +147,9 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); - auto mk_invocation = [&](parallel_layer_guid_t layer, ParallelLayerAttrs const &attrs) - -> DynamicNodeInvocation - { + auto mk_invocation = + [&](parallel_layer_guid_t layer, + ParallelLayerAttrs const &attrs) -> DynamicNodeInvocation { if (attrs.op_attrs.is_parallel_replicate()) { // build replicate invocation DynamicNodeInvocation repl_inv = build_replicate_invocation( @@ -159,50 +167,45 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { return DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, }; }; - auto mk_value_attrs = [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs - { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); + auto mk_value_attrs = + [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs { + ParallelTensorAttrs attrs = get_parallel_tensor_attrs(pcg, tensor); return DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, }; }; std::unordered_map result_inputs = - map_keys_and_values(get_incoming_tensors(pcg, layer), - mk_slot, - mk_value_attrs); + map_keys_and_values( + get_incoming_tensors(pcg, layer), mk_slot, mk_value_attrs); std::unordered_map result_outputs = - map_keys_and_values(get_outgoing_tensors(pcg, layer), - mk_slot, - mk_value_attrs); + map_keys_and_values( + get_outgoing_tensors(pcg, layer), mk_slot, mk_value_attrs); DynamicNodeInvocation invocation = DynamicNodeInvocation{ - /*inputs=*/result_inputs, - /*node_attrs=*/result_attrs, - /*outputs=*/result_outputs, + /*inputs=*/result_inputs, + /*node_attrs=*/result_attrs, + /*outputs=*/result_outputs, }; return invocation; }; }; - return dynamic_open_dataflow_graph_from_invocation_set( - transform_pairs( - unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), - mk_invocation)); + return dynamic_open_dataflow_graph_from_invocation_set(transform_pairs( + unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), mk_invocation)); } } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index f4960fe67a..64fe2df0be 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,11 +1,11 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" #include "utils/containers/are_all_same.h" #include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" -#include "task-spec/dynamic_graph/training_operation_attrs.h" namespace FlexFlow { @@ -84,7 +84,8 @@ DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation( DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation const &invocation) { - TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); + TrainingOperationAttrs op_attrs = + assert_unwrap(invocation.node_attrs.op_attrs); auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { return std::pair{ @@ -106,15 +107,16 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation bwd{ /*inputs=*/{ - to_fwd(output_slot, output), - to_grad(output_slot, output), + to_fwd(output_slot, output), + to_grad(output_slot, output), }, /*node_attrs=*/ pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), - /*outputs=*/{ - to_grad(input_slot, input), + /*outputs=*/ + { + to_grad(input_slot, input), }, - }; + }; return bwd; } else { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc index d1452242ca..a9be225ff5 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc @@ -4,17 +4,14 @@ namespace FlexFlow { -bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, OperatorType op_type) { - return op_attrs.visit(overload { - [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { - return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; - }, - [](LossAttrs const &) -> bool { - return false; - }, - [](CopyAttrs const &) -> bool { - return false; - }, +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, + OperatorType op_type) { + return op_attrs.visit(overload{ + [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { + return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; + }, + [](LossAttrs const &) -> bool { return false; }, + [](CopyAttrs const &) -> bool { return false; }, }); } diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index ed22a8cbde..bf88d5ec38 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,8 +1,8 @@ #include "task-spec/dynamic_graph/pass_expansion.h" +#include "op-attrs/ops/element_unary.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include -#include "op-attrs/ops/element_unary.h" using namespace ::FlexFlow; @@ -37,18 +37,17 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; - TrainingOperationAttrs op_attrs = - TrainingOperationAttrs{ + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ PCGOperatorAttrs{ - LinearAttrs{ - /*out_channels=*/8_p, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*activation=*/std::nullopt, - /*regularizer=*/std::nullopt, - }, + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, - }; + }; DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); @@ -157,18 +156,17 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); SUBCASE("normal operator") { - TrainingOperationAttrs op_attrs = - TrainingOperationAttrs{ + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ PCGOperatorAttrs{ - LinearAttrs{ - /*out_channels=*/8_p, - /*use_bias=*/true, - /*data_type=*/DataType::FLOAT, - /*activation=*/std::nullopt, - /*regularizer=*/std::nullopt, - }, + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, - }; + }; DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { return DynamicNodeInvocation{ @@ -227,14 +225,13 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("replicate operator optimization") { - TrainingOperationAttrs op_attrs = - TrainingOperationAttrs{ + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ PCGOperatorAttrs{ - ReplicateAttrs{ - /*replicate_degree=*/2_p, - }, + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, - }; + }; DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { return DynamicNodeInvocation{ @@ -262,7 +259,8 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; - DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + DynamicTensorRole grad_role = + DynamicTensorRole{FwbTensorType::GRADIENT}; return DynamicNodeInvocation{ /*inputs=*/{ @@ -324,28 +322,27 @@ TEST_SUITE(FF_TEST_SUITE) { }; TrainingOperationAttrs input_op_attrs = TrainingOperationAttrs{ - PCGOperatorAttrs{ - InputAttrs{ - TensorShape{ - TensorDims{ - FFOrdered{ - 4_p, - 8_p, - }, + PCGOperatorAttrs{ + InputAttrs{ + TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 8_p, + }, + }, + DataType::FLOAT, + }, }, - DataType::FLOAT, - }, }, - }, }; TrainingOperationAttrs relu_op_attrs = TrainingOperationAttrs{ - PCGOperatorAttrs{ - make_relu_attrs(), - }, + PCGOperatorAttrs{ + make_relu_attrs(), + }, }; - DynamicOpenDataflowGraph input = [&]() -> DynamicOpenDataflowGraph { DynamicNodeAttrs n1 = mk_node_attrs(10, input_op_attrs, std::nullopt); DynamicNodeAttrs n2 = mk_node_attrs(11, relu_op_attrs, std::nullopt); @@ -396,10 +393,14 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicOpenDataflowGraph result = perform_pass_expansion(input); DynamicOpenDataflowGraph correct = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1_fwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); - DynamicNodeAttrs n2_fwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); - DynamicNodeAttrs n1_bwd = mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); - DynamicNodeAttrs n2_bwd = mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n1_fwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n2_fwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n1_bwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n2_bwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); DynamicValueAttrs v1_activation = mk_value_attrs(0, mk_dynamic_tensor_role_fwd()); diff --git a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index f2104fd113..0c944bb9bd 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H -#include "utils/containers/foldl.h" #include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/containers/foldl.h" namespace FlexFlow { @@ -13,8 +13,7 @@ bidict merge_disjoint_bidicts(C const &c) { bidict empty = {}; return foldl(c, /*init=*/empty, - [](bidict const &lhs, - bidict const &rhs) { + [](bidict const &lhs, bidict const &rhs) { return binary_merge_disjoint_bidicts(lhs, rhs); }); } diff --git a/lib/utils/include/utils/containers/transform_pairs.h b/lib/utils/include/utils/containers/transform_pairs.h index c01b50554f..3e421ea445 100644 --- a/lib/utils/include/utils/containers/transform_pairs.h +++ b/lib/utils/include/utils/containers/transform_pairs.h @@ -21,7 +21,8 @@ template > -std::unordered_set transform_pairs(std::unordered_set> const &c, F &&f) { +std::unordered_set + transform_pairs(std::unordered_set> const &c, F &&f) { auto ff = [&](std::pair const &p) -> Out { return f(p.first, p.second); }; diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h index b5557e9e49..52c225d157 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h @@ -7,25 +7,20 @@ namespace FlexFlow { template std::unordered_set> - get_kwarg_dataflow_value_uses( - KwargDataflowGraphView const &g, - KwargDataflowOutput const &v) { - - KwargDataflowEdgeQuery query = - KwargDataflowEdgeQuery{ - /*src_nodes=*/query_set::match_single_value(v.node), - /*src_slots=*/query_set::match_single_value(v.slot_name), - /*dst_nodes=*/query_set::matchall(), - /*dst_slots=*/query_set::matchall(), - }; - - std::unordered_set> edges = - g.query_edges(query); - - return transform( - edges, [&](KwargDataflowEdge const &e) { - return e.dst; - }); + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &g, + KwargDataflowOutput const &v) { + + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::match_single_value(v.node), + /*src_slots=*/query_set::match_single_value(v.slot_name), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; + + std::unordered_set> edges = g.query_edges(query); + + return transform(edges, + [&](KwargDataflowEdge const &e) { return e.dst; }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index c73f696172..2d078eb304 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H #include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" @@ -19,7 +20,6 @@ #include #include #include -#include "utils/containers/require_same.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 7b725fdec1..5492ff3f78 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -4,6 +4,7 @@ #include "utils/containers/generate_map.h" #include "utils/containers/items.h" #include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" @@ -23,7 +24,6 @@ #include #include #include -#include "utils/containers/require_same.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc index 8650de44f6..13a1bcd968 100644 --- a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc +++ b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -6,7 +6,7 @@ namespace FlexFlow { using K = value_type<0>; using V = value_type<1>; -template - bidict binary_merge_disjoint_bidicts(bidict const &, bidict const &); +template bidict binary_merge_disjoint_bidicts(bidict const &, + bidict const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform_pairs.cc b/lib/utils/src/utils/containers/transform_pairs.cc index 241f1ad425..4afda936e4 100644 --- a/lib/utils/src/utils/containers/transform_pairs.cc +++ b/lib/utils/src/utils/containers/transform_pairs.cc @@ -8,10 +8,10 @@ using R = value_type<1>; using Out = value_type<2>; using F = std::function; -template - std::vector transform_pairs(std::vector> const &, F &&); +template std::vector transform_pairs(std::vector> const &, + F &&); -template - std::unordered_set transform_pairs(std::unordered_set> const &, F &&); +template std::unordered_set + transform_pairs(std::unordered_set> const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc index 2e42863e53..b1d2988223 100644 --- a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc @@ -5,10 +5,8 @@ namespace FlexFlow { using SlotName = ordered_value_type<0>; -template - std::unordered_set> - get_kwarg_dataflow_value_uses( - KwargDataflowGraphView const &, - KwargDataflowOutput const &); +template std::unordered_set> + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &, + KwargDataflowOutput const &); } // namespace FlexFlow From f2b075482f46dff9d5c66fd4a7616bc425f23421 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:15:38 -0700 Subject: [PATCH 09/14] Format Realm. --- .../src/realm-execution/pcg_instance.cc | 31 ++- .../src/realm-execution/test_op_replicate.cc | 185 +++++++++--------- 2 files changed, 107 insertions(+), 109 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index f2edac7f88..332669a9dc 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -217,14 +217,11 @@ static Realm::Event spawn_dynamic_node_invocation( }; auto issue_replicate_bwd = [&]() { - - DynamicValueAttrs output_grad = get_only( - values( - filter_keys( - invocation.inputs, - [](DynamicTensorSlot const &s) -> bool { - return s.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}; - }))); + DynamicValueAttrs output_grad = get_only(values( + filter_keys(invocation.inputs, [](DynamicTensorSlot const &s) -> bool { + return s.slot_tensor_role == + DynamicTensorRole{FwbTensorType::GRADIENT}; + }))); DynamicValueAttrs input_grad = get_only(values(invocation.outputs)); @@ -246,15 +243,15 @@ static Realm::Event spawn_dynamic_node_invocation( tensor_instance_backing.backing.at(replica_key).first; e = ctx.issue_copy( - /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), - /*src_inst=*/src_inst, - /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), - /*dst_inst=*/dst_inst, - /*requests=*/Realm::ProfilingRequestSet{}, - /*wait_on=*/e, - /*priority=*/0, - /*redop_id=*/redop_id, - /*exlusive=*/false); + /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), + /*src_inst=*/src_inst, + /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), + /*dst_inst=*/dst_inst, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/e, + /*priority=*/0, + /*redop_id=*/redop_id, + /*exlusive=*/false); } return e; }; diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc index 2523cae798..46d29e2bef 100644 --- a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -13,6 +13,7 @@ #include "op-attrs/tensor_slot_name.dtg.h" #include "pcg/device_type.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" @@ -27,7 +28,6 @@ #include "test/utils/doctest/check_kv.h" #include "utils/containers/require_only_key.h" #include -#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" namespace test { @@ -49,7 +49,8 @@ static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); } -MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_type) { +MappedParallelComputationGraph + make_test_mpcg_for_device_type(DeviceType device_type) { positive_int batch_size = 10_p; positive_int data_dim = 16_p; positive_int hidden_dim = 32_p; @@ -63,8 +64,8 @@ MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_ ParallelComputationGraph pcg = empty_parallel_computation_graph(); - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + TensorShape input_tensor_shape = + TensorShape{TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; ParallelLayerAddedResult inputs_layer = pcg_add_input_layer(pcg, input_tensor_shape); @@ -144,91 +145,92 @@ MappedParallelComputationGraph make_test_mpcg_for_device_type(DeviceType device_ /*discard_copy_component=*/1_n, /*shard_component=*/FFOrdered{0_n}}; - MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( - /*pcg=*/pcg, - /*mapped_op_task_groups=*/{ - { - inputs_layer.parallel_layer, - MappedOperatorTaskGroup{ - { + MappedParallelComputationGraph mpcg = + mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, }, - }, - }, - }, - { - inputs_layer_2.parallel_layer, - MappedOperatorTaskGroup{ - { { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, }, - }, - }, - }, - { - add_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::LHS_INPUT, tensor_coord0}, - {TensorSlotName::RHS_INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, }, - }, - }, - }, - { - repl_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, }, { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, }, - }, - }, - }, - { - relu_operator_1.parallel_layer, - MappedOperatorTaskGroup{ - { - { - cpu0, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord0}, - {TensorSlotName::OUTPUT, tensor_coord0}, - }}, - }, - { - cpu1, - OperatorAtomicTaskShardBinding{{ - {TensorSlotName::INPUT, tensor_coord1}, - {TensorSlotName::OUTPUT, tensor_coord1}, - }}, - }, - }, - }, - }, - }); + }); return mpcg; } @@ -245,21 +247,20 @@ TEST_SUITE(FF_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::CPU); - + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::CPU); std::unordered_map input_tensors; - OptimizerAttrs optimizer_attrs = - OptimizerAttrs{ - SGDOptimizerAttrs{ + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ /*lr=*/0.001, /*momentum=*/0.9, /*nesterov=*/false, /*weight_decay=*/0.001, - }, - }; + }, + }; DistributedFfHandle device_handle = create_distributed_ff_handle( ctx, @@ -303,17 +304,17 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { manager.start_controller([](RealmContext &ctx) { Allocator allocator = ctx.get_current_device_allocator(); - MappedParallelComputationGraph mpcg = make_test_mpcg_for_device_type(DeviceType::GPU); + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::GPU); - OptimizerAttrs optimizer_attrs = - OptimizerAttrs{ - SGDOptimizerAttrs{ + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ /*lr=*/0.001, /*momentum=*/0.9, /*nesterov=*/false, /*weight_decay=*/0.001, - }, - }; + }, + }; std::unordered_map input_tensors; From 9d03c9766beaa67b2cbc1eea87976d2c44ae6153 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:16:39 -0700 Subject: [PATCH 10/14] Refactor redop infrastructure and switch to Legion's redops. --- .../redops/realm_redop_registry.h | 16 + .../redops/redop_id_t.dtg.toml | 30 + .../realm-execution/redops/redop_id_t.h | 22 + .../realm-execution/tasks/realm_reduction.h | 154 ---- .../src/realm-execution/realm_manager.cc | 4 +- .../redops/realm_redop_registry.cc | 689 ++++++++++++++++++ .../src/realm-execution/redops/redop_id_t.cc | 32 + .../tasks/realm_task_registry.cc | 10 - 8 files changed, 792 insertions(+), 165 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h create mode 100644 lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/redops/redop_id_t.h delete mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_reduction.h create mode 100644 lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc create mode 100644 lib/realm-execution/src/realm-execution/redops/redop_id_t.cc diff --git a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h new file mode 100644 index 0000000000..a338a38bbf --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Registers all known reduction operators (redops). + */ +void Realm::Event register_all_redops(Realm::Runtime); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml new file mode 100644 index 0000000000..5183ff5e72 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "redop_id_t" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] +docstring = ''' +\brief An enum for identifying reduction operators (redops) for use in the Realm runtime. +''' + +[[values]] +name = "SUM_BOOL_REDOP_ID" + +[[values]] +name = "SUM_INT32_REDOP_ID" + +[[values]] +name = "SUM_INT64_REDOP_ID" + +[[values]] +name = "SUM_HALF_REDOP_ID" + +[[values]] +name = "SUM_FLOAT_REDOP_ID" + +[[values]] +name = "SUM_DOUBLE_REDOP_ID" diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h new file mode 100644 index 0000000000..b9ef91a05a --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Registers all known reduction operators (redops). + */ +Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType); + +/** + * \brief Convert a \ref FlexFlow::redop_id_t into a Realm reduction op ID. + */ +Realm::Processor::ReductionOpID + get_realm_reduction_op_id_for_redop_id(redop_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h b/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h deleted file mode 100644 index 512e344824..0000000000 --- a/lib/realm-execution/include/realm-execution/tasks/realm_reduction.h +++ /dev/null @@ -1,154 +0,0 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDUCTION_H -#include "op-attrs/datatype.dtg.h" -#include - -namespace FlexFlow { - -/** - * \brief Realm Sum Reduction for Float - * \see https://legion.stanford.edu/tutorial/realm/reductions.html - */ -struct SumReductionFloat { - using LHS = float; - using RHS = float; - - /** \brief Identity element for addition (0.0) */ - static constexpr RHS identity = 0.0f; - - /** - * \brief Apply reduction: lhs += rhs - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param lhs Left-hand side accumulator (modified in place) - * \param rhs Value to add - */ - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - // Atomic float add via CAS loop - union { - float f; - int i; - } old_val, new_val; - do { - old_val.f = lhs; - new_val.f = old_val.f + rhs; - } while ( - !__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i)); - } - } - - /** - * \brief Fold two RHS values: rhs1 += rhs2 - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param rhs1 Accumulator (modified in place) - * \param rhs2 Value to fold in - */ - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - // Atomic float add via CAS loop - union { - float f; - int i; - } old_val, new_val; - do { - old_val.f = rhs1; - new_val.f = old_val.f + rhs2; - } while ( - !__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i)); - } - } -}; - -/** - * \brief Realm Sum Reduction for Double - * \see https://legion.stanford.edu/tutorial/realm/reductions.html - */ -struct SumReductionDouble { - using LHS = double; - using RHS = double; - - /** \brief Identity element for addition (0.0) */ - static constexpr RHS identity = 0.0; - - /** - * \brief Apply reduction: lhs += rhs - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param lhs Left-hand side accumulator (modified in place) - * \param rhs Value to add - */ - template - static void apply(LHS &lhs, RHS rhs) { - if (EXCLUSIVE) { - lhs += rhs; - } else { - // Atomic double add via CAS loop using long long reinterpretation - union { - double d; - long long i; - } old_val, new_val; - do { - old_val.d = lhs; - new_val.d = old_val.d + rhs; - } while (!__sync_bool_compare_and_swap( - (long long *)&lhs, old_val.i, new_val.i)); - } - } - - /** - * \brief Fold two RHS values: rhs1 += rhs2 - * \tparam EXCLUSIVE If true, direct addition; if false, atomic CAS loop - * \param rhs1 Accumulator (modified in place) - * \param rhs2 Value to fold in - */ - template - static void fold(RHS &rhs1, RHS rhs2) { - if (EXCLUSIVE) { - rhs1 += rhs2; - } else { - // Atomic double add via CAS loop using long long reinterpretation - union { - double d; - long long i; - } old_val, new_val; - do { - old_val.d = rhs1; - new_val.d = old_val.d + rhs2; - } while (!__sync_bool_compare_and_swap( - (long long *)&rhs1, old_val.i, new_val.i)); - } - } -}; - -/** - * \brief Reduction op IDs for sum reductions - * \warning These IDs must not conflict with other registered reduction ops - */ -enum SumReductionOpIDs { - REDOP_SUM_FLOAT = 1, ///< Sum reduction op ID for float - REDOP_SUM_DOUBLE = 2, ///< Sum reduction op ID for double -}; - -/** - * \brief Returns the Realm reduction op ID for a sum reduction over the given datatype - * \param dtype The datatype to look up - * \return The corresponding Realm::ReductionOpID - * \throws PANIC if no sum reduction is registered for the given datatype - */ -inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) { - switch (dtype) { - case DataType::FLOAT: - return REDOP_SUM_FLOAT; - case DataType::DOUBLE: - return REDOP_SUM_DOUBLE; - default: - PANIC("no sum reduction registered for datatype {}", dtype); - } -} -} // namespace FlexFlow -#endif diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index e76be7054b..c7136d8a98 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,6 +1,7 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" #include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/redops/realm_redop_registry.h" namespace FlexFlow { @@ -9,8 +10,9 @@ RealmManager::RealmManager(int *argc, char ***argv) bool ok = this->get_runtime().init(argc, argv); ASSERT(ok); - // Register all tasks at initialization time so we don't need to later + // Register all tasks and redops at initialization time so we don't need to later register_all_tasks().wait(); + register_all_redops(this->get_runtime()); } RealmManager::~RealmManager() { diff --git a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc new file mode 100644 index 0000000000..d10b158463 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc @@ -0,0 +1,689 @@ +#include "realm-execution/redops/realm_redop_registry.h" + +namespace FlexFlow { + +// Reduction operators and related infrastructure borrowed from Legion. We +// maintain the Legion naming scheme to maximizing compatibility with the +// existing code, despite not otherwise relying or using Legion in any way. +// https://gitlab.com/StanfordLegion/legion/-/blob/5263aeff477fb94239c50d9306d58c4244e9fc38/runtime/legion/api/redop.inl#L31 +#if !defined(__cpp_lib_atomic_ref) || (__cpp_lib_atomic_ref < 201806L) +// We only need this crap if we're using a version of c++ < 20 +// Starting with c++20 we can do all this the right way with atomic_ref +namespace TypePunning { +// The tenth circle of hell is reserved for members of the C++ committee +// that decided to deviate from C's support for type punning unions. +// Add on to it the fact that it took them 9 fucking years to realize +// that they needed std::atomic_ref and it's plain to see they are all +// just a bunch of idiots that should never be allowed near a programming +// language standard ever again. They've clearly never written lock-free +// code in their lives. +template +class Pointer { +public: + Pointer(void *p) : pointer(convert(p)) {} + static inline T *convert(void *p) { + T *ptr = nullptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(p)); + return ptr; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline T operator[](size_t off) const { + return pointer[off]; + } + +private: + T volatile *const pointer; +}; +template +class AlignedPointer { +public: + AlignedPointer(void *p) : off(align(p)), pointer(convert(p, off)) {} + static inline T *convert(void *p, size_t off) { + uint8_t *p1 = nullptr; + static_assert(sizeof(p1) == sizeof(p)); + memcpy(&p1, &p, sizeof(p)); + p1 = p1 - off; + T *p2 = nullptr; + static_assert(sizeof(p1) == sizeof(p2)); + memcpy(&p2, &p1, sizeof(p1)); + return p2; + } + static inline size_t align(void *p) { + uintptr_t ptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(ptr)); + return ptr % ALIGNMENT; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline size_t offset(void) const { + return off; + } + +private: + size_t off; + T volatile *const pointer; +}; +template +class Alias { +public: + inline void load(Pointer const &pointer, size_t off = 0) { + T1 value = pointer[off]; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + template + inline void load(AlignedPointer const &pointer) { + T1 value = *pointer; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + inline T1 as_one(void) const { + T1 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline T2 as_two(void) const { + T2 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline Alias &operator=(T2 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + +private: + // Make this one private so it is can never be called + inline Alias &operator=(T1 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + static_assert(sizeof(T1) == sizeof(T2)); + uint8_t buffer[sizeof(T1)]; +}; +}; // namespace TypePunning +#endif + +// Define a prefix for annotating functions for CUDA compilation +#if defined(__CUDACC__) || defined(__HIPCC__) +#define __LEGION_CUDA_HD__ __host__ __device__ +#else +#define __LEGION_CUDA_HD__ +#endif + +template <> +class SumReduction { +public: + typedef bool LHS; + typedef bool RHS; + + static constexpr bool identity = false; + static constexpr int REDOP_ID = LEGION_REDOP_OR_BOOL; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int32_t LHS; + typedef int32_t RHS; + + static constexpr int32_t identity = 0; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT32; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int64_t LHS; + typedef int64_t RHS; + + static constexpr int64_t identity = 0; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT64; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction<__half> { +public: + typedef __half LHS; + typedef __half RHS; + + static inline const __half identity = __half(0, false /*raw*/); + static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT16; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef float LHS; + typedef float RHS; + + static constexpr float identity = 0.f; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT32; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef double LHS; + typedef double RHS; + + static constexpr double identity = 0.0; + static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT64; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs = lhs || rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&lhs); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 = rhs1 || rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&rhs1); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs2; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs2; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, + RHS rhs) { + lhs = lhs + rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) + atomicAdd(&lhs, rhs); +#else + // 16-bit atomics are not supported prior to volta + // 32-bit GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&lhs); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + RHS newval = lhs, oldval, other; + if (offset == 0) { + other = *((&lhs) + 1); + do { + oldval = newval; + newval = newval + rhs; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); + newval = __uint2lohalf(result); + other = __uint2hihalf(result); + } while (oldval != newval); + } else { + other = *((&lhs) - 1); + do { + oldval = newval; + newval = newval + rhs; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); + other = __uint2lohalf(result); + newval = __uint2hihalf(result); + } while (oldval != newval); + } +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias> oldval, newval; + TypePunning::AlignedPointer pointer((void *)&lhs); + unsigned const offset = pointer.offset() / sizeof(__half); + do { + oldval.load(pointer); + std::array next = oldval.as_two(); + next[offset] = __convert_float_to_halfint( + __convert_halfint_to_float(next[offset]) + float(rhs)); + newval = next; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, + RHS rhs2) { + rhs1 = rhs1 + rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) + atomicAdd(&rhs1, rhs2); +#else + // 16-bit atomics are not supported prior to volta + // 32-bit GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&rhs1); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + RHS newval = rhs1, oldval, other; + if (offset == 0) { + other = *((&rhs1) + 1); + do { + oldval = newval; + newval = newval + rhs2; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); + newval = __uint2lohalf(result); + other = __uint2hihalf(result); + } while (oldval != newval); + } else { + other = *((&rhs1) - 1); + do { + oldval = newval; + newval = newval + rhs2; + unsigned int const result = atomicCAS( + ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); + other = __uint2lohalf(result); + newval = __uint2hihalf(result); + } while (oldval != newval); + } +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias> oldval, newval; + TypePunning::AlignedPointer pointer((void *)&rhs1); + unsigned const offset = pointer.offset() / sizeof(__half); + do { + oldval.load(pointer); + std::array next = oldval.as_two(); + next[offset] = __convert_float_to_halfint( + __convert_halfint_to_float(next[offset]) + float(rhs2)); + newval = next; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +void Realm::Event register_all_redops(Realm::Runtime rt) { + // Registration is synchronous, so no need to capture events here + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_BOOL_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT32_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT64_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_HALF_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_FLOAT_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_DOUBLE_REDOP_ID)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc new file mode 100644 index 0000000000..702ddd5e97 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc @@ -0,0 +1,32 @@ +#include "realm-execution/redops/redop_id_t.h" + +namespace FlexFlow { + +Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { + + switch (dtype) { + case DataType::BOOL: + return redop_id_t::SUM_BOOL_REDOP_ID; + case DataType::INT32: + return redop_id_t::SUM_INT32_REDOP_ID; + case DataType::INT64: + return redop_id_t::SUM_INT64_REDOP_ID; + case DataType::HALF: + return redop_id_t::SUM_HALF_REDOP_ID; + case DataType::FLOAT: + return redop_id_t::SUM_FLOAT_REDOP_ID; + case DataType::DOUBLE: + return redop_id_t::SUM_DOUBLE_REDOP_ID; + default: + PANIC("No known sum reduction for data type {}", dtype); + } +} + +Realm::Processor::ReductionOpID + get_realm_reduction_op_id_for_redop_id(redop_id_t redop_id) { + return static_cast(redop_id); +} + +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index acafdf59fd..e7a8948f8d 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -5,7 +5,6 @@ #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h" #include "realm-execution/tasks/impl/per_device_op_state_init_task.h" -#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/exception.h" @@ -31,18 +30,9 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::ProfilingRequestSet()); } -static void register_reductions() { - // register sum reduction ops - Realm::Runtime rt = Realm::Runtime::get_runtime(); - rt.register_reduction(REDOP_SUM_FLOAT); - rt.register_reduction(REDOP_SUM_DOUBLE); - // register_reduction is synchronous — no event returned -} - Realm::Event register_all_tasks() { std::vector pending_registrations; - register_reductions(); std::vector init_task_ids = { // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, From 0d589f2056022d1068deb2d6abf745e3dcc048cd Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 12:39:41 -0700 Subject: [PATCH 11/14] Fix build for reductions. --- .../redops/realm_redop_registry.h | 6 +- .../redops/redop_id_t.dtg.toml | 3 - .../realm-execution/redops/redop_id_t.h | 12 +- .../src/realm-execution/pcg_instance.cc | 7 +- .../src/realm-execution/realm_manager.cc | 2 +- .../redops/realm_redop_registry.cc | 165 +----------------- .../src/realm-execution/redops/redop_id_t.cc | 12 +- 7 files changed, 26 insertions(+), 181 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h index a338a38bbf..e7e51326e1 100644 --- a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h +++ b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H #include "realm-execution/realm.h" #include "realm-execution/redops/redop_id_t.dtg.h" @@ -9,7 +9,7 @@ namespace FlexFlow { /** * \brief Registers all known reduction operators (redops). */ -void Realm::Event register_all_redops(Realm::Runtime); +void register_all_redops(Realm::Runtime); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml index 5183ff5e72..44e1f32c59 100644 --- a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml @@ -20,9 +20,6 @@ name = "SUM_INT32_REDOP_ID" [[values]] name = "SUM_INT64_REDOP_ID" -[[values]] -name = "SUM_HALF_REDOP_ID" - [[values]] name = "SUM_FLOAT_REDOP_ID" diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h index b9ef91a05a..8565b20b17 100644 --- a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h @@ -1,21 +1,21 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_REDOP_REGISTRY_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H +#include "op-attrs/datatype.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/redops/redop_id_t.dtg.h" namespace FlexFlow { /** - * \brief Registers all known reduction operators (redops). + * \brief Return the sum reduction operator (redop) ID for a given data type. */ -Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType); +redop_id_t get_sum_redop_id_for_data_type(DataType); /** * \brief Convert a \ref FlexFlow::redop_id_t into a Realm reduction op ID. */ -Realm::Processor::ReductionOpID - get_realm_reduction_op_id_for_redop_id(redop_id_t); +Realm::ReductionOpID get_realm_reduction_op_id_for_redop_id(redop_id_t); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 332669a9dc..1ac3821142 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -5,8 +5,8 @@ #include "realm-execution/distributed_per_device_op_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/impl/op_task.h" -#include "realm-execution/tasks/realm_reduction.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/copy_insertion.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" @@ -228,8 +228,9 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance dst_inst = tensor_instance_backing.backing.at(input_grad).first; - Realm::ReductionOpID redop_id = get_sum_reduction_op_id( - assert_unwrap(output_grad.parallel_tensor_shape).data_type); + Realm::ReductionOpID redop_id = + get_realm_reduction_op_id_for_redop_id(get_sum_redop_id_for_data_type( + assert_unwrap(output_grad.parallel_tensor_shape).data_type)); // chain reductions sequentially to avoid write races on dst Realm::Event e = precondition; diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index c7136d8a98..5a8f9cbbbb 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,7 +1,7 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" -#include "realm-execution/tasks/realm_task_registry.h" #include "realm-execution/redops/realm_redop_registry.h" +#include "realm-execution/tasks/realm_task_registry.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc index d10b158463..ab3304836a 100644 --- a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc +++ b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc @@ -1,4 +1,5 @@ #include "realm-execution/redops/realm_redop_registry.h" +#include "realm-execution/redops/redop_id_t.h" namespace FlexFlow { @@ -120,6 +121,12 @@ class Alias { #define __LEGION_CUDA_HD__ #endif +template +class SumReduction { + // Empty definition + // Specializations provided for each type +}; + template <> class SumReduction { public: @@ -127,7 +134,6 @@ class SumReduction { typedef bool RHS; static constexpr bool identity = false; - static constexpr int REDOP_ID = LEGION_REDOP_OR_BOOL; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -142,7 +148,6 @@ class SumReduction { typedef int32_t RHS; static constexpr int32_t identity = 0; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT32; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -157,22 +162,6 @@ class SumReduction { typedef int64_t RHS; static constexpr int64_t identity = 0; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_INT64; - - template - __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); - template - __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); -}; - -template <> -class SumReduction<__half> { -public: - typedef __half LHS; - typedef __half RHS; - - static inline const __half identity = __half(0, false /*raw*/); - static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT16; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -187,7 +176,6 @@ class SumReduction { typedef float RHS; static constexpr float identity = 0.f; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT32; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -202,7 +190,6 @@ class SumReduction { typedef double RHS; static constexpr double identity = 0.0; - static constexpr int REDOP_ID = LEGION_REDOP_SUM_FLOAT64; template __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); @@ -382,140 +369,6 @@ __LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, #endif } -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, - RHS rhs) { - lhs = lhs + rhs; -} - -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::apply(LHS &lhs, - RHS rhs) { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) -#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) - atomicAdd(&lhs, rhs); -#else - // 16-bit atomics are not supported prior to volta - // 32-bit GPU atomics need 4 byte alignment - const uintptr_t unaligned = reinterpret_cast(&lhs); - unsigned const offset = unaligned % sizeof(unsigned int); - const uintptr_t aligned = unaligned - offset; - unsigned int *ptr = reinterpret_cast(aligned); - RHS newval = lhs, oldval, other; - if (offset == 0) { - other = *((&lhs) + 1); - do { - oldval = newval; - newval = newval + rhs; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); - newval = __uint2lohalf(result); - other = __uint2hihalf(result); - } while (oldval != newval); - } else { - other = *((&lhs) - 1); - do { - oldval = newval; - newval = newval + rhs; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); - other = __uint2lohalf(result); - newval = __uint2hihalf(result); - } while (oldval != newval); - } -#endif -#else -#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) - std::atomic_ref atomic(lhs); - RHS oldval = atomic.load(); - RHS newval; - do { - newval = oldval + rhs; - } while (!atomic.compare_exchange_weak(oldval, newval)); -#else - // No atomic floating point operations so use compare and swap - TypePunning::Alias> oldval, newval; - TypePunning::AlignedPointer pointer((void *)&lhs); - unsigned const offset = pointer.offset() / sizeof(__half); - do { - oldval.load(pointer); - std::array next = oldval.as_two(); - next[offset] = __convert_float_to_halfint( - __convert_halfint_to_float(next[offset]) + float(rhs)); - newval = next; - } while (!__sync_bool_compare_and_swap( - (int32_t *)pointer, oldval.as_one(), newval.as_one())); -#endif -#endif -} - -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, - RHS rhs2) { - rhs1 = rhs1 + rhs2; -} - -template <> -__LEGION_CUDA_HD__ inline void SumReduction<__half>::fold(RHS &rhs1, - RHS rhs2) { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) -#if (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10) - atomicAdd(&rhs1, rhs2); -#else - // 16-bit atomics are not supported prior to volta - // 32-bit GPU atomics need 4 byte alignment - const uintptr_t unaligned = reinterpret_cast(&rhs1); - unsigned const offset = unaligned % sizeof(unsigned int); - const uintptr_t aligned = unaligned - offset; - unsigned int *ptr = reinterpret_cast(aligned); - RHS newval = rhs1, oldval, other; - if (offset == 0) { - other = *((&rhs1) + 1); - do { - oldval = newval; - newval = newval + rhs2; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(other, oldval), __hilohalf2uint(other, newval)); - newval = __uint2lohalf(result); - other = __uint2hihalf(result); - } while (oldval != newval); - } else { - other = *((&rhs1) - 1); - do { - oldval = newval; - newval = newval + rhs2; - unsigned int const result = atomicCAS( - ptr, __hilohalf2uint(oldval, other), __hilohalf2uint(newval, other)); - other = __uint2lohalf(result); - newval = __uint2hihalf(result); - } while (oldval != newval); - } -#endif -#else -#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) - std::atomic_ref atomic(rhs1); - RHS oldval = atomic.load(); - RHS newval; - do { - newval = oldval + rhs2; - } while (!atomic.compare_exchange_weak(oldval, newval)); -#else - // No atomic floating point operations so use compare and swap - TypePunning::Alias> oldval, newval; - TypePunning::AlignedPointer pointer((void *)&rhs1); - unsigned const offset = pointer.offset() / sizeof(__half); - do { - oldval.load(pointer); - std::array next = oldval.as_two(); - next[offset] = __convert_float_to_halfint( - __convert_halfint_to_float(next[offset]) + float(rhs2)); - newval = next; - } while (!__sync_bool_compare_and_swap( - (int32_t *)pointer, oldval.as_one(), newval.as_one())); -#endif -#endif -} - template <> __LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, RHS rhs) { @@ -670,7 +523,7 @@ __LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, #endif } -void Realm::Event register_all_redops(Realm::Runtime rt) { +void register_all_redops(Realm::Runtime rt) { // Registration is synchronous, so no need to capture events here rt.register_reduction>( get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_BOOL_REDOP_ID)); @@ -678,8 +531,6 @@ void Realm::Event register_all_redops(Realm::Runtime rt) { get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT32_REDOP_ID)); rt.register_reduction>( get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT64_REDOP_ID)); - rt.register_reduction>( - get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_HALF_REDOP_ID)); rt.register_reduction>( get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_FLOAT_REDOP_ID)); rt.register_reduction>( diff --git a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc index 702ddd5e97..f31769419f 100644 --- a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc +++ b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc @@ -1,9 +1,9 @@ #include "realm-execution/redops/redop_id_t.h" +#include "utils/exception.h" namespace FlexFlow { -Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { - +redop_id_t get_sum_redop_id_for_data_type(DataType dtype) { switch (dtype) { case DataType::BOOL: return redop_id_t::SUM_BOOL_REDOP_ID; @@ -11,8 +11,6 @@ Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { return redop_id_t::SUM_INT32_REDOP_ID; case DataType::INT64: return redop_id_t::SUM_INT64_REDOP_ID; - case DataType::HALF: - return redop_id_t::SUM_HALF_REDOP_ID; case DataType::FLOAT: return redop_id_t::SUM_FLOAT_REDOP_ID; case DataType::DOUBLE: @@ -22,11 +20,9 @@ Realm::ReductionOpID get_sum_redop_id_for_data_type(DataType) { } } -Realm::Processor::ReductionOpID +Realm::ReductionOpID get_realm_reduction_op_id_for_redop_id(redop_id_t redop_id) { - return static_cast(redop_id); -} - + return static_cast(redop_id); } } // namespace FlexFlow From 5ef0b070ad9cf21cbae6346e8294074fb81e74ad Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 14:45:34 -0700 Subject: [PATCH 12/14] Split reduction from copy and put back device op state init code. --- .../include/realm-execution/realm_context.h | 29 ++-- ...uted_per_device_op_state_initialization.cc | 6 +- .../src/realm-execution/pcg_instance.cc | 19 ++- .../src/realm-execution/realm_context.cc | 127 ++++++++++++------ 4 files changed, 112 insertions(+), 69 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index eab42d0d79..5b76d52e2c 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include #include @@ -65,16 +66,24 @@ struct RealmContext { /** \name Data movement and reduction */ ///\{ - Realm::Event - issue_copy(ParallelTensorShape const &src_shape, - Realm::RegionInstance src_inst, - ParallelTensorShape const &dst_shape, - Realm::RegionInstance dst_inst, - Realm::ProfilingRequestSet const &requests, - Realm::Event wait_on = Realm::Event::NO_EVENT, - int priority = 0, - std::optional redop_id = std::nullopt, - bool exclusive = false); + Realm::Event issue_copy(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); + + Realm::Event issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); ///\} /** \name Instance management */ diff --git a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc index e7d8647b12..1d517a8fe4 100644 --- a/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_per_device_op_state_initialization.cc @@ -31,7 +31,6 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( std::unordered_map *> device_state_map; - std::vector completion_events; for (DynamicNodeInvocation const &invocation : dg.invocations) { Realm::Processor target_proc = ctx.map_device_coord_to_processor( assert_unwrap(invocation.node_attrs.device_coord)); @@ -57,7 +56,6 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( precondition); if (completion_event.has_value()) { - completion_events.push_back(completion_event.value()); device_state_map.insert(std::pair{invocation, device_state_ptr}); } else { // Task doesn't require initialization, clean up and don't store result @@ -65,9 +63,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization( } } - // wait for all init tasks — direct write to *result_ptr happens - // before each init task event fires so result is ready after this - Realm::Event::merge_events(completion_events).wait(); + ctx.get_outstanding_events().wait(); auto deref = [](DeviceSpecificPtr *const &p) { return *p; }; std::unordered_map> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 1ac3821142..aa67110127 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -228,12 +228,11 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance dst_inst = tensor_instance_backing.backing.at(input_grad).first; - Realm::ReductionOpID redop_id = - get_realm_reduction_op_id_for_redop_id(get_sum_redop_id_for_data_type( - assert_unwrap(output_grad.parallel_tensor_shape).data_type)); + redop_id_t redop_id = get_sum_redop_id_for_data_type( + assert_unwrap(output_grad.parallel_tensor_shape).data_type); // chain reductions sequentially to avoid write races on dst - Realm::Event e = precondition; + Realm::Event result = precondition; for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) { DynamicValueAttrs replica_key = output_grad; replica_key.mapping = @@ -243,18 +242,18 @@ static Realm::Event spawn_dynamic_node_invocation( Realm::RegionInstance src_inst = tensor_instance_backing.backing.at(replica_key).first; - e = ctx.issue_copy( + result = ctx.issue_reduction( /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), /*src_inst=*/src_inst, /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), /*dst_inst=*/dst_inst, - /*requests=*/Realm::ProfilingRequestSet{}, - /*wait_on=*/e, - /*priority=*/0, /*redop_id=*/redop_id, - /*exlusive=*/false); + /*is_fold=*/false, + /*exlusive=*/false, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/result); } - return e; + return result; }; TrainingOperationAttrs op_attrs = diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index a4669bf43e..36dd7c71cc 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -7,6 +7,7 @@ #include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" #include "realm-execution/realm_allocator.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/containers/contains_key.h" @@ -154,6 +155,46 @@ static Realm::IndexSpace ispace_from_dims(TensorDims const &dims) { return Realm::IndexSpace{rect}; } +[[nodiscard]] static Realm::Event + issue_copy_for_field(TensorDims const &dims, + Realm::CopySrcDstField const &src_field, + Realm::CopySrcDstField const &dst_field, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + switch (dims.ff_ordered.num_dims()) { +#if REALM_MAX_DIM >= 1 + case 1: + return ispace_from_dims<1>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 2 + case 2: + return ispace_from_dims<2>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 3 + case 3: + return ispace_from_dims<3>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 4 + case 4: + return ispace_from_dims<4>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 5 + case 5: + return ispace_from_dims<5>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif + default: + PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", + dims.ff_ordered.num_dims()); + break; + } +} + Realm::Event RealmContext::issue_copy(ParallelTensorShape const &src_shape, Realm::RegionInstance src_inst, @@ -161,9 +202,7 @@ Realm::Event Realm::RegionInstance dst_inst, Realm::ProfilingRequestSet const &requests, Realm::Event wait_on, - int priority, - std::optional redop_id, - bool exclusive) { + int priority) { TensorShape src_piece_shape = get_piece_shape(src_shape); TensorShape dst_piece_shape = get_piece_shape(dst_shape); ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match @@ -185,48 +224,48 @@ Realm::Event size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), /*subfield_offset=*/0); - // set reduction op on dst field if provided - if (redop_id.has_value()) { - dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive); - } + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); + this->outstanding_events.push_back(result); + return result; +} - Realm::Event result; - switch (src_piece_shape.dims.ff_ordered.num_dims()) { -#if REALM_MAX_DIM >= 1 - case 1: - result = ispace_from_dims<1>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 2 - case 2: - result = ispace_from_dims<2>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 3 - case 3: - result = ispace_from_dims<3>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 4 - case 4: - result = ispace_from_dims<4>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 5 - case 5: - result = ispace_from_dims<5>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif - default: - PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", - src_piece_shape.dims.ff_ordered.num_dims()); - break; - } +Realm::Event + RealmContext::issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + TensorShape src_piece_shape = get_piece_shape(src_shape); + TensorShape dst_piece_shape = get_piece_shape(dst_shape); + ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match + + Realm::CopySrcDstField src_field; + src_field.set_field( + /*inst=*/src_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + Realm::CopySrcDstField dst_field; + dst_field.set_field( + /*inst=*/dst_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + dst_field.set_redop( + get_realm_reduction_op_id_for_redop_id(redop_id), is_fold, exclusive); + + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); this->outstanding_events.push_back(result); return result; } From b73751727c77c92358624fb9bd9997e3f77e73c3 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 14:48:55 -0700 Subject: [PATCH 13/14] Replicate is not a task, don't represent it as one. --- .../include/realm-execution/tasks/task_id_t.dtg.toml | 9 --------- .../src/realm-execution/tasks/realm_task_registry.cc | 3 --- 2 files changed, 12 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index b1e5e07e28..b0bcc23b4d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -327,15 +327,6 @@ name = "COMBINE_FWD_TASK_ID" [[values]] name = "COMBINE_BWD_TASK_ID" -[[values]] -name = "REPLICATE_INIT_TASK_ID" - -[[values]] -name = "REPLICATE_FWD_TASK_ID" - -[[values]] -name = "REPLICATE_BWD_TASK_ID" - [[values]] name = "REDUCTION_INIT_TASK_ID" diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..dfdfe72ce0 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -49,7 +49,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_INIT_TASK_ID, task_id_t::REDUCTION_INIT_TASK_ID, task_id_t::REPARTITION_INIT_TASK_ID, - task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, }; @@ -86,7 +85,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_FWD_TASK_ID, task_id_t::REDUCTION_FWD_TASK_ID, task_id_t::REPARTITION_FWD_TASK_ID, - task_id_t::REPLICATE_FWD_TASK_ID, task_id_t::RESHAPE_FWD_TASK_ID, task_id_t::REVERSE_FWD_TASK_ID, task_id_t::SOFTMAX_FWD_TASK_ID, @@ -115,7 +113,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_BWD_TASK_ID, task_id_t::REDUCTION_BWD_TASK_ID, task_id_t::REPARTITION_BWD_TASK_ID, - task_id_t::REPLICATE_BWD_TASK_ID, task_id_t::RESHAPE_BWD_TASK_ID, task_id_t::REVERSE_BWD_TASK_ID, task_id_t::SOFTMAX_BWD_TASK_ID, From c536cc96bfd5b4b88739dcbcfcab9f00a96a7fed Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 21 May 2026 14:49:47 -0700 Subject: [PATCH 14/14] Put back the per device op state return code path. --- .../tasks/impl/per_device_op_state_init_task.cc | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc index 0ea51810e4..753fccf74b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/per_device_op_state_init_task.cc @@ -66,17 +66,11 @@ void per_device_op_state_init_task_body(void const *args, result_state, ctx.get_current_device_idx())}; DeviceSpecificPtr result_device_specific{ ctx.get_current_device_idx(), result_state_ptr}; - - // replace spawn_per_device_op_state_init_return_task with: - // NOTE: SM/TODO: direct write assumes single-node shared address space - // For multi-node, replace with UserEvent trigger pattern - *task_args.origin_result_ptr = result_device_specific; - - // spawn_per_device_op_state_init_return_task(ctx, - // task_args.origin_proc, - // result_device_specific, - // task_args.origin_result_ptr, - // Realm::Event::NO_EVENT); + spawn_per_device_op_state_init_return_task(ctx, + task_args.origin_proc, + result_device_specific, + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); } std::optional spawn_per_device_op_state_init_task(