From 38ea8784c636e69fe89c4d3869b648070b989c0e Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Fri, 15 Jan 2021 11:00:13 +0800 Subject: [PATCH] update infer mirror ops --- .../parallel/ops_info/batch_parallel_info.cc | 16 -------- .../parallel/ops_info/batch_parallel_info.h | 1 - .../frontend/parallel/ops_info/loss_info.h | 1 - .../frontend/parallel/ops_info/matmul_info.cc | 26 ------------ .../frontend/parallel/ops_info/matmul_info.h | 1 - .../parallel/ops_info/operator_info.cc | 40 +++++++++++++++++++ .../parallel/ops_info/operator_info.h | 2 +- .../cpp/parallel/ops_info/matmul_info_test.cc | 2 +- tests/ut/cpp/parallel/step_parallel_test.cc | 13 +++--- 9 files changed, 47 insertions(+), 55 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc index de4617c83c..bf4b350377 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc @@ -57,22 +57,6 @@ Status BatchParallelInfo::InferDevMatrixShape() { return SUCCESS; } -Status BatchParallelInfo::InferMirrorOps() { - mirror_ops_.clear(); - if (g_device_manager->DeviceNum() == 1) { - MS_LOG(INFO) << name_ << " : The device num is 1, no need to create mirror ops."; - return SUCCESS; - } - - MS_LOG(INFO) << name_ << " : Batch parallel input number " << strategy_->GetInputNumber(); - for (size_t i = 0; i < input_value_.size(); i++) { - MS_EXCEPTION_IF_NULL(g_device_manager); - OperatorVector op_vec = CreateMirrorOps(g_device_manager->world_group(), g_device_manager->DeviceNum()); - mirror_ops_.push_back(op_vec); - } - return SUCCESS; -} - Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } Status BatchParallelInfo::InferTensorMap() { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h index a96ba18c99..ffe6ecccd9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h @@ -44,7 +44,6 @@ class BatchParallelInfo : public OperatorInfo { protected: Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h index 3f045929ac..f5aa52c4e6 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h @@ -48,7 +48,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { protected: Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; - Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorMap() override; Status InferTensorInfo() override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc index 6829766057..af77317beb 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -209,32 +209,6 @@ Status MatMulBase::InferDevMatrixShape() { return SUCCESS; } -// all-reduce weight's grad -Status MatMulBase::InferMirrorOps() { - mirror_ops_.clear(); - - Shape mat_b_tensor_map = inputs_tensor_map_[1]; - std::vector mat_b_group; - if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) { - return FAILED; - } - - OperatorVector op_for_inputs; // op_for_inputs is empty - OperatorVector op_for_weight; - - if (mat_b_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_inputs); - mirror_ops_.push_back(op_for_weight); - MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name(); - } - - return SUCCESS; -} - Status MatMulBase::InferForwardCommunication() { forward_op_.clear(); size_t dimension = origin_dev_matrix_shape_.size(); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h index 6ea93c7e7c..1006a75c0d 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h @@ -49,7 +49,6 @@ class MatMulBase : public OperatorInfo { Status SwapLastTwoElements(Shape *shape); protected: - Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 6ad3b32f82..d7161020ab 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -130,6 +130,46 @@ Status OperatorInfo::InferAttrs() { return SUCCESS; } +Status OperatorInfo::InferMirrorOps() { + mirror_ops_.clear(); + if (inputs_shape_.empty()) { + MS_LOG(INFO) << name_ << ": The inputs size is empty"; + return SUCCESS; + } + + if (inputs_tensor_map_.size() != inputs_shape_.size()) { + MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape"; + return FAILED; + } + + bool group_is_empty = true; + for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) { + std::vector group; + if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i; + mirror_ops_.clear(); + return FAILED; + } + + OperatorVector mirror_op; + if (group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i; + mirror_ops_.push_back(mirror_op); + continue; + } + + group_is_empty = false; + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + } + + if (group_is_empty) { + mirror_ops_.clear(); + MS_LOG(INFO) << name_ << ": No need to insert mirror ops"; + } + return SUCCESS; +} + Status OperatorInfo::InferRepeatedCalcInfo() { int64_t g_dev_list_size = stage_device_size_; int64_t dev_matrix_size = diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index c382707a88..b49ad00d0f 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -187,10 +187,10 @@ class OperatorInfo { virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status InferTensorMap() = 0; virtual Status InferForwardCommunication() = 0; - virtual Status InferMirrorOps() = 0; virtual Status GetAttrs() = 0; virtual Status InferTensorInfo() = 0; virtual Status InferDevMatrixShape() = 0; + virtual Status InferMirrorOps(); Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); void SetRepeatedCalcDevMatrix(); void ResetTensorMapIfRepeatedCalc(); diff --git a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc index 5eccc59a48..1c51c6135b 100644 --- a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc @@ -463,7 +463,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) { matmul1->Init(strategy); MirrorOps mirror_ops = matmul1->mirror_ops(); - ASSERT_EQ(mirror_ops.size(), 0); // all reduce only in -3 dim (strategy is 1); + ASSERT_EQ(mirror_ops.size(), 2); } TEST_F(TestMatmulInfo, InitTwice) { diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index 667da48397..f1b02034e2 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -32,8 +32,6 @@ class TestStepParallel : public UT::Common { void TearDown() {} }; -void TestStepParallel::SetUp() { UT::InitPythonPath(); } - void Init_Device_Manager() { RankList dev_list; @@ -52,6 +50,11 @@ void Init_Device_Manager() { g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); } +void TestStepParallel::SetUp() { + UT::InitPythonPath(); + Init_Device_Manager(); +} + CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) { FuncGraphPtr func_graph = std::make_shared(); ParameterPtr param1 = func_graph->add_parameter(); @@ -345,7 +348,6 @@ TEST_F(TestStepParallel, CreatOpInstance1) { } TEST_F(TestStepParallel, OperatorInstance) { - Init_Device_Manager(); // creat attrs and prim PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast(); ValuePtr transpose_a = MakeValue(false); @@ -369,7 +371,6 @@ TEST_F(TestStepParallel, OperatorInstance) { } TEST_F(TestStepParallel, ExtractInformation) { - Init_Device_Manager(); FuncGraphManagerPtr manager = Make_Manager(); FuncGraphSet graphs = manager->func_graphs(); FuncGraphPtr graph = *graphs.begin(); @@ -379,7 +380,6 @@ TEST_F(TestStepParallel, ExtractInformation) { } TEST_F(TestStepParallel, ExtractInformation2) { - Init_Device_Manager(); FuncGraphManagerPtr manager = Make_Manager(2); FuncGraphSet graphs = manager->func_graphs(); FuncGraphPtr graph = *graphs.begin(); @@ -389,7 +389,6 @@ TEST_F(TestStepParallel, ExtractInformation2) { } TEST_F(TestStepParallel, ExtractInformation3) { - Init_Device_Manager(); FuncGraphManagerPtr manager = Make_Manager(3); FuncGraphSet graphs = manager->func_graphs(); FuncGraphPtr graph = *graphs.begin(); @@ -399,7 +398,6 @@ TEST_F(TestStepParallel, ExtractInformation3) { } TEST_F(TestStepParallel, ForwardCommunication1) { - Init_Device_Manager(); ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM); ValuePtr attr1_value = MakeValue("0-1-2"); Attr attr0 = std::make_pair("op", attr0_value); @@ -499,7 +497,6 @@ TEST_F(TestStepParallel, ForwardCommunication3) { } TEST_F(TestStepParallel, GetTensorInLayout) { - Init_Device_Manager(); // creat attrs and prim FuncGraphPtr func_graph = std::make_shared(); Shape inputs_x_dims = {64, 32};