From: @xiaoda_zh Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsutengpull/15777/MERGE
| @@ -23,7 +23,8 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| void Simplify(CostPtrList *clist_ptrs) { | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| const auto run_phase = CostModelContext::GetInstance()->run_phase(); | |||
| if (run_phase == TRAINING_PHASE) { | |||
| // training phase | |||
| SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); | |||
| } else { | |||
| @@ -35,7 +36,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { | |||
| // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method | |||
| // excludes the cost with greater computation_cost_ and greater communication_forward. | |||
| // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | |||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | |||
| const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal(); | |||
| if (!simplify_cal) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(clist_ptrs); | |||
| @@ -57,7 +59,8 @@ void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { | |||
| void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { | |||
| // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing | |||
| // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. | |||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | |||
| const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal(); | |||
| if (!simplify_cal) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(clist_ptrs); | |||
| @@ -78,19 +81,23 @@ void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) | |||
| void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { | |||
| MS_EXCEPTION_IF_NULL(origin_cost); | |||
| const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); | |||
| const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const(); | |||
| const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| if (is_redistribution) { | |||
| // Redistribution cost | |||
| if ((origin_cost->communication_redis_forward_ > EPS) && | |||
| (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { | |||
| origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; | |||
| } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { | |||
| origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; | |||
| (origin_cost->communication_redis_forward_ <= comm_threshold)) { | |||
| origin_cost->communication_redis_forward_ = comm_const; | |||
| } else if (origin_cost->communication_redis_forward_ > comm_threshold) { | |||
| origin_cost->communication_redis_forward_ += comm_bias; | |||
| } | |||
| if ((origin_cost->communication_redis_backward_ > EPS) && | |||
| (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { | |||
| origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; | |||
| } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { | |||
| origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; | |||
| (origin_cost->communication_redis_backward_ <= comm_threshold)) { | |||
| origin_cost->communication_redis_backward_ = comm_const; | |||
| } else if (origin_cost->communication_redis_backward_ > comm_threshold) { | |||
| origin_cost->communication_redis_backward_ += comm_bias; | |||
| } | |||
| origin_cost->communication_cost_ = | |||
| origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; | |||
| @@ -104,18 +111,17 @@ void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) | |||
| } | |||
| // forward cost | |||
| if ((origin_cost->communication_without_parameter_ > EPS) && | |||
| (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { | |||
| origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; | |||
| } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { | |||
| origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; | |||
| (origin_cost->communication_without_parameter_ <= comm_threshold)) { | |||
| origin_cost->communication_without_parameter_ = comm_const; | |||
| } else if (origin_cost->communication_without_parameter_ > comm_threshold) { | |||
| origin_cost->communication_without_parameter_ += comm_bias; | |||
| } | |||
| // total | |||
| if (origin_cost->communication_cost_ > EPS) { | |||
| origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; | |||
| } | |||
| if (origin_cost->communication_with_partial_para_ > EPS) { | |||
| origin_cost->communication_with_partial_para_ = | |||
| origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; | |||
| origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward; | |||
| } | |||
| } | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include <vector> | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | |||
| #include "frontend/parallel/costmodel_context.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -44,8 +45,8 @@ using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPut | |||
| struct Cost { | |||
| Cost(); | |||
| Cost(double computation, double commuication, const std::shared_ptr<Decision> &decision_ = nullptr) | |||
| : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { | |||
| Cost(double computation, double communication, const std::shared_ptr<Decision> &decision_ = nullptr) | |||
| : computation_cost_(computation), communication_cost_(communication), decision_ptr_(std::move(decision_)) { | |||
| memory_with_reuse_ = 0.0; | |||
| communication_without_parameter_ = 0.0; | |||
| communication_with_partial_para_ = 0.0; | |||
| @@ -273,7 +274,7 @@ struct TriangleEliminationDecision : public Decision { | |||
| * | |||
| * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. | |||
| * In addition, v and w have other complicated connections, resulting in v and w can not be performed other | |||
| * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple | |||
| * eliminations. After the StarElimination, u is merged into v, and the resulting graph is split into multiple | |||
| * connected components. | |||
| * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. | |||
| */ | |||
| @@ -132,18 +132,14 @@ Status GetStrategy(const CostGraphPtr &graph) { | |||
| Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| std::vector<EliminationPtr>::reverse_iterator rit; | |||
| const auto triangle_star_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||
| for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { | |||
| if ((*rit)->isa<OpElimination>()) { | |||
| auto elimination = (*rit)->cast<OpEliminationPtr>(); | |||
| auto e = elimination->new_edge_; | |||
| auto w = elimination->op_; | |||
| MS_EXCEPTION_IF_NULL(e); | |||
| MS_EXCEPTION_IF_NULL(w); | |||
| auto left_edge = elimination->left_edge_; | |||
| auto right_edge = elimination->right_edge_; | |||
| MS_EXCEPTION_IF_NULL(left_edge); | |||
| MS_EXCEPTION_IF_NULL(right_edge); | |||
| auto decision = e->selected_cost()->decision_ptr_->cast<OpEliminationDecisionPtr>(); | |||
| w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); | |||
| left_edge->set_selected_cost(decision->left_cost_); | |||
| @@ -201,12 +197,12 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| right_edge->set_selected_cost(decision->right_edge_cost_); | |||
| // 'left_node' recovers the strategy. | |||
| left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| if (triangle_star_overwrite) { | |||
| // 'right_node' recovers the strategy. | |||
| MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; | |||
| right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | |||
| } else { | |||
| // In this case, 'right_node' is not overwriten strategy, and it checks strategy consistency. | |||
| // In this case, 'right_node' is not overwritten strategy, and it checks strategy consistency. | |||
| right_node->CheckSelectedStrategy(decision->right_node_strategy_); | |||
| } | |||
| MS_LOG(INFO) << "Recover triangleElimination succeeded."; | |||
| @@ -215,7 +211,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| auto merged_node = elimination->eliminated_node_; | |||
| auto succ_edges = elimination->succ_edges_; | |||
| auto succ_nodes = elimination->succ_ops_; | |||
| // decision is hided in succ_nodes[0] | |||
| // decision is hidden in succ_nodes[0] | |||
| auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast<StarEliminationDecisionPtr>(); | |||
| merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); | |||
| @@ -228,7 +224,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| // Star is eliminated into 'succ_nodes[0]' | |||
| succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | |||
| for (size_t k = 1; k < succ_nodes.size(); ++k) { | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| if (triangle_star_overwrite) { | |||
| // 'succ_nodes[k]' is overwritten strategy and cost. | |||
| succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]); | |||
| } else { | |||
| @@ -90,11 +90,13 @@ Status Edge::InitEdgeCost() { | |||
| } | |||
| } | |||
| if (!has_available_cost) { | |||
| if (FULLY_USE_DEVICES) { | |||
| const auto fully_use = CostModelContext::GetInstance()->fully_use_device(); | |||
| const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); | |||
| if (fully_use) { | |||
| MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | |||
| << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " | |||
| "'fully_use_devices' false."; | |||
| } else if (ELEMENTWISE_OP_STRA_FOLLOW) { | |||
| } else if (stra_follow) { | |||
| MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | |||
| << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " | |||
| "Try to set 'elementwise_op_strategy_follow' false."; | |||
| @@ -130,6 +132,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co | |||
| double backward_comm_cost = tensor_redistribution.backward_comm_cost(); | |||
| double computation_cost = tensor_redistribution.computation_cost(); | |||
| double mem_cost = tensor_redistribution.memory_cost(); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| // Now AllGather, ReduceScatter, AlltoAll don't support bool type | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| @@ -142,7 +145,7 @@ Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, co | |||
| (*cost)->communication_without_parameter_ = type_length * comm_cost; | |||
| (*cost)->communication_with_partial_para_ = | |||
| (*cost)->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); | |||
| gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); | |||
| (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; | |||
| (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; | |||
| (*cost)->memory_with_reuse_ = mem_cost; | |||
| @@ -173,13 +176,14 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr | |||
| std::function<void(size_t, double, double, double, double, double)> recursive = | |||
| [&](size_t k, double computation, double memory, double communication, double communication_without_para, | |||
| double communication_forward) { | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| if (k == edges.size()) { | |||
| auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | |||
| CostPtr new_cost = std::make_shared<Cost>(computation, communication); | |||
| MS_EXCEPTION_IF_NULL(new_cost); | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| communication_without_para + gamma * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| new_cost->decision_ptr_ = decision; | |||
| @@ -242,10 +246,11 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||
| auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); | |||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| MS_EXCEPTION_IF_NULL(cost); | |||
| cost->communication_without_parameter_ = communication_without_para; | |||
| cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| communication_without_para + gamma * (communication - communication_without_para); | |||
| cost->memory_with_reuse_ = memory_cost; | |||
| cost->communication_forward_ = communication_forward; | |||
| ret_cost_list->emplace_back(std::move(cost)); | |||
| @@ -28,175 +28,6 @@ namespace mindspore { | |||
| namespace parallel { | |||
| CostGraphPtr entire_costgraph = nullptr; | |||
| size_t TOTAL_OPS = 0; | |||
| double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA; | |||
| bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; | |||
| double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY; | |||
| double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; | |||
| double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; | |||
| double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; | |||
| bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | |||
| size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | |||
| bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| int64_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||
| bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX; | |||
| double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON; | |||
| bool DP_ALGO_SINGLE_LOOP = DEFAULT_DP_ALGO_SINGLE_LOOP; | |||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||
| // DEVICE_MEMORY_CAPACITY | |||
| auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); | |||
| if (device_memory <= 0) { | |||
| MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; | |||
| } | |||
| dev_memory_ = device_memory; | |||
| DEVICE_MEMORY_CAPACITY = device_memory; | |||
| MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << "."; | |||
| // COST_MODEL_ALPHA | |||
| auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||
| if (alpha <= 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; | |||
| } | |||
| costmodel_alpha_ = alpha; | |||
| MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; | |||
| // COST_MODEL_BETA | |||
| auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||
| if (beta <= 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; | |||
| } | |||
| costmodel_beta_ = beta; | |||
| MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; | |||
| // COST_MODEL_GAMMA | |||
| auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| if ((gamma < 0) || (gamma > 1)) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; | |||
| } | |||
| COST_MODEL_GAMMA = gamma; | |||
| MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << "."; | |||
| // COST_MODEL_SIMPLIFY_CALCULATION | |||
| auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal(); | |||
| COST_MODEL_SIMPLIFY_CALCULATION = simplify; | |||
| if (COST_MODEL_SIMPLIFY_CALCULATION) { | |||
| MS_LOG(INFO) << "costmodel_simplify_cal: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "costmodel_simplify_cal: false."; | |||
| } | |||
| // COST_MODEL_COMMUNI_THRESHOLD | |||
| auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); | |||
| if (communi_threshold < 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; | |||
| } | |||
| COST_MODEL_COMMUNI_THRESHOLD = communi_threshold; | |||
| MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << "."; | |||
| // COST_MODEL_COMMUNI_CONST | |||
| auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const(); | |||
| if (communi_const < 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; | |||
| } | |||
| COST_MODEL_COMMUNI_CONST = communi_const; | |||
| MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << "."; | |||
| // COST_MODEL_COMMUNI_BIAS | |||
| auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); | |||
| if (communi_bias < 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; | |||
| } | |||
| COST_MODEL_COMMUNI_BIAS = communi_bias; | |||
| MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << "."; | |||
| // TENSOR_SLICE_ALIGNMENT_ENABLE | |||
| auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); | |||
| TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable; | |||
| if (TENSOR_SLICE_ALIGNMENT_ENABLE) { | |||
| MS_LOG(INFO) << "tensor_slice_align_enable: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "tensor_slice_align_enable: false."; | |||
| } | |||
| // TENSOR_SLICE_ALIGNMENT_SIZE | |||
| auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); | |||
| if (align_size == 0) { | |||
| MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; | |||
| } | |||
| TENSOR_SLICE_ALIGNMENT_SIZE = align_size; | |||
| MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; | |||
| // FULLY_USE_DEVICES | |||
| auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); | |||
| FULLY_USE_DEVICES = fully_devices; | |||
| if (FULLY_USE_DEVICES) { | |||
| MS_LOG(INFO) << "fully_use_devices: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "fully_use_devices: false."; | |||
| } | |||
| // ELEMENTWISE_OP_STRA_FOLLOW | |||
| auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); | |||
| ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow; | |||
| if (ELEMENTWISE_OP_STRA_FOLLOW) { | |||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; | |||
| } | |||
| // MULTI_SUBGRAPHS | |||
| auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); | |||
| MULTI_SUBGRAPHS = multi_subgraphs; | |||
| if (MULTI_SUBGRAPHS) { | |||
| MS_LOG(INFO) << "multi_subgraphs: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||
| } | |||
| auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||
| TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite; | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; | |||
| } | |||
| // RUN_PHASE | |||
| auto phase = CostModelContext::GetInstance()->run_phase(); | |||
| if (phase != 0 && phase != 1) { | |||
| MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; | |||
| } | |||
| RUN_PHASE = phase; | |||
| MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; | |||
| auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||
| DP_ALGO_ENABLE_APPROX = enable_approx; | |||
| if (enable_approx) { | |||
| MS_LOG(INFO) << "dp_algo_enable_approx: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "dp_algo_enable_approx: false."; | |||
| } | |||
| auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon(); | |||
| if (epsilon <= 0 || epsilon > 1) { | |||
| MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; | |||
| } | |||
| DP_ALGO_APPROX_EPSILON = epsilon; | |||
| MS_LOG(INFO) << "epsilon: " << epsilon << "."; | |||
| auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); | |||
| DP_ALGO_SINGLE_LOOP = single_loop; | |||
| if (single_loop) { | |||
| MS_LOG(INFO) << "dp_algo_single_loop: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "dp_algo_single_loop: false."; | |||
| } | |||
| } | |||
| void CostGraph::Init() { | |||
| inputs_tensor_name_list_.clear(); | |||
| @@ -269,7 +100,6 @@ std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents( | |||
| if ((!visited[op]) && op->is_alive()) { | |||
| std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>(); | |||
| MS_EXCEPTION_IF_NULL(new_component); | |||
| new_component->SetDeviceMemoryAndCostParameter(); | |||
| DFS(op, &visited, new_component); | |||
| connected_compoents_.push_back(new_component); | |||
| } | |||
| @@ -336,10 +166,11 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||
| auto decision = | |||
| std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | |||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| MS_EXCEPTION_IF_NULL(cost); | |||
| cost->communication_without_parameter_ = communication_without_para; | |||
| cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| communication_without_para + gamma * (communication - communication_without_para); | |||
| cost->memory_with_reuse_ = memory; | |||
| cost->communication_forward_ = communication_forward; | |||
| ret.push_back(cost); | |||
| @@ -353,7 +184,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std:: | |||
| return ret; | |||
| } | |||
| // Create final cost list for the graph containing a signle node: u | |||
| // Create final cost list for the graph containing a single node: u | |||
| CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { | |||
| MS_EXCEPTION_IF_NULL(u); | |||
| CostPtrList ret; | |||
| @@ -365,11 +196,12 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { | |||
| MS_EXCEPTION_IF_NULL(cost1); | |||
| auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1); | |||
| auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| MS_EXCEPTION_IF_NULL(new_cost); | |||
| new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; | |||
| new_cost->communication_with_partial_para_ = | |||
| cost1->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); | |||
| gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_); | |||
| new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; | |||
| new_cost->communication_forward_ = cost1->communication_forward_; | |||
| ret.push_back(new_cost); | |||
| @@ -404,8 +236,10 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, | |||
| } | |||
| // Init the returned value with first cost. | |||
| CostPtr ret = after_mem_filter[0]; | |||
| const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||
| const auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; | |||
| double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_; | |||
| MS_LOG(INFO) << "Cost 0: " | |||
| << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | |||
| << ", communication_forward_: " << ret->communication_forward_ | |||
| @@ -422,8 +256,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, | |||
| << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | |||
| << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | |||
| << "."; | |||
| auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + | |||
| costmodel_beta_ * after_mem_filter[i]->communication_forward_; | |||
| auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_; | |||
| MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | |||
| if (minimum > tmp) { | |||
| minimum = tmp; | |||
| @@ -458,8 +291,10 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d | |||
| } | |||
| // Init the returned value with first cost. | |||
| CostPtr ret = after_mem_filter[0]; | |||
| const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||
| const auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; | |||
| double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_; | |||
| MS_LOG(INFO) << "Cost 0: " | |||
| << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | |||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | |||
| @@ -474,8 +309,8 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d | |||
| << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | |||
| << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | |||
| << "."; | |||
| auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + | |||
| costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; | |||
| auto tmp = | |||
| alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_; | |||
| MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | |||
| if (minimum > tmp) { | |||
| minimum = tmp; | |||
| @@ -513,14 +348,16 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect | |||
| } | |||
| std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, | |||
| &available_memory, this](size_t k) { | |||
| &available_memory](size_t k) { | |||
| const auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||
| const auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||
| if (k == all_cost_list.size()) { | |||
| double tmp_memory = 0.0, tmp_minimum = 0.0; | |||
| for (size_t i = 0; i < selected_cost_list.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(selected_cost_list[i]); | |||
| tmp_memory += selected_cost_list[i]->memory_with_reuse_; | |||
| tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + | |||
| costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; | |||
| tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ + | |||
| beta * selected_cost_list[i]->communication_with_partial_para_; | |||
| } | |||
| MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum | |||
| << "."; | |||
| @@ -582,12 +419,12 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato | |||
| << " operators in a component in the final graph."; | |||
| } | |||
| } | |||
| // | |||
| auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||
| auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity); | |||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | |||
| auto selected_cost = selected_cost_list[k]; | |||
| if (selected_cost == nullptr) { | |||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; | |||
| return FAILED; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(connected_components[k]); | |||
| @@ -627,6 +464,99 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<Operato | |||
| return SUCCESS; | |||
| } | |||
| Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) { | |||
| // In this case, the final graph should contains exactly 2 nodes. | |||
| if (alive_ops.empty()) { | |||
| MS_LOG(INFO) << "0 Operator in the final graph."; | |||
| return SUCCESS; | |||
| } | |||
| OperatorInfoPtr u, v; | |||
| MS_EXCEPTION_IF_NULL(alive_ops[0]); | |||
| MS_EXCEPTION_IF_NULL(alive_ops[1]); | |||
| const auto phase = CostModelContext::GetInstance()->run_phase(); | |||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() && | |||
| alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { | |||
| u = alive_ops[0]; | |||
| v = alive_ops[1]; | |||
| } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && | |||
| alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { | |||
| u = alive_ops[1]; | |||
| v = alive_ops[0]; | |||
| } else { | |||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { | |||
| MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() | |||
| << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; | |||
| } else { | |||
| // In this case, the final graph consists of two single nodes | |||
| MS_LOG(INFO) << "There are 2 single nodes in the final graph."; | |||
| std::vector<CostPtrList> all_list; | |||
| auto connected_components = ConstructConnectedComponents(alive_ops); | |||
| MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; | |||
| for (size_t i = 0; i < connected_components.size(); ++i) { | |||
| MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; | |||
| auto one_component = connected_components[i]; | |||
| MS_EXCEPTION_IF_NULL(one_component); | |||
| auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); | |||
| all_list.push_back(cost_list); | |||
| } | |||
| CostPtrList selected_cost_list; | |||
| if (phase == TRAINING_PHASE) { | |||
| // training phase | |||
| selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity); | |||
| } else { | |||
| // inference phase | |||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " | |||
| "phase is not supported."; | |||
| } | |||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | |||
| auto selected_cost = selected_cost_list[k]; | |||
| if (selected_cost == nullptr) { | |||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity | |||
| << "."; | |||
| return FAILED; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(connected_components[k]); | |||
| auto one_operator = connected_components[k]->GetOperators()[0]; | |||
| MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); | |||
| auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>(); | |||
| MS_EXCEPTION_IF_NULL(decision); | |||
| one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); | |||
| MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "There are 2 nodes in the final graph."; | |||
| // In this case, the finale graph is exactly of the form: u --> v | |||
| MS_EXCEPTION_IF_NULL(u); | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| auto e = u->GetAliveSuccEdges()[0]; | |||
| MS_EXCEPTION_IF_NULL(e); | |||
| auto cost_list = CreateFinalCostList(u, e, v); | |||
| CostPtr cost = nullptr; | |||
| if (phase == TRAINING_PHASE) { | |||
| // training phase | |||
| cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " | |||
| "phase is not supported."; | |||
| } | |||
| if (cost == nullptr) { | |||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; | |||
| return FAILED; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cost->decision_ptr_); | |||
| auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>(); | |||
| MS_EXCEPTION_IF_NULL(decision); | |||
| u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); | |||
| v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); | |||
| e->set_selected_cost(decision->middle_cost_); | |||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; | |||
| return SUCCESS; | |||
| } | |||
| // searching the strategy for the final eliminated graph | |||
| Status CostGraph::SearchStrategy() { | |||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; | |||
| @@ -637,9 +567,11 @@ Status CostGraph::SearchStrategy() { | |||
| alive_ops.push_back(op); | |||
| } | |||
| }); | |||
| const auto phase = CostModelContext::GetInstance()->run_phase(); | |||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||
| if (alive_ops.size() > 2) { | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| if (phase == TRAINING_PHASE) { | |||
| // training phase | |||
| return SearchStrategyForMultiNodeFinalGraph(alive_ops); | |||
| } else { | |||
| @@ -652,15 +584,15 @@ Status CostGraph::SearchStrategy() { | |||
| OperatorInfoPtr u = alive_ops[0]; | |||
| auto cost_list = CreateFinalSingleCostList(u); | |||
| CostPtr cost = nullptr; | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| if (phase == TRAINING_PHASE) { | |||
| // training phase | |||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||
| cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity); | |||
| } else { | |||
| // inference phase | |||
| cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); | |||
| cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity); | |||
| } | |||
| if (cost == nullptr) { | |||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||
| MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << "."; | |||
| return FAILED; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(u); | |||
| @@ -671,93 +603,7 @@ Status CostGraph::SearchStrategy() { | |||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; | |||
| return SUCCESS; | |||
| } else { | |||
| // In this case, the final graph should contains exactly 2 nodes. | |||
| if (alive_ops.empty()) { | |||
| MS_LOG(INFO) << "0 Operator in the final graph."; | |||
| return SUCCESS; | |||
| } | |||
| OperatorInfoPtr u, v; | |||
| MS_EXCEPTION_IF_NULL(alive_ops[0]); | |||
| MS_EXCEPTION_IF_NULL(alive_ops[1]); | |||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() && | |||
| alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { | |||
| u = alive_ops[0]; | |||
| v = alive_ops[1]; | |||
| } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && | |||
| alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { | |||
| u = alive_ops[1]; | |||
| v = alive_ops[0]; | |||
| } else { | |||
| if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { | |||
| MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() | |||
| << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; | |||
| } else { | |||
| // In this case, the final graph consists of two single nodes | |||
| MS_LOG(INFO) << "There are 2 single nodes in the final graph."; | |||
| std::vector<CostPtrList> all_list; | |||
| auto connected_components = ConstructConnectedComponents(alive_ops); | |||
| MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; | |||
| for (size_t i = 0; i < connected_components.size(); ++i) { | |||
| MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; | |||
| auto one_component = connected_components[i]; | |||
| MS_EXCEPTION_IF_NULL(one_component); | |||
| auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); | |||
| all_list.push_back(cost_list); | |||
| } | |||
| CostPtrList selected_cost_list; | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); | |||
| } else { | |||
| // inference phase | |||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " | |||
| "phase is not supported."; | |||
| } | |||
| for (size_t k = 0; k < selected_cost_list.size(); ++k) { | |||
| auto selected_cost = selected_cost_list[k]; | |||
| if (selected_cost == nullptr) { | |||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||
| return FAILED; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(connected_components[k]); | |||
| auto one_operator = connected_components[k]->GetOperators()[0]; | |||
| MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); | |||
| auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>(); | |||
| MS_EXCEPTION_IF_NULL(decision); | |||
| one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); | |||
| MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "There are 2 nodes in the final graph."; | |||
| // In this case, the finale graph is exactly of the form: u --> v | |||
| MS_EXCEPTION_IF_NULL(u); | |||
| MS_EXCEPTION_IF_NULL(v); | |||
| auto e = u->GetAliveSuccEdges()[0]; | |||
| MS_EXCEPTION_IF_NULL(e); | |||
| auto cost_list = CreateFinalCostList(u, e, v); | |||
| CostPtr cost = nullptr; | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| // training phase | |||
| cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " | |||
| "phase is not supported."; | |||
| } | |||
| if (cost == nullptr) { | |||
| MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; | |||
| return FAILED; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cost->decision_ptr_); | |||
| auto decision = cost->decision_ptr_->cast<FinalDecisionPtr>(); | |||
| MS_EXCEPTION_IF_NULL(decision); | |||
| u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); | |||
| v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); | |||
| e->set_selected_cost(decision->middle_cost_); | |||
| MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; | |||
| return SUCCESS; | |||
| return SearchStrategyForTwoNodeFinalGraph(alive_ops); | |||
| } | |||
| } | |||
| @@ -867,10 +713,11 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con | |||
| op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_; | |||
| auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost); | |||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| MS_EXCEPTION_IF_NULL(new_cost); | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| communication_without_para + gamma * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| MS_EXCEPTION_IF_NULL(op1_new_clist); | |||
| @@ -879,6 +726,65 @@ void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, con | |||
| } | |||
| } | |||
| std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes( | |||
| OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges, | |||
| std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges, | |||
| OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges, | |||
| std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) { | |||
| for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) { | |||
| auto &new_cost_map = op1_new_edges_cost->at(i); | |||
| auto ith_edge = op1_old_succ_edges->at(i); | |||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); | |||
| std::shared_ptr<Edge> new_edge; | |||
| if (ith_edge->is_combined()) { | |||
| std::vector<size_t> output_indexs, input_indexs; | |||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||
| input_indexs = ith_edge->next_op_input_indexs(); | |||
| new_edge = | |||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); | |||
| } else { | |||
| size_t output_index, input_index; | |||
| output_index = ith_edge->prev_op_output_index(); | |||
| input_index = ith_edge->next_op_input_index(); | |||
| new_edge = | |||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); | |||
| } | |||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||
| // replace the old successive edges with the new ones. | |||
| op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); | |||
| ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); | |||
| op1_new_succ_edges->erase(op1_new_succ_edges->begin() + i); | |||
| op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + i, new_edge); | |||
| } | |||
| for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) { | |||
| auto &new_cost_map = op2_new_edges_cost->at(i); | |||
| auto ith_edge = op2_old_succ_edges->at(i); | |||
| const auto &destination = ith_edge->next_operator(); | |||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); | |||
| std::shared_ptr<Edge> new_edge; | |||
| if (ith_edge->is_combined()) { | |||
| std::vector<size_t> output_indexs, input_indexs; | |||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||
| input_indexs = ith_edge->next_op_input_indexs(); | |||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true); | |||
| } else { | |||
| size_t output_index, input_index; | |||
| output_index = ith_edge->prev_op_output_index(); | |||
| input_index = ith_edge->next_op_input_index(); | |||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false); | |||
| } | |||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||
| // replace the old successive edges with the new ones. | |||
| destination->ReplacePreEdge(op2, new_edge); | |||
| op1->AddSuccEdge(new_edge); | |||
| op2_new_succ_edges->erase(op2_new_succ_edges->begin() + i); | |||
| op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + i, new_edge); | |||
| } | |||
| return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges); | |||
| } | |||
| std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources( | |||
| OperatorInfoPtr op1, OperatorInfoPtr op2) { | |||
| MS_EXCEPTION_IF_NULL(op1); | |||
| @@ -970,57 +876,9 @@ std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>> | |||
| op2->SetNotAlive(); | |||
| // Update the edges incident to op1, and edges incident to op2 | |||
| for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) { | |||
| auto &new_cost_map = op1_new_edges_cost[i]; | |||
| auto &ith_edge = op1_old_succ_edges[i]; | |||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name(); | |||
| std::shared_ptr<Edge> new_edge; | |||
| if (ith_edge->is_combined()) { | |||
| std::vector<size_t> output_indexs, input_indexs; | |||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||
| input_indexs = ith_edge->next_op_input_indexs(); | |||
| new_edge = | |||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true); | |||
| } else { | |||
| size_t output_index, input_index; | |||
| output_index = ith_edge->prev_op_output_index(); | |||
| input_index = ith_edge->next_op_input_index(); | |||
| new_edge = | |||
| std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false); | |||
| } | |||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||
| // replace the old successive edges with the new ones. | |||
| op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge); | |||
| ith_edge->next_operator()->ReplacePreEdge(op1, new_edge); | |||
| op1_new_succ_edges[i] = new_edge; | |||
| } | |||
| for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) { | |||
| auto &new_cost_map = op2_new_edges_cost[i]; | |||
| auto &ith_edge = op2_old_succ_edges[i]; | |||
| const auto &destination = ith_edge->next_operator(); | |||
| std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name(); | |||
| std::shared_ptr<Edge> new_edge; | |||
| if (ith_edge->is_combined()) { | |||
| std::vector<size_t> output_indexs, input_indexs; | |||
| output_indexs = ith_edge->prev_op_output_indexs(); | |||
| input_indexs = ith_edge->next_op_input_indexs(); | |||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true); | |||
| } else { | |||
| size_t output_index, input_index; | |||
| output_index = ith_edge->prev_op_output_index(); | |||
| input_index = ith_edge->next_op_input_index(); | |||
| new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false); | |||
| } | |||
| new_edge->SetCostMapAndInputOutput(new_cost_map); | |||
| // replace the old successive edges with the new ones. | |||
| destination->ReplacePreEdge(op2, new_edge); | |||
| op1->AddSuccEdge(new_edge); | |||
| op2_new_succ_edges[i] = new_edge; | |||
| } | |||
| MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded."; | |||
| return {op1_new_succ_edges, op2_new_succ_edges}; | |||
| return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2, | |||
| &op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges); | |||
| } | |||
| // Check the graph whether a TriangleElimination can be performed | |||
| @@ -1179,10 +1037,11 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||
| auto decision = | |||
| std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); | |||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| MS_EXCEPTION_IF_NULL(new_cost); | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| communication_without_para + gamma * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| MS_EXCEPTION_IF_NULL(tar_cost_list_new); | |||
| @@ -1263,9 +1122,10 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||
| auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost, | |||
| target_op_stra, tar_cost); | |||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | |||
| auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| communication_without_para + gamma * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->communication_forward_ = communication_forward; | |||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | |||
| @@ -1338,8 +1198,10 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| double new_commu_without = | |||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | |||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | |||
| const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| if (triangle_star_stra_overwrite) { | |||
| new_computation += right_op_cost->computation_cost_; | |||
| new_memory += right_op_cost->memory_with_reuse_; | |||
| new_commu_cost += right_op_cost->communication_cost_; | |||
| @@ -1352,8 +1214,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| left_op_stra, left_node_cost, right_op_stra, right_op_cost); | |||
| auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision); | |||
| new_cost->communication_without_parameter_ = new_commu_without; | |||
| new_cost->communication_with_partial_para_ = | |||
| new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | |||
| new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without); | |||
| new_cost->memory_with_reuse_ = new_memory; | |||
| new_cost->communication_forward_ = new_commu_forward; | |||
| left_node_clist_new->emplace_back(std::move(new_cost)); | |||
| @@ -1463,6 +1324,8 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||
| memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, | |||
| commu_without = merged_node_cost->communication_without_parameter_, | |||
| commu_forward = merged_node_cost->communication_forward_; | |||
| const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | |||
| if (i == 0) { | |||
| @@ -1478,7 +1341,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||
| commu_cost += succ_edges_costs[i]->communication_cost_; | |||
| commu_forward += succ_edges_costs[i]->communication_forward_; | |||
| commu_without += succ_edges_costs[i]->communication_without_parameter_; | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| if (triangle_star_stra_overwrite) { | |||
| computation_cost += succ_nodes_costs[i]->computation_cost_; | |||
| memory_cost += succ_nodes_costs[i]->memory_with_reuse_; | |||
| commu_cost += succ_nodes_costs[i]->communication_cost_; | |||
| @@ -1492,7 +1355,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||
| succ_nodes_stras, succ_nodes_costs); | |||
| auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision); | |||
| new_cost->communication_without_parameter_ = commu_without; | |||
| new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | |||
| new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without); | |||
| new_cost->memory_with_reuse_ = memory_cost; | |||
| new_cost->communication_forward_ = commu_forward; | |||
| first_succ_node_clist_new->emplace_back(std::move(new_cost)); | |||
| @@ -1895,7 +1758,8 @@ Status CostGraph::CorrectOpsMemoryCost() { | |||
| } | |||
| Status CostGraph::CalculateMemoryCost() { | |||
| if (RUN_PHASE == TRAINING_PHASE) { | |||
| const auto phase = CostModelContext::GetInstance()->run_phase(); | |||
| if (phase == TRAINING_PHASE) { | |||
| // training phase | |||
| if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | |||
| // Calculate operators' memory usage | |||
| @@ -34,32 +34,12 @@ class CostGraph; | |||
| using CostGraphPtr = std::shared_ptr<CostGraph>; | |||
| extern CostGraphPtr entire_costgraph; | |||
| extern size_t TOTAL_OPS; | |||
| extern double COST_MODEL_GAMMA; | |||
| extern bool COST_MODEL_SIMPLIFY_CALCULATION; | |||
| extern double DEVICE_MEMORY_CAPACITY; | |||
| extern double COST_MODEL_COMMUNI_THRESHOLD; | |||
| extern double COST_MODEL_COMMUNI_CONST; | |||
| extern double COST_MODEL_COMMUNI_BIAS; | |||
| extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; | |||
| extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | |||
| extern bool FULLY_USE_DEVICES; | |||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | |||
| extern bool MULTI_SUBGRAPHS; | |||
| extern bool DP_ALGO_ENABLE_APPROX; | |||
| extern double DP_ALGO_APPROX_EPSILON; | |||
| extern int64_t RUN_PHASE; | |||
| extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| extern bool DP_ALGO_SINGLE_LOOP; | |||
| class CostGraph { | |||
| // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | |||
| // output-input dependency relationship. | |||
| public: | |||
| CostGraph() { | |||
| dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; | |||
| costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; | |||
| costmodel_beta_ = DEFAULT_COST_MODEL_BETA_ASCEND; | |||
| } | |||
| CostGraph() {} | |||
| ~CostGraph() = default; | |||
| void Init(); | |||
| void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } | |||
| @@ -79,8 +59,6 @@ class CostGraph { | |||
| // An edge is uniquely identified by its name, and its output index and input index. | |||
| bool IsEdgeInCostGraph(const std::string &, size_t, size_t); | |||
| void SetDeviceMemoryAndCostParameter(); | |||
| std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>); | |||
| void DFS(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited, | |||
| const std::shared_ptr<CostGraph> &component); | |||
| @@ -91,10 +69,10 @@ class CostGraph { | |||
| CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); | |||
| CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory); | |||
| Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | |||
| Status SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &); | |||
| std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { | |||
| return edges_[{u_node, v_node}]; | |||
| } | |||
| double GetDeviceMemory() const { return dev_memory_; } | |||
| // Search the cost_list in the final graph, and determine the optimal one | |||
| Status SearchStrategy(); | |||
| @@ -194,7 +172,7 @@ class CostGraph { | |||
| Status InitReshapeStrategy(); | |||
| Status InitSelectedStrategy(); | |||
| OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; | |||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | |||
| // When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only | |||
| // once (instead of multiple times), this method is used to correct this. | |||
| Status CorrectOpsMemoryCost(); | |||
| // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies. | |||
| @@ -224,9 +202,6 @@ class CostGraph { | |||
| // Needed by rec_parser | |||
| std::vector<std::vector<std::string>> inputs_tensor_name_list_; | |||
| std::map<std::string, std::string> tuple_getitem_list_; | |||
| double dev_memory_; | |||
| double costmodel_alpha_; | |||
| double costmodel_beta_; | |||
| std::vector<OperatorInfoPtr> ops_; | |||
| std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_; | |||
| std::vector<std::shared_ptr<CostGraph>> connected_compoents_; | |||
| @@ -47,7 +47,7 @@ void CostModelContext::ResetCostModel() { | |||
| costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | |||
| costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | |||
| is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| run_phase_ = DEFAULT_RUN_PHASE; | |||
| run_phase_ = TRAINING_PHASE; | |||
| costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | |||
| costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | |||
| costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | |||
| @@ -70,37 +70,115 @@ void CostModelContext::ResetAlgoParameters() { | |||
| dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; | |||
| } | |||
| void CostModelContext::PrintCostModel() { | |||
| MS_LOG(INFO) << "device_memory_capacity: " << device_memory_capacity_ << "."; | |||
| MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; | |||
| MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; | |||
| MS_LOG(INFO) << "costmodel_gamma: " << costmodel_gamma_ << "."; | |||
| MS_LOG(INFO) << "costmodel_simplify_cal: " << costmodel_simplify_cal_ << "."; | |||
| MS_LOG(INFO) << "costmodel_communi_threshold: " << costmodel_communi_threshold_ << "."; | |||
| MS_LOG(INFO) << "costmodel_communi_const: " << costmodel_communi_const_ << "."; | |||
| MS_LOG(INFO) << "costmodel_communi_bias: " << costmodel_communi_bias_ << "."; | |||
| MS_LOG(INFO) << "is_multi_subgraphs: " << is_multi_subgraphs_ << "."; | |||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: " << triangle_star_strategy_overwrite_ << "."; | |||
| MS_LOG(INFO) << "dp_algo_enable_approxi: " << dp_algo_enable_approxi_ << "."; | |||
| MS_LOG(INFO) << "dp_algo_approxi_epsilon: " << dp_algo_approxi_epsilon_ << "."; | |||
| MS_LOG(INFO) << "dp_algo_single_loop: " << dp_algo_single_loop_ << "."; | |||
| MS_LOG(INFO) << "run_phase: " << run_phase_ << "."; | |||
| MS_LOG(INFO) << "tensor_slice_alignment_enable: " << tensor_slice_alignment_enable_ << "."; | |||
| MS_LOG(INFO) << "tensor_slice_align_size: " << tensor_slice_alignment_size_ << "."; | |||
| MS_LOG(INFO) << "fully_use_device: " << fully_use_device_ << "."; | |||
| MS_LOG(INFO) << "elementwise_stra_follow: " << elementwise_stra_follow_ << "."; | |||
| } | |||
| void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | |||
| if (device_target == kGPUDevice) { | |||
| costmodel_beta_ = DEFAULT_COST_MODEL_BETA_GPU; | |||
| } | |||
| } | |||
| void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; } | |||
| void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { | |||
| if (epsilon <= 0 || epsilon > 1) { | |||
| MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; | |||
| } | |||
| dp_algo_approxi_epsilon_ = epsilon; | |||
| } | |||
| void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; } | |||
| void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { | |||
| if (approxi) { | |||
| MS_LOG(INFO) << "dp_algo_enable_approx: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "dp_algo_enable_approx: false."; | |||
| } | |||
| dp_algo_enable_approxi_ = approxi; | |||
| } | |||
| void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } | |||
| void CostModelContext::set_device_memory_capacity(double dm_capacity) { | |||
| if (dm_capacity <= 0) { | |||
| MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; | |||
| } | |||
| device_memory_capacity_ = dm_capacity; | |||
| } | |||
| void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } | |||
| void CostModelContext::set_costmodel_alpha(double cm_alpha) { | |||
| if (cm_alpha <= 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; | |||
| } | |||
| costmodel_alpha_ = cm_alpha; | |||
| } | |||
| void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; } | |||
| void CostModelContext::set_costmodel_beta(double cm_beta) { | |||
| if (cm_beta <= 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; | |||
| } | |||
| costmodel_beta_ = cm_beta; | |||
| } | |||
| void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; } | |||
| void CostModelContext::set_costmodel_gamma(double cm_gamma) { | |||
| if ((cm_gamma < 0) || (cm_gamma > 1)) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; | |||
| } | |||
| costmodel_gamma_ = cm_gamma; | |||
| } | |||
| void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; } | |||
| void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { | |||
| if (cm_simplify) { | |||
| MS_LOG(INFO) << "costmodel_simplify_cal: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "costmodel_simplify_cal: false."; | |||
| } | |||
| costmodel_simplify_cal_ = cm_simplify; | |||
| } | |||
| void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { | |||
| if (cm_communi_th < 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; | |||
| } | |||
| costmodel_communi_threshold_ = cm_communi_th; | |||
| } | |||
| void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { | |||
| if (cm_communi_const < 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; | |||
| } | |||
| costmodel_communi_const_ = cm_communi_const; | |||
| } | |||
| void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } | |||
| void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { | |||
| if (cm_communi_bias < 0) { | |||
| MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; | |||
| } | |||
| costmodel_communi_bias_ = cm_communi_bias; | |||
| } | |||
| void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } | |||
| void CostModelContext::set_multi_subgraphs(bool multi_graphs) { | |||
| if (multi_graphs) { | |||
| MS_LOG(INFO) << "multi_subgraphs: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||
| } | |||
| is_multi_subgraphs_ = multi_graphs; | |||
| } | |||
| void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int64_t algorithm) { | |||
| costmodel_allreduce_fusion_algorithm_ = algorithm; | |||
| } | |||
| @@ -129,25 +207,64 @@ void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter | |||
| costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; | |||
| } | |||
| void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; } | |||
| void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { | |||
| if (ts_align) { | |||
| MS_LOG(INFO) << "tensor_slice_align_enable: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "tensor_slice_align_enable: false."; | |||
| } | |||
| tensor_slice_alignment_enable_ = ts_align; | |||
| } | |||
| void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { | |||
| if (ts_align_size == 0) { | |||
| MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; | |||
| } | |||
| tensor_slice_alignment_size_ = ts_align_size; | |||
| } | |||
| void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } | |||
| void CostModelContext::set_fully_use_device(bool fully_use) { | |||
| if (fully_use) { | |||
| MS_LOG(INFO) << "fully_use_devices: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "fully_use_devices: false."; | |||
| } | |||
| fully_use_device_ = fully_use; | |||
| } | |||
| void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | |||
| if (elementwise_follow) { | |||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; | |||
| } | |||
| elementwise_stra_follow_ = elementwise_follow; | |||
| } | |||
| void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { | |||
| if (overwrite) { | |||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; | |||
| } | |||
| triangle_star_strategy_overwrite_ = overwrite; | |||
| } | |||
| void CostModelContext::set_run_phase(int64_t phase) { run_phase_ = phase; } | |||
| void CostModelContext::set_run_phase(int64_t phase) { | |||
| if (phase != 0 && phase != 1) { | |||
| MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; | |||
| } | |||
| run_phase_ = phase; | |||
| } | |||
| void CostModelContext::set_dp_algo_single_loop(bool single_loop) { dp_algo_single_loop_ = single_loop; } | |||
| void CostModelContext::set_dp_algo_single_loop(bool single_loop) { | |||
| if (single_loop) { | |||
| MS_LOG(INFO) << "dp_algo_single_loop: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "dp_algo_single_loop: false."; | |||
| } | |||
| dp_algo_single_loop_ = single_loop; | |||
| } | |||
| struct CostRegister { | |||
| CostRegister() { | |||
| @@ -41,7 +41,6 @@ namespace parallel { | |||
| #define DEFAULT_FULLY_USE_DEVICES true | |||
| #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | |||
| #define DEFAULT_IS_MULTI_SUBGRAPHS false | |||
| #define DEFAULT_RUN_PHASE 0 | |||
| #define TRAINING_PHASE 0 | |||
| #define INFERENCE_PHASE 1 | |||
| #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; | |||
| @@ -58,6 +57,7 @@ class CostModelContext { | |||
| void ResetAlgoParameters(); | |||
| static std::shared_ptr<CostModelContext> GetInstance(); | |||
| void PrintCostModel(); | |||
| void set_costmodel_context_for_device(const std::string &); | |||
| // DEVICE_MEMORY_CAPACITY | |||
| @@ -475,7 +475,8 @@ Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num, | |||
| size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { | |||
| int64_t product = | |||
| std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>()); | |||
| if (!FULLY_USE_DEVICES) { | |||
| const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device(); | |||
| if (!fully_use_device) { | |||
| if (LongToSize(product) > dev_num) { | |||
| return FAILED; | |||
| } | |||
| @@ -564,7 +565,9 @@ void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_te | |||
| } | |||
| Status MatMulBase::CheckForTensorSliceValid() const { | |||
| if (!TENSOR_SLICE_ALIGNMENT_ENABLE) { | |||
| const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); | |||
| const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); | |||
| if (!align_enable) { | |||
| return SUCCESS; | |||
| } | |||
| if (inputs_tensor_info_.empty()) { | |||
| @@ -572,8 +575,8 @@ Status MatMulBase::CheckForTensorSliceValid() const { | |||
| } | |||
| for (auto &one_input_tensor : inputs_tensor_info_) { | |||
| auto slice_shape = one_input_tensor.slice_shape(); | |||
| if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || | |||
| (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { | |||
| if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) || | |||
| (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) { | |||
| return FAILED; | |||
| } | |||
| } | |||
| @@ -608,12 +611,12 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & | |||
| double computation_cost = | |||
| operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||
| result->communication_without_parameter_ = | |||
| operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| result->communication_with_partial_para_ = | |||
| result->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||
| result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); | |||
| // Breaking ties for preferring data parallelization | |||
| BreakingTiesForPerferringDataParallel(strategy, result); | |||
| @@ -839,7 +839,8 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input | |||
| for (auto &input_partition : inputs_partitions) { | |||
| product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>()); | |||
| } | |||
| if (!FULLY_USE_DEVICES) { | |||
| const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device(); | |||
| if (!fully_use_device) { | |||
| if (LongToSize(product) > dev_num) { | |||
| return FAILED; | |||
| } | |||
| @@ -1202,12 +1203,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { | |||
| double computation_cost = | |||
| operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||
| result->communication_without_parameter_ = | |||
| operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| result->communication_with_partial_para_ = | |||
| result->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||
| result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); | |||
| // Breaking ties for preferring data parallelization | |||
| BreakingTiesForPerferringDataParallel(strategy, result); | |||
| @@ -396,12 +396,12 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra | |||
| double computation_cost = | |||
| operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| const auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); | |||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||
| result->communication_without_parameter_ = | |||
| operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| result->communication_with_partial_para_ = | |||
| result->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||
| result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_); | |||
| // Breaking ties for preferring data parallelization | |||
| BreakingTiesForPerferringDataParallel(strategy, result); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,6 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/parallel/parallel_stub/executor_manager_stub.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -26,6 +27,5 @@ std::shared_ptr<Executor> ExecutorManager::GetExecutor(const std::string &device | |||
| executors_[device_key] = executor; | |||
| return executor; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,6 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ | |||
| #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_MANAGER_STUB_H_ | |||
| #include <set> | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -13,6 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H | |||
| #define MINDSPORE_CCSRC_PARALLEL_EXECUTOR_STUB_H | |||
| @@ -52,13 +52,18 @@ static bool IsInWhiteList(const CNodePtr &cnode) { | |||
| return false; | |||
| } | |||
| static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { | |||
| static void SetGradTag(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, size_t curr_depth) { | |||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||
| MS_LOG(WARNING) << "When setting the tags for Grad nodes, exceeded the maximum recursion depth: " | |||
| << MAX_RECURSIVE_DEPTH; | |||
| return; | |||
| } | |||
| const auto &node_users = manager->node_users()[node]; | |||
| for (auto &user_pair : node_users) { | |||
| auto user_node = user_pair.first; | |||
| if (!user_node->grad()) { | |||
| user_node->set_grad(true); | |||
| SetGradTag(user_node, manager); | |||
| SetGradTag(user_node, manager, ++curr_depth); | |||
| } | |||
| } | |||
| } | |||
| @@ -69,7 +74,7 @@ void PipelineTransformer::LabelRequiredGradCNode() { | |||
| if (!ParameterRequireGrad(parameter)) { | |||
| continue; | |||
| } | |||
| SetGradTag(parameter, manager_); | |||
| SetGradTag(parameter, manager_, 0); | |||
| } | |||
| } | |||
| @@ -234,7 +234,8 @@ void InitCostGraph() { | |||
| if (entire_costgraph == nullptr) { | |||
| entire_costgraph = std::make_shared<CostGraph>(); | |||
| } | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||
| CostModelContext::GetInstance()->PrintCostModel(); | |||
| entire_costgraph->Init(); | |||
| } | |||
| @@ -252,10 +253,11 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive | |||
| if (prim->name() == RESHAPE) { | |||
| MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; | |||
| } | |||
| const auto fully_use_devices = CostModelContext::GetInstance()->fully_use_device(); | |||
| // Set cost for this configured strategy | |||
| if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; | |||
| } else if (FULLY_USE_DEVICES) { | |||
| } else if (fully_use_devices) { | |||
| // If configured to fully use devices, then checking for the user-specified strategy | |||
| int64_t used_devices = operator_info->used_devices(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| @@ -323,7 +325,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| operator_info->set_cnode(cnode); | |||
| // key of strategy map | |||
| std::string strategy_key_name = ""; | |||
| auto param_names = NodeParameterName(cnode); | |||
| auto param_names = NodeParameterName(cnode, -1, 0); | |||
| if (!param_names.empty()) { | |||
| strategy_key_name = prim->name() + "_" + param_names[0].first; | |||
| } | |||
| @@ -394,7 +396,8 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| if (search_cnode == from_cnode_to_info.end()) { | |||
| size_t loop_index = 0; | |||
| bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); | |||
| if (DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { | |||
| const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); | |||
| if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { | |||
| const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; | |||
| bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && | |||
| (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && | |||
| @@ -430,7 +433,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueId(), operator_info)); | |||
| if (DP_ALGO_SINGLE_LOOP && is_in_loop) { | |||
| if (single_loop && is_in_loop) { | |||
| operators_in_forloop.push_back(operator_info); | |||
| ops_in_a_loop_.insert(operator_info->name()); | |||
| loop_to_ops[loop_index]++; | |||
| @@ -511,7 +514,8 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| if (search_cnode == from_cnode_to_info.end()) { | |||
| size_t loop_index = 0; | |||
| bool is_in_loop = GetLoopIndexFromCNode(cnode, &loop_index); | |||
| bool is_op_created = DP_ALGO_SINGLE_LOOP && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); | |||
| const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); | |||
| bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); | |||
| if (is_op_created) { | |||
| const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; | |||
| bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && | |||
| @@ -548,7 +552,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| << ", CNode fullname_with_scope: " << cnode->fullname_with_scope() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); | |||
| if (DP_ALGO_SINGLE_LOOP && is_in_loop) { | |||
| if (single_loop && is_in_loop) { | |||
| operators_in_forloop.push_back(operator_info); | |||
| ops_in_a_loop_.insert(operator_info->name()); | |||
| loop_to_ops[loop_index]++; | |||
| @@ -581,8 +585,9 @@ void CreateEdgeBetweenTwoOps(const OperatorInfoPtr &prev_op_info, const Operator | |||
| MS_LOG(INFO) << "The two operators in two separate for-loops, thus skip the edge."; | |||
| return; | |||
| } | |||
| const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); | |||
| bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || | |||
| (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); | |||
| (stra_follow && IsElementWiseOperator(prev_prim->name())); | |||
| if (follow_strategy) { | |||
| // Redistribution in not allowed on the edge. | |||
| // Elementwise operators have the same strategy as their previous operators. | |||
| @@ -1031,7 +1036,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const | |||
| graph = EliminateGraph(graph, eli_list, index_list); | |||
| size_t num_device = g_device_manager->DeviceNum(); | |||
| double device_memory = entire_costgraph->GetDeviceMemory(); | |||
| const auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); | |||
| if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { | |||
| MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; | |||
| } else { | |||
| @@ -1914,7 +1914,11 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| } | |||
| // find previous parallel care node's next node. | |||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes) { | |||
| bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vector<size_t> *indexes, size_t curr_depth) { | |||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||
| MS_LOG(WARNING) << "When find the previous node, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH; | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(unique_ids); | |||
| MS_EXCEPTION_IF_NULL(indexes); | |||
| if (!node->isa<CNode>()) { | |||
| @@ -1942,7 +1946,7 @@ bool FindPreNodes(const AnfNodePtr &node, vector<std::string> *unique_ids, vecto | |||
| find = true; | |||
| continue; | |||
| } | |||
| if (FindPreNodes(cnode, unique_ids, indexes)) { | |||
| if (FindPreNodes(cnode, unique_ids, indexes, ++curr_depth)) { | |||
| find = true; | |||
| continue; | |||
| } | |||
| @@ -1954,7 +1958,7 @@ void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *u | |||
| std::vector<size_t> *indexes) { | |||
| MS_EXCEPTION_IF_NULL(unique_ids); | |||
| CNodePtr cnode = root->get_return(); | |||
| if (!FindPreNodes(cnode, unique_ids, indexes)) { | |||
| if (!FindPreNodes(cnode, unique_ids, indexes, 0)) { | |||
| MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph"; | |||
| } | |||
| } | |||
| @@ -2044,7 +2048,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini | |||
| // load strategy checkpoint | |||
| // key of strategy map | |||
| std::string strategy_key_name = ""; | |||
| auto param_names = NodeParameterName(cnode); | |||
| auto param_names = NodeParameterName(cnode, -1, 0); | |||
| if (!param_names.empty()) { | |||
| strategy_key_name = prim->name() + "_" + param_names[0].first; | |||
| } | |||
| @@ -2151,13 +2155,18 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n | |||
| return nullptr; | |||
| } | |||
| std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) { | |||
| std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node, size_t curr_depth) { | |||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||
| MS_LOG(WARNING) << "When finding the next tensor layout for the parameter, exceeded the maximum recursion depth: " | |||
| << MAX_RECURSIVE_DEPTH; | |||
| return nullptr; | |||
| } | |||
| FuncGraphManagerPtr manager = node->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| AnfNodeIndexSet node_set = manager->node_users()[node]; | |||
| for (auto &node_pair : node_set) { | |||
| if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) { | |||
| auto layout_param = FindParameterNextLayout(node_pair.first); | |||
| auto layout_param = FindParameterNextLayout(node_pair.first, ++curr_depth); | |||
| if (!layout_param) { | |||
| continue; | |||
| } | |||
| @@ -2184,7 +2193,7 @@ std::shared_ptr<TensorLayout> FindParameterNextLayout(const AnfNodePtr &node) { | |||
| std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { | |||
| // Create DataParallel tensor layout for parameter(support WideDeep). | |||
| auto next_layout = FindParameterNextLayout(node); | |||
| auto next_layout = FindParameterNextLayout(node, 0); | |||
| if (next_layout != nullptr) { | |||
| return next_layout; | |||
| } | |||
| @@ -2329,14 +2338,19 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| CNodePtr HandleDependLoss(const CNodePtr &cnode) { | |||
| CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) { | |||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||
| MS_LOG(WARNING) << "When handling the loss node of Depend, exceeded the max recursive depth: " | |||
| << MAX_RECURSIVE_DEPTH; | |||
| return nullptr; | |||
| } | |||
| // Handle return->depend->loss | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == DEPEND) { | |||
| auto depend_before = cnode->input(1)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(depend_before); | |||
| return HandleDependLoss(depend_before); | |||
| return HandleDependLoss(depend_before, ++curr_depth); | |||
| } | |||
| return cnode; | |||
| } | |||
| @@ -2370,7 +2384,7 @@ LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { | |||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(pre_cnode); | |||
| } | |||
| pre_cnode = HandleDependLoss(pre_cnode); | |||
| pre_cnode = HandleDependLoss(pre_cnode, 0); | |||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||
| // notice: the GetNext op has not input | |||
| @@ -2792,7 +2806,12 @@ bool IsCohesiveNode(const CNodePtr &cnode) { | |||
| IsPrimitiveCNode(cnode, prim::kPrimAllGather) || IsPrimitiveCNode(cnode, prim::kPrimMiniStepAllGather); | |||
| } | |||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index) { | |||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth) { | |||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||
| MS_LOG(WARNING) << "When finding the parameters' name of a operator, exceeded the maximum depth: " | |||
| << MAX_RECURSIVE_DEPTH; | |||
| return {}; | |||
| } | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| std::vector<std::pair<std::string, int64_t>> param_names; | |||
| for (int64_t i = 0; i < UlongToLong(node_inputs.size()); ++i) { | |||
| @@ -2809,7 +2828,7 @@ std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &n | |||
| continue; | |||
| } | |||
| if (IsCohesiveNode(cnode) && cnode->inputs().size() >= 1) { | |||
| auto input_param_names = NodeParameterName(cnode, idx); | |||
| auto input_param_names = NodeParameterName(cnode, idx, 0); | |||
| param_names.insert(param_names.end(), input_param_names.begin(), input_param_names.end()); | |||
| } | |||
| } | |||
| @@ -2827,7 +2846,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| continue; | |||
| } | |||
| auto param_names = NodeParameterName(cnode); | |||
| auto param_names = NodeParameterName(cnode, -1, 0); | |||
| if (param_names.empty()) { | |||
| continue; | |||
| } | |||
| @@ -2950,13 +2969,17 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG | |||
| InsertNode(op, node, 2, pre_node, root, "shape"); | |||
| } | |||
| static AnfNodePtr FindGrad(const CNodePtr &cnode) { | |||
| static AnfNodePtr FindGrad(const CNodePtr &cnode, size_t curr_depth) { | |||
| if (curr_depth > MAX_RECURSIVE_DEPTH) { | |||
| MS_LOG(WARNING) << "When finding Grad nodes, exceeded the maximum recursion depth: " << MAX_RECURSIVE_DEPTH; | |||
| return nullptr; | |||
| } | |||
| for (auto &node : cnode->inputs()) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) { | |||
| return FindGrad(node->cast<CNodePtr>()); | |||
| return FindGrad(node->cast<CNodePtr>(), ++curr_depth); | |||
| } else { | |||
| return node; | |||
| } | |||
| @@ -2995,7 +3018,7 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) | |||
| continue; | |||
| } | |||
| auto root = node->func_graph(); | |||
| auto grad_node = FindGrad(cnode); | |||
| auto grad_node = FindGrad(cnode, 0); | |||
| if (grad_node) { | |||
| InsertShapeOp(cnode, grad_node, root); | |||
| } | |||
| @@ -139,7 +139,7 @@ bool IsLastStage(); | |||
| void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, | |||
| const FuncGraphManagerPtr &manager); | |||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index = -1); | |||
| std::vector<std::pair<std::string, int64_t>> NodeParameterName(const CNodePtr &node, int64_t index, size_t curr_depth); | |||
| void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes); | |||
| @@ -153,7 +153,6 @@ class TestDPAlgo : public UT::Common { | |||
| void TestDPAlgo::SetUp() { | |||
| cost_graph = std::make_shared<CostGraph>(); | |||
| cost_graph->SetDeviceMemoryAndCostParameter(); | |||
| RankList dev_list; | |||
| for (int32_t i = 0; i < 10; i++) { | |||
| @@ -52,7 +52,6 @@ class TestCostGraph : public UT::Common { | |||
| }; | |||
| void TestCostGraph::SetUp() { | |||
| cost_graph.SetDeviceMemoryAndCostParameter(); | |||
| RankList dev_list; | |||
| for (int32_t i = 0; i < 10; i++) { | |||
| @@ -305,7 +304,6 @@ TEST_F(TestCostGraph, test_ConstructConnectedComponents) { | |||
| TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) { | |||
| CostGraph entire_cost_graph; | |||
| entire_cost_graph.SetDeviceMemoryAndCostParameter(); | |||
| double memory = 1024.0; | |||
| CostPtrList clist_1, clist_2; | |||
| std::vector<CostPtrList> all_list; | |||
| @@ -371,7 +369,8 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) { | |||
| ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); | |||
| cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); | |||
| auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); | |||
| cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory()); | |||
| const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity(); | |||
| cost_graph.SelectCostWithMinInferenceTime(cost_list, device_mem_capacity); | |||
| } | |||
| TEST_F(TestCostGraph, test_EliminationOp) { | |||