| @@ -57,22 +57,6 @@ Status BatchParallelInfo::InferDevMatrixShape() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status BatchParallelInfo::InferMirrorOps() { | |||||
| mirror_ops_.clear(); | |||||
| if (g_device_manager->DeviceNum() == 1) { | |||||
| MS_LOG(INFO) << name_ << " : The device num is 1, no need to create mirror ops."; | |||||
| return SUCCESS; | |||||
| } | |||||
| MS_LOG(INFO) << name_ << " : Batch parallel input number " << strategy_->GetInputNumber(); | |||||
| for (size_t i = 0; i < input_value_.size(); i++) { | |||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| OperatorVector op_vec = CreateMirrorOps(g_device_manager->world_group(), g_device_manager->DeviceNum()); | |||||
| mirror_ops_.push_back(op_vec); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } | Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } | ||||
| Status BatchParallelInfo::InferTensorMap() { | Status BatchParallelInfo::InferTensorMap() { | ||||
| @@ -44,7 +44,6 @@ class BatchParallelInfo : public OperatorInfo { | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | Status CheckStrategy(const StrategyPtr &strategy) override; | ||||
| Status InferMirrorOps() override; | |||||
| Status InferForwardCommunication() override; | Status InferForwardCommunication() override; | ||||
| Status InferTensorInfo() override; | Status InferTensorInfo() override; | ||||
| Status InferDevMatrixShape() override; | Status InferDevMatrixShape() override; | ||||
| @@ -48,7 +48,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { | |||||
| protected: | protected: | ||||
| Status CheckStrategy(const StrategyPtr &strategy) override; | Status CheckStrategy(const StrategyPtr &strategy) override; | ||||
| Status GetAttrs() override; | Status GetAttrs() override; | ||||
| Status InferMirrorOps() override { return SUCCESS; } | |||||
| Status InferForwardCommunication() override { return SUCCESS; } | Status InferForwardCommunication() override { return SUCCESS; } | ||||
| Status InferTensorMap() override; | Status InferTensorMap() override; | ||||
| Status InferTensorInfo() override; | Status InferTensorInfo() override; | ||||
| @@ -209,32 +209,6 @@ Status MatMulBase::InferDevMatrixShape() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // all-reduce weight's grad | |||||
| Status MatMulBase::InferMirrorOps() { | |||||
| mirror_ops_.clear(); | |||||
| Shape mat_b_tensor_map = inputs_tensor_map_[1]; | |||||
| std::vector<Group> mat_b_group; | |||||
| if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| OperatorVector op_for_inputs; // op_for_inputs is empty | |||||
| OperatorVector op_for_weight; | |||||
| if (mat_b_group.empty()) { | |||||
| MS_LOG(INFO) << name_ << " : The mirror ops is empty."; | |||||
| return SUCCESS; | |||||
| } else { | |||||
| op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum()); | |||||
| mirror_ops_.push_back(op_for_inputs); | |||||
| mirror_ops_.push_back(op_for_weight); | |||||
| MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name(); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MatMulBase::InferForwardCommunication() { | Status MatMulBase::InferForwardCommunication() { | ||||
| forward_op_.clear(); | forward_op_.clear(); | ||||
| size_t dimension = origin_dev_matrix_shape_.size(); | size_t dimension = origin_dev_matrix_shape_.size(); | ||||
| @@ -49,7 +49,6 @@ class MatMulBase : public OperatorInfo { | |||||
| Status SwapLastTwoElements(Shape *shape); | Status SwapLastTwoElements(Shape *shape); | ||||
| protected: | protected: | ||||
| Status InferMirrorOps() override; | |||||
| Status InferForwardCommunication() override; | Status InferForwardCommunication() override; | ||||
| Status InferTensorInfo() override; | Status InferTensorInfo() override; | ||||
| Status InferDevMatrixShape() override; | Status InferDevMatrixShape() override; | ||||
| @@ -130,6 +130,46 @@ Status OperatorInfo::InferAttrs() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OperatorInfo::InferMirrorOps() { | |||||
| mirror_ops_.clear(); | |||||
| if (inputs_shape_.empty()) { | |||||
| MS_LOG(INFO) << name_ << ": The inputs size is empty"; | |||||
| return SUCCESS; | |||||
| } | |||||
| if (inputs_tensor_map_.size() != inputs_shape_.size()) { | |||||
| MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape"; | |||||
| return FAILED; | |||||
| } | |||||
| bool group_is_empty = true; | |||||
| for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) { | |||||
| std::vector<Group> group; | |||||
| if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i; | |||||
| mirror_ops_.clear(); | |||||
| return FAILED; | |||||
| } | |||||
| OperatorVector mirror_op; | |||||
| if (group.empty()) { | |||||
| MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i; | |||||
| mirror_ops_.push_back(mirror_op); | |||||
| continue; | |||||
| } | |||||
| group_is_empty = false; | |||||
| mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); | |||||
| mirror_ops_.push_back(mirror_op); | |||||
| } | |||||
| if (group_is_empty) { | |||||
| mirror_ops_.clear(); | |||||
| MS_LOG(INFO) << name_ << ": No need to insert mirror ops"; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OperatorInfo::InferRepeatedCalcInfo() { | Status OperatorInfo::InferRepeatedCalcInfo() { | ||||
| int64_t g_dev_list_size = stage_device_size_; | int64_t g_dev_list_size = stage_device_size_; | ||||
| int64_t dev_matrix_size = | int64_t dev_matrix_size = | ||||
| @@ -187,10 +187,10 @@ class OperatorInfo { | |||||
| virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; | virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; | ||||
| virtual Status InferTensorMap() = 0; | virtual Status InferTensorMap() = 0; | ||||
| virtual Status InferForwardCommunication() = 0; | virtual Status InferForwardCommunication() = 0; | ||||
| virtual Status InferMirrorOps() = 0; | |||||
| virtual Status GetAttrs() = 0; | virtual Status GetAttrs() = 0; | ||||
| virtual Status InferTensorInfo() = 0; | virtual Status InferTensorInfo() = 0; | ||||
| virtual Status InferDevMatrixShape() = 0; | virtual Status InferDevMatrixShape() = 0; | ||||
| virtual Status InferMirrorOps(); | |||||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | ||||
| void SetRepeatedCalcDevMatrix(); | void SetRepeatedCalcDevMatrix(); | ||||
| void ResetTensorMapIfRepeatedCalc(); | void ResetTensorMapIfRepeatedCalc(); | ||||
| @@ -463,7 +463,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) { | |||||
| matmul1->Init(strategy); | matmul1->Init(strategy); | ||||
| MirrorOps mirror_ops = matmul1->mirror_ops(); | MirrorOps mirror_ops = matmul1->mirror_ops(); | ||||
| ASSERT_EQ(mirror_ops.size(), 0); // all reduce only in -3 dim (strategy is 1); | |||||
| ASSERT_EQ(mirror_ops.size(), 2); | |||||
| } | } | ||||
| TEST_F(TestMatmulInfo, InitTwice) { | TEST_F(TestMatmulInfo, InitTwice) { | ||||
| @@ -32,8 +32,6 @@ class TestStepParallel : public UT::Common { | |||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| void TestStepParallel::SetUp() { UT::InitPythonPath(); } | |||||
| void Init_Device_Manager() { | void Init_Device_Manager() { | ||||
| RankList dev_list; | RankList dev_list; | ||||
| @@ -52,6 +50,11 @@ void Init_Device_Manager() { | |||||
| g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); | g_device_manager->Init(dev_list, local_dev, stage_map, "hccl"); | ||||
| } | } | ||||
| void TestStepParallel::SetUp() { | |||||
| UT::InitPythonPath(); | |||||
| Init_Device_Manager(); | |||||
| } | |||||
| CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) { | CNodePtr Make_Node(Shape x, Shape y, Shape out, int64_t condition = 0) { | ||||
| FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); | ||||
| ParameterPtr param1 = func_graph->add_parameter(); | ParameterPtr param1 = func_graph->add_parameter(); | ||||
| @@ -345,7 +348,6 @@ TEST_F(TestStepParallel, CreatOpInstance1) { | |||||
| } | } | ||||
| TEST_F(TestStepParallel, OperatorInstance) { | TEST_F(TestStepParallel, OperatorInstance) { | ||||
| Init_Device_Manager(); | |||||
| // creat attrs and prim | // creat attrs and prim | ||||
| PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>(); | PrimitivePtr prim = NewValueNode(prim::kPrimMatMul)->value()->cast<PrimitivePtr>(); | ||||
| ValuePtr transpose_a = MakeValue(false); | ValuePtr transpose_a = MakeValue(false); | ||||
| @@ -369,7 +371,6 @@ TEST_F(TestStepParallel, OperatorInstance) { | |||||
| } | } | ||||
| TEST_F(TestStepParallel, ExtractInformation) { | TEST_F(TestStepParallel, ExtractInformation) { | ||||
| Init_Device_Manager(); | |||||
| FuncGraphManagerPtr manager = Make_Manager(); | FuncGraphManagerPtr manager = Make_Manager(); | ||||
| FuncGraphSet graphs = manager->func_graphs(); | FuncGraphSet graphs = manager->func_graphs(); | ||||
| FuncGraphPtr graph = *graphs.begin(); | FuncGraphPtr graph = *graphs.begin(); | ||||
| @@ -379,7 +380,6 @@ TEST_F(TestStepParallel, ExtractInformation) { | |||||
| } | } | ||||
| TEST_F(TestStepParallel, ExtractInformation2) { | TEST_F(TestStepParallel, ExtractInformation2) { | ||||
| Init_Device_Manager(); | |||||
| FuncGraphManagerPtr manager = Make_Manager(2); | FuncGraphManagerPtr manager = Make_Manager(2); | ||||
| FuncGraphSet graphs = manager->func_graphs(); | FuncGraphSet graphs = manager->func_graphs(); | ||||
| FuncGraphPtr graph = *graphs.begin(); | FuncGraphPtr graph = *graphs.begin(); | ||||
| @@ -389,7 +389,6 @@ TEST_F(TestStepParallel, ExtractInformation2) { | |||||
| } | } | ||||
| TEST_F(TestStepParallel, ExtractInformation3) { | TEST_F(TestStepParallel, ExtractInformation3) { | ||||
| Init_Device_Manager(); | |||||
| FuncGraphManagerPtr manager = Make_Manager(3); | FuncGraphManagerPtr manager = Make_Manager(3); | ||||
| FuncGraphSet graphs = manager->func_graphs(); | FuncGraphSet graphs = manager->func_graphs(); | ||||
| FuncGraphPtr graph = *graphs.begin(); | FuncGraphPtr graph = *graphs.begin(); | ||||
| @@ -399,7 +398,6 @@ TEST_F(TestStepParallel, ExtractInformation3) { | |||||
| } | } | ||||
| TEST_F(TestStepParallel, ForwardCommunication1) { | TEST_F(TestStepParallel, ForwardCommunication1) { | ||||
| Init_Device_Manager(); | |||||
| ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM); | ValuePtr attr0_value = MakeValue(REDUCE_OP_SUM); | ||||
| ValuePtr attr1_value = MakeValue("0-1-2"); | ValuePtr attr1_value = MakeValue("0-1-2"); | ||||
| Attr attr0 = std::make_pair("op", attr0_value); | Attr attr0 = std::make_pair("op", attr0_value); | ||||
| @@ -499,7 +497,6 @@ TEST_F(TestStepParallel, ForwardCommunication3) { | |||||
| } | } | ||||
| TEST_F(TestStepParallel, GetTensorInLayout) { | TEST_F(TestStepParallel, GetTensorInLayout) { | ||||
| Init_Device_Manager(); | |||||
| // creat attrs and prim | // creat attrs and prim | ||||
| FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); | FuncGraphPtr func_graph = std::make_shared<FuncGraph>(); | ||||
| Shape inputs_x_dims = {64, 32}; | Shape inputs_x_dims = {64, 32}; | ||||