Merge pull request !1215 from Xiaoda/change-succive-edges-order-and-add-checkingtags/v0.3.0-alpha
| @@ -211,13 +211,14 @@ struct ContractEliminationDecision : public Decision { | |||
| */ | |||
| struct TriangleEliminationDecision : public Decision { | |||
| TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, | |||
| StrategyPtr left_stra, CostPtr l_node_cost) | |||
| StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) | |||
| : eliminated_op_strategy_(std::move(elimi_stra)), | |||
| eliminated_op_cost_(std::move(elimi_op_cost)), | |||
| left_edge_cost_(std::move(l_edge_cost)), | |||
| right_edge_cost_(std::move(r_edge_cost)), | |||
| left_node_strategy_(std::move(left_stra)), | |||
| left_node_cost_(std::move(l_node_cost)) { | |||
| left_node_cost_(std::move(l_node_cost)), | |||
| right_node_strategy_(std::move(right_stra)) { | |||
| type_ = DecisionType::TRIANGLE_ELIMINATION; | |||
| } | |||
| @@ -227,6 +228,7 @@ struct TriangleEliminationDecision : public Decision { | |||
| CostPtr right_edge_cost_; | |||
| StrategyPtr left_node_strategy_; | |||
| CostPtr left_node_cost_; | |||
| StrategyPtr right_node_strategy_; | |||
| MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); | |||
| }; | |||
| @@ -85,7 +85,9 @@ Status GetStrategy(const CostGraphPtr &graph) { | |||
| right_edge = tmp; | |||
| } | |||
| auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); | |||
| auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge); | |||
| auto right_node = l_r_edge->next_operator(); | |||
| auto elimi = | |||
| std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| auto star_center = graph->CheckStarElimination(); | |||
| @@ -181,6 +183,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| auto left_edge = elimination->left_edge_; | |||
| auto eliminated_node = elimination->eliminated_node_; | |||
| auto right_edge = elimination->right_edge_; | |||
| auto right_node = elimination->right_node_; | |||
| auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>(); | |||
| eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); | |||
| @@ -188,6 +191,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| right_edge->set_selected_cost(decision->right_edge_cost_); | |||
| // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. | |||
| left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | |||
| right_node->CheckSelectedStrategy(decision->right_node_strategy_); | |||
| MS_LOG(INFO) << "Recover triangleElimination succeeded."; | |||
| } else if ((*rit)->isa<StarElimination>()) { | |||
| auto elimination = (*rit)->cast<StarEliminationPtr>(); | |||
| @@ -206,6 +210,9 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); | |||
| // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. | |||
| succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | |||
| for (size_t k = 1; k < succ_nodes.size(); ++k) { | |||
| succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); | |||
| } | |||
| MS_LOG(INFO) << "Recover starElimination succeeded."; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unknown Elimination type."; | |||
| @@ -102,17 +102,20 @@ struct ContractElimination : public Elimination { | |||
| // Triangle Elimination | |||
| struct TriangleElimination : public Elimination { | |||
| TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge) | |||
| TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, | |||
| OperatorInfoPtr r_node) | |||
| : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), | |||
| eliminated_node_(std::move(elim_node)), | |||
| left_edge_(std::move(l_edge)), | |||
| left_node_(std::move(l_node)), | |||
| right_edge_(std::move(r_edge)) {} | |||
| right_edge_(std::move(r_edge)), | |||
| right_node_(std::move(r_node)) {} | |||
| OperatorInfoPtr eliminated_node_; | |||
| EdgePtr left_edge_; | |||
| OperatorInfoPtr left_node_; | |||
| EdgePtr right_edge_; | |||
| OperatorInfoPtr right_node_; | |||
| MS_DECLARE_PARENT(TriangleElimination, Elimination); | |||
| }; | |||
| @@ -1111,8 +1111,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | |||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | |||
| auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, | |||
| right_edge_cost, left_op_stra, left_node_cost); | |||
| auto decision = std::make_shared<TriangleEliminationDecision>( | |||
| elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); | |||
| auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision); | |||
| new_cost->communication_without_parameter_ = new_commu_without; | |||
| new_cost->communication_with_partial_para_ = | |||
| @@ -546,10 +546,14 @@ std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() { | |||
| for (auto &edge : succ_edges_) { | |||
| if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { | |||
| ret.push_back(edge); | |||
| } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) { | |||
| // CAST is ordered in front of L2NORMALIZE | |||
| ret.push_back(edge); | |||
| } | |||
| } | |||
| for (auto &edge : succ_edges_) { | |||
| if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) { | |||
| if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) && | |||
| (edge->next_operator()->name().find(CAST) == std::string::npos)) { | |||
| ret.push_back(edge); | |||
| } | |||
| } | |||
| @@ -1279,10 +1283,18 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra | |||
| CheckGlobalDeviceManager(); | |||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); | |||
| if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { | |||
| cost->computation_cost_ -= 1.0; | |||
| cost->communication_cost_ -= 1.0; | |||
| cost->communication_with_partial_para_ -= 1.0; | |||
| cost->communication_without_parameter_ -= 1.0; | |||
| if (cost->computation_cost_ > 1.0) { | |||
| cost->computation_cost_ -= 1.0; | |||
| } | |||
| if (cost->communication_cost_ > 1.0) { | |||
| cost->communication_cost_ -= 1.0; | |||
| } | |||
| if (cost->communication_with_partial_para_ > 1.0) { | |||
| cost->communication_with_partial_para_ -= 1.0; | |||
| } | |||
| if (cost->communication_without_parameter_ > 1.0) { | |||
| cost->communication_without_parameter_ -= 1.0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1290,5 +1302,15 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra | |||
| double OperatorInfo::GetForwardMemoryCostFromCNode() { | |||
| return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||
| } | |||
| void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { | |||
| MS_EXCEPTION_IF_NULL(s_strategy); | |||
| if (!s_strategy->IsEqual(selected_strategy_)) { | |||
| MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:"; | |||
| PrintStrategy(selected_strategy_); | |||
| MS_LOG(INFO) << "The minimal strategy:"; | |||
| PrintStrategy(s_strategy); | |||
| } | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -138,6 +138,7 @@ class OperatorInfo { | |||
| } | |||
| StrategyPtr selected_strategy() const { return selected_strategy_; } | |||
| CostPtr selected_cost() const { return selected_cost_; } | |||
| void CheckSelectedStrategy(const StrategyPtr &); | |||
| Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } | |||
| void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; } | |||
| const std::vector<ValuePtr> &input_value() const { return input_value_; } | |||
| @@ -48,6 +48,16 @@ class Strategy { | |||
| } | |||
| void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; } | |||
| bool IsEqual(const StrategyPtr &another_stra) { | |||
| if (another_stra == nullptr) { | |||
| return false; | |||
| } | |||
| if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| private: | |||
| const int32_t stage_; | |||
| @@ -64,5 +64,23 @@ TEST_F(TestStrategy, GetInputDim) { | |||
| ASSERT_EQ(inputs, inputs_test); | |||
| } | |||
| TEST_F(TestStrategy, IsEqual) { | |||
| int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0; | |||
| std::vector<int32_t> dimension1 = {8, 1}; | |||
| std::vector<int32_t> dimension2 = {1, 8}; | |||
| std::vector<std::vector<int32_t>> inputs1 = {dimension1}; | |||
| std::vector<std::vector<int32_t>> inputs2 = {dimension1}; | |||
| std::vector<std::vector<int32_t>> inputs3 = {dimension2}; | |||
| std::vector<std::vector<int32_t>> inputs4 = {dimension1, dimension2}; | |||
| StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1); | |||
| StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2); | |||
| StrategyPtr stra3 = std::make_shared<Strategy>(stage3, inputs3); | |||
| StrategyPtr stra4 = std::make_shared<Strategy>(stage4, inputs4); | |||
| ASSERT_EQ(stra1->IsEqual(stra2), true); | |||
| ASSERT_EQ(stra1->IsEqual(stra3), false); | |||
| ASSERT_EQ(stra1->IsEqual(stra4), false); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||