Merge pull request !859 from Xiaoda/support-inferring-phase-parallel-strategy-searchingtags/v0.3.0-alpha
| @@ -23,8 +23,17 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| void Simplify(CostPtrList *clist_ptrs) { | |||
| // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method | |||
| // excludes the cost with greater computation_cost_ and greater communication_cost. | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); | |||
| } else { | |||
| // inference phase | |||
| SimplifyForDecreasingCommunicationForward(clist_ptrs); | |||
| } | |||
| } | |||
| void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { | |||
| // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method | |||
| // excludes the cost with greater computation_cost_ and greater communication_forward. | |||
| // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | |||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | |||
| return; | |||
| @@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) { | |||
| }); | |||
| CostPtrList ret; | |||
| for (size_t i = 0; i < clist_ptrs->size(); ++i) { | |||
| if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) { | |||
| if ((ret.size() == size_t(0)) || | |||
| (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) { | |||
| ret.emplace_back(std::move(clist_ptrs->at(id[i]))); | |||
| } | |||
| } | |||
| *clist_ptrs = std::move(ret); | |||
| } | |||
| void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { | |||
| void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { | |||
| // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing | |||
| // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. | |||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | |||
| @@ -51,18 +51,22 @@ struct Cost { | |||
| communication_with_partial_para_ = 0.0; | |||
| communication_redis_forward_ = 0.0; | |||
| communication_redis_backward_ = 0.0; | |||
| communication_forward_ = 0.0; | |||
| } | |||
| // 'memory_with_reuse_' calculates the peak memory usage in a training phase | |||
| double memory_with_reuse_; | |||
| // 'computation_cost_' models the training time of an iteration in a training phase | |||
| // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated | |||
| // by ONLY forward phase | |||
| double computation_cost_; | |||
| // 'communication_cost_' includes communications from operators (forward and backward) and edges | |||
| // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) | |||
| double communication_cost_; | |||
| // communication_without_parameter_ = communication_cost_ - (backward communication from operators) | |||
| double communication_without_parameter_; | |||
| // communication_with_partial_para_ = | |||
| // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) | |||
| double communication_with_partial_para_; | |||
| // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. | |||
| double communication_forward_; | |||
| double communication_redis_forward_; | |||
| double communication_redis_backward_; | |||
| std::shared_ptr<Decision> decision_ptr_; | |||
| @@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr<FinalDecision>; | |||
| using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>; | |||
| void Simplify(CostPtrList *clist); | |||
| void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist); | |||
| void SimplifyForDecreasingCommunicationForward(CostPtrList *clist); | |||
| void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist); | |||
| void RefineForPracticalCost(const CostPtr &, bool is_redistribution); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() { | |||
| << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; | |||
| // refine communication cost calculation for practice | |||
| RefineForPracticalCost(cost, true); | |||
| cost->communication_forward_ = cost->communication_redis_forward_; | |||
| CostPtrKey ck = {target_output_str, target_input_str}; | |||
| CostPtrList cl; | |||
| cl.push_back(cost); | |||
| @@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||
| (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); | |||
| CostPtrList selected_cost_list(all_cost_list.size(), nullptr); | |||
| std::function<void(size_t, double, double, double, double)> recursive = | |||
| [&](size_t k, double computation, double memory, double communication, double communication_without_para) { | |||
| std::function<void(size_t, double, double, double, double, double)> recursive = | |||
| [&](size_t k, double computation, double memory, double communication, double communication_without_para, | |||
| double communication_forward) { | |||
| if (k == edges.size()) { | |||
| auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | |||
| CostPtr new_cost = std::make_shared<Cost>(computation, communication); | |||
| @@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| new_cost->decision_ptr_ = decision; | |||
| result.push_back(new_cost); | |||
| return; | |||
| @@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||
| selected_cost_list[k] = c; | |||
| recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, | |||
| communication + c->communication_cost_, | |||
| communication_without_para + c->communication_without_parameter_); | |||
| communication_without_para + c->communication_without_parameter_, | |||
| communication_forward + c->communication_forward_); | |||
| } | |||
| }; | |||
| recursive(0, 0.0, 0.0, 0.0, 0.0); | |||
| SimplifyForDreasingCommunicationWithPartialPara(&result); | |||
| recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0); | |||
| Simplify(&result); | |||
| return result; | |||
| } | |||
| @@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||
| left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; | |||
| double communication = | |||
| left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; | |||
| double communication_forward = | |||
| left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_; | |||
| double communication_without_para = left_cost->communication_without_parameter_ + | |||
| middle_cost->communication_without_parameter_ + | |||
| right_cost->communication_without_parameter_; | |||
| @@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||
| cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| cost->memory_with_reuse_ = memory_cost; | |||
| cost->communication_forward_ = communication_forward; | |||
| ret_cost_list->emplace_back(std::move(cost)); | |||
| } | |||
| } | |||
| @@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP | |||
| CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), | |||
| op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); | |||
| } | |||
| SimplifyForDreasingCommunicationWithPartialPara(&result); | |||
| Simplify(&result); | |||
| return result; | |||
| } | |||
| @@ -38,6 +38,8 @@ bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | |||
| size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | |||
| bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||
| @@ -142,6 +144,23 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| } else { | |||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; | |||
| } | |||
| // MULTI_SUBGRAPHS | |||
| auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); | |||
| MULTI_SUBGRAPHS = multi_subgraphs; | |||
| if (MULTI_SUBGRAPHS) { | |||
| MS_LOG(INFO) << "multi_subgraphs: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||
| } | |||
| // RUN_PHASE | |||
| auto phase = CostModelContext::GetInstance()->run_phase(); | |||
| if (phase != 0 && phase != 1) { | |||
| MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; | |||
| } | |||
| RUN_PHASE = phase; | |||
| MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; | |||
| } | |||
| void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { | |||
| @@ -249,19 +268,21 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||
| MS_EXCEPTION_IF_NULL(cost3); | |||
| double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; | |||
| double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; | |||
| double commmunication = | |||
| cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; | |||
| double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; | |||
| double communication_forward = | |||
| cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_; | |||
| double communication_without_para = cost1->communication_without_parameter_ + | |||
| cost2->communication_without_parameter_ + | |||
| cost3->communication_without_parameter_; | |||
| auto decision = | |||
| std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | |||
| auto cost = std::make_shared<Cost>(computation, commmunication, decision); | |||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | |||
| MS_EXCEPTION_IF_NULL(cost); | |||
| cost->communication_without_parameter_ = communication_without_para; | |||
| cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para); | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| cost->memory_with_reuse_ = memory; | |||
| cost->communication_forward_ = communication_forward; | |||
| ret.push_back(cost); | |||
| } | |||
| } | |||
| @@ -269,7 +290,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||
| } | |||
| } | |||
| SimplifyForDreasingCommunicationWithPartialPara(&ret); | |||
| Simplify(&ret); | |||
| return ret; | |||
| } | |||
| @@ -291,32 +312,67 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { | |||
| cost1->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); | |||
| new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; | |||
| new_cost->communication_forward_ = cost1->communication_forward_; | |||
| ret.push_back(new_cost); | |||
| } | |||
| } | |||
| SimplifyForDreasingCommunicationWithPartialPara(&ret); | |||
| Simplify(&ret); | |||
| return ret; | |||
| } | |||
| CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) { | |||
| CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) { | |||
| // Select the cost with minimum inference time. Currently, the inference time is modeled as = | |||
| // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_ | |||
| if (cost_list.empty()) { | |||
| MS_LOG(ERROR) << "Final cost list is null."; | |||
| return nullptr; | |||
| } | |||
| CostPtrList after_mem_filter; | |||
| // Filter out the valid costs | |||
| double minimum_memory = DBL_MAX; | |||
| // Filter out the valid costs. | |||
| for (auto &a_cost : cost_list) { | |||
| if (a_cost->memory_with_reuse_ <= memory) { | |||
| after_mem_filter.emplace_back(std::move(a_cost)); | |||
| } else if (a_cost->memory_with_reuse_ < minimum_memory) { | |||
| minimum_memory = a_cost->memory_with_reuse_; | |||
| } | |||
| } | |||
| if (after_mem_filter.empty()) { | |||
| MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory | |||
| << ", the memory capacity is: " << memory << "."; | |||
| return nullptr; | |||
| } | |||
| // Init the returned value with first cost. | |||
| CostPtr ret = after_mem_filter[0]; | |||
| std::function<CostPtr(CostPtr, const CostPtr &)> LocalCompare = [&](CostPtr init, const CostPtr &cost_x) { | |||
| MS_EXCEPTION_IF_NULL(cost_x); | |||
| if (init == nullptr || cost_x->computation_cost_ < memory) { | |||
| init = cost_x; | |||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; | |||
| MS_LOG(INFO) << "Cost 0: " | |||
| << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | |||
| << ", communication_forward_: " << ret->communication_forward_ | |||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | |||
| << ", communication_cost_: " << ret->communication_cost_ | |||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | |||
| MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; | |||
| for (size_t i = 1; i < after_mem_filter.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | |||
| MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | |||
| << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ | |||
| << ", communication_forward_: " << after_mem_filter[i]->communication_forward_ | |||
| << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ | |||
| << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | |||
| << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | |||
| << "."; | |||
| auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + | |||
| costmodel_beta_ * after_mem_filter[i]->communication_forward_; | |||
| MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | |||
| if (minimum > tmp) { | |||
| minimum = tmp; | |||
| ret = after_mem_filter[i]; | |||
| MS_LOG(INFO) << "Selected: " << i; | |||
| } | |||
| return init; | |||
| }; | |||
| CostPtr ret = nullptr; | |||
| return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); | |||
| } | |||
| return ret; | |||
| } | |||
| CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { | |||
| @@ -524,12 +580,26 @@ Status CostGraph::SearchStrategy() { | |||
| }); | |||
| if (alive_ops.size() > 2) { | |||
| return SearchStrategyForMultiNodeFinalGraph(alive_ops); | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| return SearchStrategyForMultiNodeFinalGraph(alive_ops); | |||
| } else { | |||
| // inference phase | |||
| MS_LOG(EXCEPTION) | |||
| << "Currently, searching strategy for the multi-node final graph in inference phase is not supported."; | |||
| } | |||
| } else if (alive_ops.size() == 1) { | |||
| MS_LOG(INFO) << "There are 1 single node in the final graph."; | |||
| OperatorInfoPtr u = alive_ops[0]; | |||
| auto cost_list = CreateFinalSingleCostList(u); | |||
| auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||
| CostPtr cost = nullptr; | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||
| } else { | |||
| // inference phase | |||
| cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); | |||
| } | |||
| if (cost == nullptr) { | |||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||
| return FAILED; | |||
| @@ -575,7 +645,15 @@ Status CostGraph::SearchStrategy() { | |||
| auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); | |||
| all_list.push_back(cost_list); | |||
| } | |||
| auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||
| CostPtrList selected_cost_list; | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||
| } else { | |||
| // inference phase | |||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " | |||
| "phase is not supported."; | |||
| } | |||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | |||
| auto selected_cost = selected_cost_list[k]; | |||
| if (selected_cost == nullptr) { | |||
| @@ -601,7 +679,14 @@ Status CostGraph::SearchStrategy() { | |||
| auto e = u->GetAliveSuccEdges()[0]; | |||
| MS_EXCEPTION_IF_NULL(e); | |||
| auto cost_list = CreateFinalCostList(u, e, v); | |||
| auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||
| CostPtr cost = nullptr; | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " | |||
| "phase is not supported."; | |||
| } | |||
| if (cost == nullptr) { | |||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||
| return FAILED; | |||
| @@ -841,6 +926,8 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||
| double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | |||
| double communication = | |||
| op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | |||
| double communication_forward = | |||
| op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_; | |||
| double communication_without_para = op_cost->communication_without_parameter_ + | |||
| edge_cost->communication_without_parameter_ + | |||
| tar_cost->communication_without_parameter_; | |||
| @@ -853,6 +940,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| MS_EXCEPTION_IF_NULL(tar_cost_list_new); | |||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| @@ -885,7 +973,7 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { | |||
| CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); | |||
| } | |||
| SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new); | |||
| Simplify(&tar_clist_new); | |||
| // Set the new costlist w.r.t the strategy | |||
| tar_stra_cost->cost_list = tar_clist_new; | |||
| if ((!valid) && (!tar_clist_new.empty())) { | |||
| @@ -922,6 +1010,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||
| contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | |||
| double communication = | |||
| contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | |||
| double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ + | |||
| tar_cost->communication_forward_; | |||
| double communication_without_para = contract_op_cost->communication_without_parameter_ + | |||
| edge_cost->communication_without_parameter_ + | |||
| tar_cost->communication_without_parameter_; | |||
| @@ -933,6 +1023,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| } | |||
| @@ -962,7 +1053,7 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { | |||
| CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); | |||
| } | |||
| SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new); | |||
| Simplify(&tar_clist_new); | |||
| // Set the new costlist w.r.t the strategy | |||
| tar_stra_cost->cost_list = tar_clist_new; | |||
| if ((!valid) && (!tar_clist_new.empty())) { | |||
| @@ -998,6 +1089,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; | |||
| double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + | |||
| left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; | |||
| double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ + | |||
| left_node_cost->communication_forward_ + right_edge_cost->communication_forward_; | |||
| double new_commu_without = | |||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | |||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | |||
| @@ -1009,6 +1102,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| new_cost->communication_with_partial_para_ = | |||
| new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | |||
| new_cost->memory_with_reuse_ = new_memory; | |||
| new_cost->communication_forward_ = new_commu_forward; | |||
| left_node_clist_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| } | |||
| @@ -1079,7 +1173,7 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, | |||
| &left_node_clist_new); | |||
| } | |||
| } | |||
| SimplifyForDreasingCommunicationWithPartialPara(&left_node_clist_new); | |||
| Simplify(&left_node_clist_new); | |||
| // Set the new costlist w.r.t the strategy | |||
| left_node_stra_cost->cost_list = left_node_clist_new; | |||
| if ((!valid) && (!left_node_clist_new.empty())) { | |||
| @@ -1112,19 +1206,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||
| double computation_cost = merged_node_cost->computation_cost_, | |||
| memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, | |||
| commu_without = merged_node_cost->communication_without_parameter_; | |||
| commu_without = merged_node_cost->communication_without_parameter_, | |||
| commu_forward = merged_node_cost->communication_forward_; | |||
| for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | |||
| if (i == 0) { | |||
| computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; | |||
| memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; | |||
| commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; | |||
| commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_; | |||
| commu_without += succ_edges_costs[i]->communication_without_parameter_ + | |||
| succ_nodes_costs[i]->communication_without_parameter_; | |||
| } else { | |||
| computation_cost += succ_edges_costs[i]->computation_cost_; | |||
| memory_cost += succ_edges_costs[i]->memory_with_reuse_; | |||
| commu_cost += succ_edges_costs[i]->communication_cost_; | |||
| commu_forward += succ_edges_costs[i]->communication_forward_; | |||
| commu_without += succ_edges_costs[i]->communication_without_parameter_; | |||
| } | |||
| } | |||
| @@ -1135,6 +1232,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||
| new_cost->communication_without_parameter_ = commu_without; | |||
| new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | |||
| new_cost->memory_with_reuse_ = memory_cost; | |||
| new_cost->communication_forward_ = commu_forward; | |||
| first_succ_node_clist_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| } | |||
| @@ -1220,7 +1318,7 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||
| CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, | |||
| merged_op_stra, merged_op_clist, &first_succ_node_clist_new); | |||
| } | |||
| SimplifyForDreasingCommunicationWithPartialPara(&first_succ_node_clist_new); | |||
| Simplify(&first_succ_node_clist_new); | |||
| // Set the new costlist w.r.t the strategy | |||
| first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; | |||
| if ((!valid) && (!first_succ_node_clist_new.empty())) { | |||
| @@ -45,6 +45,9 @@ namespace parallel { | |||
| #define DEFAULT_FULLY_USE_DEVICES true | |||
| #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | |||
| #define DEFAULT_IS_MULTI_SUBGRAPHS false | |||
| #define DEFAULT_RUN_PHASE 0 | |||
| #define TRAINING_PHASE 0 | |||
| #define INFERENCE_PHASE 1 | |||
| class CostGraph; | |||
| using CostGraphPtr = std::shared_ptr<CostGraph>; | |||
| @@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; | |||
| extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | |||
| extern bool FULLY_USE_DEVICES; | |||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | |||
| extern bool MULTI_SUBGRAPHS; | |||
| extern int32_t RUN_PHASE; | |||
| class CostGraph { | |||
| // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | |||
| @@ -98,7 +103,7 @@ class CostGraph { | |||
| CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); | |||
| CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); | |||
| CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory); | |||
| CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); | |||
| CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); | |||
| CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory); | |||
| Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | |||
| @@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() { | |||
| costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | |||
| costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | |||
| is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| run_phase_ = DEFAULT_RUN_PHASE; | |||
| costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | |||
| costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | |||
| costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | |||
| @@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ | |||
| void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | |||
| elementwise_stra_follow_ = elementwise_follow; | |||
| } | |||
| void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -113,6 +113,9 @@ class CostModelContext { | |||
| void set_elementwise_stra_follow(bool); | |||
| bool elementwise_stra_follow() const { return elementwise_stra_follow_; } | |||
| void set_run_phase(int32_t); | |||
| int32_t run_phase() const { return run_phase_; } | |||
| private: | |||
| CostModelContext(); | |||
| static std::shared_ptr<CostModelContext> cm_context_inst_; | |||
| @@ -141,8 +144,11 @@ class CostModelContext { | |||
| // COST_MODEL_COMMUNI_BIAS | |||
| double costmodel_communi_bias_; | |||
| // MULTI_SUBGRAPHS | |||
| bool is_multi_subgraphs_; | |||
| int32_t run_phase_; // 0: 'training', 1: 'inference' | |||
| int32_t costmodel_allreduce_fusion_algorithm_; | |||
| int32_t costmodel_allreduce_fusion_times_; | |||
| @@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & | |||
| << ", communication_with_partial_para_: " << result->communication_with_partial_para_; | |||
| // refine communication cost calculation for practice | |||
| RefineForPracticalCost(result, false); | |||
| result->communication_forward_ = result->communication_without_parameter_; | |||
| std::shared_ptr<StrategyWithCost> swc = | |||
| std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | |||
| @@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { | |||
| BreakingTiesForPerferringDataParallel(strategy, result); | |||
| // refine communication cost calculation for practice | |||
| RefineForPracticalCost(result, false); | |||
| result->communication_forward_ = result->communication_without_parameter_; | |||
| std::shared_ptr<StrategyWithCost> swc = | |||
| std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); | |||
| @@ -69,16 +69,16 @@ class TensorRedistribution { | |||
| RankList dev_list_; | |||
| OperatorList operator_list_; | |||
| bool reshape_flag_; | |||
| // communication cost | |||
| // communication cost, which is the sum of forward communication cost and backward communication cost | |||
| double comm_cost_; | |||
| // forward communication cost | |||
| double forward_comm_cost_; | |||
| // backward communication cost | |||
| double backward_comm_cost_; | |||
| // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the | |||
| // inputs. | |||
| // inputs. This is calculated ONLY for forward phase. | |||
| double computation_cost_; | |||
| // memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is | |||
| // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is | |||
| // calculated by the outputs. | |||
| double memory_cost_; | |||
| bool construct_op_flag_; | |||
| @@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Get the parameter cost_model_communi_bias of the DP algorithm.") | |||
| .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") | |||
| .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") | |||
| .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.") | |||
| .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.") | |||
| .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, | |||
| "Set the parameter gradient AllReduce fusion algorithm.") | |||
| .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, | |||
| @@ -239,6 +239,33 @@ class _CostModelContext: | |||
| raise ValueError("Context handle is none in context!!!") | |||
| return self._context_handle.get_multi_subgraphs() | |||
| def set_run_phase(self, phase): | |||
| """ | |||
| Set the flag of running phase: training (0) or inference (1) | |||
| Args: | |||
| phase (int): A parameter indicating which phase is running. | |||
| Raises: | |||
| ValueError: If context handle is none, or phase is not in {0, 1}. | |||
| """ | |||
| if self._context_handle is None: | |||
| raise ValueError("Context handle is none in context!!!") | |||
| if phase not in (0, 1): | |||
| raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase)) | |||
| self._context_handle.set_run_phase(phase) | |||
| def get_run_phase(self): | |||
| """ | |||
| Get the flag of running phase. | |||
| Raises: | |||
| ValueError: If context handle is none. | |||
| """ | |||
| if self._context_handle is None: | |||
| raise ValueError("Context handle is none in context!!!") | |||
| return self._context_handle.get_run_phase() | |||
| def set_costmodel_allreduce_fusion_algorithm(self, algorithm): | |||
| """ | |||
| Set costmodel allreduce fusion algorithm. | |||
| @@ -453,6 +480,7 @@ set_cost_model_context_func_map = { | |||
| "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, | |||
| "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, | |||
| "multi_subgraphs": cost_model_context().set_multi_subgraphs, | |||
| "run_phase": cost_model_context().set_run_phase, | |||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, | |||
| "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, | |||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, | |||
| @@ -473,7 +501,8 @@ get_cost_model_context_func_map = { | |||
| "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, | |||
| "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, | |||
| "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, | |||
| "multi_subgraphs": cost_model_context().get_multi_subgraphs(), | |||
| "multi_subgraphs": cost_model_context().get_multi_subgraphs, | |||
| "run_phase": cost_model_context().get_run_phase, | |||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, | |||
| "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, | |||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, | |||
| @@ -488,7 +517,7 @@ get_cost_model_context_func_map = { | |||
| @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, | |||
| costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, | |||
| multi_subgraphs=bool, | |||
| multi_subgraphs=bool, run_phase=int, | |||
| costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, | |||
| costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, | |||
| costmodel_allreduce_fusion_allreduce_inherent_time=float, | |||
| @@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs): | |||
| costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. | |||
| costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. | |||
| multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. | |||
| run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0. | |||
| costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. | |||
| 0: bypass allreduce fusion; | |||
| 1: only use backward computation time to group allreduce; | |||
| @@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) { | |||
| ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); | |||
| cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); | |||
| auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); | |||
| cost_graph.SelectCostWithMemoryConstraint(cost_list, cost_graph.GetDeviceMemory()); | |||
| cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory()); | |||
| } | |||
| TEST_F(TestCostGraph, test_EliminationOp) { | |||
| @@ -14,15 +14,21 @@ | |||
| import mindspore.context as context | |||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| from mindspore.parallel._cost_model_context import reset_cost_model_context | |||
| from mindspore.parallel.algo_parameter_config import reset_algo_parameters | |||
| from mindspore.parallel._utils import _reset_op_id | |||
| def setup_module(module): | |||
| auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) | |||
| reset_cost_model_context() | |||
| reset_algo_parameters() | |||
| _reset_op_id() | |||
| def teardown_module(): | |||
| context.reset_auto_parallel_context() | |||
| reset_cost_model_context() | |||
| reset_algo_parameters() | |||
| _reset_op_id() | |||
| @@ -0,0 +1,36 @@ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor, context | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn import WithLossCell, TrainOneStepCell | |||
| from mindspore.nn import Momentum | |||
| from mindspore.parallel._cost_model_context import set_cost_model_context | |||
| class Net(nn.Cell): | |||
| def __init__(self, input_ch, out_ch): | |||
| super(Net, self).__init__() | |||
| self.dense = nn.Dense(input_ch, out_ch) | |||
| self.relu = P.ReLU() | |||
| def construct(self, x): | |||
| x = self.dense(x) | |||
| x = self.relu(x) | |||
| return x | |||
| def test_inference_phase(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| set_cost_model_context(run_phase=1) | |||
| net = Net(512, 128) | |||
| predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001) | |||
| label = Tensor(np.ones([64, 128]).astype(np.float32)) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, optimizer) | |||
| train_network.set_train() | |||
| output = train_network(predict, label) | |||