From: @xiaoda_zh Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsutengpull/15777/MERGE
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| void Simplify(CostPtrList *clist_ptrs) { | void Simplify(CostPtrList *clist_ptrs) { | ||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| const auto run_phase = CostModelContext::GetInstance()->run_phase(); | |||||
| if (run_phase == TRAINING_PHASE) { | |||||
| // training phase | // training phase | ||||
| SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); | SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); | ||||
| } else { | } else { | ||||
| @@ -35,7 +36,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { | |||||
| // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method | // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method | ||||
| // excludes the cost with greater computation_cost_ and greater communication_forward. | // excludes the cost with greater computation_cost_ and greater communication_forward. | ||||
| // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | ||||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | |||||
| const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal(); | |||||
| if (!simplify_cal) { | |||||
| return; | return; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(clist_ptrs); | MS_EXCEPTION_IF_NULL(clist_ptrs); | ||||
| @@ -57,7 +59,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { | |||||
| void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { | void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { | ||||
| // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing | // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing | ||||
| // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. | // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. | ||||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | |||||
| const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal(); | |||||
| if (!simplify_cal) { | |||||
| return; | return; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(clist_ptrs); | MS_EXCEPTION_IF_NULL(clist_ptrs); | ||||
| @@ -78,19 +81,23 @@ void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) | |||||
| void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { | void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { | ||||
| MS_EXCEPTION_IF_NULL(origin_cost); | MS_EXCEPTION_IF_NULL(origin_cost); | ||||
| const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); | |||||
| const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const(); | |||||
| const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); | |||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| if (is_redistribution) { | if (is_redistribution) { | ||||
| // Redistribution cost | // Redistribution cost | ||||
| if ((origin_cost->communication_redis_forward_ > EPS) && | if ((origin_cost->communication_redis_forward_ > EPS) && | ||||
| (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { | |||||
| origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; | |||||
| } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { | |||||
| origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; | |||||
| (origin_cost->communication_redis_forward_ <= comm_threshold)) { | |||||
| origin_cost->communication_redis_forward_ = comm_const; | |||||
| } else if (origin_cost->communication_redis_forward_ > comm_threshold) { | |||||
| origin_cost->communication_redis_forward_ += comm_bias; | |||||
| } | } | ||||
| if ((origin_cost->communication_redis_backward_ > EPS) && | if ((origin_cost->communication_redis_backward_ > EPS) && | ||||
| (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { | |||||
| origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; | |||||
| } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { | |||||
| origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; | |||||
| (origin_cost->communication_redis_backward_ <= comm_threshold)) { | |||||
| origin_cost->communication_redis_backward_ = comm_const; | |||||
| } else if (origin_cost->communication_redis_backward_ > comm_threshold) { | |||||
| origin_cost->communication_redis_backward_ += comm_bias; | |||||
| } | } | ||||
| origin_cost->communication_cost_ = | origin_cost->communication_cost_ = | ||||
| origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; | origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; | ||||
| @@ -104,18 +111,17 @@ void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) | |||||
| } | } | ||||
| // forward cost | // forward cost | ||||
| if ((origin_cost->communication_without_parameter_ > EPS) && | if ((origin_cost->communication_without_parameter_ > EPS) && | ||||
| (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { | |||||
| origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; | |||||
| } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { | |||||
| origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; | |||||
| (origin_cost->communication_without_parameter_ <= comm_threshold)) { | |||||
| origin_cost->communication_without_parameter_ = comm_const; | |||||
| } else if (origin_cost->communication_without_parameter_ > comm_threshold) { | |||||
| origin_cost->communication_without_parameter_ += comm_bias; | |||||
| } | } | ||||
| // total | // total | ||||
| if (origin_cost->communication_cost_ > EPS) { | if (origin_cost->communication_cost_ > EPS) { | ||||
| origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; | origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; | ||||
| } | } | ||||
| if (origin_cost->communication_with_partial_para_ > EPS) { | if (origin_cost->communication_with_partial_para_ > EPS) { | ||||
| origin_cost->communication_with_partial_para_ = | |||||
| origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; | |||||
| origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | #include "frontend/parallel/tensor_layout/tensor_info.h" | ||||
| #include "frontend/parallel/costmodel_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -44,8 +45,8 @@ using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPut | |||||
| struct Cost { | struct Cost { | ||||
| Cost(); | Cost(); | ||||
| Cost(double computation, double commuication, const std::shared_ptr<Decision> &decision_ = nullptr) | |||||
| : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { | |||||
| Cost(double computation, double communication, const std::shared_ptr<Decision> &decision_ = nullptr) | |||||
| : computation_cost_(computation), communication_cost_(communication), decision_ptr_(std::move(decision_)) { | |||||
| memory_with_reuse_ = 0.0; | memory_with_reuse_ = 0.0; | ||||
| communication_without_parameter_ = 0.0; | communication_without_parameter_ = 0.0; | ||||
| communication_with_partial_para_ = 0.0; | communication_with_partial_para_ = 0.0; | ||||
| @@ -273,7 +274,7 @@ struct TriangleEliminationDecision : public Decision { | |||||
| * | * | ||||
| * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. | * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. | ||||
| * In addition, v and w have other complicated connections, resulting in v and w can not be performed other | * In addition, v and w have other complicated connections, resulting in v and w can not be performed other | ||||
| * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple | |||||
| * eliminations. After the StarElimination, u is merged into v, and the resulting graph is split into multiple | |||||
| * connected components. | * connected components. | ||||
| * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. | * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. | ||||
| */ | */ | ||||
| @@ -132,18 +132,14 @@ Status GetStrategy(const CostGraphPtr &graph) { | |||||
| Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | ||||
| std::vector<EliminationPtr>::reverse_iterator rit; | std::vector<EliminationPtr>::reverse_iterator rit; | ||||
| const auto triangle_star_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||||
| for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { | for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { | ||||
| if ((*rit)->isa<OpElimination>()) { | if ((*rit)->isa<OpElimination>()) { | ||||
| auto elimination = (*rit)->cast<OpEliminationPtr>(); | auto elimination = (*rit)->cast<OpEliminationPtr>(); | ||||
| auto e = elimination->new_edge_; | auto e = elimination->new_edge_; | ||||
| auto w = elimination->op_; | auto w = elimination->op_; | ||||
| MS_EXCEPTION_IF_NULL(e); | |||||
| MS_EXCEPTION_IF_NULL(w); | |||||
| auto left_edge = elimination->left_edge_; | auto left_edge = elimination->left_edge_; | ||||
| auto right_edge = elimination->right_edge_; | auto right_edge = elimination->right_edge_; | ||||
| MS_EXCEPTION_IF_NULL(left_edge); | |||||
| MS_EXCEPTION_IF_NULL(right_edge); | |||||
| auto decision = e->selected_cost()->decision_ptr_->cast<OpEliminationDecisionPtr>(); | auto decision = e->selected_cost()->decision_ptr_->cast<OpEliminationDecisionPtr>(); | ||||
| w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); | w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); | ||||
| left_edge->set_selected_cost(decision->left_cost_); | left_edge->set_selected_cost(decision->left_cost_); | ||||
| @@ -201,12 +197,12 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||||
| right_edge->set_selected_cost(decision->right_edge_cost_); | right_edge->set_selected_cost(decision->right_edge_cost_); | ||||
| // 'left_node' recovers the strategy. | // 'left_node' recovers the strategy. | ||||
| left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | ||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| if (triangle_star_overwrite) { | |||||
| // 'right_node' recovers the strategy. | // 'right_node' recovers the strategy. | ||||
| MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; | MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; | ||||
| right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | ||||
| } else { | } else { | ||||
| // In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency. | |||||
| // In this case, 'right_node' is not overwritten strategy, and it checks strategy consistency. | |||||
| right_node->CheckSelectedStrategy(decision->right_node_strategy_); | right_node->CheckSelectedStrategy(decision->right_node_strategy_); | ||||
| } | } | ||||
| MS_LOG(INFO) << "Recover triangleElimination succeeded."; | MS_LOG(INFO) << "Recover triangleElimination succeeded."; | ||||
| @@ -215,7 +211,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||||
| auto merged_node = elimination->eliminated_node_; | auto merged_node = elimination->eliminated_node_; | ||||
| auto succ_edges = elimination->succ_edges_; | auto succ_edges = elimination->succ_edges_; | ||||
| auto succ_nodes = elimination->succ_ops_; | auto succ_nodes = elimination->succ_ops_; | ||||
| // decision is hided in succ_nodes[0] | |||||
| // decision is hidden in succ_nodes[0] | |||||
| auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast<StarEliminationDecisionPtr>(); | auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast<StarEliminationDecisionPtr>(); | ||||
| merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); | merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); | ||||
| @@ -228,7 +224,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||||
| // Star is eliminated into 'succ_nodes[0]' | // Star is eliminated into 'succ_nodes[0]' | ||||
| succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | ||||
| for (size_t k = 1; k < succ_nodes.size(); ++k) { | for (size_t k = 1; k < succ_nodes.size(); ++k) { | ||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| if (triangle_star_overwrite) { | |||||
| // 'succ_nodes[k]' is overwritten strategy and cost. | // 'succ_nodes[k]' is overwritten strategy and cost. | ||||
| succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]); | succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]); | ||||
| } else { | } else { | ||||
| @@ -90,11 +90,13 @@ Status Edge::InitEdgeCost() { | |||||
| } | } | ||||
| } | } | ||||
| if (!has_available_cost) { | if (!has_available_cost) { | ||||
| if (FULLY_USE_DEVICES) { | |||||
| const auto fully_use = CostModelContext::GetInstance()->fully_use_device(); | |||||
| const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); | |||||
| if (fully_use) { | |||||
| MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | ||||
| << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " | << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " | ||||
| "'fully_use_devices' false."; | "'fully_use_devices' false."; | ||||
| } else if (ELEMENTWISE_OP_STRA_FOLLOW) { | |||||
| } else if (stra_follow) { | |||||
| MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | ||||
| << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " | << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " | ||||
| "Try to set 'elementwise_op_strategy_follow' false."; | "Try to set 'elementwise_op_strategy_follow' false."; | ||||
| @@ -130,6 +132,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co | |||||
| double backward_comm_cost = tensor_redistribution.backward_comm_cost(); | double backward_comm_cost = tensor_redistribution.backward_comm_cost(); | ||||
| double computation_cost = tensor_redistribution.computation_cost(); | double computation_cost = tensor_redistribution.computation_cost(); | ||||
| double mem_cost = tensor_redistribution.memory_cost(); | double mem_cost = tensor_redistribution.memory_cost(); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| // Now AllGather, ReduceScatter, AlltoAll don't support bool type | // Now AllGather, ReduceScatter, AlltoAll don't support bool type | ||||
| MS_EXCEPTION_IF_NULL(type); | MS_EXCEPTION_IF_NULL(type); | ||||
| @@ -142,7 +145,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co | |||||
| (*cost)->communication_without_parameter_ = type_length * comm_cost; | (*cost)->communication_without_parameter_ = type_length * comm_cost; | ||||
| (*cost)->communication_with_partial_para_ = | (*cost)->communication_with_partial_para_ = | ||||
| (*cost)->communication_without_parameter_ + | (*cost)->communication_without_parameter_ + | ||||
| COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); | |||||
| gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); | |||||
| (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; | (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; | ||||
| (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; | (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; | ||||
| (*cost)->memory_with_reuse_ = mem_cost; | (*cost)->memory_with_reuse_ = mem_cost; | ||||
| @@ -173,13 +176,14 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||||
| std::function<void(size_t, double, double, double, double, double)> recursive = | std::function<void(size_t, double, double, double, double, double)> recursive = | ||||
| [&](size_t k, double computation, double memory, double communication, double communication_without_para, | [&](size_t k, double computation, double memory, double communication, double communication_without_para, | ||||
| double communication_forward) { | double communication_forward) { | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| if (k == edges.size()) { | if (k == edges.size()) { | ||||
| auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | ||||
| CostPtr new_cost = std::make_shared<Cost>(computation, communication); | CostPtr new_cost = std::make_shared<Cost>(computation, communication); | ||||
| MS_EXCEPTION_IF_NULL(new_cost); | MS_EXCEPTION_IF_NULL(new_cost); | ||||
| new_cost->communication_without_parameter_ = communication_without_para; | new_cost->communication_without_parameter_ = communication_without_para; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| communication_without_para + gamma * (communication - communication_without_para); | |||||
| new_cost->memory_with_reuse_ = memory; | new_cost->memory_with_reuse_ = memory; | ||||
| new_cost->communication_forward_ = communication_forward; | new_cost->communication_forward_ = communication_forward; | ||||
| new_cost->decision_ptr_ = decision; | new_cost->decision_ptr_ = decision; | ||||
| @@ -242,10 +246,11 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||||
| auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); | auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); | ||||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | auto cost = std::make_shared<Cost>(computation, communication, decision); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| MS_EXCEPTION_IF_NULL(cost); | MS_EXCEPTION_IF_NULL(cost); | ||||
| cost->communication_without_parameter_ = communication_without_para; | cost->communication_without_parameter_ = communication_without_para; | ||||
| cost->communication_with_partial_para_ = | cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| communication_without_para + gamma * (communication - communication_without_para); | |||||
| cost->memory_with_reuse_ = memory_cost; | cost->memory_with_reuse_ = memory_cost; | ||||
| cost->communication_forward_ = communication_forward; | cost->communication_forward_ = communication_forward; | ||||
| ret_cost_list->emplace_back(std::move(cost)); | ret_cost_list->emplace_back(std::move(cost)); | ||||
| @@ -28,175 +28,6 @@ namespace mindspore { | |||||
| namespace parallel { | namespace parallel { | ||||
| CostGraphPtr entire_costgraph = nullptr; | CostGraphPtr entire_costgraph = nullptr; | ||||
| size_t TOTAL_OPS = 0; | size_t TOTAL_OPS = 0; | ||||
| double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA; | |||||
| bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; | |||||
| double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY; | |||||
| double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; | |||||
| double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; | |||||
| double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; | |||||
| bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | |||||
| size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | |||||
| bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||||
| int64_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||||
| bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||||
| bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX; | |||||
| double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON; | |||||
| bool DP_ALGO_SINGLE_LOOP = DEFAULT_DP_ALGO_SINGLE_LOOP; | |||||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | |||||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||||
| // DEVICE_MEMORY_CAPACITY | |||||
| auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); | |||||
| if (device_memory <= 0) { | |||||
| MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; | |||||
| } | |||||
| dev_memory_ = device_memory; | |||||
| DEVICE_MEMORY_CAPACITY = device_memory; | |||||
| MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << "."; | |||||
| // COST_MODEL_ALPHA | |||||
| auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||||
| if (alpha <= 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; | |||||
| } | |||||
| costmodel_alpha_ = alpha; | |||||
| MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; | |||||
| // COST_MODEL_BETA | |||||
| auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||||
| if (beta <= 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; | |||||
| } | |||||
| costmodel_beta_ = beta; | |||||
| MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; | |||||
| // COST_MODEL_GAMMA | |||||
| auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| if ((gamma < 0) || (gamma > 1)) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; | |||||
| } | |||||
| COST_MODEL_GAMMA = gamma; | |||||
| MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << "."; | |||||
| // COST_MODEL_SIMPLIFY_CALCULATION | |||||
| auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal(); | |||||
| COST_MODEL_SIMPLIFY_CALCULATION = simplify; | |||||
| if (COST_MODEL_SIMPLIFY_CALCULATION) { | |||||
| MS_LOG(INFO) << "costmodel_simplify_cal: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "costmodel_simplify_cal: false."; | |||||
| } | |||||
| // COST_MODEL_COMMUNI_THRESHOLD | |||||
| auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); | |||||
| if (communi_threshold < 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; | |||||
| } | |||||
| COST_MODEL_COMMUNI_THRESHOLD = communi_threshold; | |||||
| MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << "."; | |||||
| // COST_MODEL_COMMUNI_CONST | |||||
| auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const(); | |||||
| if (communi_const < 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; | |||||
| } | |||||
| COST_MODEL_COMMUNI_CONST = communi_const; | |||||
| MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << "."; | |||||
| // COST_MODEL_COMMUNI_BIAS | |||||
| auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); | |||||
| if (communi_bias < 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; | |||||
| } | |||||
| COST_MODEL_COMMUNI_BIAS = communi_bias; | |||||
| MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << "."; | |||||
| // TENSOR_SLICE_ALIGNMENT_ENABLE | |||||
| auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); | |||||
| TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable; | |||||
| if (TENSOR_SLICE_ALIGNMENT_ENABLE) { | |||||
| MS_LOG(INFO) << "tensor_slice_align_enable: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "tensor_slice_align_enable: false."; | |||||
| } | |||||
| // TENSOR_SLICE_ALIGNMENT_SIZE | |||||
| auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); | |||||
| if (align_size == 0) { | |||||
| MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; | |||||
| } | |||||
| TENSOR_SLICE_ALIGNMENT_SIZE = align_size; | |||||
| MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; | |||||
| // FULLY_USE_DEVICES | |||||
| auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); | |||||
| FULLY_USE_DEVICES = fully_devices; | |||||
| if (FULLY_USE_DEVICES) { | |||||
| MS_LOG(INFO) << "fully_use_devices: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "fully_use_devices: false."; | |||||
| } | |||||
| // ELEMENTWISE_OP_STRA_FOLLOW | |||||
| auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); | |||||
| ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow; | |||||
| if (ELEMENTWISE_OP_STRA_FOLLOW) { | |||||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; | |||||
| } | |||||
| // MULTI_SUBGRAPHS | |||||
| auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); | |||||
| MULTI_SUBGRAPHS = multi_subgraphs; | |||||
| if (MULTI_SUBGRAPHS) { | |||||
| MS_LOG(INFO) << "multi_subgraphs: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||||
| } | |||||
| auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||||
| TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite; | |||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; | |||||
| } | |||||
| // RUN_PHASE | |||||
| auto phase = CostModelContext::GetInstance()->run_phase(); | |||||
| if (phase != 0 && phase != 1) { | |||||
| MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; | |||||
| } | |||||
| RUN_PHASE = phase; | |||||
| MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; | |||||
| auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||||
| DP_ALGO_ENABLE_APPROX = enable_approx; | |||||
| if (enable_approx) { | |||||
| MS_LOG(INFO) << "dp_algo_enable_approx: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "dp_algo_enable_approx: false."; | |||||
| } | |||||
| auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon(); | |||||
| if (epsilon <= 0 || epsilon > 1) { | |||||
| MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; | |||||
| } | |||||
| DP_ALGO_APPROX_EPSILON = epsilon; | |||||
| MS_LOG(INFO) << "epsilon: " << epsilon << "."; | |||||
| auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); | |||||
| DP_ALGO_SINGLE_LOOP = single_loop; | |||||
| if (single_loop) { | |||||
| MS_LOG(INFO) << "dp_algo_single_loop: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "dp_algo_single_loop: false."; | |||||
| } | |||||
| } | |||||
| void CostGraph::Init() { | void CostGraph::Init() { | ||||
| inputs_tensor_name_list_.clear(); | inputs_tensor_name_list_.clear(); | ||||
| @@ -269,7 +100,6 @@ std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents( | |||||
| if ((!visited[op]) && op->is_alive()) { | if ((!visited[op]) && op->is_alive()) { | ||||
| std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>(); | std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>(); | ||||
| MS_EXCEPTION_IF_NULL(new_component); | MS_EXCEPTION_IF_NULL(new_component); | ||||
| new_component->SetDeviceMemoryAndCostParameter(); | |||||
| DFS(op, &visited, new_component); | DFS(op, &visited, new_component); | ||||
| connected_compoents_.push_back(new_component); | connected_compoents_.push_back(new_component); | ||||
| } | } | ||||
| @@ -336,10 +166,11 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||||
| auto decision = | auto decision = | ||||
| std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | ||||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | auto cost = std::make_shared<Cost>(computation, communication, decision); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| MS_EXCEPTION_IF_NULL(cost); | MS_EXCEPTION_IF_NULL(cost); | ||||
| cost->communication_without_parameter_ = communication_without_para; | cost->communication_without_parameter_ = communication_without_para; | ||||
| cost->communication_with_partial_para_ = | cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| communication_without_para + gamma * (communication - communication_without_para); | |||||
| cost->memory_with_reuse_ = memory; | cost->memory_with_reuse_ = memory; | ||||
| cost->communication_forward_ = communication_forward; | cost->communication_forward_ = communication_forward; | ||||
| ret.push_back(cost); | ret.push_back(cost); | ||||
| @@ -353,7 +184,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // Create final cost list for the graph containing a signle node: u | |||||
| // Create final cost list for the graph containing a single node: u | |||||
| CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { | CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { | ||||
| MS_EXCEPTION_IF_NULL(u); | MS_EXCEPTION_IF_NULL(u); | ||||
| CostPtrList ret; | CostPtrList ret; | ||||
| @@ -365,11 +196,12 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { | |||||
| MS_EXCEPTION_IF_NULL(cost1); | MS_EXCEPTION_IF_NULL(cost1); | ||||
| auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1); | auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1); | ||||
| auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision); | auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| MS_EXCEPTION_IF_NULL(new_cost); | MS_EXCEPTION_IF_NULL(new_cost); | ||||
| new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; | new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| cost1->communication_without_parameter_ + | cost1->communication_without_parameter_ + | ||||
| COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); | |||||
| gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_); | |||||
| new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; | new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; | ||||
| new_cost->communication_forward_ = cost1->communication_forward_; | new_cost->communication_forward_ = cost1->communication_forward_; | ||||
| ret.push_back(new_cost); | ret.push_back(new_cost); | ||||
| @@ -404,8 +236,10 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, | |||||
| } | } | ||||
| // Init the returned value with first cost. | // Init the returned value with first cost. | ||||
| CostPtr ret = after_mem_filter[0]; | CostPtr ret = after_mem_filter[0]; | ||||
| const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||||
| const auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; | |||||
| double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_; | |||||
| MS_LOG(INFO) << "Cost 0: " | MS_LOG(INFO) << "Cost 0: " | ||||
| << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | ||||
| << ", communication_forward_: " << ret->communication_forward_ | << ", communication_forward_: " << ret->communication_forward_ | ||||
| @@ -422,8 +256,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, | |||||
| << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | ||||
| << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | ||||
| << "."; | << "."; | ||||
| auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + | |||||
| costmodel_beta_ * after_mem_filter[i]->communication_forward_; | |||||
| auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_; | |||||
| MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | ||||
| if (minimum > tmp) { | if (minimum > tmp) { | ||||
| minimum = tmp; | minimum = tmp; | ||||
| @@ -458,8 +291,10 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d | |||||
| } | } | ||||
| // Init the returned value with first cost. | // Init the returned value with first cost. | ||||
| CostPtr ret = after_mem_filter[0]; | CostPtr ret = after_mem_filter[0]; | ||||
| const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||||
| const auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; | |||||
| double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_; | |||||
| MS_LOG(INFO) << "Cost 0: " | MS_LOG(INFO) << "Cost 0: " | ||||
| << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | ||||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | ||||
| @@ -474,8 +309,8 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d | |||||
| << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | ||||
| << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | ||||
| << "."; | << "."; | ||||
| auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + | |||||
| costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; | |||||
| auto tmp = | |||||
| alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_; | |||||
| MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | ||||
| if (minimum > tmp) { | if (minimum > tmp) { | ||||
| minimum = tmp; | minimum = tmp; | ||||
| @@ -513,14 +348,16 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect | |||||
| } | } | ||||
| std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, | std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, | ||||
| &available_memory, this](size_t k) { | |||||
| &available_memory](size_t k) { | |||||
| const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||||
| const auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||||
| if (k == all_cost_list.size()) { | if (k == all_cost_list.size()) { | ||||
| double tmp_memory = 0.0, tmp_minimum = 0.0; | double tmp_memory = 0.0, tmp_minimum = 0.0; | ||||
| for (size_t i = 0; i < selected_cost_list.size(); ++i) { | for (size_t i = 0; i < selected_cost_list.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(selected_cost_list[i]); | MS_EXCEPTION_IF_NULL(selected_cost_list[i]); | ||||
| tmp_memory += selected_cost_list[i]->memory_with_reuse_; | tmp_memory += selected_cost_list[i]->memory_with_reuse_; | ||||
| tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + | |||||
| costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; | |||||
| tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ + | |||||
| beta * selected_cost_list[i]->communication_with_partial_para_; | |||||
| } | } | ||||
| MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum | MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum | ||||
| << "."; | << "."; | ||||
| @@ -582,12 +419,12 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato | |||||
| << " operators in a component in the final graph."; | << " operators in a component in the final graph."; | ||||
| } | } | ||||
| } | } | ||||
| // | |||||
| auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||||
| auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity); | |||||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | for (size_t k = 0; k < selected_cost_list.size(); ++k) { | ||||
| auto selected_cost = selected_cost_list[k]; | auto selected_cost = selected_cost_list[k]; | ||||
| if (selected_cost == nullptr) { | if (selected_cost == nullptr) { | ||||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(connected_components[k]); | MS_EXCEPTION_IF_NULL(connected_components[k]); | ||||
| @@ -627,6 +464,99 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) { | |||||
| // In this case, the final graph should contains exactly 2 nodes. | |||||
| if (alive_ops.empty()) { | |||||
| MS_LOG(INFO) << "0 Operator in the final graph."; | |||||
| return SUCCESS; | |||||
| } | |||||
| OperatorInfoPtr u, v; | |||||
| MS_EXCEPTION_IF_NULL(alive_ops[0]); | |||||
| MS_EXCEPTION_IF_NULL(alive_ops[1]); | |||||
| const auto phase = CostModelContext::GetInstance()->run_phase(); | |||||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() && | |||||
| alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { | |||||
| u = alive_ops[0]; | |||||
| v = alive_ops[1]; | |||||
| } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && | |||||
| alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { | |||||
| u = alive_ops[1]; | |||||
| v = alive_ops[0]; | |||||
| } else { | |||||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { | |||||
| MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() | |||||
| << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; | |||||
| } else { | |||||
| // In this case, the final graph consists of two single nodes | |||||
| MS_LOG(INFO) << "There are 2 single nodes in the final graph."; | |||||
| std::vector<CostPtrList> all_list; | |||||
| auto connected_components = ConstructConnectedComponents(alive_ops); | |||||
| MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; | |||||
| for (size_t i = 0; i < connected_components.size(); ++i) { | |||||
| MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; | |||||
| auto one_component = connected_components[i]; | |||||
| MS_EXCEPTION_IF_NULL(one_component); | |||||
| auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); | |||||
| all_list.push_back(cost_list); | |||||
| } | |||||
| CostPtrList selected_cost_list; | |||||
| if (phase == TRAINING_PHASE) { | |||||
| // training phase | |||||
| selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity); | |||||
| } else { | |||||
| // inference phase | |||||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " | |||||
| "phase is not supported."; | |||||
| } | |||||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | |||||
| auto selected_cost = selected_cost_list[k]; | |||||
| if (selected_cost == nullptr) { | |||||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity | |||||
| << "."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(connected_components[k]); | |||||
| auto one_operator = connected_components[k]->GetOperators()[0]; | |||||
| MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); | |||||
| auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(decision); | |||||
| one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); | |||||
| MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "There are 2 nodes in the final graph."; | |||||
| // In this case, the finale graph is exactly of the form: u --> v | |||||
| MS_EXCEPTION_IF_NULL(u); | |||||
| MS_EXCEPTION_IF_NULL(v); | |||||
| auto e = u->GetAliveSuccEdges()[0]; | |||||
| MS_EXCEPTION_IF_NULL(e); | |||||
| auto cost_list = CreateFinalCostList(u, e, v); | |||||
| CostPtr cost = nullptr; | |||||
| if (phase == TRAINING_PHASE) { | |||||
| // training phase | |||||
| cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " | |||||
| "phase is not supported."; | |||||
| } | |||||
| if (cost == nullptr) { | |||||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(cost->decision_ptr_); | |||||
| auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(decision); | |||||
| u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); | |||||
| v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); | |||||
| e->set_selected_cost(decision->middle_cost_); | |||||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; | |||||
| return SUCCESS; | |||||
| } | |||||
| // searching the strategy for the final eliminated graph | // searching the strategy for the final eliminated graph | ||||
| Status CostGraph::SearchStrategy() { | Status CostGraph::SearchStrategy() { | ||||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; | MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; | ||||
| @@ -637,9 +567,11 @@ Status CostGraph::SearchStrategy() { | |||||
| alive_ops.push_back(op); | alive_ops.push_back(op); | ||||
| } | } | ||||
| }); | }); | ||||
| const auto phase = CostModelContext::GetInstance()->run_phase(); | |||||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||||
| if (alive_ops.size() > 2) { | if (alive_ops.size() > 2) { | ||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| if (phase == TRAINING_PHASE) { | |||||
| // training phase | // training phase | ||||
| return SearchStrategyForMultiNodeFinalGraph(alive_ops); | return SearchStrategyForMultiNodeFinalGraph(alive_ops); | ||||
| } else { | } else { | ||||
| @@ -652,15 +584,15 @@ Status CostGraph::SearchStrategy() { | |||||
| OperatorInfoPtr u = alive_ops[0]; | OperatorInfoPtr u = alive_ops[0]; | ||||
| auto cost_list = CreateFinalSingleCostList(u); | auto cost_list = CreateFinalSingleCostList(u); | ||||
| CostPtr cost = nullptr; | CostPtr cost = nullptr; | ||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| if (phase == TRAINING_PHASE) { | |||||
| // training phase | // training phase | ||||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||||
| cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity); | |||||
| } else { | } else { | ||||
| // inference phase | // inference phase | ||||
| cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); | |||||
| cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity); | |||||
| } | } | ||||
| if (cost == nullptr) { | if (cost == nullptr) { | ||||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(u); | MS_EXCEPTION_IF_NULL(u); | ||||
| @@ -671,93 +603,7 @@ Status CostGraph::SearchStrategy() { | |||||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; | MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } else { | } else { | ||||
| // In this case, the final graph should contains exactly 2 nodes. | |||||
| if (alive_ops.empty()) { | |||||
| MS_LOG(INFO) << "0 Operator in the final graph."; | |||||
| return SUCCESS; | |||||
| } | |||||
| OperatorInfoPtr u, v; | |||||
| MS_EXCEPTION_IF_NULL(alive_ops[0]); | |||||
| MS_EXCEPTION_IF_NULL(alive_ops[1]); | |||||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() && | |||||
| alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { | |||||
| u = alive_ops[0]; | |||||
| v = alive_ops[1]; | |||||
| } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && | |||||
| alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { | |||||
| u = alive_ops[1]; | |||||
| v = alive_ops[0]; | |||||
| } else { | |||||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { | |||||
| MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() | |||||
| << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; | |||||
| } else { | |||||
| // In this case, the final graph consists of two single nodes | |||||
| MS_LOG(INFO) << "There are 2 single nodes in the final graph."; | |||||
| std::vector<CostPtrList> all_list; | |||||
| auto connected_components = ConstructConnectedComponents(alive_ops); | |||||
| MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; | |||||
| for (size_t i = 0; i < connected_components.size(); ++i) { | |||||
| MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; | |||||
| auto one_component = connected_components[i]; | |||||
| MS_EXCEPTION_IF_NULL(one_component); | |||||
| auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); | |||||
| all_list.push_back(cost_list); | |||||
| } | |||||
| CostPtrList selected_cost_list; | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||||
| } else { | |||||
| // inference phase | |||||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " | |||||
| "phase is not supported."; | |||||
| } | |||||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | |||||
| auto selected_cost = selected_cost_list[k]; | |||||
| if (selected_cost == nullptr) { | |||||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(connected_components[k]); | |||||
| auto one_operator = connected_components[k]->GetOperators()[0]; | |||||
| MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); | |||||
| auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(decision); | |||||
| one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); | |||||
| MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "There are 2 nodes in the final graph."; | |||||
| // In this case, the finale graph is exactly of the form: u --> v | |||||
| MS_EXCEPTION_IF_NULL(u); | |||||
| MS_EXCEPTION_IF_NULL(v); | |||||
| auto e = u->GetAliveSuccEdges()[0]; | |||||
| MS_EXCEPTION_IF_NULL(e); | |||||
| auto cost_list = CreateFinalCostList(u, e, v); | |||||
| CostPtr cost = nullptr; | |||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| // training phase | |||||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " | |||||
| "phase is not supported."; | |||||
| } | |||||
| if (cost == nullptr) { | |||||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(cost->decision_ptr_); | |||||
| auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(decision); | |||||
| u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); | |||||
| v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); | |||||
| e->set_selected_cost(decision->middle_cost_); | |||||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; | |||||
| return SUCCESS; | |||||
| return SearchStrategyForTwoNodeFinalGraph(alive_ops); | |||||
| } | } | ||||
| } | } | ||||
| @@ -867,10 +713,11 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con | |||||
| op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_; | op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_; | ||||
| auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost); | auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost); | ||||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | auto new_cost = std::make_shared<Cost>(computation, communication, decision); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| MS_EXCEPTION_IF_NULL(new_cost); | MS_EXCEPTION_IF_NULL(new_cost); | ||||
| new_cost->communication_without_parameter_ = communication_without_para; | new_cost->communication_without_parameter_ = communication_without_para; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| communication_without_para + gamma * (communication - communication_without_para); | |||||
| new_cost->memory_with_reuse_ = memory; | new_cost->memory_with_reuse_ = memory; | ||||
| new_cost->communication_forward_ = communication_forward; | new_cost->communication_forward_ = communication_forward; | ||||
| MS_EXCEPTION_IF_NULL(op1_new_clist); | MS_EXCEPTION_IF_NULL(op1_new_clist); | ||||
| @@ -879,6 +726,65 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con | |||||
| } | } | ||||
| } | } | ||||
| std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes( | |||||
| OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges, | |||||
| std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges, | |||||
| OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges, | |||||
| std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) { | |||||
| for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) { | |||||
| auto &new_cost_map = op1_new_edges_cost->at(i); | |||||
| auto ith_edge = op1_old_succ_edges->at(i); | |||||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); | |||||
| std::shared_ptr<Edge> new_edge; | |||||
| if (ith_edge->is_combined()) { | |||||
| std::vector<size_t> output_indexs, input_indexs; | |||||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||||
| input_indexs = ith_edge->next_op_input_indexs(); | |||||
| new_edge = | |||||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); | |||||
| } else { | |||||
| size_t output_index, input_index; | |||||
| output_index = ith_edge->prev_op_output_index(); | |||||
| input_index = ith_edge->next_op_input_index(); | |||||
| new_edge = | |||||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); | |||||
| } | |||||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||||
| // replace the old successive edges with the new ones. | |||||
| op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); | |||||
| ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); | |||||
| op1_new_succ_edges->erase(op1_new_succ_edges->begin() + i); | |||||
| op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + i, new_edge); | |||||
| } | |||||
| for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) { | |||||
| auto &new_cost_map = op2_new_edges_cost->at(i); | |||||
| auto ith_edge = op2_old_succ_edges->at(i); | |||||
| const auto &destination = ith_edge->next_operator(); | |||||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); | |||||
| std::shared_ptr<Edge> new_edge; | |||||
| if (ith_edge->is_combined()) { | |||||
| std::vector<size_t> output_indexs, input_indexs; | |||||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||||
| input_indexs = ith_edge->next_op_input_indexs(); | |||||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true); | |||||
| } else { | |||||
| size_t output_index, input_index; | |||||
| output_index = ith_edge->prev_op_output_index(); | |||||
| input_index = ith_edge->next_op_input_index(); | |||||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false); | |||||
| } | |||||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||||
| // replace the old successive edges with the new ones. | |||||
| destination->ReplacePreEdge(op2, new_edge); | |||||
| op1->AddSuccEdge(new_edge); | |||||
| op2_new_succ_edges->erase(op2_new_succ_edges->begin() + i); | |||||
| op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + i, new_edge); | |||||
| } | |||||
| return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges); | |||||
| } | |||||
| std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources( | std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources( | ||||
| OperatorInfoPtr op1, OperatorInfoPtr op2) { | OperatorInfoPtr op1, OperatorInfoPtr op2) { | ||||
| MS_EXCEPTION_IF_NULL(op1); | MS_EXCEPTION_IF_NULL(op1); | ||||
| @@ -970,57 +876,9 @@ std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>> | |||||
| op2->SetNotAlive(); | op2->SetNotAlive(); | ||||
| // Update the edges incident to op1, and edges incident to op2 | // Update the edges incident to op1, and edges incident to op2 | ||||
| for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { | |||||
| auto &new_cost_map = op1_new_edges_cost[i]; | |||||
| auto &ith_edge = op1_old_succ_edges[i]; | |||||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); | |||||
| std::shared_ptr<Edge> new_edge; | |||||
| if (ith_edge->is_combined()) { | |||||
| std::vector<size_t> output_indexs, input_indexs; | |||||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||||
| input_indexs = ith_edge->next_op_input_indexs(); | |||||
| new_edge = | |||||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); | |||||
| } else { | |||||
| size_t output_index, input_index; | |||||
| output_index = ith_edge->prev_op_output_index(); | |||||
| input_index = ith_edge->next_op_input_index(); | |||||
| new_edge = | |||||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); | |||||
| } | |||||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||||
| // replace the old successive edges with the new ones. | |||||
| op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); | |||||
| ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); | |||||
| op1_new_succ_edges[i] = new_edge; | |||||
| } | |||||
| for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { | |||||
| auto &new_cost_map = op2_new_edges_cost[i]; | |||||
| auto &ith_edge = op2_old_succ_edges[i]; | |||||
| const auto &destination = ith_edge->next_operator(); | |||||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); | |||||
| std::shared_ptr<Edge> new_edge; | |||||
| if (ith_edge->is_combined()) { | |||||
| std::vector<size_t> output_indexs, input_indexs; | |||||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||||
| input_indexs = ith_edge->next_op_input_indexs(); | |||||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true); | |||||
| } else { | |||||
| size_t output_index, input_index; | |||||
| output_index = ith_edge->prev_op_output_index(); | |||||
| input_index = ith_edge->next_op_input_index(); | |||||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false); | |||||
| } | |||||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||||
| // replace the old successive edges with the new ones. | |||||
| destination->ReplacePreEdge(op2, new_edge); | |||||
| op1->AddSuccEdge(new_edge); | |||||
| op2_new_succ_edges[i] = new_edge; | |||||
| } | |||||
| MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded."; | MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded."; | ||||
| return {op1_new_succ_edges, op2_new_succ_edges}; | |||||
| return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2, | |||||
| &op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges); | |||||
| } | } | ||||
| // Check the graph whether a TriangleElimination can be performed | // Check the graph whether a TriangleElimination can be performed | ||||
| @@ -1179,10 +1037,11 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||||
| auto decision = | auto decision = | ||||
| std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); | std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); | ||||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | auto new_cost = std::make_shared<Cost>(computation, communication, decision); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| MS_EXCEPTION_IF_NULL(new_cost); | MS_EXCEPTION_IF_NULL(new_cost); | ||||
| new_cost->communication_without_parameter_ = communication_without_para; | new_cost->communication_without_parameter_ = communication_without_para; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| communication_without_para + gamma * (communication - communication_without_para); | |||||
| new_cost->memory_with_reuse_ = memory; | new_cost->memory_with_reuse_ = memory; | ||||
| new_cost->communication_forward_ = communication_forward; | new_cost->communication_forward_ = communication_forward; | ||||
| MS_EXCEPTION_IF_NULL(tar_cost_list_new); | MS_EXCEPTION_IF_NULL(tar_cost_list_new); | ||||
| @@ -1263,9 +1122,10 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||||
| auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost, | auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost, | ||||
| target_op_stra, tar_cost); | target_op_stra, tar_cost); | ||||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | auto new_cost = std::make_shared<Cost>(computation, communication, decision); | ||||
| auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| new_cost->communication_without_parameter_ = communication_without_para; | new_cost->communication_without_parameter_ = communication_without_para; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| communication_without_para + gamma * (communication - communication_without_para); | |||||
| new_cost->memory_with_reuse_ = memory; | new_cost->memory_with_reuse_ = memory; | ||||
| new_cost->communication_forward_ = communication_forward; | new_cost->communication_forward_ = communication_forward; | ||||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | tar_cost_list_new->emplace_back(std::move(new_cost)); | ||||
| @@ -1338,8 +1198,10 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||||
| double new_commu_without = | double new_commu_without = | ||||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | ||||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | ||||
| const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| if (triangle_star_stra_overwrite) { | |||||
| new_computation += right_op_cost->computation_cost_; | new_computation += right_op_cost->computation_cost_; | ||||
| new_memory += right_op_cost->memory_with_reuse_; | new_memory += right_op_cost->memory_with_reuse_; | ||||
| new_commu_cost += right_op_cost->communication_cost_; | new_commu_cost += right_op_cost->communication_cost_; | ||||
| @@ -1352,8 +1214,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||||
| left_op_stra, left_node_cost, right_op_stra, right_op_cost); | left_op_stra, left_node_cost, right_op_stra, right_op_cost); | ||||
| auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision); | auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision); | ||||
| new_cost->communication_without_parameter_ = new_commu_without; | new_cost->communication_without_parameter_ = new_commu_without; | ||||
| new_cost->communication_with_partial_para_ = | |||||
| new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | |||||
| new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without); | |||||
| new_cost->memory_with_reuse_ = new_memory; | new_cost->memory_with_reuse_ = new_memory; | ||||
| new_cost->communication_forward_ = new_commu_forward; | new_cost->communication_forward_ = new_commu_forward; | ||||
| left_node_clist_new->emplace_back(std::move(new_cost)); | left_node_clist_new->emplace_back(std::move(new_cost)); | ||||
| @@ -1463,6 +1324,8 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||||
| memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, | memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, | ||||
| commu_without = merged_node_cost->communication_without_parameter_, | commu_without = merged_node_cost->communication_without_parameter_, | ||||
| commu_forward = merged_node_cost->communication_forward_; | commu_forward = merged_node_cost->communication_forward_; | ||||
| const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | ||||
| if (i == 0) { | if (i == 0) { | ||||
| @@ -1478,7 +1341,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||||
| commu_cost += succ_edges_costs[i]->communication_cost_; | commu_cost += succ_edges_costs[i]->communication_cost_; | ||||
| commu_forward += succ_edges_costs[i]->communication_forward_; | commu_forward += succ_edges_costs[i]->communication_forward_; | ||||
| commu_without += succ_edges_costs[i]->communication_without_parameter_; | commu_without += succ_edges_costs[i]->communication_without_parameter_; | ||||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||||
| if (triangle_star_stra_overwrite) { | |||||
| computation_cost += succ_nodes_costs[i]->computation_cost_; | computation_cost += succ_nodes_costs[i]->computation_cost_; | ||||
| memory_cost += succ_nodes_costs[i]->memory_with_reuse_; | memory_cost += succ_nodes_costs[i]->memory_with_reuse_; | ||||
| commu_cost += succ_nodes_costs[i]->communication_cost_; | commu_cost += succ_nodes_costs[i]->communication_cost_; | ||||
| @@ -1492,7 +1355,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||||
| succ_nodes_stras, succ_nodes_costs); | succ_nodes_stras, succ_nodes_costs); | ||||
| auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision); | auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision); | ||||
| new_cost->communication_without_parameter_ = commu_without; | new_cost->communication_without_parameter_ = commu_without; | ||||
| new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | |||||
| new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without); | |||||
| new_cost->memory_with_reuse_ = memory_cost; | new_cost->memory_with_reuse_ = memory_cost; | ||||
| new_cost->communication_forward_ = commu_forward; | new_cost->communication_forward_ = commu_forward; | ||||
| first_succ_node_clist_new->emplace_back(std::move(new_cost)); | first_succ_node_clist_new->emplace_back(std::move(new_cost)); | ||||
| @@ -1895,7 +1758,8 @@ Status CostGraph::CorrectOpsMemoryCost() { | |||||
| } | } | ||||
| Status CostGraph::CalculateMemoryCost() { | Status CostGraph::CalculateMemoryCost() { | ||||
| if (RUN_PHASE == TRAINING_PHASE) { | |||||
| const auto phase = CostModelContext::GetInstance()->run_phase(); | |||||
| if (phase == TRAINING_PHASE) { | |||||
| // training phase | // training phase | ||||
| if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | ||||
| // Calculate operators' memory usage | // Calculate operators' memory usage | ||||
| @@ -34,32 +34,12 @@ class CostGraph; | |||||
| using CostGraphPtr = std::shared_ptr<CostGraph>; | using CostGraphPtr = std::shared_ptr<CostGraph>; | ||||
| extern CostGraphPtr entire_costgraph; | extern CostGraphPtr entire_costgraph; | ||||
| extern size_t TOTAL_OPS; | extern size_t TOTAL_OPS; | ||||
| extern double COST_MODEL_GAMMA; | |||||
| extern bool COST_MODEL_SIMPLIFY_CALCULATION; | |||||
| extern double DEVICE_MEMORY_CAPACITY; | |||||
| extern double COST_MODEL_COMMUNI_THRESHOLD; | |||||
| extern double COST_MODEL_COMMUNI_CONST; | |||||
| extern double COST_MODEL_COMMUNI_BIAS; | |||||
| extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; | |||||
| extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | |||||
| extern bool FULLY_USE_DEVICES; | |||||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | |||||
| extern bool MULTI_SUBGRAPHS; | |||||
| extern bool DP_ALGO_ENABLE_APPROX; | |||||
| extern double DP_ALGO_APPROX_EPSILON; | |||||
| extern int64_t RUN_PHASE; | |||||
| extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||||
| extern bool DP_ALGO_SINGLE_LOOP; | |||||
| class CostGraph { | class CostGraph { | ||||
| // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | ||||
| // output-input dependency relationship. | // output-input dependency relationship. | ||||
| public: | public: | ||||
| CostGraph() { | |||||
| dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; | |||||
| costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; | |||||
| costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; | |||||
| } | |||||
| CostGraph() {} | |||||
| ~CostGraph() = default; | ~CostGraph() = default; | ||||
| void Init(); | void Init(); | ||||
| void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } | void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } | ||||
| @@ -79,8 +59,6 @@ class CostGraph { | |||||
| // An edge is uniquely identified by its name, and its output index and input index. | // An edge is uniquely identified by its name, and its output index and input index. | ||||
| bool IsEdgeInCostGraph(const std::string &, size_t, size_t); | bool IsEdgeInCostGraph(const std::string &, size_t, size_t); | ||||
| void SetDeviceMemoryAndCostParameter(); | |||||
| std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>); | std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>); | ||||
| void DFS(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited, | void DFS(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited, | ||||
| const std::shared_ptr<CostGraph> &component); | const std::shared_ptr<CostGraph> &component); | ||||
| @@ -91,10 +69,10 @@ class CostGraph { | |||||
| CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); | CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); | ||||
| CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory); | CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory); | ||||
| Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | ||||
| Status SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | |||||
| std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { | std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { | ||||
| return edges_[{u_node, v_node}]; | return edges_[{u_node, v_node}]; | ||||
| } | } | ||||
| double GetDeviceMemory() const { return dev_memory_; } | |||||
| // Search the cost_list in the final graph, and determine the optimal one | // Search the cost_list in the final graph, and determine the optimal one | ||||
| Status SearchStrategy(); | Status SearchStrategy(); | ||||
| @@ -194,7 +172,7 @@ class CostGraph { | |||||
| Status InitReshapeStrategy(); | Status InitReshapeStrategy(); | ||||
| Status InitSelectedStrategy(); | Status InitSelectedStrategy(); | ||||
| OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | ||||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | |||||
| // When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only | |||||
| // once (instead of multiple times), this method is used to correct this. | // once (instead of multiple times), this method is used to correct this. | ||||
| Status CorrectOpsMemoryCost(); | Status CorrectOpsMemoryCost(); | ||||
| // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies. | // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies. | ||||
| @@ -224,9 +202,6 @@ class CostGraph { | |||||
| // Needed by rec_parser | // Needed by rec_parser | ||||
| std::vector<std::vector<std::string>> inputs_tensor_name_list_; | std::vector<std::vector<std::string>> inputs_tensor_name_list_; | ||||
| std::map<std::string, std::string> tuple_getitem_list_; | std::map<std::string, std::string> tuple_getitem_list_; | ||||
| double dev_memory_; | |||||
| double costmodel_alpha_; | |||||
| double costmodel_beta_; | |||||
| std::vector<OperatorInfoPtr> ops_; | std::vector<OperatorInfoPtr> ops_; | ||||
| std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_; | std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_; | ||||
| std::vector<std::shared_ptr<CostGraph>> connected_compoents_; | std::vector<std::shared_ptr<CostGraph>> connected_compoents_; | ||||
| @@ -47,7 +47,7 @@ void CostModelContext::ResetCostModel() { | |||||
| costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | ||||
| costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | ||||
| is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | ||||
| run_phase_ = DEFAULT_RUN_PHASE; | |||||
| run_phase_ = TRAINING_PHASE; | |||||
| costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | ||||
| costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | ||||
| costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | ||||
| @@ -70,37 +70,115 @@ void CostModelContext::ResetAlgoParameters() { | |||||
| dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; | dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; | ||||
| } | } | ||||
| void CostModelContext::PrintCostModel() { | |||||
| MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << "."; | |||||
| MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; | |||||
| MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; | |||||
| MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << "."; | |||||
| MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << "."; | |||||
| MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << "."; | |||||
| MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << "."; | |||||
| MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << "."; | |||||
| MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << "."; | |||||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << "."; | |||||
| MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << "."; | |||||
| MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << "."; | |||||
| MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << "."; | |||||
| MS_LOG(INFO) << "run_phase: " << run_phase_ << "."; | |||||
| MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << "."; | |||||
| MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << "."; | |||||
| MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << "."; | |||||
| MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << "."; | |||||
| } | |||||
| void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | ||||
| if (device_target == kGPUDevice) { | if (device_target == kGPUDevice) { | ||||
| costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU; | costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU; | ||||
| } | } | ||||
| } | } | ||||
| void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; } | |||||
| void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { | |||||
| if (epsilon <= 0 || epsilon > 1) { | |||||
| MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; | |||||
| } | |||||
| dp_algo_approxi_epsilon_ = epsilon; | |||||
| } | |||||
| void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; } | |||||
| void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { | |||||
| if (approxi) { | |||||
| MS_LOG(INFO) << "dp_algo_enable_approx: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "dp_algo_enable_approx: false."; | |||||
| } | |||||
| dp_algo_enable_approxi_ = approxi; | |||||
| } | |||||
| void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } | |||||
| void CostModelContext::set_device_memory_capacity(double dm_capacity) { | |||||
| if (dm_capacity <= 0) { | |||||
| MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; | |||||
| } | |||||
| device_memory_capacity_ = dm_capacity; | |||||
| } | |||||
| void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } | |||||
| void CostModelContext::set_costmodel_alpha(double cm_alpha) { | |||||
| if (cm_alpha <= 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; | |||||
| } | |||||
| costmodel_alpha_ = cm_alpha; | |||||
| } | |||||
| void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; } | |||||
| void CostModelContext::set_costmodel_beta(double cm_beta) { | |||||
| if (cm_beta <= 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; | |||||
| } | |||||
| costmodel_beta_ = cm_beta; | |||||
| } | |||||
| void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; } | |||||
| void CostModelContext::set_costmodel_gamma(double cm_gamma) { | |||||
| if ((cm_gamma < 0) || (cm_gamma > 1)) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; | |||||
| } | |||||
| costmodel_gamma_ = cm_gamma; | |||||
| } | |||||
| void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; } | |||||
| void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { | |||||
| if (cm_simplify) { | |||||
| MS_LOG(INFO) << "costmodel_simplify_cal: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "costmodel_simplify_cal: false."; | |||||
| } | |||||
| costmodel_simplify_cal_ = cm_simplify; | |||||
| } | |||||
| void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { | void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { | ||||
| if (cm_communi_th < 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; | |||||
| } | |||||
| costmodel_communi_threshold_ = cm_communi_th; | costmodel_communi_threshold_ = cm_communi_th; | ||||
| } | } | ||||
| void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { | void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { | ||||
| if (cm_communi_const < 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; | |||||
| } | |||||
| costmodel_communi_const_ = cm_communi_const; | costmodel_communi_const_ = cm_communi_const; | ||||
| } | } | ||||
| void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } | |||||
| void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { | |||||
| if (cm_communi_bias < 0) { | |||||
| MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; | |||||
| } | |||||
| costmodel_communi_bias_ = cm_communi_bias; | |||||
| } | |||||
| void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } | |||||
| void CostModelContext::set_multi_subgraphs(bool multi_graphs) { | |||||
| if (multi_graphs) { | |||||
| MS_LOG(INFO) << "multi_subgraphs: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||||
| } | |||||
| is_multi_subgraphs_ = multi_graphs; | |||||
| } | |||||
| void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) { | void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) { | ||||
| costmodel_allreduce_fusion_algorithm_ = algorithm; | costmodel_allreduce_fusion_algorithm_ = algorithm; | ||||
| } | } | ||||
| @@ -129,25 +207,64 @@ void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter | |||||
| costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; | costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; | ||||
| } | } | ||||
| void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; } | |||||
| void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { | |||||
| if (ts_align) { | |||||
| MS_LOG(INFO) << "tensor_slice_align_enable: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "tensor_slice_align_enable: false."; | |||||
| } | |||||
| tensor_slice_alignment_enable_ = ts_align; | |||||
| } | |||||
| void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { | void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { | ||||
| if (ts_align_size == 0) { | |||||
| MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; | |||||
| } | |||||
| tensor_slice_alignment_size_ = ts_align_size; | tensor_slice_alignment_size_ = ts_align_size; | ||||
| } | } | ||||
| void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } | |||||
| void CostModelContext::set_fully_use_device(bool fully_use) { | |||||
| if (fully_use) { | |||||
| MS_LOG(INFO) << "fully_use_devices: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "fully_use_devices: false."; | |||||
| } | |||||
| fully_use_device_ = fully_use; | |||||
| } | |||||
| void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | ||||
| if (elementwise_follow) { | |||||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; | |||||
| } | |||||
| elementwise_stra_follow_ = elementwise_follow; | elementwise_stra_follow_ = elementwise_follow; | ||||
| } | } | ||||
| void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { | void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { | ||||
| if (overwrite) { | |||||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; | |||||
| } | |||||
| triangle_star_strategy_overwrite_ = overwrite; | triangle_star_strategy_overwrite_ = overwrite; | ||||
| } | } | ||||
| void CostModelContext::set_run_phase(int64_t phase) { run_phase_ = phase; } | |||||
| void CostModelContext::set_run_phase(int64_t phase) { | |||||
| if (phase != 0 && phase != 1) { | |||||
| MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; | |||||
| } | |||||
| run_phase_ = phase; | |||||
| } | |||||
| void CostModelContext::set_dp_algo_single_loop(bool single_loop) { dp_algo_single_loop_ = single_loop; } | |||||
| void CostModelContext::set_dp_algo_single_loop(bool single_loop) { | |||||
| if (single_loop) { | |||||
| MS_LOG(INFO) << "dp_algo_single_loop: true."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "dp_algo_single_loop: false."; | |||||
| } | |||||
| dp_algo_single_loop_ = single_loop; | |||||
| } | |||||
| struct CostRegister { | struct CostRegister { | ||||
| CostRegister() { | CostRegister() { | ||||
| @@ -41,7 +41,6 @@ namespace parallel { | |||||
| #define DEFAULT_FULLY_USE_DEVICES true | #define DEFAULT_FULLY_USE_DEVICES true | ||||
| #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | ||||
| #define DEFAULT_IS_MULTI_SUBGRAPHS false | #define DEFAULT_IS_MULTI_SUBGRAPHS false | ||||
| #define DEFAULT_RUN_PHASE 0 | |||||
| #define TRAINING_PHASE 0 | #define TRAINING_PHASE 0 | ||||
| #define INFERENCE_PHASE 1 | #define INFERENCE_PHASE 1 | ||||
| #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; | #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; | ||||
| @@ -58,6 +57,7 @@ class CostModelContext { | |||||
| void ResetAlgoParameters(); | void ResetAlgoParameters(); | ||||
| static std::shared_ptr<CostModelContext> GetInstance(); | static std::shared_ptr<CostModelContext> GetInstance(); | ||||
| void PrintCostModel(); | |||||
| void set_costmodel_context_for_device(const std::string &); | void set_costmodel_context_for_device(const std::string &); | ||||
| // DEVICE_MEMORY_CAPACITY | // DEVICE_MEMORY_CAPACITY | ||||
| @@ -475,7 +475,8 @@ Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num, | |||||
| size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { | size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { | ||||
| int64_t product = | int64_t product = | ||||
| std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>()); | std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>()); | ||||
| if (!FULLY_USE_DEVICES) { | |||||
| const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device(); | |||||
| if (!fully_use_device) { | |||||
| if (LongToSize(product) > dev_num) { | if (LongToSize(product) > dev_num) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -564,7 +565,9 @@ void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_te | |||||
| } | } | ||||
| Status MatMulBase::CheckForTensorSliceValid() const { | Status MatMulBase::CheckForTensorSliceValid() const { | ||||
| if (!TENSOR_SLICE_ALIGNMENT_ENABLE) { | |||||
| const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); | |||||
| const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); | |||||
| if (!align_enable) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (inputs_tensor_info_.empty()) { | if (inputs_tensor_info_.empty()) { | ||||
| @@ -572,8 +575,8 @@ Status MatMulBase::CheckForTensorSliceValid() const { | |||||
| } | } | ||||
| for (auto &one_input_tensor : inputs_tensor_info_) { | for (auto &one_input_tensor : inputs_tensor_info_) { | ||||
| auto slice_shape = one_input_tensor.slice_shape(); | auto slice_shape = one_input_tensor.slice_shape(); | ||||
| if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || | |||||
| (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { | |||||
| if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) || | |||||
| (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -608,12 +611,12 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & | |||||
| double computation_cost = | double computation_cost = | ||||
| operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | ||||
| double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | ||||
| result->communication_without_parameter_ = | result->communication_without_parameter_ = | ||||
| operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | ||||
| result->communication_with_partial_para_ = | result->communication_with_partial_para_ = | ||||
| result->communication_without_parameter_ + | |||||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||||
| result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); | |||||
| // Breaking ties for preferring data parallelization | // Breaking ties for preferring data parallelization | ||||
| BreakingTiesForPerferringDataParallel(strategy, result); | BreakingTiesForPerferringDataParallel(strategy, result); | ||||
| @@ -839,7 +839,8 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input | |||||
| for (auto &input_partition : inputs_partitions) { | for (auto &input_partition : inputs_partitions) { | ||||
| product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>()); | product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>()); | ||||
| } | } | ||||
| if (!FULLY_USE_DEVICES) { | |||||
| const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device(); | |||||
| if (!fully_use_device) { | |||||
| if (LongToSize(product) > dev_num) { | if (LongToSize(product) > dev_num) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1202,12 +1203,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { | |||||
| double computation_cost = | double computation_cost = | ||||
| operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | ||||
| result->communication_without_parameter_ = | result->communication_without_parameter_ = | ||||
| operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| result->communication_with_partial_para_ = | result->communication_with_partial_para_ = | ||||
| result->communication_without_parameter_ + | |||||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||||
| result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); | |||||
| // Breaking ties for preferring data parallelization | // Breaking ties for preferring data parallelization | ||||
| BreakingTiesForPerferringDataParallel(strategy, result); | BreakingTiesForPerferringDataParallel(strategy, result); | ||||
| @@ -396,12 +396,12 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra | |||||
| double computation_cost = | double computation_cost = | ||||
| operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | ||||
| result->communication_without_parameter_ = | result->communication_without_parameter_ = | ||||
| operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| result->communication_with_partial_para_ = | result->communication_with_partial_para_ = | ||||
| result->communication_without_parameter_ + | |||||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||||
| result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); | |||||
| // Breaking ties for preferring data parallelization | // Breaking ties for preferring data parallelization | ||||
| BreakingTiesForPerferringDataParallel(strategy, result); | BreakingTiesForPerferringDataParallel(strategy, result); | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -13,6 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "frontend/parallel/parallel_stub/executor_manager_stub.h" | #include "frontend/parallel/parallel_stub/executor_manager_stub.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -26,6 +27,5 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device | |||||
| executors_[device_key] = executor; | executors_[device_key] = executor; | ||||
| return executor; | return executor; | ||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -13,6 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ | #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ | ||||
| #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ | #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ | ||||
| #include <set> | #include <set> | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -13,6 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H | #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H | ||||
| #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H | #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H | ||||
| @@ -52,13 +52,18 @@ static bool IsInWhiteList(const CNodePtr &cnode) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { | |||||
| static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) { | |||||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||||
| MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: " | |||||
| << MAX_RECURSIVE_DEPTH; | |||||
| return; | |||||
| } | |||||
| const auto &node_users = manager->node_users()[node]; | const auto &node_users = manager->node_users()[node]; | ||||
| for (auto &user_pair : node_users) { | for (auto &user_pair : node_users) { | ||||
| auto user_node = user_pair.first; | auto user_node = user_pair.first; | ||||
| if (!user_node->grad()) { | if (!user_node->grad()) { | ||||
| user_node->set_grad(true); | user_node->set_grad(true); | ||||
| SetGradTag(user_node, manager); | |||||
| SetGradTag(user_node, manager, ++curr_depth); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -69,7 +74,7 @@ void PipelineTransformer::LabelRequiredGradCNode() { | |||||
| if (!ParameterRequireGrad(parameter)) { | if (!ParameterRequireGrad(parameter)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| SetGradTag(parameter, manager_); | |||||
| SetGradTag(parameter, manager_, 0); | |||||
| } | } | ||||
| } | } | ||||
| @@ -234,7 +234,8 @@ void InitCostGraph() { | |||||
| if (entire_costgraph == nullptr) { | if (entire_costgraph == nullptr) { | ||||
| entire_costgraph = std::make_shared<CostGraph>(); | entire_costgraph = std::make_shared<CostGraph>(); | ||||
| } | } | ||||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||||
| CostModelContext::GetInstance()->PrintCostModel(); | |||||
| entire_costgraph->Init(); | entire_costgraph->Init(); | ||||
| } | } | ||||
| @@ -252,10 +253,11 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive | |||||
| if (prim->name() == RESHAPE) { | if (prim->name() == RESHAPE) { | ||||
| MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; | MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; | ||||
| } | } | ||||
| const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device(); | |||||
| // Set cost for this configured strategy | // Set cost for this configured strategy | ||||
| if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { | if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; | MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; | ||||
| } else if (FULLY_USE_DEVICES) { | |||||
| } else if (fully_use_devices) { | |||||
| // If configured to fully use devices, then checking for the user-specified strategy | // If configured to fully use devices, then checking for the user-specified strategy | ||||
| int64_t used_devices = operator_info->used_devices(); | int64_t used_devices = operator_info->used_devices(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| @@ -323,7 +325,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| operator_info->set_cnode(cnode); | operator_info->set_cnode(cnode); | ||||
| // key of strategy map | // key of strategy map | ||||
| std::string strategy_key_name = ""; | std::string strategy_key_name = ""; | ||||
| auto param_names = NodeParameterName(cnode); | |||||
| auto param_names = NodeParameterName(cnode, -1, 0); | |||||
| if (!param_names.empty()) { | if (!param_names.empty()) { | ||||
| strategy_key_name = prim->name() + "_" + param_names[0].first; | strategy_key_name = prim->name() + "_" + param_names[0].first; | ||||
| } | } | ||||
| @@ -394,7 +396,8 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| if (search_cnode == from_cnode_to_info.end()) { | if (search_cnode == from_cnode_to_info.end()) { | ||||
| size_t loop_index = 0; | size_t loop_index = 0; | ||||
| bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); | bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); | ||||
| if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { | |||||
| const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); | |||||
| if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { | |||||
| const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; | const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; | ||||
| bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && | bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && | ||||
| (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && | (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && | ||||
| @@ -430,7 +433,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() | << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() | ||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | ||||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info)); | (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info)); | ||||
| if (DP_ALGO_SINGLE_LOOP && is_in_loop) { | |||||
| if (single_loop && is_in_loop) { | |||||
| operators_in_forloop.push_back(operator_info); | operators_in_forloop.push_back(operator_info); | ||||
| ops_in_a_loop_.insert(operator_info->name()); | ops_in_a_loop_.insert(operator_info->name()); | ||||
| loop_to_ops[loop_index]++; | loop_to_ops[loop_index]++; | ||||
| @@ -511,7 +514,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| if (search_cnode == from_cnode_to_info.end()) { | if (search_cnode == from_cnode_to_info.end()) { | ||||
| size_t loop_index = 0; | size_t loop_index = 0; | ||||
| bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); | bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); | ||||
| bool is_op_created = DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); | |||||
| const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); | |||||
| bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); | |||||
| if (is_op_created) { | if (is_op_created) { | ||||
| const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; | const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; | ||||
| bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && | bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && | ||||
| @@ -548,7 +552,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() | << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() | ||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | ||||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); | (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); | ||||
| if (DP_ALGO_SINGLE_LOOP && is_in_loop) { | |||||
| if (single_loop && is_in_loop) { | |||||
| operators_in_forloop.push_back(operator_info); | operators_in_forloop.push_back(operator_info); | ||||
| ops_in_a_loop_.insert(operator_info->name()); | ops_in_a_loop_.insert(operator_info->name()); | ||||
| loop_to_ops[loop_index]++; | loop_to_ops[loop_index]++; | ||||
| @@ -581,8 +585,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator | |||||
| MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; | MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; | ||||
| return; | return; | ||||
| } | } | ||||
| const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); | |||||
| bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || | bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || | ||||
| (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); | |||||
| (stra_follow && IsElementWiseOperator(prev_prim->name())); | |||||
| if (follow_strategy) { | if (follow_strategy) { | ||||
| // Redistribution in not allowed on the edge. | // Redistribution in not allowed on the edge. | ||||
| // Elementwise operators have the same strategy as their previous operators. | // Elementwise operators have the same strategy as their previous operators. | ||||
| @@ -1031,7 +1036,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const | |||||
| graph = EliminateGraph(graph, eli_list, index_list); | graph = EliminateGraph(graph, eli_list, index_list); | ||||
| size_t num_device = g_device_manager->DeviceNum(); | size_t num_device = g_device_manager->DeviceNum(); | ||||
| double device_memory = entire_costgraph->GetDeviceMemory(); | |||||
| const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); | |||||
| if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { | if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { | ||||
| MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; | MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; | ||||
| } else { | } else { | ||||
| @@ -1914,7 +1914,11 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||||
| } | } | ||||
| // find previous parallel care node's next node. | // find previous parallel care node's next node. | ||||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) { | |||||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes, size_t curr_depth) { | |||||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||||
| MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH; | |||||
| return false; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(unique_ids); | MS_EXCEPTION_IF_NULL(unique_ids); | ||||
| MS_EXCEPTION_IF_NULL(indexes); | MS_EXCEPTION_IF_NULL(indexes); | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| @@ -1942,7 +1946,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vecto | |||||
| find = true; | find = true; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (FindPreNodes(cnode, unique_ids, indexes)) { | |||||
| if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) { | |||||
| find = true; | find = true; | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -1954,7 +1958,7 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *u | |||||
| std::vector<size_t> *indexes) { | std::vector<size_t> *indexes) { | ||||
| MS_EXCEPTION_IF_NULL(unique_ids); | MS_EXCEPTION_IF_NULL(unique_ids); | ||||
| CNodePtr cnode = root->get_return(); | CNodePtr cnode = root->get_return(); | ||||
| if (!FindPreNodes(cnode, unique_ids, indexes)) { | |||||
| if (!FindPreNodes(cnode, unique_ids, indexes, 0)) { | |||||
| MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; | MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; | ||||
| } | } | ||||
| } | } | ||||
| @@ -2044,7 +2048,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||||
| // load strategy checkpoint | // load strategy checkpoint | ||||
| // key of strategy map | // key of strategy map | ||||
| std::string strategy_key_name = ""; | std::string strategy_key_name = ""; | ||||
| auto param_names = NodeParameterName(cnode); | |||||
| auto param_names = NodeParameterName(cnode, -1, 0); | |||||
| if (!param_names.empty()) { | if (!param_names.empty()) { | ||||
| strategy_key_name = prim->name() + "_" + param_names[0].first; | strategy_key_name = prim->name() + "_" + param_names[0].first; | ||||
| } | } | ||||
| @@ -2151,13 +2155,18 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) { | |||||
| std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) { | |||||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||||
| MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: " | |||||
| << MAX_RECURSIVE_DEPTH; | |||||
| return nullptr; | |||||
| } | |||||
| FuncGraphManagerPtr manager = node->func_graph()->manager(); | FuncGraphManagerPtr manager = node->func_graph()->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| AnfNodeIndexSet node_set = manager->node_users()[node]; | AnfNodeIndexSet node_set = manager->node_users()[node]; | ||||
| for (auto &node_pair : node_set) { | for (auto &node_pair : node_set) { | ||||
| if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) { | if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) { | ||||
| auto layout_param = FindParameterNextLayout(node_pair.first); | |||||
| auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth); | |||||
| if (!layout_param) { | if (!layout_param) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -2184,7 +2193,7 @@ std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) { | |||||
| std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { | std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { | ||||
| // Create DataParallel tensor layout for parameter(support WideDeep). | // Create DataParallel tensor layout for parameter(support WideDeep). | ||||
| auto next_layout = FindParameterNextLayout(node); | |||||
| auto next_layout = FindParameterNextLayout(node, 0); | |||||
| if (next_layout != nullptr) { | if (next_layout != nullptr) { | ||||
| return next_layout; | return next_layout; | ||||
| } | } | ||||
| @@ -2329,14 +2338,19 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| } | } | ||||
| CNodePtr HandleDependLoss(const CNodePtr &cnode) { | |||||
| CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) { | |||||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||||
| MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: " | |||||
| << MAX_RECURSIVE_DEPTH; | |||||
| return nullptr; | |||||
| } | |||||
| // Handle return->depend->loss | // Handle return->depend->loss | ||||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| if (prim->name() == DEPEND) { | if (prim->name() == DEPEND) { | ||||
| auto depend_before = cnode->input(1)->cast<CNodePtr>(); | auto depend_before = cnode->input(1)->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(depend_before); | MS_EXCEPTION_IF_NULL(depend_before); | ||||
| return HandleDependLoss(depend_before); | |||||
| return HandleDependLoss(depend_before, ++curr_depth); | |||||
| } | } | ||||
| return cnode; | return cnode; | ||||
| } | } | ||||
| @@ -2370,7 +2384,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { | |||||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | MS_EXCEPTION_IF_NULL(pre_cnode); | ||||
| } | } | ||||
| pre_cnode = HandleDependLoss(pre_cnode); | |||||
| pre_cnode = HandleDependLoss(pre_cnode, 0); | |||||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| // notice: the GetNext op has not input | // notice: the GetNext op has not input | ||||
| @@ -2792,7 +2806,12 @@ bool IsCohesiveNode(const CNodePtr &cnode) { | |||||
| IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); | IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); | ||||
| } | } | ||||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index) { | |||||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { | |||||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||||
| MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: " | |||||
| << MAX_RECURSIVE_DEPTH; | |||||
| return {}; | |||||
| } | |||||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | std::vector<AnfNodePtr> node_inputs{node->inputs()}; | ||||
| std::vector<std::pair<std::string, int64_t>> param_names; | std::vector<std::pair<std::string, int64_t>> param_names; | ||||
| for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { | for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { | ||||
| @@ -2809,7 +2828,7 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { | if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { | ||||
| auto input_param_names = NodeParameterName(cnode, idx); | |||||
| auto input_param_names = NodeParameterName(cnode, idx, 0); | |||||
| param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); | param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -2827,7 +2846,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto param_names = NodeParameterName(cnode); | |||||
| auto param_names = NodeParameterName(cnode, -1, 0); | |||||
| if (param_names.empty()) { | if (param_names.empty()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -2950,13 +2969,17 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG | |||||
| InsertNode(op, node, 2, pre_node, root, "shape"); | InsertNode(op, node, 2, pre_node, root, "shape"); | ||||
| } | } | ||||
| static AnfNodePtr FindGrad(const CNodePtr &cnode) { | |||||
| static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) { | |||||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||||
| MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH; | |||||
| return nullptr; | |||||
| } | |||||
| for (auto &node : cnode->inputs()) { | for (auto &node : cnode->inputs()) { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) { | if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) { | ||||
| return FindGrad(node->cast<CNodePtr>()); | |||||
| return FindGrad(node->cast<CNodePtr>(), ++curr_depth); | |||||
| } else { | } else { | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -2995,7 +3018,7 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto root = node->func_graph(); | auto root = node->func_graph(); | ||||
| auto grad_node = FindGrad(cnode); | |||||
| auto grad_node = FindGrad(cnode, 0); | |||||
| if (grad_node) { | if (grad_node) { | ||||
| InsertShapeOp(cnode, grad_node, root); | InsertShapeOp(cnode, grad_node, root); | ||||
| } | } | ||||
| @@ -139,7 +139,7 @@ bool IsLastStage(); | |||||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | ||||
| const FuncGraphManagerPtr &manager); | const FuncGraphManagerPtr &manager); | ||||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index = -1); | |||||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth); | |||||
| void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes); | void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes); | ||||
| @@ -153,7 +153,6 @@ class TestDPAlgo : public UT::Common { | |||||
| void TestDPAlgo::SetUp() { | void TestDPAlgo::SetUp() { | ||||
| cost_graph = std::make_shared<CostGraph>(); | cost_graph = std::make_shared<CostGraph>(); | ||||
| cost_graph->SetDeviceMemoryAndCostParameter(); | |||||
| RankList dev_list; | RankList dev_list; | ||||
| for (int32_t i = 0; i < 10; i++) { | for (int32_t i = 0; i < 10; i++) { | ||||
| @@ -52,7 +52,6 @@ class TestCostGraph : public UT::Common { | |||||
| }; | }; | ||||
| void TestCostGraph::SetUp() { | void TestCostGraph::SetUp() { | ||||
| cost_graph.SetDeviceMemoryAndCostParameter(); | |||||
| RankList dev_list; | RankList dev_list; | ||||
| for (int32_t i = 0; i < 10; i++) { | for (int32_t i = 0; i < 10; i++) { | ||||
| @@ -305,7 +304,6 @@ TEST_F(TestCostGraph, test_ConstructConnectedComponents) { | |||||
| TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) { | TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) { | ||||
| CostGraph entire_cost_graph; | CostGraph entire_cost_graph; | ||||
| entire_cost_graph.SetDeviceMemoryAndCostParameter(); | |||||
| double memory = 1024.0; | double memory = 1024.0; | ||||
| CostPtrList clist_1, clist_2; | CostPtrList clist_1, clist_2; | ||||
| std::vector<CostPtrList> all_list; | std::vector<CostPtrList> all_list; | ||||
| @@ -371,7 +369,8 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) { | |||||
| ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); | ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); | ||||
| cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); | cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); | ||||
| auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); | auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); | ||||
| cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory()); | |||||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||||
| cost_graph.SelectCostWithMinInferenceTime(cost_list, device_mem_capacity); | |||||
| } | } | ||||
| TEST_F(TestCostGraph, test_EliminationOp) { | TEST_F(TestCostGraph, test_EliminationOp) { | ||||