diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc index 531a5cd7f6..713905798d 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc @@ -23,7 +23,8 @@ namespace mindspore { namespace parallel { void Simplify(CostPtrList *clist_ptrs) { - if (RUN_PHASE == TRAINING_PHASE) { + const auto run_phase = CostModelContext::GetInstance()->run_phase(); + if (run_phase == TRAINING_PHASE) { // training phase SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); } else { @@ -35,7 +36,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method // excludes the cost with greater computation_cost_ and greater communication_forward. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} - if (!COST_MODEL_SIMPLIFY_CALCULATION) { + const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal(); + if (!simplify_cal) { return; } MS_EXCEPTION_IF_NULL(clist_ptrs); @@ -57,7 +59,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. - if (!COST_MODEL_SIMPLIFY_CALCULATION) { + const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal(); + if (!simplify_cal) { return; } MS_EXCEPTION_IF_NULL(clist_ptrs); @@ -78,19 +81,23 @@ void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { MS_EXCEPTION_IF_NULL(origin_cost); + const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); + const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const(); + const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); if (is_redistribution) { // Redistribution cost if ((origin_cost->communication_redis_forward_ > EPS) && - (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; + (origin_cost->communication_redis_forward_ <= comm_threshold)) { + origin_cost->communication_redis_forward_ = comm_const; + } else if (origin_cost->communication_redis_forward_ > comm_threshold) { + origin_cost->communication_redis_forward_ += comm_bias; } if ((origin_cost->communication_redis_backward_ > EPS) && - (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; + (origin_cost->communication_redis_backward_ <= comm_threshold)) { + origin_cost->communication_redis_backward_ = comm_const; + } else if (origin_cost->communication_redis_backward_ > comm_threshold) { + origin_cost->communication_redis_backward_ += comm_bias; } origin_cost->communication_cost_ = origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; @@ -104,18 +111,17 @@ void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) } // forward cost if ((origin_cost->communication_without_parameter_ > EPS) && - (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; + (origin_cost->communication_without_parameter_ <= comm_threshold)) { + origin_cost->communication_without_parameter_ = comm_const; + } else if (origin_cost->communication_without_parameter_ > comm_threshold) { + origin_cost->communication_without_parameter_ += comm_bias; } // total if (origin_cost->communication_cost_ > EPS) { origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; } if (origin_cost->communication_with_partial_para_ > EPS) { - origin_cost->communication_with_partial_para_ = - origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; + origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward; } } } diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h index 14aaa265a2..ad595b895f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h @@ -24,6 +24,7 @@ #include #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/tensor_info.h" +#include "frontend/parallel/costmodel_context.h" namespace mindspore { namespace parallel { @@ -44,8 +45,8 @@ using RedistributionOpListPtr = std::shared_ptr &decision_ = nullptr) - : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { + Cost(double computation, double communication, const std::shared_ptr &decision_ = nullptr) + : computation_cost_(computation), communication_cost_(communication), decision_ptr_(std::move(decision_)) { memory_with_reuse_ = 0.0; communication_without_parameter_ = 0.0; communication_with_partial_para_ = 0.0; @@ -273,7 +274,7 @@ struct TriangleEliminationDecision : public Decision { * * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. * In addition, v and w have other complicated connections, resulting in v and w can not be performed other - * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple + * eliminations. After the StarElimination, u is merged into v, and the resulting graph is split into multiple * connected components. * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. */ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc index c017244cfb..036b18143f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc @@ -132,18 +132,14 @@ Status GetStrategy(const CostGraphPtr &graph) { Status RecoverStrategy(std::vector eliminations) { std::vector::reverse_iterator rit; - + const auto triangle_star_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { if ((*rit)->isa()) { auto elimination = (*rit)->cast(); auto e = elimination->new_edge_; auto w = elimination->op_; - MS_EXCEPTION_IF_NULL(e); - MS_EXCEPTION_IF_NULL(w); auto left_edge = elimination->left_edge_; auto right_edge = elimination->right_edge_; - MS_EXCEPTION_IF_NULL(left_edge); - MS_EXCEPTION_IF_NULL(right_edge); auto decision = e->selected_cost()->decision_ptr_->cast(); w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); left_edge->set_selected_cost(decision->left_cost_); @@ -201,12 +197,12 @@ Status RecoverStrategy(std::vector eliminations) { right_edge->set_selected_cost(decision->right_edge_cost_); // 'left_node' recovers the strategy. left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); - if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { + if (triangle_star_overwrite) { // 'right_node' recovers the strategy. MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); } else { - // In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency. + // In this case, 'right_node' is not overwritten strategy, and it checks strategy consistency. right_node->CheckSelectedStrategy(decision->right_node_strategy_); } MS_LOG(INFO) << "Recover triangleElimination succeeded."; @@ -215,7 +211,7 @@ Status RecoverStrategy(std::vector eliminations) { auto merged_node = elimination->eliminated_node_; auto succ_edges = elimination->succ_edges_; auto succ_nodes = elimination->succ_ops_; - // decision is hided in succ_nodes[0] + // decision is hidden in succ_nodes[0] auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast(); merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); @@ -228,7 +224,7 @@ Status RecoverStrategy(std::vector eliminations) { // Star is eliminated into 'succ_nodes[0]' succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); for (size_t k = 1; k < succ_nodes.size(); ++k) { - if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { + if (triangle_star_overwrite) { // 'succ_nodes[k]' is overwritten strategy and cost. succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]); } else { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc index f33353ebce..034f86e411 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -90,11 +90,13 @@ Status Edge::InitEdgeCost() { } } if (!has_available_cost) { - if (FULLY_USE_DEVICES) { + const auto fully_use = CostModelContext::GetInstance()->fully_use_device(); + const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); + if (fully_use) { MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " "'fully_use_devices' false."; - } else if (ELEMENTWISE_OP_STRA_FOLLOW) { + } else if (stra_follow) { MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " "Try to set 'elementwise_op_strategy_follow' false."; @@ -130,6 +132,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co double backward_comm_cost = tensor_redistribution.backward_comm_cost(); double computation_cost = tensor_redistribution.computation_cost(); double mem_cost = tensor_redistribution.memory_cost(); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); // Now AllGather, ReduceScatter, AlltoAll don't support bool type MS_EXCEPTION_IF_NULL(type); @@ -142,7 +145,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co (*cost)->communication_without_parameter_ = type_length * comm_cost; (*cost)->communication_with_partial_para_ = (*cost)->communication_without_parameter_ + - COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); + gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; (*cost)->memory_with_reuse_ = mem_cost; @@ -173,13 +176,14 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr std::function recursive = [&](size_t k, double computation, double memory, double communication, double communication_without_para, double communication_forward) { + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); if (k == edges.size()) { auto decision = std::make_shared(selected_cost_list); CostPtr new_cost = std::make_shared(computation, communication); MS_EXCEPTION_IF_NULL(new_cost); new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + communication_without_para + gamma * (communication - communication_without_para); new_cost->memory_with_reuse_ = memory; new_cost->communication_forward_ = communication_forward; new_cost->decision_ptr_ = decision; @@ -242,10 +246,11 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr auto decision = std::make_shared(op_strategy, left_cost, middle_cost, right_cost); auto cost = std::make_shared(computation, communication, decision); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); MS_EXCEPTION_IF_NULL(cost); cost->communication_without_parameter_ = communication_without_para; cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + communication_without_para + gamma * (communication - communication_without_para); cost->memory_with_reuse_ = memory_cost; cost->communication_forward_ = communication_forward; ret_cost_list->emplace_back(std::move(cost)); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index 86a83284e2..3a907804d7 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -28,175 +28,6 @@ namespace mindspore { namespace parallel { CostGraphPtr entire_costgraph = nullptr; size_t TOTAL_OPS = 0; -double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA; -bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; -double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY; -double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; -double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; -double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; -bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; -size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; -bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; -bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; -bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; -int64_t RUN_PHASE = DEFAULT_RUN_PHASE; -bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; -bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX; -double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON; -bool DP_ALGO_SINGLE_LOOP = DEFAULT_DP_ALGO_SINGLE_LOOP; - -void CostGraph::SetDeviceMemoryAndCostParameter() { - MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); - - // DEVICE_MEMORY_CAPACITY - auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); - if (device_memory <= 0) { - MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; - } - dev_memory_ = device_memory; - DEVICE_MEMORY_CAPACITY = device_memory; - MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << "."; - - // COST_MODEL_ALPHA - auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); - if (alpha <= 0) { - MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; - } - costmodel_alpha_ = alpha; - MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; - - // COST_MODEL_BETA - auto beta = CostModelContext::GetInstance()->costmodel_beta(); - if (beta <= 0) { - MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; - } - costmodel_beta_ = beta; - MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; - - // COST_MODEL_GAMMA - auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); - if ((gamma < 0) || (gamma > 1)) { - MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; - } - COST_MODEL_GAMMA = gamma; - MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << "."; - - // COST_MODEL_SIMPLIFY_CALCULATION - auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal(); - COST_MODEL_SIMPLIFY_CALCULATION = simplify; - if (COST_MODEL_SIMPLIFY_CALCULATION) { - MS_LOG(INFO) << "costmodel_simplify_cal: true."; - } else { - MS_LOG(INFO) << "costmodel_simplify_cal: false."; - } - - // COST_MODEL_COMMUNI_THRESHOLD - auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); - if (communi_threshold < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; - } - COST_MODEL_COMMUNI_THRESHOLD = communi_threshold; - MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << "."; - - // COST_MODEL_COMMUNI_CONST - auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const(); - if (communi_const < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; - } - COST_MODEL_COMMUNI_CONST = communi_const; - MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << "."; - - // COST_MODEL_COMMUNI_BIAS - auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); - if (communi_bias < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; - } - COST_MODEL_COMMUNI_BIAS = communi_bias; - MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << "."; - - // TENSOR_SLICE_ALIGNMENT_ENABLE - auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); - TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable; - if (TENSOR_SLICE_ALIGNMENT_ENABLE) { - MS_LOG(INFO) << "tensor_slice_align_enable: true."; - } else { - MS_LOG(INFO) << "tensor_slice_align_enable: false."; - } - - // TENSOR_SLICE_ALIGNMENT_SIZE - auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); - if (align_size == 0) { - MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; - } - TENSOR_SLICE_ALIGNMENT_SIZE = align_size; - MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; - - // FULLY_USE_DEVICES - auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); - FULLY_USE_DEVICES = fully_devices; - if (FULLY_USE_DEVICES) { - MS_LOG(INFO) << "fully_use_devices: true."; - } else { - MS_LOG(INFO) << "fully_use_devices: false."; - } - - // ELEMENTWISE_OP_STRA_FOLLOW - auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); - ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow; - if (ELEMENTWISE_OP_STRA_FOLLOW) { - MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; - } else { - MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; - } - - // MULTI_SUBGRAPHS - auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); - MULTI_SUBGRAPHS = multi_subgraphs; - if (MULTI_SUBGRAPHS) { - MS_LOG(INFO) << "multi_subgraphs: true."; - } else { - MS_LOG(INFO) << "multi_subgraphs: false."; - } - - auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); - TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite; - if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { - MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; - } else { - MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; - } - - // RUN_PHASE - auto phase = CostModelContext::GetInstance()->run_phase(); - if (phase != 0 && phase != 1) { - MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; - } - RUN_PHASE = phase; - MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; - - auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi(); - DP_ALGO_ENABLE_APPROX = enable_approx; - if (enable_approx) { - MS_LOG(INFO) << "dp_algo_enable_approx: true."; - } else { - MS_LOG(INFO) << "dp_algo_enable_approx: false."; - } - - auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon(); - if (epsilon <= 0 || epsilon > 1) { - MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; - } - DP_ALGO_APPROX_EPSILON = epsilon; - MS_LOG(INFO) << "epsilon: " << epsilon << "."; - - auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); - DP_ALGO_SINGLE_LOOP = single_loop; - if (single_loop) { - MS_LOG(INFO) << "dp_algo_single_loop: true."; - } else { - MS_LOG(INFO) << "dp_algo_single_loop: false."; - } -} void CostGraph::Init() { inputs_tensor_name_list_.clear(); @@ -269,7 +100,6 @@ std::vector> CostGraph::ConstructConnectedComponents( if ((!visited[op]) && op->is_alive()) { std::shared_ptr new_component = std::make_shared(); MS_EXCEPTION_IF_NULL(new_component); - new_component->SetDeviceMemoryAndCostParameter(); DFS(op, &visited, new_component); connected_compoents_.push_back(new_component); } @@ -336,10 +166,11 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: auto decision = std::make_shared(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); auto cost = std::make_shared(computation, communication, decision); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); MS_EXCEPTION_IF_NULL(cost); cost->communication_without_parameter_ = communication_without_para; cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + communication_without_para + gamma * (communication - communication_without_para); cost->memory_with_reuse_ = memory; cost->communication_forward_ = communication_forward; ret.push_back(cost); @@ -353,7 +184,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: return ret; } -// Create final cost list for the graph containing a signle node: u +// Create final cost list for the graph containing a single node: u CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { MS_EXCEPTION_IF_NULL(u); CostPtrList ret; @@ -365,11 +196,12 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { MS_EXCEPTION_IF_NULL(cost1); auto decision = std::make_shared(u_strategy_ptr, cost1); auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); MS_EXCEPTION_IF_NULL(new_cost); new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; new_cost->communication_with_partial_para_ = cost1->communication_without_parameter_ + - COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); + gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_); new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; new_cost->communication_forward_ = cost1->communication_forward_; ret.push_back(new_cost); @@ -404,8 +236,10 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, } // Init the returned value with first cost. CostPtr ret = after_mem_filter[0]; + const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); + const auto beta = CostModelContext::GetInstance()->costmodel_beta(); - double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; + double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_; MS_LOG(INFO) << "Cost 0: " << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ << ", communication_forward_: " << ret->communication_forward_ @@ -422,8 +256,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ << "."; - auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + - costmodel_beta_ * after_mem_filter[i]->communication_forward_; + auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_; MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; if (minimum > tmp) { minimum = tmp; @@ -458,8 +291,10 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d } // Init the returned value with first cost. CostPtr ret = after_mem_filter[0]; + const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); + const auto beta = CostModelContext::GetInstance()->costmodel_beta(); - double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; + double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_; MS_LOG(INFO) << "Cost 0: " << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ @@ -474,8 +309,8 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ << "."; - auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + - costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; + auto tmp = + alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_; MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; if (minimum > tmp) { minimum = tmp; @@ -513,14 +348,16 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect } std::function recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, - &available_memory, this](size_t k) { + &available_memory](size_t k) { + const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); + const auto beta = CostModelContext::GetInstance()->costmodel_beta(); if (k == all_cost_list.size()) { double tmp_memory = 0.0, tmp_minimum = 0.0; for (size_t i = 0; i < selected_cost_list.size(); ++i) { MS_EXCEPTION_IF_NULL(selected_cost_list[i]); tmp_memory += selected_cost_list[i]->memory_with_reuse_; - tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + - costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; + tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ + + beta * selected_cost_list[i]->communication_with_partial_para_; } MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum << "."; @@ -582,12 +419,12 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vectordevice_memory_capacity(); + auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity); for (size_t k = 0; k < selected_cost_list.size(); ++k) { auto selected_cost = selected_cost_list[k]; if (selected_cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; return FAILED; } MS_EXCEPTION_IF_NULL(connected_components[k]); @@ -627,6 +464,99 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { + // In this case, the final graph should contains exactly 2 nodes. + if (alive_ops.empty()) { + MS_LOG(INFO) << "0 Operator in the final graph."; + return SUCCESS; + } + OperatorInfoPtr u, v; + MS_EXCEPTION_IF_NULL(alive_ops[0]); + MS_EXCEPTION_IF_NULL(alive_ops[1]); + const auto phase = CostModelContext::GetInstance()->run_phase(); + const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); + if (!alive_ops[0]->GetAliveSuccEdges().empty() && + alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { + u = alive_ops[0]; + v = alive_ops[1]; + } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && + alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { + u = alive_ops[1]; + v = alive_ops[0]; + } else { + if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { + MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() + << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; + } else { + // In this case, the final graph consists of two single nodes + MS_LOG(INFO) << "There are 2 single nodes in the final graph."; + std::vector all_list; + auto connected_components = ConstructConnectedComponents(alive_ops); + MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; + for (size_t i = 0; i < connected_components.size(); ++i) { + MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; + auto one_component = connected_components[i]; + MS_EXCEPTION_IF_NULL(one_component); + auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); + all_list.push_back(cost_list); + } + CostPtrList selected_cost_list; + if (phase == TRAINING_PHASE) { + // training phase + selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity); + } else { + // inference phase + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " + "phase is not supported."; + } + for (size_t k = 0; k < selected_cost_list.size(); ++k) { + auto selected_cost = selected_cost_list[k]; + if (selected_cost == nullptr) { + MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity + << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(connected_components[k]); + auto one_operator = connected_components[k]->GetOperators()[0]; + MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); + auto decision = selected_cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); + MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; + } + + return SUCCESS; + } + } + MS_LOG(INFO) << "There are 2 nodes in the final graph."; + // In this case, the finale graph is exactly of the form: u --> v + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + auto e = u->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(e); + auto cost_list = CreateFinalCostList(u, e, v); + CostPtr cost = nullptr; + if (phase == TRAINING_PHASE) { + // training phase + cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity); + } else { + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " + "phase is not supported."; + } + if (cost == nullptr) { + MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(cost->decision_ptr_); + auto decision = cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); + v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); + e->set_selected_cost(decision->middle_cost_); + MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; + return SUCCESS; +} + // searching the strategy for the final eliminated graph Status CostGraph::SearchStrategy() { MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; @@ -637,9 +567,11 @@ Status CostGraph::SearchStrategy() { alive_ops.push_back(op); } }); + const auto phase = CostModelContext::GetInstance()->run_phase(); + const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); if (alive_ops.size() > 2) { - if (RUN_PHASE == TRAINING_PHASE) { + if (phase == TRAINING_PHASE) { // training phase return SearchStrategyForMultiNodeFinalGraph(alive_ops); } else { @@ -652,15 +584,15 @@ Status CostGraph::SearchStrategy() { OperatorInfoPtr u = alive_ops[0]; auto cost_list = CreateFinalSingleCostList(u); CostPtr cost = nullptr; - if (RUN_PHASE == TRAINING_PHASE) { + if (phase == TRAINING_PHASE) { // training phase - cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity); } else { // inference phase - cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); + cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity); } if (cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; return FAILED; } MS_EXCEPTION_IF_NULL(u); @@ -671,93 +603,7 @@ Status CostGraph::SearchStrategy() { MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; return SUCCESS; } else { - // In this case, the final graph should contains exactly 2 nodes. - if (alive_ops.empty()) { - MS_LOG(INFO) << "0 Operator in the final graph."; - return SUCCESS; - } - OperatorInfoPtr u, v; - MS_EXCEPTION_IF_NULL(alive_ops[0]); - MS_EXCEPTION_IF_NULL(alive_ops[1]); - if (!alive_ops[0]->GetAliveSuccEdges().empty() && - alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { - u = alive_ops[0]; - v = alive_ops[1]; - } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && - alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { - u = alive_ops[1]; - v = alive_ops[0]; - } else { - if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { - MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() - << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; - } else { - // In this case, the final graph consists of two single nodes - MS_LOG(INFO) << "There are 2 single nodes in the final graph."; - std::vector all_list; - auto connected_components = ConstructConnectedComponents(alive_ops); - MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; - for (size_t i = 0; i < connected_components.size(); ++i) { - MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; - auto one_component = connected_components[i]; - MS_EXCEPTION_IF_NULL(one_component); - auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); - all_list.push_back(cost_list); - } - CostPtrList selected_cost_list; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); - } else { - // inference phase - MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " - "phase is not supported."; - } - for (size_t k = 0; k < selected_cost_list.size(); ++k) { - auto selected_cost = selected_cost_list[k]; - if (selected_cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(connected_components[k]); - auto one_operator = connected_components[k]->GetOperators()[0]; - MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); - auto decision = selected_cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); - MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; - } - - return SUCCESS; - } - } - MS_LOG(INFO) << "There are 2 nodes in the final graph."; - // In this case, the finale graph is exactly of the form: u --> v - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - auto e = u->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(e); - auto cost_list = CreateFinalCostList(u, e, v); - CostPtr cost = nullptr; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); - } else { - MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " - "phase is not supported."; - } - if (cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(cost->decision_ptr_); - auto decision = cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); - v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); - e->set_selected_cost(decision->middle_cost_); - MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; - return SUCCESS; + return SearchStrategyForTwoNodeFinalGraph(alive_ops); } } @@ -867,10 +713,11 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_; auto decision = std::make_shared(op1_old_stra, op1_cost, op2_old_stra, op2_cost); auto new_cost = std::make_shared(computation, communication, decision); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); MS_EXCEPTION_IF_NULL(new_cost); new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + communication_without_para + gamma * (communication - communication_without_para); new_cost->memory_with_reuse_ = memory; new_cost->communication_forward_ = communication_forward; MS_EXCEPTION_IF_NULL(op1_new_clist); @@ -879,6 +726,65 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con } } +std::pair, std::vector> UpdateEdgesIncidentToNodes( + OperatorInfoPtr op1, std::vector *op1_old_succ_edges, + std::vector> *op1_new_edges_cost, std::vector *op1_new_succ_edges, + OperatorInfoPtr op2, std::vector *op2_old_succ_edges, + std::vector> *op2_new_edges_cost, std::vector *op2_new_succ_edges) { + for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) { + auto &new_cost_map = op1_new_edges_cost->at(i); + auto ith_edge = op1_old_succ_edges->at(i); + + std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); + std::shared_ptr new_edge; + if (ith_edge->is_combined()) { + std::vector output_indexs, input_indexs; + output_indexs = ith_edge->prev_op_output_indexs(); + input_indexs = ith_edge->next_op_input_indexs(); + new_edge = + std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); + } else { + size_t output_index, input_index; + output_index = ith_edge->prev_op_output_index(); + input_index = ith_edge->next_op_input_index(); + new_edge = + std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); + } + new_edge->SetCostMapAndInputOutput(new_cost_map); + // replace the old successive edges with the new ones. + op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); + ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); + op1_new_succ_edges->erase(op1_new_succ_edges->begin() + i); + op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + i, new_edge); + } + for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) { + auto &new_cost_map = op2_new_edges_cost->at(i); + auto ith_edge = op2_old_succ_edges->at(i); + const auto &destination = ith_edge->next_operator(); + + std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); + std::shared_ptr new_edge; + if (ith_edge->is_combined()) { + std::vector output_indexs, input_indexs; + output_indexs = ith_edge->prev_op_output_indexs(); + input_indexs = ith_edge->next_op_input_indexs(); + new_edge = std::make_shared(new_edge_name, op1, destination, output_indexs, input_indexs, true); + } else { + size_t output_index, input_index; + output_index = ith_edge->prev_op_output_index(); + input_index = ith_edge->next_op_input_index(); + new_edge = std::make_shared(new_edge_name, op1, destination, output_index, input_index, false); + } + new_edge->SetCostMapAndInputOutput(new_cost_map); + // replace the old successive edges with the new ones. + destination->ReplacePreEdge(op2, new_edge); + op1->AddSuccEdge(new_edge); + op2_new_succ_edges->erase(op2_new_succ_edges->begin() + i); + op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + i, new_edge); + } + return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges); +} + std::pair>, std::vector>> CostGraph::EliminationSources( OperatorInfoPtr op1, OperatorInfoPtr op2) { MS_EXCEPTION_IF_NULL(op1); @@ -970,57 +876,9 @@ std::pair>, std::vector> op2->SetNotAlive(); // Update the edges incident to op1, and edges incident to op2 - for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { - auto &new_cost_map = op1_new_edges_cost[i]; - auto &ith_edge = op1_old_succ_edges[i]; - - std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); - std::shared_ptr new_edge; - if (ith_edge->is_combined()) { - std::vector output_indexs, input_indexs; - output_indexs = ith_edge->prev_op_output_indexs(); - input_indexs = ith_edge->next_op_input_indexs(); - new_edge = - std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); - } else { - size_t output_index, input_index; - output_index = ith_edge->prev_op_output_index(); - input_index = ith_edge->next_op_input_index(); - new_edge = - std::make_shared(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); - } - new_edge->SetCostMapAndInputOutput(new_cost_map); - // replace the old successive edges with the new ones. - op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); - ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); - op1_new_succ_edges[i] = new_edge; - } - for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { - auto &new_cost_map = op2_new_edges_cost[i]; - auto &ith_edge = op2_old_succ_edges[i]; - const auto &destination = ith_edge->next_operator(); - - std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); - std::shared_ptr new_edge; - if (ith_edge->is_combined()) { - std::vector output_indexs, input_indexs; - output_indexs = ith_edge->prev_op_output_indexs(); - input_indexs = ith_edge->next_op_input_indexs(); - new_edge = std::make_shared(new_edge_name, op1, destination, output_indexs, input_indexs, true); - } else { - size_t output_index, input_index; - output_index = ith_edge->prev_op_output_index(); - input_index = ith_edge->next_op_input_index(); - new_edge = std::make_shared(new_edge_name, op1, destination, output_index, input_index, false); - } - new_edge->SetCostMapAndInputOutput(new_cost_map); - // replace the old successive edges with the new ones. - destination->ReplacePreEdge(op2, new_edge); - op1->AddSuccEdge(new_edge); - op2_new_succ_edges[i] = new_edge; - } MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded."; - return {op1_new_succ_edges, op2_new_succ_edges}; + return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2, + &op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges); } // Check the graph whether a TriangleElimination can be performed @@ -1179,10 +1037,11 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const auto decision = std::make_shared(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); auto new_cost = std::make_shared(computation, communication, decision); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); MS_EXCEPTION_IF_NULL(new_cost); new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + communication_without_para + gamma * (communication - communication_without_para); new_cost->memory_with_reuse_ = memory; new_cost->communication_forward_ = communication_forward; MS_EXCEPTION_IF_NULL(tar_cost_list_new); @@ -1263,9 +1122,10 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str auto decision = std::make_shared(contract_op_stra, contract_op_cost, edge_cost, target_op_stra, tar_cost); auto new_cost = std::make_shared(computation, communication, decision); + auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); new_cost->communication_without_parameter_ = communication_without_para; new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + communication_without_para + gamma * (communication - communication_without_para); new_cost->memory_with_reuse_ = memory; new_cost->communication_forward_ = communication_forward; tar_cost_list_new->emplace_back(std::move(new_cost)); @@ -1338,8 +1198,10 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, double new_commu_without = elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; + const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); - if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { + if (triangle_star_stra_overwrite) { new_computation += right_op_cost->computation_cost_; new_memory += right_op_cost->memory_with_reuse_; new_commu_cost += right_op_cost->communication_cost_; @@ -1352,8 +1214,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, left_op_stra, left_node_cost, right_op_stra, right_op_cost); auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); new_cost->communication_without_parameter_ = new_commu_without; - new_cost->communication_with_partial_para_ = - new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); + new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without); new_cost->memory_with_reuse_ = new_memory; new_cost->communication_forward_ = new_commu_forward; left_node_clist_new->emplace_back(std::move(new_cost)); @@ -1463,6 +1324,8 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, commu_without = merged_node_cost->communication_without_parameter_, commu_forward = merged_node_cost->communication_forward_; + const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); if (i == 0) { @@ -1478,7 +1341,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n commu_cost += succ_edges_costs[i]->communication_cost_; commu_forward += succ_edges_costs[i]->communication_forward_; commu_without += succ_edges_costs[i]->communication_without_parameter_; - if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { + if (triangle_star_stra_overwrite) { computation_cost += succ_nodes_costs[i]->computation_cost_; memory_cost += succ_nodes_costs[i]->memory_with_reuse_; commu_cost += succ_nodes_costs[i]->communication_cost_; @@ -1492,7 +1355,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n succ_nodes_stras, succ_nodes_costs); auto new_cost = std::make_shared(computation_cost, commu_cost, decision); new_cost->communication_without_parameter_ = commu_without; - new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); + new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without); new_cost->memory_with_reuse_ = memory_cost; new_cost->communication_forward_ = commu_forward; first_succ_node_clist_new->emplace_back(std::move(new_cost)); @@ -1895,7 +1758,8 @@ Status CostGraph::CorrectOpsMemoryCost() { } Status CostGraph::CalculateMemoryCost() { - if (RUN_PHASE == TRAINING_PHASE) { + const auto phase = CostModelContext::GetInstance()->run_phase(); + if (phase == TRAINING_PHASE) { // training phase if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { // Calculate operators' memory usage diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 2e48af620b..4c14385117 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -34,32 +34,12 @@ class CostGraph; using CostGraphPtr = std::shared_ptr; extern CostGraphPtr entire_costgraph; extern size_t TOTAL_OPS; -extern double COST_MODEL_GAMMA; -extern bool COST_MODEL_SIMPLIFY_CALCULATION; -extern double DEVICE_MEMORY_CAPACITY; -extern double COST_MODEL_COMMUNI_THRESHOLD; -extern double COST_MODEL_COMMUNI_CONST; -extern double COST_MODEL_COMMUNI_BIAS; -extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; -extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; -extern bool FULLY_USE_DEVICES; -extern bool ELEMENTWISE_OP_STRA_FOLLOW; -extern bool MULTI_SUBGRAPHS; -extern bool DP_ALGO_ENABLE_APPROX; -extern double DP_ALGO_APPROX_EPSILON; -extern int64_t RUN_PHASE; -extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; -extern bool DP_ALGO_SINGLE_LOOP; class CostGraph { // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have // output-input dependency relationship. public: - CostGraph() { - dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; - costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; - costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; - } + CostGraph() {} ~CostGraph() = default; void Init(); void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } @@ -79,8 +59,6 @@ class CostGraph { // An edge is uniquely identified by its name, and its output index and input index. bool IsEdgeInCostGraph(const std::string &, size_t, size_t); - void SetDeviceMemoryAndCostParameter(); - std::vector> ConstructConnectedComponents(std::vector); void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, const std::shared_ptr &component); @@ -91,10 +69,10 @@ class CostGraph { CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); + Status SearchStrategyForTwoNodeFinalGraph(const std::vector &); std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { return edges_[{u_node, v_node}]; } - double GetDeviceMemory() const { return dev_memory_; } // Search the cost_list in the final graph, and determine the optimal one Status SearchStrategy(); @@ -194,7 +172,7 @@ class CostGraph { Status InitReshapeStrategy(); Status InitSelectedStrategy(); OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; - // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only + // When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only // once (instead of multiple times), this method is used to correct this. Status CorrectOpsMemoryCost(); // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies. @@ -224,9 +202,6 @@ class CostGraph { // Needed by rec_parser std::vector> inputs_tensor_name_list_; std::map tuple_getitem_list_; - double dev_memory_; - double costmodel_alpha_; - double costmodel_beta_; std::vector ops_; std::map, std::vector> edges_; std::vector> connected_compoents_; diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc index 183f8b2627..c04c0f7bf5 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc @@ -47,7 +47,7 @@ void CostModelContext::ResetCostModel() { costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; - run_phase_ = DEFAULT_RUN_PHASE; + run_phase_ = TRAINING_PHASE; costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; @@ -70,37 +70,115 @@ void CostModelContext::ResetAlgoParameters() { dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; } +void CostModelContext::PrintCostModel() { + MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << "."; + MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; + MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; + MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << "."; + MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << "."; + MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << "."; + MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << "."; + MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << "."; + MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << "."; + MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << "."; + MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << "."; + MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << "."; + MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << "."; + MS_LOG(INFO) << "run_phase: " << run_phase_ << "."; + MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << "."; + MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << "."; + MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << "."; + MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << "."; +} + void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { if (device_target == kGPUDevice) { costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU; } } -void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; } +void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { + if (epsilon <= 0 || epsilon > 1) { + MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; + } + dp_algo_approxi_epsilon_ = epsilon; +} -void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; } +void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { + if (approxi) { + MS_LOG(INFO) << "dp_algo_enable_approx: true."; + } else { + MS_LOG(INFO) << "dp_algo_enable_approx: false."; + } + dp_algo_enable_approxi_ = approxi; +} -void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } +void CostModelContext::set_device_memory_capacity(double dm_capacity) { + if (dm_capacity <= 0) { + MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; + } + device_memory_capacity_ = dm_capacity; +} -void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } +void CostModelContext::set_costmodel_alpha(double cm_alpha) { + if (cm_alpha <= 0) { + MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; + } + costmodel_alpha_ = cm_alpha; +} -void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; } +void CostModelContext::set_costmodel_beta(double cm_beta) { + if (cm_beta <= 0) { + MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; + } + costmodel_beta_ = cm_beta; +} -void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; } +void CostModelContext::set_costmodel_gamma(double cm_gamma) { + if ((cm_gamma < 0) || (cm_gamma > 1)) { + MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; + } + costmodel_gamma_ = cm_gamma; +} -void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; } +void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { + if (cm_simplify) { + MS_LOG(INFO) << "costmodel_simplify_cal: true."; + } else { + MS_LOG(INFO) << "costmodel_simplify_cal: false."; + } + costmodel_simplify_cal_ = cm_simplify; +} void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { + if (cm_communi_th < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; + } costmodel_communi_threshold_ = cm_communi_th; } void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { + if (cm_communi_const < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; + } costmodel_communi_const_ = cm_communi_const; } -void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } +void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { + if (cm_communi_bias < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; + } + costmodel_communi_bias_ = cm_communi_bias; +} -void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } +void CostModelContext::set_multi_subgraphs(bool multi_graphs) { + if (multi_graphs) { + MS_LOG(INFO) << "multi_subgraphs: true."; + } else { + MS_LOG(INFO) << "multi_subgraphs: false."; + } + is_multi_subgraphs_ = multi_graphs; +} void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) { costmodel_allreduce_fusion_algorithm_ = algorithm; } @@ -129,25 +207,64 @@ void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; } -void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; } +void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { + if (ts_align) { + MS_LOG(INFO) << "tensor_slice_align_enable: true."; + } else { + MS_LOG(INFO) << "tensor_slice_align_enable: false."; + } + tensor_slice_alignment_enable_ = ts_align; +} void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { + if (ts_align_size == 0) { + MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; + } tensor_slice_alignment_size_ = ts_align_size; } -void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } +void CostModelContext::set_fully_use_device(bool fully_use) { + if (fully_use) { + MS_LOG(INFO) << "fully_use_devices: true."; + } else { + MS_LOG(INFO) << "fully_use_devices: false."; + } + fully_use_device_ = fully_use; +} void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { + if (elementwise_follow) { + MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; + } else { + MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; + } elementwise_stra_follow_ = elementwise_follow; } void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { + if (overwrite) { + MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; + } else { + MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; + } triangle_star_strategy_overwrite_ = overwrite; } -void CostModelContext::set_run_phase(int64_t phase) { run_phase_ = phase; } +void CostModelContext::set_run_phase(int64_t phase) { + if (phase != 0 && phase != 1) { + MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; + } + run_phase_ = phase; +} -void CostModelContext::set_dp_algo_single_loop(bool single_loop) { dp_algo_single_loop_ = single_loop; } +void CostModelContext::set_dp_algo_single_loop(bool single_loop) { + if (single_loop) { + MS_LOG(INFO) << "dp_algo_single_loop: true."; + } else { + MS_LOG(INFO) << "dp_algo_single_loop: false."; + } + dp_algo_single_loop_ = single_loop; +} struct CostRegister { CostRegister() { diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.h b/mindspore/ccsrc/frontend/parallel/costmodel_context.h index c3dc0a479b..a8257d27b9 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.h +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.h @@ -41,7 +41,6 @@ namespace parallel { #define DEFAULT_FULLY_USE_DEVICES true #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false #define DEFAULT_IS_MULTI_SUBGRAPHS false -#define DEFAULT_RUN_PHASE 0 #define TRAINING_PHASE 0 #define INFERENCE_PHASE 1 #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; @@ -58,6 +57,7 @@ class CostModelContext { void ResetAlgoParameters(); static std::shared_ptr GetInstance(); + void PrintCostModel(); void set_costmodel_context_for_device(const std::string &); // DEVICE_MEMORY_CAPACITY diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc index bf54c27d35..e500e5aedc 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -475,7 +475,8 @@ Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num, size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { int64_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); - if (!FULLY_USE_DEVICES) { + const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device(); + if (!fully_use_device) { if (LongToSize(product) > dev_num) { return FAILED; } @@ -564,7 +565,9 @@ void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_te } Status MatMulBase::CheckForTensorSliceValid() const { - if (!TENSOR_SLICE_ALIGNMENT_ENABLE) { + const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); + const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); + if (!align_enable) { return SUCCESS; } if (inputs_tensor_info_.empty()) { @@ -572,8 +575,8 @@ Status MatMulBase::CheckForTensorSliceValid() const { } for (auto &one_input_tensor : inputs_tensor_info_) { auto slice_shape = one_input_tensor.slice_shape(); - if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || - (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { + if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) || + (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) { return FAILED; } } @@ -608,12 +611,12 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & double computation_cost = operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); // Breaking ties for preferring data parallelization BreakingTiesForPerferringDataParallel(strategy, result); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index d4cce6f6be..ca0dfbf881 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -839,7 +839,8 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input for (auto &input_partition : inputs_partitions) { product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); } - if (!FULLY_USE_DEVICES) { + const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device(); + if (!fully_use_device) { if (LongToSize(product) > dev_num) { return FAILED; } @@ -1202,12 +1203,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { double computation_cost = operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); // Breaking ties for preferring data parallelization BreakingTiesForPerferringDataParallel(strategy, result); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc index d2fb83397a..fbbc42eaea 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -396,12 +396,12 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra double computation_cost = operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); std::shared_ptr result = std::make_shared(computation_cost, communication_cost); result->communication_without_parameter_ = operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); // Breaking ties for preferring data parallelization BreakingTiesForPerferringDataParallel(strategy, result); diff --git a/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.cc b/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.cc index 2c056bc451..9707ad956b 100644 --- a/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.cc +++ b/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "frontend/parallel/parallel_stub/executor_manager_stub.h" namespace mindspore { namespace parallel { @@ -26,6 +27,5 @@ std::shared_ptr ExecutorManager::GetExecutor(const std::string &device executors_[device_key] = executor; return executor; } - } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.h b/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.h index 67a771262b..6a987d0ba5 100644 --- a/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.h +++ b/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_manager_stub.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ #include diff --git a/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_stub.h b/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_stub.h index bb655f9178..6d84d68d90 100644 --- a/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_stub.h +++ b/mindspore/ccsrc/frontend/parallel/parallel_stub/executor_stub.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 40e0ba43f3..72a53e51ab 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -52,13 +52,18 @@ static bool IsInWhiteList(const CNodePtr &cnode) { return false; } -static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { +static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: " + << MAX_RECURSIVE_DEPTH; + return; + } const auto &node_users = manager->node_users()[node]; for (auto &user_pair : node_users) { auto user_node = user_pair.first; if (!user_node->grad()) { user_node->set_grad(true); - SetGradTag(user_node, manager); + SetGradTag(user_node, manager, ++curr_depth); } } } @@ -69,7 +74,7 @@ void PipelineTransformer::LabelRequiredGradCNode() { if (!ParameterRequireGrad(parameter)) { continue; } - SetGradTag(parameter, manager_); + SetGradTag(parameter, manager_, 0); } } diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 17995be170..8088f07db0 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -234,7 +234,8 @@ void InitCostGraph() { if (entire_costgraph == nullptr) { entire_costgraph = std::make_shared(); } - entire_costgraph->SetDeviceMemoryAndCostParameter(); + MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); + CostModelContext::GetInstance()->PrintCostModel(); entire_costgraph->Init(); } @@ -252,10 +253,11 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive if (prim->name() == RESHAPE) { MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; } + const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device(); // Set cost for this configured strategy if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; - } else if (FULLY_USE_DEVICES) { + } else if (fully_use_devices) { // If configured to fully use devices, then checking for the user-specified strategy int64_t used_devices = operator_info->used_devices(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -323,7 +325,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & operator_info->set_cnode(cnode); // key of strategy map std::string strategy_key_name = ""; - auto param_names = NodeParameterName(cnode); + auto param_names = NodeParameterName(cnode, -1, 0); if (!param_names.empty()) { strategy_key_name = prim->name() + "_" + param_names[0].first; } @@ -394,7 +396,8 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node if (search_cnode == from_cnode_to_info.end()) { size_t loop_index = 0; bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); - if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { + const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); + if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && @@ -430,7 +433,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info)); - if (DP_ALGO_SINGLE_LOOP && is_in_loop) { + if (single_loop && is_in_loop) { operators_in_forloop.push_back(operator_info); ops_in_a_loop_.insert(operator_info->name()); loop_to_ops[loop_index]++; @@ -511,7 +514,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no if (search_cnode == from_cnode_to_info.end()) { size_t loop_index = 0; bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); - bool is_op_created = DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); + const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); + bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); if (is_op_created) { const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && @@ -548,7 +552,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); - if (DP_ALGO_SINGLE_LOOP && is_in_loop) { + if (single_loop && is_in_loop) { operators_in_forloop.push_back(operator_info); ops_in_a_loop_.insert(operator_info->name()); loop_to_ops[loop_index]++; @@ -581,8 +585,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; return; } + const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || - (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); + (stra_follow && IsElementWiseOperator(prev_prim->name())); if (follow_strategy) { // Redistribution in not allowed on the edge. // Elementwise operators have the same strategy as their previous operators. @@ -1031,7 +1036,7 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const graph = EliminateGraph(graph, eli_list, index_list); size_t num_device = g_device_manager->DeviceNum(); - double device_memory = entire_costgraph->GetDeviceMemory(); + const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; } else { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index bcc68ad76c..ab948f1142 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1914,7 +1914,11 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { } // find previous parallel care node's next node. -bool FindPreNodes(const AnfNodePtr &node, vector *unique_ids, vector *indexes) { +bool FindPreNodes(const AnfNodePtr &node, vector *unique_ids, vector *indexes, size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH; + return false; + } MS_EXCEPTION_IF_NULL(unique_ids); MS_EXCEPTION_IF_NULL(indexes); if (!node->isa()) { @@ -1942,7 +1946,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector *unique_ids, vecto find = true; continue; } - if (FindPreNodes(cnode, unique_ids, indexes)) { + if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) { find = true; continue; } @@ -1954,7 +1958,7 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector *u std::vector *indexes) { MS_EXCEPTION_IF_NULL(unique_ids); CNodePtr cnode = root->get_return(); - if (!FindPreNodes(cnode, unique_ids, indexes)) { + if (!FindPreNodes(cnode, unique_ids, indexes, 0)) { MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; } } @@ -2044,7 +2048,7 @@ void ExtractInformation(const std::vector &all_nodes, bool is_traini // load strategy checkpoint // key of strategy map std::string strategy_key_name = ""; - auto param_names = NodeParameterName(cnode); + auto param_names = NodeParameterName(cnode, -1, 0); if (!param_names.empty()) { strategy_key_name = prim->name() + "_" + param_names[0].first; } @@ -2151,13 +2155,18 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &n return nullptr; } -std::shared_ptr FindParameterNextLayout(const AnfNodePtr &node) { +std::shared_ptr FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: " + << MAX_RECURSIVE_DEPTH; + return nullptr; + } FuncGraphManagerPtr manager = node->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; for (auto &node_pair : node_set) { if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) { - auto layout_param = FindParameterNextLayout(node_pair.first); + auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth); if (!layout_param) { continue; } @@ -2184,7 +2193,7 @@ std::shared_ptr FindParameterNextLayout(const AnfNodePtr &node) { std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { // Create DataParallel tensor layout for parameter(support WideDeep). - auto next_layout = FindParameterNextLayout(node); + auto next_layout = FindParameterNextLayout(node, 0); if (next_layout != nullptr) { return next_layout; } @@ -2329,14 +2338,19 @@ void ReshapeInit(const std::vector &all_nodes) { } } -CNodePtr HandleDependLoss(const CNodePtr &cnode) { +CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: " + << MAX_RECURSIVE_DEPTH; + return nullptr; + } // Handle return->depend->loss auto prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); if (prim->name() == DEPEND) { auto depend_before = cnode->input(1)->cast(); MS_EXCEPTION_IF_NULL(depend_before); - return HandleDependLoss(depend_before); + return HandleDependLoss(depend_before, ++curr_depth); } return cnode; } @@ -2370,7 +2384,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { pre_cnode = pre_cnode->input(1)->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); } - pre_cnode = HandleDependLoss(pre_cnode); + pre_cnode = HandleDependLoss(pre_cnode, 0); auto current_prim = GetValueNode(pre_cnode->input(0)); // notice: the GetNext op has not input @@ -2792,7 +2806,12 @@ bool IsCohesiveNode(const CNodePtr &cnode) { IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); } -std::vector> NodeParameterName(const CNodePtr &node, int64_t index) { +std::vector> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: " + << MAX_RECURSIVE_DEPTH; + return {}; + } std::vector node_inputs{node->inputs()}; std::vector> param_names; for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { @@ -2809,7 +2828,7 @@ std::vector> NodeParameterName(const CNodePtr &n continue; } if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { - auto input_param_names = NodeParameterName(cnode, idx); + auto input_param_names = NodeParameterName(cnode, idx, 0); param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); } } @@ -2827,7 +2846,7 @@ void CheckpointStrategy(const std::vector &all_nodes) { if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; } - auto param_names = NodeParameterName(cnode); + auto param_names = NodeParameterName(cnode, -1, 0); if (param_names.empty()) { continue; } @@ -2950,13 +2969,17 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG InsertNode(op, node, 2, pre_node, root, "shape"); } -static AnfNodePtr FindGrad(const CNodePtr &cnode) { +static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) { + if (curr_depth > MAX_RECURSIVE_DEPTH) { + MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH; + return nullptr; + } for (auto &node : cnode->inputs()) { if (!node->isa()) { continue; } if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) { - return FindGrad(node->cast()); + return FindGrad(node->cast(), ++curr_depth); } else { return node; } @@ -2995,7 +3018,7 @@ void HandleRootReshapeAndSaveStrategy(const std::vector &all_nodes) continue; } auto root = node->func_graph(); - auto grad_node = FindGrad(cnode); + auto grad_node = FindGrad(cnode, 0); if (grad_node) { InsertShapeOp(cnode, grad_node, root); } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 5c0298f31e..34ae08b6c4 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -139,7 +139,7 @@ bool IsLastStage(); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -std::vector> NodeParameterName(const CNodePtr &node, int64_t index = -1); +std::vector> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth); void CheckpointStrategy(const std::vector &all_nodes); diff --git a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc index 8840e76d85..ceab5a433b 100644 --- a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc @@ -153,7 +153,6 @@ class TestDPAlgo : public UT::Common { void TestDPAlgo::SetUp() { cost_graph = std::make_shared(); - cost_graph->SetDeviceMemoryAndCostParameter(); RankList dev_list; for (int32_t i = 0; i < 10; i++) { diff --git a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc index f471775327..38946d8025 100644 --- a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc @@ -52,7 +52,6 @@ class TestCostGraph : public UT::Common { }; void TestCostGraph::SetUp() { - cost_graph.SetDeviceMemoryAndCostParameter(); RankList dev_list; for (int32_t i = 0; i < 10; i++) { @@ -305,7 +304,6 @@ TEST_F(TestCostGraph, test_ConstructConnectedComponents) { TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) { CostGraph entire_cost_graph; - entire_cost_graph.SetDeviceMemoryAndCostParameter(); double memory = 1024.0; CostPtrList clist_1, clist_2; std::vector all_list; @@ -371,7 +369,8 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) { ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); - cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory()); + const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); + cost_graph.SelectCostWithMinInferenceTime(cost_list, device_mem_capacity); } TEST_F(TestCostGraph, test_EliminationOp) {