diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml index 88a65f75c5..f2dd7c9350 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml @@ -11,13 +11,13 @@ features = [ ] includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", "op-attrs/ops/dropout_attrs.dtg.h", "op-attrs/ops/element_binary_attrs.dtg.h", @@ -61,7 +61,7 @@ key = "cast" [[values]] type = "::FlexFlow::CombineAttrs" -key = "combine_distributed" +key = "parallel_combine" [[values]] type = "::FlexFlow::ConcatAttrs" @@ -125,15 +125,15 @@ key = "reduce" [[values]] type = "::FlexFlow::ReductionAttrs" -key = "reduce_distributed" +key = "parallel_reduce" [[values]] type = "::FlexFlow::RepartitionAttrs" -key = "partition_distributed" +key = "parallel_partition" [[values]] type = "::FlexFlow::ReplicateAttrs" -key = "replicate_distributed" +key = "parallel_replicate" [[values]] type = "::FlexFlow::ReverseAttrs" diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 9d02923689..ca7e417814 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees( ElementUnaryAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { ASSERT(input_degrees.sum_degree.value == 1); - ASSERT(input_degrees.discard_copy_degree.value == 1); return input_degrees; } diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 672b160cbd..09e49a123c 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -53,22 +53,26 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("sum degree > 1") { + SUBCASE("discard copy degree > 1") { positive_int degree = 2_p; - CHECK_THROWS(get_output_shape( - attrs, - make_input( - SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); + ParallelTensorShape par_input = + make_input(SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = par_input; + + CHECK(result == correct); } - SUBCASE("discard copy degree > 1") { + SUBCASE("sum degree > 1") { positive_int degree = 2_p; CHECK_THROWS(get_output_shape( attrs, make_input( - SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p))); + SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } } } diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 12c7921282..a2afdb7914 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -8,15 +8,47 @@ namespace FlexFlow { std::unordered_set mpcg_get_parallel_layers(MappedParallelComputationGraph const &); + MappedOperatorTaskGroup mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &, parallel_layer_guid_t); ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 0368be62bc..9764e40627 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -10,6 +10,7 @@ #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" #include namespace FlexFlow { @@ -53,6 +54,10 @@ std::unordered_map get_incoming_edges(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); + std::unordered_set get_initial_layers(ParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h new file mode 100644 index 0000000000..f5e5575632 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H + +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" + +namespace FlexFlow { + +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &); +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index f4fa946a66..fc1dff504b 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -2,12 +2,14 @@ #include "op-attrs/pcg_operator_attrs.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/bidict/algorithms/transform_keys.h" #include "utils/containers/transform.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" +#include "utils/many_to_one/many_to_one_from_map.h" namespace FlexFlow { @@ -46,6 +48,59 @@ ParallelComputationGraph }; } +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { + return get_source_layer(pcg_from_mpcg(mpcg), t); +} + +PCGOperatorAttrs + mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return pcg_get_op_attrs(pcg_from_mpcg(mpcg), l); +} + +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { + return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); +} + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return get_incoming_edges(pcg_from_mpcg(mpcg), l); +} + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return get_outgoing_edges(pcg_from_mpcg(mpcg), l); +} + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return many_to_one_from_map(get_incoming_tensors(pcg_from_mpcg(mpcg), l)); +} + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); +} + +std::unordered_set + mpcg_get_edges(MappedParallelComputationGraph const &mpcg) { + return get_edges(pcg_from_mpcg(mpcg)); +} + +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { + return pcg_get_parallel_tensor_uses(pcg_from_mpcg(mpcg), t); +} + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index a548ceb65a..5098cadafe 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -28,6 +28,7 @@ #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" @@ -206,6 +207,17 @@ std::unordered_map }); } +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + std::unordered_set> raw_uses = + get_kwarg_dataflow_value_uses(pcg.raw_graph, t.raw_graph_output); + + return transform(raw_uses, [](KwargDataflowInput const &i) { + return parallel_tensor_use_t{i}; + }); +} + std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc new file mode 100644 index 0000000000..71a9cadf1c --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc @@ -0,0 +1,14 @@ +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" + +namespace FlexFlow { + +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { + return parallel_layer_guid_t{u.raw_dataflow_input.node}; +} + +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &u) { + return u.raw_dataflow_input.slot_name; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index ab89e916c0..5b76d52e2c 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include #include @@ -63,7 +64,7 @@ struct RealmContext { int priority = 0); ///\} - /** \name Data movement */ + /** \name Data movement and reduction */ ///\{ Realm::Event issue_copy(ParallelTensorShape const &src_shape, Realm::RegionInstance src_inst, @@ -72,6 +73,17 @@ struct RealmContext { Realm::ProfilingRequestSet const &requests, Realm::Event wait_on = Realm::Event::NO_EVENT, int priority = 0); + + Realm::Event issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); ///\} /** \name Instance management */ diff --git a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h new file mode 100644 index 0000000000..e7e51326e1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Registers all known reduction operators (redops). + */ +void register_all_redops(Realm::Runtime); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml new file mode 100644 index 0000000000..44e1f32c59 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "redop_id_t" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] +docstring = ''' +\brief An enum for identifying reduction operators (redops) for use in the Realm runtime. +''' + +[[values]] +name = "SUM_BOOL_REDOP_ID" + +[[values]] +name = "SUM_INT32_REDOP_ID" + +[[values]] +name = "SUM_INT64_REDOP_ID" + +[[values]] +name = "SUM_FLOAT_REDOP_ID" + +[[values]] +name = "SUM_DOUBLE_REDOP_ID" diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h new file mode 100644 index 0000000000..8565b20b17 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H + +#include "op-attrs/datatype.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Return the sum reduction operator (redop) ID for a given data type. + */ +redop_id_t get_sum_redop_id_for_data_type(DataType); + +/** + * \brief Convert a \ref FlexFlow::redop_id_t into a Realm reduction op ID. + */ +Realm::ReductionOpID get_realm_reduction_op_id_for_redop_id(redop_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index b1e5e07e28..b0bcc23b4d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -327,15 +327,6 @@ name = "COMBINE_FWD_TASK_ID" [[values]] name = "COMBINE_BWD_TASK_ID" -[[values]] -name = "REPLICATE_INIT_TASK_ID" - -[[values]] -name = "REPLICATE_FWD_TASK_ID" - -[[values]] -name = "REPLICATE_BWD_TASK_ID" - [[values]] name = "REDUCTION_INIT_TASK_ID" diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 0ecd02143e..aa67110127 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -5,6 +5,7 @@ #include "realm-execution/distributed_per_device_op_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/copy_insertion.h" @@ -215,6 +216,46 @@ static Realm::Event spawn_dynamic_node_invocation( precondition); }; + auto issue_replicate_bwd = [&]() { + DynamicValueAttrs output_grad = get_only(values( + filter_keys(invocation.inputs, [](DynamicTensorSlot const &s) -> bool { + return s.slot_tensor_role == + DynamicTensorRole{FwbTensorType::GRADIENT}; + }))); + + DynamicValueAttrs input_grad = get_only(values(invocation.outputs)); + + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(input_grad).first; + + redop_id_t redop_id = get_sum_redop_id_for_data_type( + assert_unwrap(output_grad.parallel_tensor_shape).data_type); + + // chain reductions sequentially to avoid write races on dst + Realm::Event result = precondition; + for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) { + DynamicValueAttrs replica_key = output_grad; + replica_key.mapping = + bidict{{p, m}}; + replica_key.shard_coord = p; + + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(replica_key).first; + + result = ctx.issue_reduction( + /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), + /*src_inst=*/src_inst, + /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), + /*dst_inst=*/dst_inst, + /*redop_id=*/redop_id, + /*is_fold=*/false, + /*exlusive=*/false, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/result); + } + return result; + }; + TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); return op_attrs.visit(overload{ @@ -222,6 +263,14 @@ static Realm::Event spawn_dynamic_node_invocation( return pcg_op_attrs.visit(overload{ [&](InputAttrs const &) { return Realm::Event::NO_EVENT; }, [&](WeightAttrs const &) { return Realm::Event::NO_EVENT; }, + [&](ReplicateAttrs const &) { + if (invocation.node_attrs.task_type.has_value() && + invocation.node_attrs.task_type.value() == + DynamicTaskType::BWD) { + return issue_replicate_bwd(); + } + return issue_copy(); // forward + }, [&](auto const &) { return spawn_task(); }, }); }, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 790c1bd613..36dd7c71cc 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -7,6 +7,7 @@ #include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" #include "realm-execution/realm_allocator.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/containers/contains_key.h" @@ -154,6 +155,46 @@ static Realm::IndexSpace ispace_from_dims(TensorDims const &dims) { return Realm::IndexSpace{rect}; } +[[nodiscard]] static Realm::Event + issue_copy_for_field(TensorDims const &dims, + Realm::CopySrcDstField const &src_field, + Realm::CopySrcDstField const &dst_field, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + switch (dims.ff_ordered.num_dims()) { +#if REALM_MAX_DIM >= 1 + case 1: + return ispace_from_dims<1>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 2 + case 2: + return ispace_from_dims<2>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 3 + case 3: + return ispace_from_dims<3>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 4 + case 4: + return ispace_from_dims<4>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 5 + case 5: + return ispace_from_dims<5>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif + default: + PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", + dims.ff_ordered.num_dims()); + break; + } +} + Realm::Event RealmContext::issue_copy(ParallelTensorShape const &src_shape, Realm::RegionInstance src_inst, @@ -183,43 +224,48 @@ Realm::Event size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), /*subfield_offset=*/0); - Realm::Event result; - switch (src_piece_shape.dims.ff_ordered.num_dims()) { -#if REALM_MAX_DIM >= 1 - case 1: - result = ispace_from_dims<1>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 2 - case 2: - result = ispace_from_dims<2>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 3 - case 3: - result = ispace_from_dims<3>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 4 - case 4: - result = ispace_from_dims<4>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 5 - case 5: - result = ispace_from_dims<5>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif - default: - PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", - src_piece_shape.dims.ff_ordered.num_dims()); - break; - } + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); + this->outstanding_events.push_back(result); + return result; +} + +Realm::Event + RealmContext::issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + TensorShape src_piece_shape = get_piece_shape(src_shape); + TensorShape dst_piece_shape = get_piece_shape(dst_shape); + ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match + + Realm::CopySrcDstField src_field; + src_field.set_field( + /*inst=*/src_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + Realm::CopySrcDstField dst_field; + dst_field.set_field( + /*inst=*/dst_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + dst_field.set_redop( + get_realm_reduction_op_id_for_redop_id(redop_id), is_fold, exclusive); + + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); this->outstanding_events.push_back(result); return result; } diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index e76be7054b..5a8f9cbbbb 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,5 +1,6 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" +#include "realm-execution/redops/realm_redop_registry.h" #include "realm-execution/tasks/realm_task_registry.h" namespace FlexFlow { @@ -9,8 +10,9 @@ RealmManager::RealmManager(int *argc, char ***argv) bool ok = this->get_runtime().init(argc, argv); ASSERT(ok); - // Register all tasks at initialization time so we don't need to later + // Register all tasks and redops at initialization time so we don't need to later register_all_tasks().wait(); + register_all_redops(this->get_runtime()); } RealmManager::~RealmManager() { diff --git a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc new file mode 100644 index 0000000000..ab3304836a --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc @@ -0,0 +1,540 @@ +#include "realm-execution/redops/realm_redop_registry.h" +#include "realm-execution/redops/redop_id_t.h" + +namespace FlexFlow { + +// Reduction operators and related infrastructure borrowed from Legion. We +// maintain the Legion naming scheme to maximizing compatibility with the +// existing code, despite not otherwise relying or using Legion in any way. +// https://gitlab.com/StanfordLegion/legion/-/blob/5263aeff477fb94239c50d9306d58c4244e9fc38/runtime/legion/api/redop.inl#L31 +#if !defined(__cpp_lib_atomic_ref) || (__cpp_lib_atomic_ref < 201806L) +// We only need this crap if we're using a version of c++ < 20 +// Starting with c++20 we can do all this the right way with atomic_ref +namespace TypePunning { +// The tenth circle of hell is reserved for members of the C++ committee +// that decided to deviate from C's support for type punning unions. +// Add on to it the fact that it took them 9 fucking years to realize +// that they needed std::atomic_ref and it's plain to see they are all +// just a bunch of idiots that should never be allowed near a programming +// language standard ever again. They've clearly never written lock-free +// code in their lives. +template +class Pointer { +public: + Pointer(void *p) : pointer(convert(p)) {} + static inline T *convert(void *p) { + T *ptr = nullptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(p)); + return ptr; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline T operator[](size_t off) const { + return pointer[off]; + } + +private: + T volatile *const pointer; +}; +template +class AlignedPointer { +public: + AlignedPointer(void *p) : off(align(p)), pointer(convert(p, off)) {} + static inline T *convert(void *p, size_t off) { + uint8_t *p1 = nullptr; + static_assert(sizeof(p1) == sizeof(p)); + memcpy(&p1, &p, sizeof(p)); + p1 = p1 - off; + T *p2 = nullptr; + static_assert(sizeof(p1) == sizeof(p2)); + memcpy(&p2, &p1, sizeof(p1)); + return p2; + } + static inline size_t align(void *p) { + uintptr_t ptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(ptr)); + return ptr % ALIGNMENT; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline size_t offset(void) const { + return off; + } + +private: + size_t off; + T volatile *const pointer; +}; +template +class Alias { +public: + inline void load(Pointer const &pointer, size_t off = 0) { + T1 value = pointer[off]; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + template + inline void load(AlignedPointer const &pointer) { + T1 value = *pointer; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + inline T1 as_one(void) const { + T1 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline T2 as_two(void) const { + T2 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline Alias &operator=(T2 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + +private: + // Make this one private so it is can never be called + inline Alias &operator=(T1 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + static_assert(sizeof(T1) == sizeof(T2)); + uint8_t buffer[sizeof(T1)]; +}; +}; // namespace TypePunning +#endif + +// Define a prefix for annotating functions for CUDA compilation +#if defined(__CUDACC__) || defined(__HIPCC__) +#define __LEGION_CUDA_HD__ __host__ __device__ +#else +#define __LEGION_CUDA_HD__ +#endif + +template +class SumReduction { + // Empty definition + // Specializations provided for each type +}; + +template <> +class SumReduction { +public: + typedef bool LHS; + typedef bool RHS; + + static constexpr bool identity = false; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int32_t LHS; + typedef int32_t RHS; + + static constexpr int32_t identity = 0; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int64_t LHS; + typedef int64_t RHS; + + static constexpr int64_t identity = 0; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef float LHS; + typedef float RHS; + + static constexpr float identity = 0.f; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef double LHS; + typedef double RHS; + + static constexpr double identity = 0.0; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs = lhs || rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&lhs); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 = rhs1 || rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&rhs1); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs2; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs2; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +void register_all_redops(Realm::Runtime rt) { + // Registration is synchronous, so no need to capture events here + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_BOOL_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT32_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT64_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_FLOAT_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_DOUBLE_REDOP_ID)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc new file mode 100644 index 0000000000..f31769419f --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc @@ -0,0 +1,28 @@ +#include "realm-execution/redops/redop_id_t.h" +#include "utils/exception.h" + +namespace FlexFlow { + +redop_id_t get_sum_redop_id_for_data_type(DataType dtype) { + switch (dtype) { + case DataType::BOOL: + return redop_id_t::SUM_BOOL_REDOP_ID; + case DataType::INT32: + return redop_id_t::SUM_INT32_REDOP_ID; + case DataType::INT64: + return redop_id_t::SUM_INT64_REDOP_ID; + case DataType::FLOAT: + return redop_id_t::SUM_FLOAT_REDOP_ID; + case DataType::DOUBLE: + return redop_id_t::SUM_DOUBLE_REDOP_ID; + default: + PANIC("No known sum reduction for data type {}", dtype); + } +} + +Realm::ReductionOpID + get_realm_reduction_op_id_for_redop_id(redop_id_t redop_id) { + return static_cast(redop_id); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..dfdfe72ce0 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -49,7 +49,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_INIT_TASK_ID, task_id_t::REDUCTION_INIT_TASK_ID, task_id_t::REPARTITION_INIT_TASK_ID, - task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, }; @@ -86,7 +85,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_FWD_TASK_ID, task_id_t::REDUCTION_FWD_TASK_ID, task_id_t::REPARTITION_FWD_TASK_ID, - task_id_t::REPLICATE_FWD_TASK_ID, task_id_t::RESHAPE_FWD_TASK_ID, task_id_t::REVERSE_FWD_TASK_ID, task_id_t::SOFTMAX_FWD_TASK_ID, @@ -115,7 +113,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_BWD_TASK_ID, task_id_t::REDUCTION_BWD_TASK_ID, task_id_t::REPARTITION_BWD_TASK_ID, - task_id_t::REPLICATE_BWD_TASK_ID, task_id_t::RESHAPE_BWD_TASK_ID, task_id_t::REVERSE_BWD_TASK_ID, task_id_t::SOFTMAX_BWD_TASK_ID, diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index dd4b0a66ca..0bdc2ca6b5 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -64,9 +64,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_INIT_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_INIT_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return std::nullopt; }, [](ReverseAttrs const &) { return std::nullopt; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_INIT_TASK_ID; }, @@ -115,9 +113,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_FWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_FWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_FWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_FWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_FWD_TASK_ID; }, @@ -166,9 +162,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_BWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_BWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_BWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_BWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_BWD_TASK_ID; }, diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc new file mode 100644 index 0000000000..46d29e2bef --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -0,0 +1,350 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch, + Allocator &allocator) { + return tensor_accessor_all( + compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); +} + +MappedParallelComputationGraph + make_test_mpcg_for_device_type(DeviceType device_type) { + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_tensor_shape = + TensorShape{TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs{replicate_degree}; + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + MachineSpaceCoordinate cpu0{0_n, 0_n, device_type}; + MachineSpaceCoordinate cpu1{0_n, 1_n, device_type}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_component=*/FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /*sum_component=*/0_n, + /*discard_copy_component=*/1_n, + /*shard_component=*/FFOrdered{0_n}}; + + MappedParallelComputationGraph mpcg = + mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + { + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + }); + + return mpcg; +} + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::CPU); + + std::unordered_map + input_tensors; + + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::GPU); + + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} +} // namespace test diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index cbfe3ab264..2a3dc8bbb8 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -48,8 +48,8 @@ std::unordered_set get_subgraph_outgoing_edges( std::unordered_set const &); std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &, - open_parallel_tensor_guid_t const &); + get_open_parallel_tensor_uses(SubParallelComputationGraph const &, + open_parallel_tensor_guid_t const &); SubParallelComputationGraphData get_sub_pcg_data(SubParallelComputationGraph const &); diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 6ed2ef563e..f2686f7cf7 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -109,9 +109,10 @@ SubParallelComputationGraph apply_substitution_from_output_result( input_parallel_tensor_guid_t output_graph_input = output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( output_expr_input); - std::unordered_set uses = get_parallel_tensor_uses( - substitution_output_graph, - open_parallel_tensor_guid_from_input(output_graph_input)); + std::unordered_set uses = + get_open_parallel_tensor_uses( + substitution_output_graph, + open_parallel_tensor_guid_from_input(output_graph_input)); for (parallel_tensor_use_t const &use : uses) { SubParallelComputationGraphEdge new_edge = subpcg_edge_from_tensor_and_use(base_graph_tensor, use); diff --git a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc index 2ad5b54a17..4374a951f8 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc @@ -2,7 +2,7 @@ #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_pairs.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/containers/values.h" #include "utils/containers/zip_values_strict.h" @@ -26,7 +26,7 @@ bidict mapping_for_layer = bidict_from_pairs(values( zip_values_strict(layer_outputs, output_graph_expr_outputs))); - result = merge_disjoint_bidicts(result, mapping_for_layer); + result = binary_merge_disjoint_bidicts(result, mapping_for_layer); } return result; diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index 498fd6c1bf..85a0493e33 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -4,8 +4,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/bidict/algorithms/exhaustive_relational_join.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_values.h" @@ -34,7 +34,7 @@ bidict exhaustive_relational_join(pattern_node_outputs.reversed(), matched_layer_output_tensors); - result = merge_disjoint_bidicts(result, mapping); + result = binary_merge_disjoint_bidicts(result, mapping); } return result; diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 34b8ae1e96..c0c05ad5b1 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -131,8 +131,8 @@ std::unordered_set get_subgraph_incoming_edges( } std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &spcg, - open_parallel_tensor_guid_t const &t) { + get_open_parallel_tensor_uses(SubParallelComputationGraph const &spcg, + open_parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = get_open_kwarg_dataflow_value_uses(spcg.raw_graph, t.raw_open_dataflow_value); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h new file mode 100644 index 0000000000..9caea8c341 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H + +#include "op-attrs/operator_type.dtg.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, + OperatorType); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 4c1b9d4609..ef41042a51 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -25,15 +25,40 @@ bool node_is_copy(DynamicNodeAttrs const &n) { return n.op_attrs.has_value() && n.op_attrs.value().is_copy(); } +static bool is_replicate_invocation(DynamicNodeInvocation const &i) { + return i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().has() && + i.node_attrs.op_attrs.value() + .get() + .has(); +} + bool value_is_mapped(DynamicValueAttrs const &n) { return n.mapping.has_value(); } bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &g) { auto slot_is_mapped = [](DynamicTensorSlot const &) -> bool { return false; }; - - return no_part_of_dynamic_graph_satisfies( - g, node_is_copy, value_is_mapped, slot_is_mapped); + // check all non-replicate invocations + for (DynamicNodeInvocation const &i : g.invocations) { + if (is_replicate_invocation(i)) { + continue; // replicate tensors have mapping set by design + } + if (node_is_copy(i.node_attrs)) { + return false; + } + for (auto const &[slot, value] : i.inputs) { + if (value_is_mapped(value)) { + return false; + } + } + for (auto const &[slot, value] : i.outputs) { + if (value_is_mapped(value)) { + return false; + } + } + } + return true; } bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &g) { @@ -85,6 +110,11 @@ std::unordered_set perform_copy_insertion_for_invocation( std::unordered_map const &unmapped_value_to_mapped_source_value) { + // replicate nodes have no MappedOperatorTaskGroup — + // pass through unchanged, no copies needed + if (is_replicate_invocation(i)) { + return {i}; + } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); auto map_tensor = [&](DynamicTensorSlot const &slot, @@ -157,6 +187,14 @@ DynamicOpenDataflowGraph std::unordered_map unmapped_value_to_mapped_source_value; for (DynamicNodeInvocation const &i : g.invocations) { + // replicate nodes have no MappedOperatorTaskGroup — + // output mapping already fully set, maps to itself + if (is_replicate_invocation(i)) { + for (auto const &[slot, value] : i.outputs) { + unmapped_value_to_mapped_source_value.insert(std::pair{value, value}); + } + continue; + } for (auto const &[slot, value] : i.outputs) { unmapped_value_to_mapped_source_value.insert( std::pair{value, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 380c2d17a1..7a149787b9 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -3,80 +3,209 @@ #include "op-attrs/pcg_operator_attrs.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" #include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" +#include "utils/containers/map_keys_and_values.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/transform_pairs.h" #include #include #include namespace FlexFlow { +static bidict + get_input_mapping_for_replicate( + MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &replicate_layer) { + + ASSERT(mpcg_get_pcg_op_attrs(mpcg, replicate_layer).is_parallel_replicate()); + + auto [input_slot_name, input_edge] = + get_only(mpcg_get_incoming_edges(mpcg, replicate_layer)); + + parallel_layer_guid_t producer_layer = get_src_layer(input_edge); + TensorSlotName producer_slot = get_src_layer_output_slot_name(input_edge); + + return get_tensor_bindings_for_slot_name( + /*task_group=*/mpcg_get_mapping_for_layer(mpcg, producer_layer), + /*slot_name=*/producer_slot); +} + +static bidict + build_replicated_output_mapping( + MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &output_tensor_guid) { + + std::unordered_set consumers = + mpcg_get_parallel_tensor_uses(mpcg, output_tensor_guid); + ASSERT(!consumers.empty()); + + // union all consumer bindings — each consumer shard maps to a distinct + // (discard_copy, machine) pair since replicas are always on different machines + bidict result = + merge_disjoint_bidicts(transform( + consumers, + [&](parallel_tensor_use_t const &use) + -> bidict { + parallel_layer_guid_t consumer_layer = + parallel_tensor_use_get_layer(use); + TensorSlotName slot_name = parallel_tensor_use_get_slot(use); + + MappedOperatorTaskGroup consumer_mapping = + mpcg_get_mapping_for_layer(mpcg, consumer_layer); + bidict + binding = get_tensor_bindings_for_slot_name(consumer_mapping, + slot_name); + + return binding; + })); + + return result; +} + +static DynamicNodeInvocation + build_replicate_invocation(parallel_layer_guid_t const &layer, + ReplicateAttrs const &attrs, + MappedParallelComputationGraph const &mpcg) { + + ManyToOne incoming = + mpcg_get_incoming_tensors(mpcg, layer); + TensorSlotName input_slot_name = TensorSlotName::INPUT; + parallel_tensor_guid_t input_tensor_guid = + require_only_key(incoming.l_to_r(), input_slot_name); + ParallelTensorAttrs input_attrs = + mpcg_get_parallel_tensor_attrs(mpcg, input_tensor_guid); + + bidict outgoing = + mpcg_get_outgoing_tensors(mpcg, layer); + TensorSlotName output_slot_name = TensorSlotName::OUTPUT; + parallel_tensor_guid_t output_tensor_guid = + require_only_key(outgoing.l_to_r(), output_slot_name); + ParallelTensorAttrs output_attrs = + mpcg_get_parallel_tensor_attrs(mpcg, output_tensor_guid); + + bidict input_mapping = + get_input_mapping_for_replicate(mpcg, layer); + + DynamicValueAttrs input_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, + /*parallel_tensor_shape=*/input_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/input_mapping, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + + DynamicValueAttrs output_value{ + /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, + /*parallel_tensor_shape=*/output_attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/build_replicated_output_mapping(mpcg, output_tensor_guid), + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + + DynamicNodeAttrs node_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/TrainingOperationAttrs{PCGOperatorAttrs{attrs}}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + DynamicNodeInvocation invocation_node{ + /*inputs=*/{ + { + DynamicTensorSlot{input_slot_name, std::nullopt}, + input_value, + }, + }, + /*node_attrs=*/node_attrs, + /*outputs=*/ + { + { + DynamicTensorSlot{output_slot_name, std::nullopt}, + output_value, + }, + }, + }; + + return invocation_node; +} + DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &mpcg) { - DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); - for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { - DynamicNodeAttrs result_attrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, - /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, - /*per_device_op_state=*/std::nullopt, - }; + auto mk_invocation = + [&](parallel_layer_guid_t layer, + ParallelLayerAttrs const &attrs) -> DynamicNodeInvocation { + if (attrs.op_attrs.is_parallel_replicate()) { + // build replicate invocation + DynamicNodeInvocation repl_inv = build_replicate_invocation( + layer, attrs.op_attrs.require_parallel_replicate(), mpcg); + return repl_inv; + } else { + DynamicNodeAttrs result_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; - std::unordered_map result_inputs = - transform(get_incoming_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); - std::unordered_map result_outputs = - transform(get_outgoing_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); - - result.invocations.emplace(result_inputs, result_attrs, result_outputs); - } + auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }; + }; - return result; + auto mk_value_attrs = + [&](parallel_tensor_guid_t const &tensor) -> DynamicValueAttrs { + ParallelTensorAttrs attrs = get_parallel_tensor_attrs(pcg, tensor); + + return DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + }; + + std::unordered_map result_inputs = + map_keys_and_values( + get_incoming_tensors(pcg, layer), mk_slot, mk_value_attrs); + + std::unordered_map result_outputs = + map_keys_and_values( + get_outgoing_tensors(pcg, layer), mk_slot, mk_value_attrs); + + DynamicNodeInvocation invocation = DynamicNodeInvocation{ + /*inputs=*/result_inputs, + /*node_attrs=*/result_attrs, + /*outputs=*/result_outputs, + }; + + return invocation; + }; + }; + + return dynamic_open_dataflow_graph_from_invocation_set(transform_pairs( + unordered_set_of(get_parallel_layer_attrs_mapping(pcg)), mk_invocation)); } } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 0cee06368f..64fe2df0be 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,7 +1,9 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" #include "utils/containers/are_all_same.h" +#include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" @@ -82,6 +84,9 @@ DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation( DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation const &invocation) { + TrainingOperationAttrs op_attrs = + assert_unwrap(invocation.node_attrs.op_attrs); + auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { return std::pair{ pass_expand_slot(k, FwbTensorType::FORWARD), @@ -96,17 +101,37 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( }; }; - return DynamicNodeInvocation{ - /*inputs=*/ - merge_disjoint_maps(std::vector{ - transform(invocation.inputs, to_fwd), - transform(invocation.outputs, to_fwd), - transform(invocation.outputs, to_grad), - }), - /*node_attrs=*/ - pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), - /*outputs=*/ - transform(invocation.inputs, to_grad), + if (training_op_attrs_has_op_type(op_attrs, OperatorType::REPLICATE)) { + auto [input_slot, input] = get_only(invocation.inputs); + auto [output_slot, output] = get_only(invocation.outputs); + + DynamicNodeInvocation bwd{ + /*inputs=*/{ + to_fwd(output_slot, output), + to_grad(output_slot, output), + }, + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/ + { + to_grad(input_slot, input), + }, + }; + + return bwd; + } else { + return DynamicNodeInvocation{ + /*inputs=*/ + merge_disjoint_maps(std::vector{ + transform(invocation.inputs, to_fwd), + transform(invocation.outputs, to_fwd), + transform(invocation.outputs, to_grad), + }), + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/ + transform(invocation.inputs, to_grad), + }; }; } diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index fb6efb96d0..d3365ae44c 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -39,7 +39,6 @@ bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &g) { value_is_shard_expanded, slot_is_shard_expanded); } - static bidict restrict_tensor_mapping_keys_to_coord( bidict const @@ -85,6 +84,113 @@ static DynamicNodeInvocation shard_invocation_for_binding( }; } +static std::unordered_set + perform_shard_expansion_for_replicate(DynamicNodeInvocation const &i) { + auto const &[input_slot, input] = get_only(i.inputs); + auto const &[output_slot, output] = get_only(i.outputs); + + bidict input_mapping = + assert_unwrap(input.mapping); + bidict output_mapping = + assert_unwrap(output.mapping); + + return transform(output_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) { + ParallelTensorSpaceCoordinate input_p{ + /*sum_component=*/p.sum_component, + /*discard_copy_component=*/nonnegative_int{0}, + /*shard_components=*/p.shard_components, + }; + return shard_invocation_for_binding( + i, + output_mapping.at_l(p), + OperatorAtomicTaskShardBinding{{ + {input_slot.slot_name, input_p}, + {output_slot.slot_name, p}, + }}); + }); +} + +static std::unordered_set + perform_shard_expansion_for_replicate_bwd(DynamicNodeInvocation const &i) { + + std::optional output_grad_opt; + std::optional output_fwd_opt; + std::optional output_grad_slot_opt; + std::optional output_fwd_slot_opt; + + for (auto const &[slot, value] : i.inputs) { + if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) { + output_grad_slot_opt = slot; + output_grad_opt = value; + } else { + output_fwd_slot_opt = slot; + output_fwd_opt = value; + } + } + + DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt); + DynamicValueAttrs output_fwd = assert_unwrap(output_fwd_opt); + DynamicTensorSlot output_grad_slot = assert_unwrap(output_grad_slot_opt); + DynamicTensorSlot output_fwd_slot = assert_unwrap(output_fwd_slot_opt); + auto const &[input_grad_slot, input_grad] = get_only(i.outputs); + + bidict + output_grad_mapping = assert_unwrap(output_grad.mapping); + bidict + input_grad_mapping = assert_unwrap(input_grad.mapping); + + std::unordered_map, + std::unordered_set> + by_shard; + for (auto const &p : output_grad_mapping.left_values()) { + by_shard[p.shard_components].insert(p); + } + + std::unordered_set result; + for (auto const &[shard_components, replica_coords] : by_shard) { + ParallelTensorSpaceCoordinate src_p{ + nonnegative_int{0}, nonnegative_int{0}, shard_components}; + MachineSpaceCoordinate src_machine = input_grad_mapping.at_l(src_p); + + bidict + replica_mapping; + for (auto const &p : replica_coords) { + replica_mapping.equate(p, output_grad_mapping.at_l(p)); + } + + DynamicValueAttrs sharded_output_grad = output_grad; + sharded_output_grad.mapping = replica_mapping; + sharded_output_grad.shard_coord = src_p; + + DynamicValueAttrs sharded_output_fwd = output_fwd; + sharded_output_fwd.mapping = replica_mapping; + sharded_output_fwd.shard_coord = src_p; + + DynamicValueAttrs sharded_input_grad = input_grad; + sharded_input_grad.mapping = + bidict{ + {src_p, src_machine}}; + sharded_input_grad.shard_coord = src_p; + + DynamicNodeAttrs sharded_node = i.node_attrs; + sharded_node.device_coord = src_machine; + + result.insert(DynamicNodeInvocation{ + /*inputs=*/{ + {output_fwd_slot, sharded_output_fwd}, + {output_grad_slot, sharded_output_grad}, + }, + /*node_attrs=*/sharded_node, + /*outputs=*/ + { + {input_grad_slot, sharded_input_grad}, + }, + }); + } + return result; +} + static std::unordered_set perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); @@ -121,6 +227,25 @@ std::unordered_set return perform_shard_expansion_for_copy(i); } + bool const is_replicate = + i.node_attrs.op_attrs.has_value() && + i.node_attrs.op_attrs.value().has() && + i.node_attrs.op_attrs.value() + .get() + .has(); + + // forward replicate + if (is_replicate && i.node_attrs.task_type.has_value() && + i.node_attrs.task_type.value() == DynamicTaskType::FWD) { + return perform_shard_expansion_for_replicate(i); + } + + // backward replicate + if (is_replicate && i.node_attrs.task_type.has_value() && + i.node_attrs.task_type.value() == DynamicTaskType::BWD) { + return perform_shard_expansion_for_replicate_bwd(i); + } + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); std::unordered_set shard_machine_coords = diff --git a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc new file mode 100644 index 0000000000..a9be225ff5 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc @@ -0,0 +1,18 @@ +#include "task-spec/dynamic_graph/training_operation_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "utils/overload.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, + OperatorType op_type) { + return op_attrs.visit(overload{ + [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { + return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; + }, + [](LossAttrs const &) -> bool { return false; }, + [](CopyAttrs const &) -> bool { return false; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc index 13465d7a5f..c8460af538 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc @@ -36,8 +36,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); @@ -62,8 +62,8 @@ static std::optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); diff --git a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc index d66ff9ab8d..9a092b90b8 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc @@ -35,8 +35,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(forward_kernel, profiling, @@ -62,8 +62,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(backward_kernel, profiling, diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index fb087f5295..bf88d5ec38 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,4 +1,5 @@ #include "task-spec/dynamic_graph/pass_expansion.h" +#include "op-attrs/ops/element_unary.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include @@ -36,6 +37,18 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, + }, + }; + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); @@ -46,14 +59,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -79,14 +91,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/DynamicTaskType::FWD, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -130,88 +141,162 @@ TEST_SUITE(FF_TEST_SUITE) { dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; - DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { - DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); - DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); - DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); - - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, - {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, - {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); + DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); + DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); + + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + + DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); + DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); + DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); + DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); + DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); + DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); + + SUBCASE("normal operator") { + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, }; - }(); - - DynamicNodeInvocation result = - perform_bwd_pass_expansion_for_invocation(invocation); - - DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { - DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; - DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; - - DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); - DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); - DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); - DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); - DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); - DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, - {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, - {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*pass_type=*/DynamicTaskType::BWD, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, - {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, - {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, - {mk_slot(TensorSlotName::SCALE, grad_role), v1_grad}, + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, + {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, + {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } + + SUBCASE("replicate operator optimization") { + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, }; - }(); - ASSERT(result == correct); + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v2}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = + DynamicTensorRole{FwbTensorType::GRADIENT}; + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v2_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } } TEST_CASE("perform_pass_expansion(DynamicOpenDataflowGraph)") { auto mk_node_attrs = [](size_t layer_id, + TrainingOperationAttrs const &op_attrs, std::optional const &pass_type) -> DynamicNodeAttrs { return DynamicNodeAttrs{ /*pass_type=*/pass_type, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/ dynamic_layer_guid_t{parallel_layer_guid_t{Node{layer_id}}}, /*per_device_op_state=*/std::nullopt, @@ -236,9 +321,31 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; + TrainingOperationAttrs input_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + InputAttrs{ + TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 8_p, + }, + }, + DataType::FLOAT, + }, + }, + }, + }; + + TrainingOperationAttrs relu_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), + }, + }; + DynamicOpenDataflowGraph input = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1 = mk_node_attrs(10, std::nullopt); - DynamicNodeAttrs n2 = mk_node_attrs(11, std::nullopt); + DynamicNodeAttrs n1 = mk_node_attrs(10, input_op_attrs, std::nullopt); + DynamicNodeAttrs n2 = mk_node_attrs(11, relu_op_attrs, std::nullopt); DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); @@ -286,10 +393,14 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicOpenDataflowGraph result = perform_pass_expansion(input); DynamicOpenDataflowGraph correct = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1_fwd = mk_node_attrs(10, DynamicTaskType::FWD); - DynamicNodeAttrs n2_fwd = mk_node_attrs(11, DynamicTaskType::FWD); - DynamicNodeAttrs n1_bwd = mk_node_attrs(10, DynamicTaskType::BWD); - DynamicNodeAttrs n2_bwd = mk_node_attrs(11, DynamicTaskType::BWD); + DynamicNodeAttrs n1_fwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n2_fwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n1_bwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n2_bwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); DynamicValueAttrs v1_activation = mk_value_attrs(0, mk_dynamic_tensor_role_fwd()); diff --git a/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h new file mode 100644 index 0000000000..5b0bb45910 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H + +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict binary_merge_disjoint_bidicts(bidict const &lhs, + bidict const &rhs) { + if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); + } + if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); + } + + bidict result; + for (auto const &kv : lhs) { + result.equate_strict(kv.first, kv.second); + } + for (auto const &kv : rhs) { + result.equate_strict(kv.first, kv.second); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index 97e7334c26..0c944bb9bd 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,35 +1,21 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H -#include "utils/bidict/algorithms/left_entries.h" -#include "utils/bidict/algorithms/right_entries.h" -#include "utils/bidict/bidict.h" -#include "utils/containers/are_disjoint.h" -#include "utils/exception.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/containers/foldl.h" namespace FlexFlow { -template -bidict merge_disjoint_bidicts(bidict const &lhs, - bidict const &rhs) { - if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); - } - if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); - } - - bidict result; - for (auto const &kv : lhs) { - result.equate(kv.first, kv.second); - } - for (auto const &kv : rhs) { - result.equate(kv.first, kv.second); - } - - return result; +template +bidict merge_disjoint_bidicts(C const &c) { + bidict empty = {}; + return foldl(c, + /*init=*/empty, + [](bidict const &lhs, bidict const &rhs) { + return binary_merge_disjoint_bidicts(lhs, rhs); + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 5dbd1c603d..2d8c5d23a8 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -213,6 +213,14 @@ struct bidict { return this->fwd_map; } + std::unordered_map const &l_to_r() const { + return this->fwd_map; + } + + std::unordered_map const &r_to_l() const { + return this->bwd_map; + } + bidict(std::unordered_map const &fwd_map, std::unordered_map const &bwd_map) : fwd_map(fwd_map), bwd_map(bwd_map) {} diff --git a/lib/utils/include/utils/containers/transform_pairs.h b/lib/utils/include/utils/containers/transform_pairs.h new file mode 100644 index 0000000000..3e421ea445 --- /dev/null +++ b/lib/utils/include/utils/containers/transform_pairs.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H + +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template > +std::vector transform_pairs(std::vector> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::unordered_set + transform_pairs(std::unordered_set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::set transform_pairs(std::set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h new file mode 100644 index 0000000000..52c225d157 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &g, + KwargDataflowOutput const &v) { + + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::match_single_value(v.node), + /*src_slots=*/query_set::match_single_value(v.slot_name), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; + + std::unordered_set> edges = g.query_edges(query); + + return transform(edges, + [&](KwargDataflowEdge const &e) { return e.dst; }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index d2f727661c..2d078eb304 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H #include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" @@ -106,6 +107,10 @@ struct ManyToOne { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: std::unordered_map m_l_to_r; std::unordered_map> m_r_to_l; diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 30d84d34c3..5492ff3f78 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -4,6 +4,7 @@ #include "utils/containers/generate_map.h" #include "utils/containers/items.h" #include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" @@ -114,6 +115,10 @@ struct OneToMany { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: std::unordered_map> m_l_to_r; std::unordered_map m_r_to_l; diff --git a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..13a1bcd968 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -0,0 +1,12 @@ +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template bidict binary_merge_disjoint_bidicts(bidict const &, + bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc index 754b8d2e90..2c27821d3b 100644 --- a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -1 +1,11 @@ #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template bidict merge_disjoint_bidicts(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform_pairs.cc b/lib/utils/src/utils/containers/transform_pairs.cc new file mode 100644 index 0000000000..4afda936e4 --- /dev/null +++ b/lib/utils/src/utils/containers/transform_pairs.cc @@ -0,0 +1,17 @@ +#include "utils/containers/transform_pairs.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; +using Out = value_type<2>; +using F = std::function; + +template std::vector transform_pairs(std::vector> const &, + F &&); + +template std::unordered_set + transform_pairs(std::unordered_set> const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc new file mode 100644 index 0000000000..b1d2988223 --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc @@ -0,0 +1,12 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template std::unordered_set> + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &, + KwargDataflowOutput const &); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc similarity index 72% rename from lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc rename to lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc index 0a1babd9f9..8a3371b8d8 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("merge_disjoint_bidicts") { + TEST_CASE("binary_merge_disjoint_bidicts") { SUBCASE("disjoint keys and values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "three"}, {4, "four"}}; - bidict result = merge_disjoint_bidicts(bd1, bd2); + bidict result = binary_merge_disjoint_bidicts(bd1, bd2); bidict correct = { {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; @@ -22,21 +22,21 @@ TEST_SUITE(FF_TEST_SUITE) { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "three"}, {3, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping key, same associated value") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "two"}, {3, "three"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "two"}, {4, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } } }