Merge pull request !859 from Xiaoda/support-inferring-phase-parallel-strategy-searchingtags/v0.3.0-alpha
| @@ -23,8 +23,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| void Simplify(CostPtrList *clist_ptrs) { | void Simplify(CostPtrList *clist_ptrs) { | ||||
| // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method | |||||
| // excludes the cost with greater computation_cost_ and greater communication_cost. | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); | |||||
| } else { | |||||
| // inference phase | |||||
| SimplifyForDecreasingCommunicationForward(clist_ptrs); | |||||
| } | |||||
| } | |||||
| void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { | |||||
| // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method | |||||
| // excludes the cost with greater computation_cost_ and greater communication_forward. | |||||
| // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | ||||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | if (!COST_MODEL_SIMPLIFY_CALCULATION) { | ||||
| return; | return; | ||||
| @@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) { | |||||
| }); | }); | ||||
| CostPtrList ret; | CostPtrList ret; | ||||
| for (size_t i = 0; i < clist_ptrs->size(); ++i) { | for (size_t i = 0; i < clist_ptrs->size(); ++i) { | ||||
| if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) { | |||||
| if ((ret.size() == size_t(0)) || | |||||
| (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) { | |||||
| ret.emplace_back(std::move(clist_ptrs->at(id[i]))); | ret.emplace_back(std::move(clist_ptrs->at(id[i]))); | ||||
| } | } | ||||
| } | } | ||||
| *clist_ptrs = std::move(ret); | *clist_ptrs = std::move(ret); | ||||
| } | } | ||||
| void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { | |||||
| void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { | |||||
| // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing | // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing | ||||
| // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. | // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. | ||||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | if (!COST_MODEL_SIMPLIFY_CALCULATION) { | ||||
| @@ -51,18 +51,22 @@ struct Cost { | |||||
| communication_with_partial_para_ = 0.0; | communication_with_partial_para_ = 0.0; | ||||
| communication_redis_forward_ = 0.0; | communication_redis_forward_ = 0.0; | ||||
| communication_redis_backward_ = 0.0; | communication_redis_backward_ = 0.0; | ||||
| communication_forward_ = 0.0; | |||||
| } | } | ||||
| // 'memory_with_reuse_' calculates the peak memory usage in a training phase | // 'memory_with_reuse_' calculates the peak memory usage in a training phase | ||||
| double memory_with_reuse_; | double memory_with_reuse_; | ||||
| // 'computation_cost_' models the training time of an iteration in a training phase | |||||
| // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated | |||||
| // by ONLY forward phase | |||||
| double computation_cost_; | double computation_cost_; | ||||
| // 'communication_cost_' includes communications from operators (forward and backward) and edges | |||||
| // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) | |||||
| double communication_cost_; | double communication_cost_; | ||||
| // communication_without_parameter_ = communication_cost_ - (backward communication from operators) | // communication_without_parameter_ = communication_cost_ - (backward communication from operators) | ||||
| double communication_without_parameter_; | double communication_without_parameter_; | ||||
| // communication_with_partial_para_ = | // communication_with_partial_para_ = | ||||
| // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) | // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) | ||||
| double communication_with_partial_para_; | double communication_with_partial_para_; | ||||
| // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. | |||||
| double communication_forward_; | |||||
| double communication_redis_forward_; | double communication_redis_forward_; | ||||
| double communication_redis_backward_; | double communication_redis_backward_; | ||||
| std::shared_ptr<Decision> decision_ptr_; | std::shared_ptr<Decision> decision_ptr_; | ||||
| @@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr<FinalDecision>; | |||||
| using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>; | using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>; | ||||
| void Simplify(CostPtrList *clist); | void Simplify(CostPtrList *clist); | ||||
| void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist); | |||||
| void SimplifyForDecreasingCommunicationForward(CostPtrList *clist); | |||||
| void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist); | |||||
| void RefineForPracticalCost(const CostPtr &, bool is_redistribution); | void RefineForPracticalCost(const CostPtr &, bool is_redistribution); | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() { | |||||
| << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; | << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; | ||||
| // refine communication cost calculation for practice | // refine communication cost calculation for practice | ||||
| RefineForPracticalCost(cost, true); | RefineForPracticalCost(cost, true); | ||||
| cost->communication_forward_ = cost->communication_redis_forward_; | |||||
| CostPtrKey ck = {target_output_str, target_input_str}; | CostPtrKey ck = {target_output_str, target_input_str}; | ||||
| CostPtrList cl; | CostPtrList cl; | ||||
| cl.push_back(cost); | cl.push_back(cost); | ||||
| @@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||||
| (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); | (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); | ||||
| CostPtrList selected_cost_list(all_cost_list.size(), nullptr); | CostPtrList selected_cost_list(all_cost_list.size(), nullptr); | ||||
| std::function<void(size_t, double, double, double, double)> recursive = | |||||
| [&](size_t k, double computation, double memory, double communication, double communication_without_para) { | |||||
| std::function<void(size_t, double, double, double, double, double)> recursive = | |||||
| [&](size_t k, double computation, double memory, double communication, double communication_without_para, | |||||
| double communication_forward) { | |||||
| if (k == edges.size()) { | if (k == edges.size()) { | ||||
| auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | ||||
| CostPtr new_cost = std::make_shared<Cost>(computation, communication); | CostPtr new_cost = std::make_shared<Cost>(computation, communication); | ||||
| @@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | ||||
| new_cost->memory_with_reuse_ = memory; | new_cost->memory_with_reuse_ = memory; | ||||
| new_cost->communication_forward_ = communication_forward; | |||||
| new_cost->decision_ptr_ = decision; | new_cost->decision_ptr_ = decision; | ||||
| result.push_back(new_cost); | result.push_back(new_cost); | ||||
| return; | return; | ||||
| @@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||||
| selected_cost_list[k] = c; | selected_cost_list[k] = c; | ||||
| recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, | recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, | ||||
| communication + c->communication_cost_, | communication + c->communication_cost_, | ||||
| communication_without_para + c->communication_without_parameter_); | |||||
| communication_without_para + c->communication_without_parameter_, | |||||
| communication_forward + c->communication_forward_); | |||||
| } | } | ||||
| }; | }; | ||||
| recursive(0, 0.0, 0.0, 0.0, 0.0); | |||||
| SimplifyForDreasingCommunicationWithPartialPara(&result); | |||||
| recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0); | |||||
| Simplify(&result); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||||
| left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; | left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; | ||||
| double communication = | double communication = | ||||
| left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; | left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; | ||||
| double communication_forward = | |||||
| left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_; | |||||
| double communication_without_para = left_cost->communication_without_parameter_ + | double communication_without_para = left_cost->communication_without_parameter_ + | ||||
| middle_cost->communication_without_parameter_ + | middle_cost->communication_without_parameter_ + | ||||
| right_cost->communication_without_parameter_; | right_cost->communication_without_parameter_; | ||||
| @@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||||
| cost->communication_with_partial_para_ = | cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | ||||
| cost->memory_with_reuse_ = memory_cost; | cost->memory_with_reuse_ = memory_cost; | ||||
| cost->communication_forward_ = communication_forward; | |||||
| ret_cost_list->emplace_back(std::move(cost)); | ret_cost_list->emplace_back(std::move(cost)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP | |||||
| CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), | CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), | ||||
| op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); | op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); | ||||
| } | } | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&result); | |||||
| Simplify(&result); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -38,6 +38,8 @@ bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | |||||
| size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | ||||
| bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | ||||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | void CostGraph::SetDeviceMemoryAndCostParameter() { | ||||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | ||||
| @@ -142,6 +144,23 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; | MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; | ||||
| } | } | ||||
| // MULTI_SUBGRAPHS | |||||
| auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); | |||||
| MULTI_SUBGRAPHS = multi_subgraphs; | |||||
| if (MULTI_SUBGRAPHS) { | |||||
| MS_LOG(INFO) << "multi_subgraphs: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||||
| } | |||||
| // RUN_PHASE | |||||
| auto phase = CostModelContext::GetInstance()->run_phase(); | |||||
| if (phase != 0 && phase != 1) { | |||||
| MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; | |||||
| } | |||||
| RUN_PHASE = phase; | |||||
| MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; | |||||
| } | } | ||||
| void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { | void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { | ||||
| @@ -249,19 +268,21 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||||
| MS_EXCEPTION_IF_NULL(cost3); | MS_EXCEPTION_IF_NULL(cost3); | ||||
| double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; | double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; | ||||
| double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; | double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; | ||||
| double commmunication = | |||||
| cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; | |||||
| double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; | |||||
| double communication_forward = | |||||
| cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_; | |||||
| double communication_without_para = cost1->communication_without_parameter_ + | double communication_without_para = cost1->communication_without_parameter_ + | ||||
| cost2->communication_without_parameter_ + | cost2->communication_without_parameter_ + | ||||
| cost3->communication_without_parameter_; | cost3->communication_without_parameter_; | ||||
| auto decision = | auto decision = | ||||
| std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | ||||
| auto cost = std::make_shared<Cost>(computation, commmunication, decision); | |||||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | |||||
| MS_EXCEPTION_IF_NULL(cost); | MS_EXCEPTION_IF_NULL(cost); | ||||
| cost->communication_without_parameter_ = communication_without_para; | cost->communication_without_parameter_ = communication_without_para; | ||||
| cost->communication_with_partial_para_ = | cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para); | |||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| cost->memory_with_reuse_ = memory; | cost->memory_with_reuse_ = memory; | ||||
| cost->communication_forward_ = communication_forward; | |||||
| ret.push_back(cost); | ret.push_back(cost); | ||||
| } | } | ||||
| } | } | ||||
| @@ -269,7 +290,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||||
| } | } | ||||
| } | } | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&ret); | |||||
| Simplify(&ret); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -291,32 +312,67 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { | |||||
| cost1->communication_without_parameter_ + | cost1->communication_without_parameter_ + | ||||
| COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); | COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); | ||||
| new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; | new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; | ||||
| new_cost->communication_forward_ = cost1->communication_forward_; | |||||
| ret.push_back(new_cost); | ret.push_back(new_cost); | ||||
| } | } | ||||
| } | } | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&ret); | |||||
| Simplify(&ret); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) { | |||||
| CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) { | |||||
| // Select the cost with minimum inference time. Currently, the inference time is modeled as = | |||||
| // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_ | |||||
| if (cost_list.empty()) { | |||||
| MS_LOG(ERROR) << "Final cost list is null."; | |||||
| return nullptr; | |||||
| } | |||||
| CostPtrList after_mem_filter; | CostPtrList after_mem_filter; | ||||
| // Filter out the valid costs | |||||
| double minimum_memory = DBL_MAX; | |||||
| // Filter out the valid costs. | |||||
| for (auto &a_cost : cost_list) { | for (auto &a_cost : cost_list) { | ||||
| if (a_cost->memory_with_reuse_ <= memory) { | if (a_cost->memory_with_reuse_ <= memory) { | ||||
| after_mem_filter.emplace_back(std::move(a_cost)); | after_mem_filter.emplace_back(std::move(a_cost)); | ||||
| } else if (a_cost->memory_with_reuse_ < minimum_memory) { | |||||
| minimum_memory = a_cost->memory_with_reuse_; | |||||
| } | } | ||||
| } | } | ||||
| if (after_mem_filter.empty()) { | |||||
| MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory | |||||
| << ", the memory capacity is: " << memory << "."; | |||||
| return nullptr; | |||||
| } | |||||
| // Init the returned value with first cost. | |||||
| CostPtr ret = after_mem_filter[0]; | |||||
| std::function<CostPtr(CostPtr, const CostPtr &)> LocalCompare = [&](CostPtr init, const CostPtr &cost_x) { | |||||
| MS_EXCEPTION_IF_NULL(cost_x); | |||||
| if (init == nullptr || cost_x->computation_cost_ < memory) { | |||||
| init = cost_x; | |||||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; | |||||
| MS_LOG(INFO) << "Cost 0: " | |||||
| << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | |||||
| << ", communication_forward_: " << ret->communication_forward_ | |||||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | |||||
| << ", communication_cost_: " << ret->communication_cost_ | |||||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | |||||
| MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; | |||||
| for (size_t i = 1; i < after_mem_filter.size(); ++i) { | |||||
| MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | |||||
| MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | |||||
| << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ | |||||
| << ", communication_forward_: " << after_mem_filter[i]->communication_forward_ | |||||
| << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ | |||||
| << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | |||||
| << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | |||||
| << "."; | |||||
| auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + | |||||
| costmodel_beta_ * after_mem_filter[i]->communication_forward_; | |||||
| MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | |||||
| if (minimum > tmp) { | |||||
| minimum = tmp; | |||||
| ret = after_mem_filter[i]; | |||||
| MS_LOG(INFO) << "Selected: " << i; | |||||
| } | } | ||||
| return init; | |||||
| }; | |||||
| CostPtr ret = nullptr; | |||||
| return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); | |||||
| } | |||||
| return ret; | |||||
| } | } | ||||
| CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { | CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { | ||||
| @@ -524,12 +580,26 @@ Status CostGraph::SearchStrategy() { | |||||
| }); | }); | ||||
| if (alive_ops.size() > 2) { | if (alive_ops.size() > 2) { | ||||
| return SearchStrategyForMultiNodeFinalGraph(alive_ops); | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| return SearchStrategyForMultiNodeFinalGraph(alive_ops); | |||||
| } else { | |||||
| // inference phase | |||||
| MS_LOG(EXCEPTION) | |||||
| << "Currently, searching strategy for the multi-node final graph in inference phase is not supported."; | |||||
| } | |||||
| } else if (alive_ops.size() == 1) { | } else if (alive_ops.size() == 1) { | ||||
| MS_LOG(INFO) << "There are 1 single node in the final graph."; | MS_LOG(INFO) << "There are 1 single node in the final graph."; | ||||
| OperatorInfoPtr u = alive_ops[0]; | OperatorInfoPtr u = alive_ops[0]; | ||||
| auto cost_list = CreateFinalSingleCostList(u); | auto cost_list = CreateFinalSingleCostList(u); | ||||
| auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||||
| CostPtr cost = nullptr; | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||||
| } else { | |||||
| // inference phase | |||||
| cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); | |||||
| } | |||||
| if (cost == nullptr) { | if (cost == nullptr) { | ||||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -575,7 +645,15 @@ Status CostGraph::SearchStrategy() { | |||||
| auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); | auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); | ||||
| all_list.push_back(cost_list); | all_list.push_back(cost_list); | ||||
| } | } | ||||
| auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||||
| CostPtrList selected_cost_list; | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||||
| } else { | |||||
| // inference phase | |||||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " | |||||
| "phase is not supported."; | |||||
| } | |||||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | for (size_t k = 0; k < selected_cost_list.size(); ++k) { | ||||
| auto selected_cost = selected_cost_list[k]; | auto selected_cost = selected_cost_list[k]; | ||||
| if (selected_cost == nullptr) { | if (selected_cost == nullptr) { | ||||
| @@ -601,7 +679,14 @@ Status CostGraph::SearchStrategy() { | |||||
| auto e = u->GetAliveSuccEdges()[0]; | auto e = u->GetAliveSuccEdges()[0]; | ||||
| MS_EXCEPTION_IF_NULL(e); | MS_EXCEPTION_IF_NULL(e); | ||||
| auto cost_list = CreateFinalCostList(u, e, v); | auto cost_list = CreateFinalCostList(u, e, v); | ||||
| auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||||
| CostPtr cost = nullptr; | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " | |||||
| "phase is not supported."; | |||||
| } | |||||
| if (cost == nullptr) { | if (cost == nullptr) { | ||||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -841,6 +926,8 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||||
| double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | ||||
| double communication = | double communication = | ||||
| op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | ||||
| double communication_forward = | |||||
| op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_; | |||||
| double communication_without_para = op_cost->communication_without_parameter_ + | double communication_without_para = op_cost->communication_without_parameter_ + | ||||
| edge_cost->communication_without_parameter_ + | edge_cost->communication_without_parameter_ + | ||||
| tar_cost->communication_without_parameter_; | tar_cost->communication_without_parameter_; | ||||
| @@ -853,6 +940,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | ||||
| new_cost->memory_with_reuse_ = memory; | new_cost->memory_with_reuse_ = memory; | ||||
| new_cost->communication_forward_ = communication_forward; | |||||
| MS_EXCEPTION_IF_NULL(tar_cost_list_new); | MS_EXCEPTION_IF_NULL(tar_cost_list_new); | ||||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | tar_cost_list_new->emplace_back(std::move(new_cost)); | ||||
| } | } | ||||
| @@ -885,7 +973,7 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { | |||||
| CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); | CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); | ||||
| } | } | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new); | |||||
| Simplify(&tar_clist_new); | |||||
| // Set the new costlist w.r.t the strategy | // Set the new costlist w.r.t the strategy | ||||
| tar_stra_cost->cost_list = tar_clist_new; | tar_stra_cost->cost_list = tar_clist_new; | ||||
| if ((!valid) && (!tar_clist_new.empty())) { | if ((!valid) && (!tar_clist_new.empty())) { | ||||
| @@ -922,6 +1010,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||||
| contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | ||||
| double communication = | double communication = | ||||
| contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | ||||
| double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ + | |||||
| tar_cost->communication_forward_; | |||||
| double communication_without_para = contract_op_cost->communication_without_parameter_ + | double communication_without_para = contract_op_cost->communication_without_parameter_ + | ||||
| edge_cost->communication_without_parameter_ + | edge_cost->communication_without_parameter_ + | ||||
| tar_cost->communication_without_parameter_; | tar_cost->communication_without_parameter_; | ||||
| @@ -933,6 +1023,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | ||||
| new_cost->memory_with_reuse_ = memory; | new_cost->memory_with_reuse_ = memory; | ||||
| new_cost->communication_forward_ = communication_forward; | |||||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | tar_cost_list_new->emplace_back(std::move(new_cost)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -962,7 +1053,7 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { | |||||
| CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); | CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); | ||||
| } | } | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new); | |||||
| Simplify(&tar_clist_new); | |||||
| // Set the new costlist w.r.t the strategy | // Set the new costlist w.r.t the strategy | ||||
| tar_stra_cost->cost_list = tar_clist_new; | tar_stra_cost->cost_list = tar_clist_new; | ||||
| if ((!valid) && (!tar_clist_new.empty())) { | if ((!valid) && (!tar_clist_new.empty())) { | ||||
| @@ -998,6 +1089,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||||
| left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; | left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; | ||||
| double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + | double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + | ||||
| left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; | left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; | ||||
| double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ + | |||||
| left_node_cost->communication_forward_ + right_edge_cost->communication_forward_; | |||||
| double new_commu_without = | double new_commu_without = | ||||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | ||||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | ||||
| @@ -1009,6 +1102,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | ||||
| new_cost->memory_with_reuse_ = new_memory; | new_cost->memory_with_reuse_ = new_memory; | ||||
| new_cost->communication_forward_ = new_commu_forward; | |||||
| left_node_clist_new->emplace_back(std::move(new_cost)); | left_node_clist_new->emplace_back(std::move(new_cost)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1079,7 +1173,7 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, | |||||
| &left_node_clist_new); | &left_node_clist_new); | ||||
| } | } | ||||
| } | } | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&left_node_clist_new); | |||||
| Simplify(&left_node_clist_new); | |||||
| // Set the new costlist w.r.t the strategy | // Set the new costlist w.r.t the strategy | ||||
| left_node_stra_cost->cost_list = left_node_clist_new; | left_node_stra_cost->cost_list = left_node_clist_new; | ||||
| if ((!valid) && (!left_node_clist_new.empty())) { | if ((!valid) && (!left_node_clist_new.empty())) { | ||||
| @@ -1112,19 +1206,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||||
| double computation_cost = merged_node_cost->computation_cost_, | double computation_cost = merged_node_cost->computation_cost_, | ||||
| memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, | memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, | ||||
| commu_without = merged_node_cost->communication_without_parameter_; | |||||
| commu_without = merged_node_cost->communication_without_parameter_, | |||||
| commu_forward = merged_node_cost->communication_forward_; | |||||
| for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | ||||
| if (i == 0) { | if (i == 0) { | ||||
| computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; | computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; | ||||
| memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; | memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; | ||||
| commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; | commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; | ||||
| commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_; | |||||
| commu_without += succ_edges_costs[i]->communication_without_parameter_ + | commu_without += succ_edges_costs[i]->communication_without_parameter_ + | ||||
| succ_nodes_costs[i]->communication_without_parameter_; | succ_nodes_costs[i]->communication_without_parameter_; | ||||
| } else { | } else { | ||||
| computation_cost += succ_edges_costs[i]->computation_cost_; | computation_cost += succ_edges_costs[i]->computation_cost_; | ||||
| memory_cost += succ_edges_costs[i]->memory_with_reuse_; | memory_cost += succ_edges_costs[i]->memory_with_reuse_; | ||||
| commu_cost += succ_edges_costs[i]->communication_cost_; | commu_cost += succ_edges_costs[i]->communication_cost_; | ||||
| commu_forward += succ_edges_costs[i]->communication_forward_; | |||||
| commu_without += succ_edges_costs[i]->communication_without_parameter_; | commu_without += succ_edges_costs[i]->communication_without_parameter_; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1135,6 +1232,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||||
| new_cost->communication_without_parameter_ = commu_without; | new_cost->communication_without_parameter_ = commu_without; | ||||
| new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | ||||
| new_cost->memory_with_reuse_ = memory_cost; | new_cost->memory_with_reuse_ = memory_cost; | ||||
| new_cost->communication_forward_ = commu_forward; | |||||
| first_succ_node_clist_new->emplace_back(std::move(new_cost)); | first_succ_node_clist_new->emplace_back(std::move(new_cost)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1220,7 +1318,7 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||||
| CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, | CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, | ||||
| merged_op_stra, merged_op_clist, &first_succ_node_clist_new); | merged_op_stra, merged_op_clist, &first_succ_node_clist_new); | ||||
| } | } | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&first_succ_node_clist_new); | |||||
| Simplify(&first_succ_node_clist_new); | |||||
| // Set the new costlist w.r.t the strategy | // Set the new costlist w.r.t the strategy | ||||
| first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; | first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; | ||||
| if ((!valid) && (!first_succ_node_clist_new.empty())) { | if ((!valid) && (!first_succ_node_clist_new.empty())) { | ||||
| @@ -45,6 +45,9 @@ namespace parallel { | |||||
| #define DEFAULT_FULLY_USE_DEVICES true | #define DEFAULT_FULLY_USE_DEVICES true | ||||
| #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | ||||
| #define DEFAULT_IS_MULTI_SUBGRAPHS false | #define DEFAULT_IS_MULTI_SUBGRAPHS false | ||||
| #define DEFAULT_RUN_PHASE 0 | |||||
| #define TRAINING_PHASE 0 | |||||
| #define INFERENCE_PHASE 1 | |||||
| class CostGraph; | class CostGraph; | ||||
| using CostGraphPtr = std::shared_ptr<CostGraph>; | using CostGraphPtr = std::shared_ptr<CostGraph>; | ||||
| @@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; | |||||
| extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | ||||
| extern bool FULLY_USE_DEVICES; | extern bool FULLY_USE_DEVICES; | ||||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | extern bool ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| extern bool MULTI_SUBGRAPHS; | |||||
| extern int32_t RUN_PHASE; | |||||
| class CostGraph { | class CostGraph { | ||||
| // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | ||||
| @@ -98,7 +103,7 @@ class CostGraph { | |||||
| CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); | CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); | ||||
| CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); | CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); | ||||
| CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory); | |||||
| CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); | |||||
| CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); | CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); | ||||
| CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory); | CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory); | ||||
| Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | ||||
| @@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() { | |||||
| costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | ||||
| costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | ||||
| is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | ||||
| run_phase_ = DEFAULT_RUN_PHASE; | |||||
| costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | ||||
| costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | ||||
| costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | ||||
| @@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ | |||||
| void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | ||||
| elementwise_stra_follow_ = elementwise_follow; | elementwise_stra_follow_ = elementwise_follow; | ||||
| } | } | ||||
| void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -113,6 +113,9 @@ class CostModelContext { | |||||
| void set_elementwise_stra_follow(bool); | void set_elementwise_stra_follow(bool); | ||||
| bool elementwise_stra_follow() const { return elementwise_stra_follow_; } | bool elementwise_stra_follow() const { return elementwise_stra_follow_; } | ||||
| void set_run_phase(int32_t); | |||||
| int32_t run_phase() const { return run_phase_; } | |||||
| private: | private: | ||||
| CostModelContext(); | CostModelContext(); | ||||
| static std::shared_ptr<CostModelContext> cm_context_inst_; | static std::shared_ptr<CostModelContext> cm_context_inst_; | ||||
| @@ -141,8 +144,11 @@ class CostModelContext { | |||||
| // COST_MODEL_COMMUNI_BIAS | // COST_MODEL_COMMUNI_BIAS | ||||
| double costmodel_communi_bias_; | double costmodel_communi_bias_; | ||||
| // MULTI_SUBGRAPHS | |||||
| bool is_multi_subgraphs_; | bool is_multi_subgraphs_; | ||||
| int32_t run_phase_; // 0: 'training', 1: 'inference' | |||||
| int32_t costmodel_allreduce_fusion_algorithm_; | int32_t costmodel_allreduce_fusion_algorithm_; | ||||
| int32_t costmodel_allreduce_fusion_times_; | int32_t costmodel_allreduce_fusion_times_; | ||||
| @@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & | |||||
| << ", communication_with_partial_para_: " << result->communication_with_partial_para_; | << ", communication_with_partial_para_: " << result->communication_with_partial_para_; | ||||
| // refine communication cost calculation for practice | // refine communication cost calculation for practice | ||||
| RefineForPracticalCost(result, false); | RefineForPracticalCost(result, false); | ||||
| result->communication_forward_ = result->communication_without_parameter_; | |||||
| std::shared_ptr<StrategyWithCost> swc = | std::shared_ptr<StrategyWithCost> swc = | ||||
| std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | ||||
| @@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { | |||||
| BreakingTiesForPerferringDataParallel(strategy, result); | BreakingTiesForPerferringDataParallel(strategy, result); | ||||
| // refine communication cost calculation for practice | // refine communication cost calculation for practice | ||||
| RefineForPracticalCost(result, false); | RefineForPracticalCost(result, false); | ||||
| result->communication_forward_ = result->communication_without_parameter_; | |||||
| std::shared_ptr<StrategyWithCost> swc = | std::shared_ptr<StrategyWithCost> swc = | ||||
| std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | ||||
| @@ -69,16 +69,16 @@ class TensorRedistribution { | |||||
| RankList dev_list_; | RankList dev_list_; | ||||
| OperatorList operator_list_; | OperatorList operator_list_; | ||||
| bool reshape_flag_; | bool reshape_flag_; | ||||
| // communication cost | |||||
| // communication cost, which is the sum of forward communication cost and backward communication cost | |||||
| double comm_cost_; | double comm_cost_; | ||||
| // forward communication cost | // forward communication cost | ||||
| double forward_comm_cost_; | double forward_comm_cost_; | ||||
| // backward communication cost | // backward communication cost | ||||
| double backward_comm_cost_; | double backward_comm_cost_; | ||||
| // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the | // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the | ||||
| // inputs. | |||||
| // inputs. This is calculated ONLY for forward phase. | |||||
| double computation_cost_; | double computation_cost_; | ||||
| // memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is | |||||
| // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is | |||||
| // calculated by the outputs. | // calculated by the outputs. | ||||
| double memory_cost_; | double memory_cost_; | ||||
| bool construct_op_flag_; | bool construct_op_flag_; | ||||
| @@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| "Get the parameter cost_model_communi_bias of the DP algorithm.") | "Get the parameter cost_model_communi_bias of the DP algorithm.") | ||||
| .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") | .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") | ||||
| .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") | .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") | ||||
| .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.") | |||||
| .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.") | |||||
| .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, | .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, | ||||
| "Set the parameter gradient AllReduce fusion algorithm.") | "Set the parameter gradient AllReduce fusion algorithm.") | ||||
| .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, | .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, | ||||
| @@ -239,6 +239,33 @@ class _CostModelContext: | |||||
| raise ValueError("Context handle is none in context!!!") | raise ValueError("Context handle is none in context!!!") | ||||
| return self._context_handle.get_multi_subgraphs() | return self._context_handle.get_multi_subgraphs() | ||||
| def set_run_phase(self, phase): | |||||
| """ | |||||
| Set the flag of running phase: training (0) or inference (1) | |||||
| Args: | |||||
| phase (int): A parameter indicating which phase is running. | |||||
| Raises: | |||||
| ValueError: If context handle is none, or phase is not in {0, 1}. | |||||
| """ | |||||
| if self._context_handle is None: | |||||
| raise ValueError("Context handle is none in context!!!") | |||||
| if phase not in (0, 1): | |||||
| raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase)) | |||||
| self._context_handle.set_run_phase(phase) | |||||
| def get_run_phase(self): | |||||
| """ | |||||
| Get the flag of running phase. | |||||
| Raises: | |||||
| ValueError: If context handle is none. | |||||
| """ | |||||
| if self._context_handle is None: | |||||
| raise ValueError("Context handle is none in context!!!") | |||||
| return self._context_handle.get_run_phase() | |||||
| def set_costmodel_allreduce_fusion_algorithm(self, algorithm): | def set_costmodel_allreduce_fusion_algorithm(self, algorithm): | ||||
| """ | """ | ||||
| Set costmodel allreduce fusion algorithm. | Set costmodel allreduce fusion algorithm. | ||||
| @@ -453,6 +480,7 @@ set_cost_model_context_func_map = { | |||||
| "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, | "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, | ||||
| "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, | "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, | ||||
| "multi_subgraphs": cost_model_context().set_multi_subgraphs, | "multi_subgraphs": cost_model_context().set_multi_subgraphs, | ||||
| "run_phase": cost_model_context().set_run_phase, | |||||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, | "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, | ||||
| "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, | "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, | ||||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, | "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, | ||||
| @@ -473,7 +501,8 @@ get_cost_model_context_func_map = { | |||||
| "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, | "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, | ||||
| "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, | "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, | ||||
| "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, | "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, | ||||
| "multi_subgraphs": cost_model_context().get_multi_subgraphs(), | |||||
| "multi_subgraphs": cost_model_context().get_multi_subgraphs, | |||||
| "run_phase": cost_model_context().get_run_phase, | |||||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, | "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, | ||||
| "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, | "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, | ||||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, | "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, | ||||
| @@ -488,7 +517,7 @@ get_cost_model_context_func_map = { | |||||
| @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, | @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, | ||||
| costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, | costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, | ||||
| multi_subgraphs=bool, | |||||
| multi_subgraphs=bool, run_phase=int, | |||||
| costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, | costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, | ||||
| costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, | costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, | ||||
| costmodel_allreduce_fusion_allreduce_inherent_time=float, | costmodel_allreduce_fusion_allreduce_inherent_time=float, | ||||
| @@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs): | |||||
| costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. | costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. | ||||
| costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. | costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. | ||||
| multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. | multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. | ||||
| run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0. | |||||
| costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. | costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. | ||||
| 0: bypass allreduce fusion; | 0: bypass allreduce fusion; | ||||
| 1: only use backward computation time to group allreduce; | 1: only use backward computation time to group allreduce; | ||||
| @@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) { | |||||
| ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); | ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); | ||||
| cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); | cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); | ||||
| auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); | auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); | ||||
| cost_graph.SelectCostWithMemoryConstraint(cost_list, cost_graph.GetDeviceMemory()); | |||||
| cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory()); | |||||
| } | } | ||||
| TEST_F(TestCostGraph, test_EliminationOp) { | TEST_F(TestCostGraph, test_EliminationOp) { | ||||
| @@ -14,15 +14,21 @@ | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | from mindspore.parallel._auto_parallel_context import auto_parallel_context | ||||
| from mindspore.parallel._cost_model_context import reset_cost_model_context | |||||
| from mindspore.parallel.algo_parameter_config import reset_algo_parameters | |||||
| from mindspore.parallel._utils import _reset_op_id | from mindspore.parallel._utils import _reset_op_id | ||||
| def setup_module(module): | def setup_module(module): | ||||
| auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) | auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) | ||||
| reset_cost_model_context() | |||||
| reset_algo_parameters() | |||||
| _reset_op_id() | _reset_op_id() | ||||
| def teardown_module(): | def teardown_module(): | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| reset_cost_model_context() | |||||
| reset_algo_parameters() | |||||
| _reset_op_id() | _reset_op_id() | ||||
| @@ -0,0 +1,36 @@ | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, context | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.nn import WithLossCell, TrainOneStepCell | |||||
| from mindspore.nn import Momentum | |||||
| from mindspore.parallel._cost_model_context import set_cost_model_context | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, input_ch, out_ch): | |||||
| super(Net, self).__init__() | |||||
| self.dense = nn.Dense(input_ch, out_ch) | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, x): | |||||
| x = self.dense(x) | |||||
| x = self.relu(x) | |||||
| return x | |||||
| def test_inference_phase(): | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| set_cost_model_context(run_phase=1) | |||||
| net = Net(512, 128) | |||||
| predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001) | |||||
| label = Tensor(np.ones([64, 128]).astype(np.float32)) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| net_with_loss = WithLossCell(net, loss) | |||||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||||
| train_network.set_train() | |||||
| output = train_network(predict, label) | |||||