Merge pull request !232 from Xiaoda/model-memory-cost-in-auto-paralleltags/v0.2.0-alpha
| @@ -207,15 +207,13 @@ struct ContractEliminationDecision : public Decision { | |||
| */ | |||
| struct TriangleEliminationDecision : public Decision { | |||
| TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, | |||
| StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost) | |||
| StrategyPtr left_stra, CostPtr l_node_cost) | |||
| : eliminated_op_strategy_(std::move(elimi_stra)), | |||
| eliminated_op_cost_(std::move(elimi_op_cost)), | |||
| left_edge_cost_(std::move(l_edge_cost)), | |||
| right_edge_cost_(std::move(r_edge_cost)), | |||
| left_node_strategy_(std::move(left_stra)), | |||
| left_node_cost_(std::move(l_node_cost)), | |||
| right_node_strategy_(std::move(right_stra)), | |||
| right_node_cost_(std::move(r_node_cost)) { | |||
| left_node_cost_(std::move(l_node_cost)) { | |||
| type_ = DecisionType::TRIANGLE_ELIMINATION; | |||
| } | |||
| @@ -225,8 +223,6 @@ struct TriangleEliminationDecision : public Decision { | |||
| CostPtr right_edge_cost_; | |||
| StrategyPtr left_node_strategy_; | |||
| CostPtr left_node_cost_; | |||
| StrategyPtr right_node_strategy_; | |||
| CostPtr right_node_cost_; | |||
| MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); | |||
| }; | |||
| @@ -76,7 +76,6 @@ Status GetStrategy(const CostGraphPtr& graph) { | |||
| auto l_r_edge = triangle_pair.second; | |||
| auto left_node = l_r_edge->prev_operator(); | |||
| auto right_node = l_r_edge->next_operator(); | |||
| auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; | |||
| auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; | |||
| MS_EXCEPTION_IF_NULL(left_edge); | |||
| @@ -86,8 +85,7 @@ Status GetStrategy(const CostGraphPtr& graph) { | |||
| right_edge = tmp; | |||
| } | |||
| auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); | |||
| auto elimi = | |||
| std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); | |||
| auto elimi = std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge); | |||
| eliminations.emplace_back(std::move(elimi)); | |||
| } | |||
| auto star_center = graph->CheckStarElimination(); | |||
| @@ -183,14 +181,13 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| auto left_edge = elimination->left_edge_; | |||
| auto eliminated_node = elimination->eliminated_node_; | |||
| auto right_edge = elimination->right_edge_; | |||
| auto right_node = elimination->right_node_; | |||
| auto decision = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>(); | |||
| eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); | |||
| left_edge->set_selected_cost(decision->left_edge_cost_); | |||
| right_edge->set_selected_cost(decision->right_edge_cost_); | |||
| // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. | |||
| left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | |||
| right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | |||
| MS_LOG(INFO) << "Recover triangleElimination succeeded."; | |||
| } else if ((*rit)->isa<StarElimination>()) { | |||
| auto elimination = (*rit)->cast<StarEliminationPtr>(); | |||
| @@ -204,9 +201,11 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| for (size_t i = 0; i < succ_edges.size(); ++i) { | |||
| succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]); | |||
| } | |||
| for (size_t j = 0; j < succ_nodes.size(); ++j) { | |||
| succ_nodes[j]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[j], decision->succ_ops_cost_list_[j]); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(succ_nodes[0]); | |||
| MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); | |||
| MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); | |||
| // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. | |||
| succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | |||
| MS_LOG(INFO) << "Recover starElimination succeeded."; | |||
| } else { | |||
| MS_LOG(ERROR) << "Unknown Elimination type."; | |||
| @@ -102,20 +102,17 @@ struct ContractElimination : public Elimination { | |||
| // Triangle Elimination | |||
| struct TriangleElimination : public Elimination { | |||
| TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, | |||
| OperatorInfoPtr r_node) | |||
| TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge) | |||
| : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), | |||
| eliminated_node_(std::move(elim_node)), | |||
| left_edge_(std::move(l_edge)), | |||
| left_node_(std::move(l_node)), | |||
| right_edge_(std::move(r_edge)), | |||
| right_node_(std::move(r_node)) {} | |||
| right_edge_(std::move(r_edge)) {} | |||
| OperatorInfoPtr eliminated_node_; | |||
| EdgePtr left_edge_; | |||
| OperatorInfoPtr left_node_; | |||
| EdgePtr right_edge_; | |||
| OperatorInfoPtr right_node_; | |||
| MS_DECLARE_PARENT(TriangleElimination, Elimination); | |||
| }; | |||
| @@ -119,6 +119,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co | |||
| double forward_comm_cost = tensor_redistribution.forward_comm_cost(); | |||
| double backward_comm_cost = tensor_redistribution.backward_comm_cost(); | |||
| double computation_cost = tensor_redistribution.computation_cost(); | |||
| double mem_cost = tensor_redistribution.memory_cost(); | |||
| // Now AllGather, ReduceScatter, AlltoAll don't support bool type | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| @@ -134,6 +135,7 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co | |||
| COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); | |||
| (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; | |||
| (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; | |||
| (*cost)->memory_with_reuse_ = mem_cost; | |||
| return Status::SUCCESS; | |||
| } | |||
| @@ -158,8 +160,8 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr | |||
| (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); | |||
| CostPtrList selected_cost_list(all_cost_list.size(), nullptr); | |||
| std::function<void(size_t, double, double, double)> recursive = | |||
| [&](size_t k, double computation, double communication, double communication_without_para) { | |||
| std::function<void(size_t, double, double, double, double)> recursive = | |||
| [&](size_t k, double computation, double memory, double communication, double communication_without_para) { | |||
| if (k == edges.size()) { | |||
| auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | |||
| CostPtr new_cost = std::make_shared<Cost>(computation, communication); | |||
| @@ -167,6 +169,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| new_cost->decision_ptr_ = decision; | |||
| result.push_back(new_cost); | |||
| return; | |||
| @@ -174,11 +177,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr | |||
| for (auto& c : all_cost_list[k]) { | |||
| MS_EXCEPTION_IF_NULL(c); | |||
| selected_cost_list[k] = c; | |||
| recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_, | |||
| recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, | |||
| communication + c->communication_cost_, | |||
| communication_without_para + c->communication_without_parameter_); | |||
| } | |||
| }; | |||
| recursive(0, 0, 0, 0); | |||
| recursive(0, 0.0, 0.0, 0.0, 0.0); | |||
| SimplifyForDreasingCommunicationWithPartialPara(&result); | |||
| return result; | |||
| } | |||
| @@ -218,6 +222,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||
| double communication_without_para = left_cost->communication_without_parameter_ + | |||
| middle_cost->communication_without_parameter_ + | |||
| right_cost->communication_without_parameter_; | |||
| double memory_cost = | |||
| left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_; | |||
| auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); | |||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | |||
| @@ -225,6 +231,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||
| cost->communication_without_parameter_ = communication_without_para; | |||
| cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| cost->memory_with_reuse_ = memory_cost; | |||
| ret_cost_list->emplace_back(std::move(cost)); | |||
| } | |||
| } | |||
| @@ -267,5 +274,24 @@ void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op, | |||
| MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; | |||
| } | |||
| } | |||
| Status Edge::CalculateMemoryCost() { | |||
| if (is_output_parameter_involve_ == -1) { | |||
| MS_LOG(ERROR) << "is_output_parameter_involve_ is unset."; | |||
| return FAILED; | |||
| } | |||
| if (is_output_parameter_involve_ == 0) { | |||
| // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is | |||
| // unnecessary to keep them in memory. | |||
| for (auto& cost_kv : cost_map_) { | |||
| auto& cost_v = cost_kv.second; | |||
| if (!cost_v.empty()) { | |||
| cost_v[0]->memory_with_reuse_ = 0; | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -133,7 +133,7 @@ class Edge { | |||
| void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } | |||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | |||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | |||
| Status CalculateMemoryCost() const { return SUCCESS; } | |||
| Status CalculateMemoryCost(); | |||
| private: | |||
| std::string edge_name_; | |||
| @@ -248,6 +248,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: | |||
| MS_EXCEPTION_IF_NULL(cost2); | |||
| MS_EXCEPTION_IF_NULL(cost3); | |||
| double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; | |||
| double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; | |||
| double commmunication = | |||
| cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; | |||
| double communication_without_para = cost1->communication_without_parameter_ + | |||
| @@ -260,6 +261,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: | |||
| cost->communication_without_parameter_ = communication_without_para; | |||
| cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para); | |||
| cost->memory_with_reuse_ = memory; | |||
| ret.push_back(cost); | |||
| } | |||
| } | |||
| @@ -288,6 +290,7 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { | |||
| new_cost->communication_with_partial_para_ = | |||
| cost1->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); | |||
| new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; | |||
| ret.push_back(new_cost); | |||
| } | |||
| } | |||
| @@ -297,9 +300,14 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { | |||
| } | |||
| CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { | |||
| if (cost_list.empty() || cost_list[0]->computation_cost_ >= memory) { | |||
| return nullptr; | |||
| CostPtrList after_mem_filter; | |||
| // Filter out the valid costs | |||
| for (auto& a_cost : cost_list) { | |||
| if (a_cost->memory_with_reuse_ <= memory) { | |||
| after_mem_filter.emplace_back(std::move(a_cost)); | |||
| } | |||
| } | |||
| std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { | |||
| MS_EXCEPTION_IF_NULL(cost_x); | |||
| if (init == nullptr || cost_x->computation_cost_ < memory) { | |||
| @@ -308,7 +316,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, | |||
| return init; | |||
| }; | |||
| CostPtr ret = nullptr; | |||
| return std::accumulate(cost_list.begin(), cost_list.end(), ret, LocalCompare); | |||
| return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); | |||
| } | |||
| CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { | |||
| @@ -318,36 +326,46 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d | |||
| MS_LOG(ERROR) << "Final cost list is null."; | |||
| return nullptr; | |||
| } | |||
| CostPtr ret = cost_list[0]; | |||
| MS_EXCEPTION_IF_NULL(ret); | |||
| if (ret->computation_cost_ >= memory) { | |||
| MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->computation_cost_ | |||
| CostPtrList after_mem_filter; | |||
| double minimum_memory = DBL_MAX; | |||
| // Filter out the valid costs. | |||
| for (auto& a_cost : cost_list) { | |||
| if (a_cost->memory_with_reuse_ <= memory) { | |||
| after_mem_filter.emplace_back(std::move(a_cost)); | |||
| } else if (a_cost->memory_with_reuse_ < minimum_memory) { | |||
| minimum_memory = a_cost->memory_with_reuse_; | |||
| } | |||
| } | |||
| if (after_mem_filter.empty()) { | |||
| MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory | |||
| << ", the memory capacity is: " << memory << "."; | |||
| return nullptr; | |||
| } | |||
| // Init the returned value with first cost. | |||
| CostPtr ret = after_mem_filter[0]; | |||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; | |||
| MS_LOG(INFO) << "minimum: " << minimum << ", computation_cost_: " << ret->computation_cost_ | |||
| MS_LOG(INFO) << "Cost 0: " | |||
| << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ | |||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | |||
| << ", communication_cost_: " << ret->communication_cost_ | |||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | |||
| for (size_t i = 1; i < cost_list.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(cost_list[i]); | |||
| if (cost_list[i]->computation_cost_ >= memory) { | |||
| MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ | |||
| << ", is larger than the memory capacity: " << memory << "."; | |||
| break; | |||
| } | |||
| MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ | |||
| << ", communication_with_partial_para_: " << cost_list[i]->communication_with_partial_para_ | |||
| << ", communication_cost_: " << cost_list[i]->communication_cost_ | |||
| << ", communication_without_parameter_: " << cost_list[i]->communication_without_parameter_ << "."; | |||
| auto tmp = costmodel_alpha_ * cost_list[i]->computation_cost_ + | |||
| costmodel_beta_ * cost_list[i]->communication_with_partial_para_; | |||
| MS_LOG(INFO) << "tmp: " << tmp; | |||
| MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum; | |||
| for (size_t i = 1; i < after_mem_filter.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(after_mem_filter[i]); | |||
| MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ | |||
| << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ | |||
| << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ | |||
| << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ | |||
| << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ | |||
| << "."; | |||
| auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + | |||
| costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; | |||
| MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; | |||
| if (minimum > tmp) { | |||
| minimum = tmp; | |||
| ret = cost_list[i]; | |||
| MS_LOG(INFO) << "selected: " << i; | |||
| ret = after_mem_filter[i]; | |||
| MS_LOG(INFO) << "Selected: " << i; | |||
| } | |||
| } | |||
| return ret; | |||
| @@ -356,17 +374,21 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d | |||
| CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList>& all_cost_list, | |||
| double available_memory) { | |||
| CostPtrList selected_cost_list(all_cost_list.size(), nullptr); | |||
| double minimum = 0.0, total_memory = 0.0; | |||
| double minimum = DBL_MAX, total_memory = 0.0; | |||
| CostPtrList ret(all_cost_list.size(), nullptr); | |||
| // Check whether valid costs exist. | |||
| for (size_t i = 0; i < all_cost_list.size(); ++i) { | |||
| if (all_cost_list[i][0] == nullptr) { | |||
| MS_LOG(ERROR) << "The cost list " << i << " is empty."; | |||
| return ret; | |||
| } else { | |||
| total_memory += all_cost_list[i][0]->computation_cost_; | |||
| minimum += costmodel_alpha_ * all_cost_list[i][0]->computation_cost_ + | |||
| costmodel_beta_ * all_cost_list[i][0]->communication_with_partial_para_; | |||
| ret[i] = all_cost_list[i][0]; | |||
| double memory_i_cost = DBL_MAX; | |||
| for (size_t j = 0; j < all_cost_list[i].size(); ++j) { | |||
| if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) { | |||
| memory_i_cost = all_cost_list[i][j]->memory_with_reuse_; | |||
| } | |||
| } | |||
| total_memory += memory_i_cost; | |||
| } | |||
| } | |||
| if (total_memory >= available_memory) { | |||
| @@ -381,7 +403,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect | |||
| double tmp_memory = 0.0, tmp_minimum = 0.0; | |||
| for (size_t i = 0; i < selected_cost_list.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(selected_cost_list[i]); | |||
| tmp_memory += selected_cost_list[i]->computation_cost_; | |||
| tmp_memory += selected_cost_list[i]->memory_with_reuse_; | |||
| tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + | |||
| costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; | |||
| } | |||
| @@ -816,6 +838,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||
| auto& tar_cost = tar_cost_list[k]; | |||
| MS_EXCEPTION_IF_NULL(tar_cost); | |||
| double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; | |||
| double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | |||
| double communication = | |||
| op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | |||
| double communication_without_para = op_cost->communication_without_parameter_ + | |||
| @@ -829,6 +852,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| MS_EXCEPTION_IF_NULL(tar_cost_list_new); | |||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| @@ -894,6 +918,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||
| MS_EXCEPTION_IF_NULL(tar_cost); | |||
| double computation = | |||
| contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; | |||
| double memory = | |||
| contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; | |||
| double communication = | |||
| contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | |||
| double communication_without_para = contract_op_cost->communication_without_parameter_ + | |||
| @@ -906,6 +932,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||
| new_cost->communication_without_parameter_ = communication_without_para; | |||
| new_cost->communication_with_partial_para_ = | |||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||
| new_cost->memory_with_reuse_ = memory; | |||
| tar_cost_list_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| } | |||
| @@ -966,23 +993,22 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| for (auto& left_node_cost : left_node_clist_origin) { | |||
| MS_EXCEPTION_IF_NULL(left_node_cost); | |||
| double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + | |||
| left_node_cost->computation_cost_ + right_edge_cost->computation_cost_ + | |||
| right_op_cost->computation_cost_; | |||
| left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; | |||
| double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ + | |||
| left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; | |||
| double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + | |||
| left_node_cost->communication_cost_ + right_edge_cost->communication_cost_ + | |||
| right_op_cost->communication_cost_; | |||
| left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; | |||
| double new_commu_without = | |||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | |||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_ + | |||
| right_op_cost->communication_without_parameter_; | |||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | |||
| auto decision = | |||
| std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, | |||
| left_op_stra, left_node_cost, right_op_stra, right_op_cost); | |||
| auto decision = std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, | |||
| right_edge_cost, left_op_stra, left_node_cost); | |||
| auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision); | |||
| new_cost->communication_without_parameter_ = new_commu_without; | |||
| new_cost->communication_with_partial_para_ = | |||
| new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | |||
| new_cost->memory_with_reuse_ = new_memory; | |||
| left_node_clist_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| } | |||
| @@ -1085,14 +1111,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n | |||
| succ_nodes_costs[0] = first_succ_node_cost; | |||
| double computation_cost = merged_node_cost->computation_cost_, | |||
| commu_cost = merged_node_cost->communication_cost_, | |||
| memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, | |||
| commu_without = merged_node_cost->communication_without_parameter_; | |||
| for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | |||
| MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | |||
| computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; | |||
| commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; | |||
| commu_without += succ_edges_costs[i]->communication_without_parameter_ + | |||
| succ_nodes_costs[i]->communication_without_parameter_; | |||
| if (i == 0) { | |||
| computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; | |||
| memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; | |||
| commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; | |||
| commu_without += succ_edges_costs[i]->communication_without_parameter_ + | |||
| succ_nodes_costs[i]->communication_without_parameter_; | |||
| } else { | |||
| computation_cost += succ_edges_costs[i]->computation_cost_; | |||
| memory_cost += succ_edges_costs[i]->memory_with_reuse_; | |||
| commu_cost += succ_edges_costs[i]->communication_cost_; | |||
| commu_without += succ_edges_costs[i]->communication_without_parameter_; | |||
| } | |||
| } | |||
| auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs, | |||
| @@ -1100,6 +1134,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n | |||
| auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision); | |||
| new_cost->communication_without_parameter_ = commu_without; | |||
| new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | |||
| new_cost->memory_with_reuse_ = memory_cost; | |||
| first_succ_node_clist_new->emplace_back(std::move(new_cost)); | |||
| } | |||
| } | |||
| @@ -1259,5 +1294,35 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c | |||
| } | |||
| return nullptr; | |||
| } | |||
| Status CostGraph::CorrectOpsMemoryCost() { | |||
| for (auto& one_op : ops_) { | |||
| if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { | |||
| if (one_op->GetAliveSuccEdges().size() > 1) { | |||
| // Filter out the case when the TmpIdentity being used by multiple operators | |||
| std::map<size_t, int> output_count; | |||
| for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { | |||
| auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); | |||
| output_count[output_index]++; | |||
| } | |||
| for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { | |||
| auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); | |||
| if (output_count[output_index] <= 1) { | |||
| continue; | |||
| } | |||
| auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator(); | |||
| MS_EXCEPTION_IF_NULL(next_op); | |||
| auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index(); | |||
| if (next_op->CorrectMemoryCost(input_index) != SUCCESS) { | |||
| MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name() | |||
| << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; | |||
| return FAILED; | |||
| } | |||
| output_count[output_index]--; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -187,6 +187,9 @@ class CostGraph { | |||
| size_t GetNumPairs() const { return edges_.size(); } | |||
| Status InitSelectedStrategy(); | |||
| OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const; | |||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | |||
| // once (instead of multiple times), this method is used to correct this. | |||
| Status CorrectOpsMemoryCost(); | |||
| // Needed by rec_parser | |||
| void add_inputs_tensor_name(const std::vector<std::string>& inputs_tensor_name) { | |||
| inputs_tensor_name_list_.push_back(inputs_tensor_name); | |||
| @@ -17,6 +17,7 @@ | |||
| #include "parallel/auto_parallel/operator_costmodel.h" | |||
| #include <random> | |||
| #include <algorithm> | |||
| #include "parallel/device_matrix.h" | |||
| #include "parallel/tensor_layout/tensor_redistribution.h" | |||
| @@ -24,12 +25,44 @@ namespace mindspore { | |||
| namespace parallel { | |||
| void OperatorCost::set_is_parameter(const std::vector<bool>& is_parameter) { is_parameter_ = is_parameter; } | |||
| void OperatorCost::set_is_parameter_involve(const std::vector<bool>& is_parameter_inv) { | |||
| is_parameter_involve_ = is_parameter_inv; | |||
| } | |||
| void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } | |||
| void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths, | |||
| const std::vector<size_t>& output_lengths) { | |||
| inputs_type_lengths_ = input_lengths; | |||
| outputs_type_lengths_ = output_lengths; | |||
| } | |||
| double OperatorCost::GetMemoryCost(const std::vector<TensorInfo>& inputs, | |||
| const std::vector<TensorInfo>& outputs) const { | |||
| double result = 0.0; | |||
| if (output_parameter_involve_ == 1) { | |||
| // When this operator has multiple outputs, they all contributes to the memory. | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]); | |||
| } | |||
| bool is_any_para_inv = | |||
| std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); | |||
| if (is_any_para_inv) { | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (is_parameter_[i]) { | |||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||
| } else if (inputs_related_ && (!is_parameter_involve_[i])) { | |||
| // When the inputs of this operator are related, and they are not parameter-involved, then they are included | |||
| // in the memory cost. | |||
| result += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| // return the per device communication cost in the forward phase. | |||
| double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| const int32_t&) const { | |||
| @@ -72,11 +105,11 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, co | |||
| return result; | |||
| } | |||
| // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||
| const std::vector<TensorInfo>& outputs, const int32_t&) const { | |||
| // In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | |||
| // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | |||
| double result = 0.0; | |||
| TensorInfo output0 = outputs[0]; | |||
| Shape input0_slice_shape = inputs[0].slice_shape(); | |||
| @@ -91,11 +124,11 @@ double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inpu | |||
| return result; | |||
| } | |||
| // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||
| const int32_t& stage_id) const { | |||
| // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) | |||
| // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) | |||
| double result = 0.0; | |||
| if (is_parameter_[1]) { | |||
| TensorInfo input1 = inputs[1]; // tensor B | |||
| @@ -145,7 +178,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs | |||
| return result; | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||
| const int32_t&) const { | |||
| @@ -154,7 +187,7 @@ double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>& | |||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||
| const int32_t&) const { | |||
| @@ -189,17 +222,17 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, c | |||
| return result; | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||
| const int32_t&) const { | |||
| // In the forward phase, the memory cost = slice(A) | |||
| // In the forward phase, the computation cost = slice(A) | |||
| TensorInfo input0 = inputs[0]; | |||
| Shape input0_slice_shape = input0.slice_shape(); | |||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||
| @@ -221,17 +254,15 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::paralle | |||
| return 0.0; | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||
| double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||
| const int32_t&) const { | |||
| TensorInfo input0_info = inputs[0]; | |||
| Shape input0_slice_shape = input0_info.slice_shape(); | |||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||
| return 0.0; | |||
| } | |||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||
| @@ -239,6 +270,11 @@ double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore:: | |||
| return 0.0; | |||
| } | |||
| // Return the per device PEAK memory cost contributed by this operator in a training iteration. | |||
| double TmpIdentityCost::GetMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&) const { | |||
| return 0.0; | |||
| } | |||
| double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||
| const int32_t&) const { | |||
| @@ -284,11 +320,11 @@ double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, con | |||
| return result; | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||
| const int32_t&) const { | |||
| // In forward phase, the memory cost = slice(A) + slice(B) | |||
| // In forward phase, the computation cost = slice(A) + slice(B) | |||
| Shape input0_slice_shape = inputs[0].slice_shape(); | |||
| Shape input1_slice_shape = inputs[1].slice_shape(); | |||
| double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) + | |||
| @@ -296,12 +332,12 @@ double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& input | |||
| return result; | |||
| } | |||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||
| const int32_t& stage_id) const { | |||
| // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) | |||
| // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) | |||
| double result = 0.0; | |||
| if (is_parameter_[1]) { | |||
| TensorInfo input1 = inputs[1]; // tensor B | |||
| @@ -337,16 +373,16 @@ double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std | |||
| return 0.0; | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||
| const int32_t&) const { | |||
| // In onehot's forward phase, the memory cost = slice(A) | |||
| // In onehot's forward phase, the computation cost = slice(A) | |||
| Shape input0_slice_shape = inputs[0].slice_shape(); | |||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||
| } | |||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||
| const int32_t&) const { | |||
| @@ -367,12 +403,12 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector< | |||
| return 0.0; | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||
| const std::vector<TensorInfo>&, | |||
| const int32_t&) const { | |||
| // In forward phase, the memory cost = slice(A) + slice(B) | |||
| // In forward phase, the computation cost = slice(A) + slice(B) | |||
| Shape input0_slice_shape = inputs[0].slice_shape(); | |||
| Shape input1_slice_shape = inputs[1].slice_shape(); | |||
| double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) + | |||
| @@ -380,7 +416,7 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v | |||
| return result; | |||
| } | |||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, | |||
| const std::vector<TensorInfo>&, | |||
| @@ -410,7 +446,7 @@ double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const st | |||
| return 0.0; | |||
| } | |||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const { | |||
| @@ -427,7 +463,7 @@ double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp | |||
| return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); | |||
| } | |||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | |||
| // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes | |||
| // this operator uses | |||
| double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||
| @@ -43,10 +43,20 @@ double ListProduct(std::vector<T> vec) { | |||
| // entries timing the length of each entry's data type | |||
| class OperatorCost { | |||
| public: | |||
| OperatorCost() { | |||
| explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { | |||
| // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked | |||
| for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { | |||
| is_parameter_.push_back(false); | |||
| is_parameter_involve_.push_back(false); | |||
| inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||
| outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||
| } | |||
| } | |||
| OperatorCost() : inputs_related_(false) { | |||
| // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked | |||
| for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { | |||
| is_parameter_.push_back(false); | |||
| is_parameter_involve_.push_back(false); | |||
| inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||
| outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); | |||
| } | |||
| @@ -54,6 +64,8 @@ class OperatorCost { | |||
| virtual ~OperatorCost() = default; | |||
| void set_is_parameter(const std::vector<bool>& is_parameter); | |||
| void set_is_parameter_involve(const std::vector<bool>&); | |||
| void set_output_parameter_involve(int); | |||
| void SetInputAndOutputTypeLength(const std::vector<size_t>& input_lengths, const std::vector<size_t>& output_lengths); | |||
| std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; } | |||
| std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; } | |||
| @@ -72,8 +84,19 @@ class OperatorCost { | |||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0; | |||
| virtual double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, | |||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0; | |||
| // per device PEAK memory cost in a training iteration | |||
| // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), | |||
| // plus necessary inputs. | |||
| virtual double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const; | |||
| protected: | |||
| // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of | |||
| // pre-operator that has parameters as input. | |||
| std::vector<bool> is_parameter_involve_; | |||
| int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved | |||
| // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while | |||
| // Mul's two inputs are dependent (related). | |||
| bool inputs_related_; | |||
| // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | |||
| std::vector<bool> is_parameter_; | |||
| // for each input and output, the followings record the number of bytes of each element | |||
| @@ -85,7 +108,8 @@ using OperatorCostPtr = std::shared_ptr<OperatorCost>; | |||
| class MatMulCost : public OperatorCost { | |||
| public: | |||
| MatMulCost() = default; | |||
| explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| MatMulCost() : OperatorCost(true) {} | |||
| ~MatMulCost() override = default; | |||
| // per device communication cost | |||
| @@ -108,12 +132,12 @@ class MatMulCost : public OperatorCost { | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| const int32_t& stage_id) const override; | |||
| }; | |||
| using MatMulCostPtr = std::shared_ptr<MatMulCost>; | |||
| class ActivationCost : public OperatorCost { | |||
| public: | |||
| ActivationCost() = default; | |||
| explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ActivationCost() : OperatorCost(false) {} | |||
| ~ActivationCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -133,14 +157,14 @@ class ActivationCost : public OperatorCost { | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| const int32_t& stage_id) const override; | |||
| }; | |||
| using ActivationCostPtr = std::shared_ptr<ActivationCost>; | |||
| using TransposeCost = ActivationCost; | |||
| using TransposeCostPtr = std::shared_ptr<TransposeCost>; | |||
| class SoftmaxCost : public OperatorCost { | |||
| public: | |||
| SoftmaxCost() = default; | |||
| explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| SoftmaxCost() : OperatorCost(false) {} | |||
| ~SoftmaxCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -160,12 +184,12 @@ class SoftmaxCost : public OperatorCost { | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| const int32_t&) const override; | |||
| }; | |||
| using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>; | |||
| class TmpIdentityCost : public OperatorCost { | |||
| public: | |||
| TmpIdentityCost() = default; | |||
| explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| TmpIdentityCost() : OperatorCost(false) {} | |||
| ~TmpIdentityCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -184,12 +208,15 @@ class TmpIdentityCost : public OperatorCost { | |||
| const int32_t& stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| const int32_t& stage_id) const override; | |||
| // per device PEAK memory cost in a training iteration | |||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override; | |||
| }; | |||
| using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>; | |||
| class BatchParallelCost : public OperatorCost { | |||
| public: | |||
| BatchParallelCost() = default; | |||
| explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| BatchParallelCost() : OperatorCost(false) {} | |||
| ~BatchParallelCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -217,7 +244,8 @@ using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>; | |||
| class VirtualDatasetCost : public OperatorCost { | |||
| public: | |||
| VirtualDatasetCost() = default; | |||
| explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| VirtualDatasetCost() : OperatorCost(false) {} | |||
| ~VirtualDatasetCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -244,12 +272,17 @@ class VirtualDatasetCost : public OperatorCost { | |||
| const int32_t&) const override { | |||
| return 0.0; | |||
| } | |||
| // per device PEAK memory cost in a training iteration | |||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs) const override { | |||
| return 0.0; | |||
| } | |||
| }; | |||
| using VirtualDatasetCostPtr = std::shared_ptr<VirtualDatasetCost>; | |||
| class GeneratorBaseCost : public OperatorCost { | |||
| public: | |||
| GeneratorBaseCost() = default; | |||
| explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| GeneratorBaseCost() : OperatorCost(false) {} | |||
| ~GeneratorBaseCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -283,7 +316,8 @@ using GeneratorBaseCostPtr = std::shared_ptr<GeneratorBaseCost>; | |||
| class PReLUCost : public OperatorCost { | |||
| public: | |||
| PReLUCost() = default; | |||
| explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| PReLUCost() : OperatorCost(true) {} | |||
| ~PReLUCost() override = default; | |||
| // per device communication cost | |||
| @@ -310,7 +344,8 @@ using PReLUCostPtr = std::shared_ptr<PReLUCost>; | |||
| class OneHotCost : public OperatorCost { | |||
| public: | |||
| OneHotCost() = default; | |||
| explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| OneHotCost() : OperatorCost(true) {} | |||
| ~OneHotCost() override = default; | |||
| // per device communication cost | |||
| @@ -337,7 +372,8 @@ using OneHotCostPtr = std::shared_ptr<OneHotCost>; | |||
| class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { | |||
| public: | |||
| SoftmaxCrossEntropyWithLogitsCost() = default; | |||
| explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} | |||
| ~SoftmaxCrossEntropyWithLogitsCost() override = default; | |||
| // per device communication cost | |||
| @@ -364,7 +400,8 @@ using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropy | |||
| class ReshapeCost : public OperatorCost { | |||
| public: | |||
| ReshapeCost() = default; | |||
| explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ReshapeCost() : OperatorCost(true) {} | |||
| ~ReshapeCost() override = default; | |||
| @@ -396,7 +433,8 @@ using ReshapeCostPtr = std::shared_ptr<ReshapeCost>; | |||
| class ArithmeticCost : public OperatorCost { | |||
| public: | |||
| ArithmeticCost() = default; | |||
| explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ArithmeticCost() : OperatorCost(false) {} | |||
| ~ArithmeticCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -425,7 +463,8 @@ using BiasAddCostPtr = std::shared_ptr<BiasAddCost>; | |||
| class ReduceMethodCost : public OperatorCost { | |||
| public: | |||
| ReduceMethodCost() = default; | |||
| explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| ReduceMethodCost() : OperatorCost(true) {} | |||
| ~ReduceMethodCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -455,7 +494,8 @@ using ReduceMethodCostPtr = std::shared_ptr<ReduceMethodCost>; | |||
| class ReduceMeanCost : public ReduceMethodCost { | |||
| public: | |||
| ReduceMeanCost() = default; | |||
| explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} | |||
| ReduceMeanCost() : ReduceMethodCost(true) {} | |||
| ~ReduceMeanCost() override = default; | |||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -465,7 +505,8 @@ using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>; | |||
| class GetNextCost : public OperatorCost { | |||
| public: | |||
| GetNextCost() = default; | |||
| explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| GetNextCost() : OperatorCost(false) {} | |||
| ~GetNextCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -499,7 +540,8 @@ using GetNextCostPtr = std::shared_ptr<GetNextCost>; | |||
| class DropOutCost : public OperatorCost { | |||
| public: | |||
| DropOutCost() = default; | |||
| explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| DropOutCost() : OperatorCost(true) {} | |||
| ~DropOutCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -530,7 +572,8 @@ using DropOutCostPtr = std::shared_ptr<DropOutCost>; | |||
| class GatherV2Cost : public OperatorCost { | |||
| public: | |||
| GatherV2Cost() = default; | |||
| explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| GatherV2Cost() : OperatorCost(true) {} | |||
| ~GatherV2Cost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||
| @@ -51,7 +51,7 @@ class Activation : public ActivationBase { | |||
| public: | |||
| Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>()) {} | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>(false)) {} | |||
| ~Activation() override = default; | |||
| Status GenerateStrategies(int32_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | |||
| @@ -102,7 +102,7 @@ class Softmax : public ActivationBase { | |||
| public: | |||
| explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {} | |||
| : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>(false)) {} | |||
| ~Softmax() override = default; | |||
| Status GenerateStrategies(int32_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | |||
| @@ -32,8 +32,8 @@ namespace parallel { | |||
| class ArithmeticBase : public OperatorInfo { | |||
| public: | |||
| ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>()) {} | |||
| const PrimitiveAttrs& attrs, OperatorCostPtr cost) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~ArithmeticBase() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| Status InitForCostModel(const StrategyPtr& strategy) override; | |||
| @@ -56,7 +56,7 @@ class ArithmeticBase : public OperatorInfo { | |||
| class SubInfo : public ArithmeticBase { | |||
| public: | |||
| SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| ~SubInfo() override = default; | |||
| }; | |||
| @@ -64,21 +64,21 @@ class TensorAddInfo : public ArithmeticBase { | |||
| public: | |||
| TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| ~TensorAddInfo() override = default; | |||
| }; | |||
| class MulInfo : public ArithmeticBase { | |||
| public: | |||
| MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| ~MulInfo() override = default; | |||
| }; | |||
| class DivInfo : public ArithmeticBase { | |||
| public: | |||
| DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| ~DivInfo() override = default; | |||
| }; | |||
| @@ -86,7 +86,7 @@ class RealDivInfo : public ArithmeticBase { | |||
| public: | |||
| RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| ~RealDivInfo() override = default; | |||
| }; | |||
| @@ -94,14 +94,14 @@ class FloorDivInfo : public ArithmeticBase { | |||
| public: | |||
| FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| ~FloorDivInfo() override = default; | |||
| }; | |||
| class PowInfo : public ArithmeticBase { | |||
| public: | |||
| PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| ~PowInfo() override = default; | |||
| }; | |||
| @@ -109,7 +109,7 @@ class GreaterInfo : public ArithmeticBase { | |||
| public: | |||
| GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| ~GreaterInfo() override = default; | |||
| }; | |||
| @@ -117,7 +117,7 @@ class AssignSubInfo : public ArithmeticBase { | |||
| public: | |||
| AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| ~AssignSubInfo() override = default; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -29,9 +29,13 @@ namespace mindspore { | |||
| namespace parallel { | |||
| class BatchParallelInfo : public OperatorInfo { | |||
| public: | |||
| BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs, OperatorCostPtr cost) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} | |||
| BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(false)), | |||
| dev_num_(1) {} | |||
| ~BatchParallelInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| @@ -58,7 +62,7 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { | |||
| public: | |||
| SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, | |||
| const Shapes& outputs_shape, const PrimitiveAttrs& attrs) | |||
| : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {} | |||
| : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>(true)) {} | |||
| ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| }; | |||
| @@ -34,7 +34,7 @@ class BiasAddInfo : public OperatorInfo { | |||
| public: | |||
| BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>(false)) {} | |||
| ~BiasAddInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ir/value.h" | |||
| @@ -31,7 +32,7 @@ class EqualInfo : public ArithmeticBase { | |||
| public: | |||
| EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| ~EqualInfo() override = default; | |||
| }; | |||
| @@ -39,7 +40,7 @@ class NotEqualInfo : public ArithmeticBase { | |||
| public: | |||
| NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {} | |||
| ~NotEqualInfo() override = default; | |||
| }; | |||
| @@ -47,7 +48,7 @@ class MaximumInfo : public ArithmeticBase { | |||
| public: | |||
| MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| ~MaximumInfo() override = default; | |||
| }; | |||
| @@ -55,7 +56,7 @@ class MinimumInfo : public ArithmeticBase { | |||
| public: | |||
| MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {} | |||
| : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {} | |||
| ~MinimumInfo() override = default; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -33,7 +33,7 @@ class DropoutDoMaskInfo : public OperatorInfo { | |||
| public: | |||
| DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>(true)) {} | |||
| ~DropoutDoMaskInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| @@ -32,7 +32,7 @@ class GetNextInfo : public OperatorInfo { | |||
| public: | |||
| GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {} | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>(false)) {} | |||
| ~GetNextInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -36,7 +36,8 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { | |||
| public: | |||
| SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, | |||
| std::make_shared<SoftmaxCrossEntropyWithLogitsCost>(false)) {} | |||
| ~SoftmaxCrossEntropyWithLogitsInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| Status InitForCostModel(const StrategyPtr& strategy) override; | |||
| @@ -593,11 +593,11 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& | |||
| // Here, we use the origin outputs_, because we only use the slice size of the output tensor. | |||
| // It does not matter whether the output tensor is transposed or not. | |||
| double computation_cost = | |||
| cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| double communication_cost = cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||
| result->communication_without_parameter_ = | |||
| cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||
| result->communication_with_partial_para_ = | |||
| result->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||
| @@ -34,7 +34,7 @@ class MatMulBase : public OperatorInfo { | |||
| public: | |||
| MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatMulCost>(true)) {} | |||
| ~MatMulBase() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| @@ -33,7 +33,7 @@ class OneHotInfo : public OperatorInfo { | |||
| public: | |||
| OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>(false)) {} | |||
| ~OneHotInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| Status InitForCostModel(const StrategyPtr& strategy) override; | |||
| @@ -1035,11 +1035,12 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { | |||
| return FAILED; | |||
| } | |||
| int32_t stage_id = strategy->GetInputStage(); | |||
| double computation_cost = cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| double communication_cost = cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| double computation_cost = | |||
| operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||
| result->communication_without_parameter_ = | |||
| cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||
| result->communication_with_partial_para_ = | |||
| result->communication_without_parameter_ + | |||
| COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); | |||
| @@ -1096,7 +1097,38 @@ Status OperatorInfo::set_is_parameter(const std::vector<bool>& is_parameter) { | |||
| return FAILED; | |||
| } | |||
| is_parameter_ = is_parameter; | |||
| cost()->set_is_parameter(is_parameter); | |||
| operator_cost()->set_is_parameter(is_parameter); | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::CalculateMemoryCost() { | |||
| // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to | |||
| // calculate memory cost. | |||
| if (is_parameter_involve_.size() != is_parameter_.size()) { | |||
| MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; | |||
| return FAILED; | |||
| } | |||
| operator_cost()->set_is_parameter_involve(is_parameter_involve_); | |||
| operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); | |||
| // Set the memory cost in the 'strategy_cost_' | |||
| for (auto& swc : strategy_cost_) { | |||
| auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); | |||
| swc->cost_list[0]->memory_with_reuse_ = mem_cost; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::CorrectMemoryCost(size_t input_index) { | |||
| for (auto& swc : strategy_cost_) { | |||
| double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * | |||
| static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]); | |||
| swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; | |||
| if (swc->cost_list[0]->memory_with_reuse_ < 0) { | |||
| MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_ | |||
| << ", the parameter memory cost is: " << parameter_mem_cost; | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -1193,7 +1225,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t>& inpu | |||
| } | |||
| inputs_type_lengths_ = input_lengths; | |||
| outputs_type_lengths_ = output_lengths; | |||
| cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); | |||
| operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); | |||
| return SUCCESS; | |||
| } | |||
| @@ -1221,7 +1253,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra | |||
| } | |||
| double OperatorInfo::GetForwardMemoryCostFromCNode() { | |||
| return cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||
| return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||
| } | |||
| } // namespace parallel | |||
| @@ -60,7 +60,7 @@ class OperatorInfo { | |||
| outputs_shape_(std::move(outputs_shape)), | |||
| attrs_(std::move(attrs)), | |||
| is_alive_(true), | |||
| cost_(cost), | |||
| operator_cost_(cost), | |||
| outputs_type_() { | |||
| std::vector<bool> not_parameteter(inputs_shape_.size(), false); | |||
| is_parameter_ = not_parameteter; | |||
| @@ -83,8 +83,8 @@ class OperatorInfo { | |||
| // Given the stage_id (which indicates the number of devices), | |||
| // generate all strategies for this operator | |||
| virtual Status GenerateStrategies(int32_t stage_id) = 0; | |||
| const OperatorCostPtr& cost() const { return cost_; } | |||
| void set_cost(const OperatorCostPtr& cost) { cost_ = cost; } | |||
| const OperatorCostPtr& operator_cost() const { return operator_cost_; } | |||
| void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; } | |||
| virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; | |||
| virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies(); | |||
| @@ -98,7 +98,7 @@ class OperatorInfo { | |||
| std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } | |||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | |||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | |||
| Status CalculateMemoryCost() const { return SUCCESS; } | |||
| Status CalculateMemoryCost(); | |||
| int ComputeOpAndPrevEdgeParameterInvolved(); | |||
| ForwardOp forward_op() const { return forward_op_; } | |||
| @@ -125,7 +125,7 @@ class OperatorInfo { | |||
| void ReplaceSuccEdge(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | |||
| void ReplacePreEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | |||
| void ReplaceSuccEdges(const std::shared_ptr<OperatorInfo>& op, const std::shared_ptr<Edge>& new_edge); | |||
| std::vector<size_t> GetOutputTypeLengths() const { return cost()->outputs_type_lengths(); } | |||
| std::vector<size_t> GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } | |||
| void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { | |||
| selected_strategy_ = s_strategy; | |||
| selected_cost_ = cost; | |||
| @@ -142,6 +142,10 @@ class OperatorInfo { | |||
| void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; } | |||
| void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } | |||
| const std::string& refkey_parameter_name() const { return refkey_parameter_name_; } | |||
| // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated | |||
| // multiple times. This method is to correct this, and makes the cost is calulated only once. | |||
| Status CorrectMemoryCost(size_t input_index); | |||
| int is_output_parameter_involve() const { return is_output_parameter_involve_; } | |||
| int used_devices() const { return used_devices_; } | |||
| // needed by rec_parser | |||
| void set_type(const std::string& type) { type_ = type; } | |||
| @@ -234,7 +238,7 @@ class OperatorInfo { | |||
| int32_t used_devices_ = -1; | |||
| private: | |||
| OperatorCostPtr cost_; | |||
| OperatorCostPtr operator_cost_; | |||
| std::vector<TypePtr> outputs_type_; | |||
| }; | |||
| @@ -35,7 +35,7 @@ class PReLUInfo : public OperatorInfo { | |||
| public: | |||
| PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>(true)) {} | |||
| ~PReLUInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| Status InitForCostModel(const StrategyPtr& strategy) override; | |||
| @@ -109,7 +109,7 @@ Status ReduceMethod::GetAttrs() { | |||
| } | |||
| cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value(); | |||
| } | |||
| auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(cost()); | |||
| auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost()); | |||
| if (reducemethodcost == nullptr) { | |||
| MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; | |||
| return FAILED; | |||
| @@ -34,7 +34,7 @@ class ReduceMethod : public OperatorInfo { | |||
| public: | |||
| ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReduceMethodCost>(true)) {} | |||
| ~ReduceMethod() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| @@ -36,7 +36,7 @@ class ReshapeInfo : public OperatorInfo { | |||
| public: | |||
| ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>()), | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)), | |||
| dev_num_(0), | |||
| input_layout_set_flag_(false), | |||
| output_layout_set_flag_(false) {} | |||
| @@ -34,7 +34,7 @@ class TmpIdentityInfo : public OperatorInfo { | |||
| public: | |||
| TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, | |||
| const std::string& name = IDENTITY_INFO) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>(false)) {} | |||
| ~TmpIdentityInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| @@ -35,7 +35,7 @@ class TransposeInfo : public OperatorInfo { | |||
| public: | |||
| TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>(false)) {} | |||
| ~TransposeInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| Status InitForCostModel(const StrategyPtr& strategy) override; | |||
| @@ -32,7 +32,7 @@ class VirtualDatasetInfo : public OperatorInfo { | |||
| public: | |||
| VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||
| const PrimitiveAttrs& attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>()) {} | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<VirtualDatasetCost>(false)) {} | |||
| ~VirtualDatasetInfo() override = default; | |||
| Status Init(const StrategyPtr& strategy) override; | |||
| Status InitForCostModel(const StrategyPtr& strategy) override; | |||
| @@ -874,11 +874,15 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | |||
| // Calculate operators' memory usage | |||
| if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed."; | |||
| MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed."; | |||
| } | |||
| // Calculate edges' memory usage | |||
| if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed."; | |||
| MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed."; | |||
| } | |||
| // Correct memory usage caused by TmpIdentity | |||
| if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed."; | |||
| @@ -159,6 +159,7 @@ Status TensorRedistribution::ComputeCost() { | |||
| backward_comm_cost_ += prod; | |||
| comm_cost_ += 2.0 * prod; | |||
| computation_cost_ += prod; | |||
| memory_cost_ += prod; | |||
| } else if (str == CONCAT_BY_AXIS) { | |||
| // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | |||
| // computation cost = before_slice_shape | |||
| @@ -175,20 +176,25 @@ Status TensorRedistribution::ComputeCost() { | |||
| if (concat_dim == 0) { | |||
| // computation cost = all_gather | |||
| computation_cost_ += prod; | |||
| memory_cost_ += prod * dev_num; | |||
| } else { | |||
| // computation cost = all_gather + split + concat | |||
| computation_cost_ += (prod + prod * dev_num + prod * dev_num); | |||
| memory_cost_ += (prod * dev_num + prod * dev_num + prod); | |||
| } | |||
| } else { | |||
| // There is only computation cost in SplitByAxis. | |||
| // computation cost = before_slice_shape | |||
| computation_cost_ += prod; | |||
| // This addtion may be erroneous | |||
| memory_cost_ += prod; | |||
| } | |||
| } | |||
| if (reshape_flag()) { | |||
| Shape prev_slice_shape = from_.slice_shape().array(); | |||
| double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>()); | |||
| computation_cost_ += 2.0 * prev_prod; | |||
| memory_cost_ += 2.0 * prev_prod; | |||
| } | |||
| return Status::SUCCESS; | |||
| } | |||
| @@ -42,6 +42,7 @@ class TensorRedistribution { | |||
| forward_comm_cost_(0.0), | |||
| backward_comm_cost_(0.0), | |||
| computation_cost_(0.0), | |||
| memory_cost_(0.0), | |||
| construct_op_flag_(construct_op_flag), | |||
| keep_reshape_(keep_reshape) {} | |||
| Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); | |||
| @@ -54,6 +55,7 @@ class TensorRedistribution { | |||
| double computation_cost() const { return computation_cost_; } | |||
| double forward_comm_cost() const { return forward_comm_cost_; } | |||
| double backward_comm_cost() const { return backward_comm_cost_; } | |||
| double memory_cost() const { return memory_cost_; } | |||
| private: | |||
| Status InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, | |||
| @@ -72,7 +74,12 @@ class TensorRedistribution { | |||
| double forward_comm_cost_; | |||
| // backward communication cost | |||
| double backward_comm_cost_; | |||
| // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the | |||
| // inputs. | |||
| double computation_cost_; | |||
| // memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is | |||
| // calculated by the outputs. | |||
| double memory_cost_; | |||
| bool construct_op_flag_; | |||
| bool keep_reshape_; | |||
| }; | |||
| @@ -84,9 +84,9 @@ TEST_F(TestActivation, test_activation_strategies) { | |||
| act_ptr_->InitForCostModel(sp); | |||
| std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info(); | |||
| ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.computation_cost_); | |||
| ASSERT_DOUBLE_EQ(act_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.communication_cost_); | |||
| } | |||
| } | |||
| @@ -109,9 +109,9 @@ TEST_F(TestActivation, test_softmax_strategies) { | |||
| soft_ptr_->InitForCostModel(sp); | |||
| std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info(); | |||
| ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.computation_cost_); | |||
| ASSERT_DOUBLE_EQ(soft_ptr_->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.communication_cost_); | |||
| } | |||
| } | |||
| @@ -569,7 +569,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) { | |||
| matmul1->InitForCostModel(sp); | |||
| std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info(); | |||
| ASSERT_DOUBLE_EQ(matmul1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(matmul1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.computation_cost_); | |||
| break; | |||
| } | |||
| @@ -599,7 +599,7 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) { | |||
| TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); | |||
| replica_inputs_info.push_back(replica_input1_info); | |||
| ASSERT_DOUBLE_EQ(matmul3->cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(matmul3->operator_cost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.computation_cost_); | |||
| break; | |||
| } | |||
| @@ -188,11 +188,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) { | |||
| tensor_add->InitForCostModel(sp); | |||
| std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info(); | |||
| double memory_cost0 = tensor_add->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double memory_cost0 = tensor_add->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double memory_cost1 = cost.computation_cost_; | |||
| bool memory = memory_cost0 - memory_cost1 <= 1.0; | |||
| double comm_cost0 = tensor_add->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double comm_cost0 = tensor_add->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double comm_cost1 = cost.communication_cost_; | |||
| bool comm = comm_cost0 - comm_cost1 <= 1.0; | |||
| @@ -210,11 +210,11 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { | |||
| tensor_add1->InitForCostModel(sp); | |||
| std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info(); | |||
| double memory_cost0 = tensor_add1->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double memory_cost0 = tensor_add1->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double memory_cost1 = cost.computation_cost_; | |||
| bool memory = memory_cost0 - memory_cost1 <= 1.0; | |||
| double comm_cost0 = tensor_add1->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double comm_cost0 = tensor_add1->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | |||
| double comm_cost1 = cost.communication_cost_; | |||
| bool comm = comm_cost0 - comm_cost1 <= 1.0; | |||
| @@ -145,9 +145,9 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) { | |||
| identity_ptr->Init(sp); | |||
| std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info(); | |||
| ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.computation_cost_); | |||
| ASSERT_DOUBLE_EQ(identity_ptr->cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| ASSERT_DOUBLE_EQ(identity_ptr->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | |||
| cost.communication_cost_); | |||
| } | |||
| } | |||