Merge pull request !156 from Xiaoda/implementing-memory-calculation-in-auto-paralleltags/v0.2.0-alpha
| @@ -23,8 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| void Simplify(CostPtrList* clist_ptrs) { | void Simplify(CostPtrList* clist_ptrs) { | ||||
| // Sort the cost_list with the memory_cost increasing, and communication_cost decreasing order. This method | |||||
| // excludes the cost with greater memory_cost and greater communication_cost. | |||||
| // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method | |||||
| // excludes the cost with greater computation_cost_ and greater communication_cost. | |||||
| // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} | ||||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | if (!COST_MODEL_SIMPLIFY_CALCULATION) { | ||||
| return; | return; | ||||
| @@ -33,7 +33,7 @@ void Simplify(CostPtrList* clist_ptrs) { | |||||
| std::vector<size_t> id(clist_ptrs->size()); | std::vector<size_t> id(clist_ptrs->size()); | ||||
| std::iota(id.begin(), id.end(), size_t(0)); | std::iota(id.begin(), id.end(), size_t(0)); | ||||
| std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { | std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { | ||||
| return clist_ptrs->at(x)->memory_cost_ < clist_ptrs->at(y)->memory_cost_; | |||||
| return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; | |||||
| }); | }); | ||||
| CostPtrList ret; | CostPtrList ret; | ||||
| for (size_t i = 0; i < clist_ptrs->size(); ++i) { | for (size_t i = 0; i < clist_ptrs->size(); ++i) { | ||||
| @@ -45,8 +45,8 @@ void Simplify(CostPtrList* clist_ptrs) { | |||||
| } | } | ||||
| void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { | void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { | ||||
| // Sort the cost_list with the memory_cost increasing, and communication_with_partial_para_cost decreasing order. | |||||
| // This method excludes the cost with greater memory_cost and greater communication_without_para_cost. | |||||
| // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing | |||||
| // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. | |||||
| if (!COST_MODEL_SIMPLIFY_CALCULATION) { | if (!COST_MODEL_SIMPLIFY_CALCULATION) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { | |||||
| std::vector<size_t> id(clist_ptrs->size()); | std::vector<size_t> id(clist_ptrs->size()); | ||||
| std::iota(id.begin(), id.end(), size_t(0)); | std::iota(id.begin(), id.end(), size_t(0)); | ||||
| std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { | std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { | ||||
| return clist_ptrs->at(x)->memory_cost_ < clist_ptrs->at(y)->memory_cost_; | |||||
| return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; | |||||
| }); | }); | ||||
| CostPtrList ret; | CostPtrList ret; | ||||
| for (size_t i = 0; i < clist_ptrs->size(); ++i) { | for (size_t i = 0; i < clist_ptrs->size(); ++i) { | ||||
| @@ -44,14 +44,18 @@ using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPut | |||||
| struct Cost { | struct Cost { | ||||
| Cost(); | Cost(); | ||||
| Cost(double memory, double commuication, const std::shared_ptr<Decision>& decision_ = nullptr) | |||||
| : memory_cost_(memory), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { | |||||
| Cost(double computation, double commuication, const std::shared_ptr<Decision>& decision_ = nullptr) | |||||
| : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { | |||||
| memory_with_reuse_ = 0.0; | |||||
| communication_without_parameter_ = 0.0; | communication_without_parameter_ = 0.0; | ||||
| communication_with_partial_para_ = 0.0; | communication_with_partial_para_ = 0.0; | ||||
| communication_redis_forward_ = 0.0; | communication_redis_forward_ = 0.0; | ||||
| communication_redis_backward_ = 0.0; | communication_redis_backward_ = 0.0; | ||||
| } | } | ||||
| double memory_cost_; | |||||
| // 'memory_with_reuse_' calculates the peak memory usage in a training phase | |||||
| double memory_with_reuse_; | |||||
| // 'computation_cost_' models the training time of an iteration in a training phase | |||||
| double computation_cost_; | |||||
| // 'communication_cost_' includes communications from operators (forward and backward) and edges | // 'communication_cost_' includes communications from operators (forward and backward) and edges | ||||
| double communication_cost_; | double communication_cost_; | ||||
| // communication_without_parameter_ = communication_cost_ - (backward communication from operators) | // communication_without_parameter_ = communication_cost_ - (backward communication from operators) | ||||
| @@ -35,7 +35,7 @@ namespace parallel { | |||||
| // interpretation of 6 operations in costmodel.h. | // interpretation of 6 operations in costmodel.h. | ||||
| // Phase 2: Search the cost_list in the final graph, and determine the optimal one | // Phase 2: Search the cost_list in the final graph, and determine the optimal one | ||||
| // Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity | // Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity | ||||
| // COST_MODEL_ALPHA * memory_cost + COST_MODEL_BETA * communication_cost | |||||
| // COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost | |||||
| // Phase 3: Recover the original CostGraph, the determine strategy for each operator | // Phase 3: Recover the original CostGraph, the determine strategy for each operator | ||||
| // After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying | // After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying | ||||
| // the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, | // the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, | ||||
| @@ -69,7 +69,7 @@ Status Edge::InitEdgeCost() { | |||||
| MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; | MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(cost); | MS_EXCEPTION_IF_NULL(cost); | ||||
| MS_LOG(DEBUG) << "The redistribution cost: memory_cost: " << cost->memory_cost_ | |||||
| MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_ | |||||
| << ", communication_cost: " << cost->communication_cost_ | << ", communication_cost: " << cost->communication_cost_ | ||||
| << ", communication_without_parameter_: " << cost->communication_without_parameter_ | << ", communication_without_parameter_: " << cost->communication_without_parameter_ | ||||
| << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; | << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; | ||||
| @@ -117,9 +117,9 @@ Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, co | |||||
| double comm_cost = tensor_redistribution.comm_cost(); | double comm_cost = tensor_redistribution.comm_cost(); | ||||
| double forward_comm_cost = tensor_redistribution.forward_comm_cost(); | double forward_comm_cost = tensor_redistribution.forward_comm_cost(); | ||||
| double backward_comm_cost = tensor_redistribution.backward_comm_cost(); | double backward_comm_cost = tensor_redistribution.backward_comm_cost(); | ||||
| double mem_cost = tensor_redistribution.mem_cost(); | |||||
| double computation_cost = tensor_redistribution.computation_cost(); | |||||
| *cost = std::make_shared<Cost>(type_length * mem_cost, type_length * comm_cost); | |||||
| *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost); | |||||
| (*cost)->communication_without_parameter_ = type_length * comm_cost; | (*cost)->communication_without_parameter_ = type_length * comm_cost; | ||||
| (*cost)->communication_with_partial_para_ = | (*cost)->communication_with_partial_para_ = | ||||
| (*cost)->communication_without_parameter_ + | (*cost)->communication_without_parameter_ + | ||||
| @@ -150,26 +150,26 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr | |||||
| (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); | (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); | ||||
| CostPtrList selected_cost_list(all_cost_list.size(), nullptr); | CostPtrList selected_cost_list(all_cost_list.size(), nullptr); | ||||
| std::function<void(size_t, double, double, double)> recursive = [&](size_t k, double memory, double communication, | |||||
| double communication_without_para) { | |||||
| if (k == edges.size()) { | |||||
| auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | |||||
| CostPtr new_cost = std::make_shared<Cost>(memory, communication); | |||||
| MS_EXCEPTION_IF_NULL(new_cost); | |||||
| new_cost->communication_without_parameter_ = communication_without_para; | |||||
| new_cost->communication_with_partial_para_ = | |||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| new_cost->decision_ptr_ = decision; | |||||
| result.push_back(new_cost); | |||||
| return; | |||||
| } | |||||
| for (auto& c : all_cost_list[k]) { | |||||
| MS_EXCEPTION_IF_NULL(c); | |||||
| selected_cost_list[k] = c; | |||||
| recursive(k + 1, memory + c->memory_cost_, communication + c->communication_cost_, | |||||
| communication_without_para + c->communication_without_parameter_); | |||||
| } | |||||
| }; | |||||
| std::function<void(size_t, double, double, double)> recursive = | |||||
| [&](size_t k, double computation, double communication, double communication_without_para) { | |||||
| if (k == edges.size()) { | |||||
| auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); | |||||
| CostPtr new_cost = std::make_shared<Cost>(computation, communication); | |||||
| MS_EXCEPTION_IF_NULL(new_cost); | |||||
| new_cost->communication_without_parameter_ = communication_without_para; | |||||
| new_cost->communication_with_partial_para_ = | |||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | |||||
| new_cost->decision_ptr_ = decision; | |||||
| result.push_back(new_cost); | |||||
| return; | |||||
| } | |||||
| for (auto& c : all_cost_list[k]) { | |||||
| MS_EXCEPTION_IF_NULL(c); | |||||
| selected_cost_list[k] = c; | |||||
| recursive(k + 1, computation + c->computation_cost_, communication + c->communication_cost_, | |||||
| communication_without_para + c->communication_without_parameter_); | |||||
| } | |||||
| }; | |||||
| recursive(0, 0, 0, 0); | recursive(0, 0, 0, 0); | ||||
| SimplifyForDreasingCommunicationWithPartialPara(&result); | SimplifyForDreasingCommunicationWithPartialPara(&result); | ||||
| return result; | return result; | ||||
| @@ -203,7 +203,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||||
| MS_EXCEPTION_IF_NULL(middle_cost); | MS_EXCEPTION_IF_NULL(middle_cost); | ||||
| for (auto& right_cost : right_cost_list) { | for (auto& right_cost : right_cost_list) { | ||||
| MS_EXCEPTION_IF_NULL(right_cost); | MS_EXCEPTION_IF_NULL(right_cost); | ||||
| double memory = left_cost->memory_cost_ + middle_cost->memory_cost_ + right_cost->memory_cost_; | |||||
| double computation = | |||||
| left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; | |||||
| double communication = | double communication = | ||||
| left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; | left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; | ||||
| double communication_without_para = left_cost->communication_without_parameter_ + | double communication_without_para = left_cost->communication_without_parameter_ + | ||||
| @@ -211,7 +212,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr | |||||
| right_cost->communication_without_parameter_; | right_cost->communication_without_parameter_; | ||||
| auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); | auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost); | ||||
| auto cost = std::make_shared<Cost>(memory, communication, decision); | |||||
| auto cost = std::make_shared<Cost>(computation, communication, decision); | |||||
| MS_EXCEPTION_IF_NULL(cost); | MS_EXCEPTION_IF_NULL(cost); | ||||
| cost->communication_without_parameter_ = communication_without_para; | cost->communication_without_parameter_ = communication_without_para; | ||||
| cost->communication_with_partial_para_ = | cost->communication_with_partial_para_ = | ||||
| @@ -133,7 +133,7 @@ class Edge { | |||||
| void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } | void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } | ||||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | ||||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | ||||
| Status CorrectStrategyCostForMemoryReuse() const { return SUCCESS; } | |||||
| Status CalculateMemoryCost() const { return SUCCESS; } | |||||
| private: | private: | ||||
| std::string edge_name_; | std::string edge_name_; | ||||
| @@ -247,7 +247,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: | |||||
| MS_EXCEPTION_IF_NULL(cost1); | MS_EXCEPTION_IF_NULL(cost1); | ||||
| MS_EXCEPTION_IF_NULL(cost2); | MS_EXCEPTION_IF_NULL(cost2); | ||||
| MS_EXCEPTION_IF_NULL(cost3); | MS_EXCEPTION_IF_NULL(cost3); | ||||
| double memory = cost1->memory_cost_ + cost2->memory_cost_ + cost3->memory_cost_; | |||||
| double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; | |||||
| double commmunication = | double commmunication = | ||||
| cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; | cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; | ||||
| double communication_without_para = cost1->communication_without_parameter_ + | double communication_without_para = cost1->communication_without_parameter_ + | ||||
| @@ -255,7 +255,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: | |||||
| cost3->communication_without_parameter_; | cost3->communication_without_parameter_; | ||||
| auto decision = | auto decision = | ||||
| std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); | ||||
| auto cost = std::make_shared<Cost>(memory, commmunication, decision); | |||||
| auto cost = std::make_shared<Cost>(computation, commmunication, decision); | |||||
| MS_EXCEPTION_IF_NULL(cost); | MS_EXCEPTION_IF_NULL(cost); | ||||
| cost->communication_without_parameter_ = communication_without_para; | cost->communication_without_parameter_ = communication_without_para; | ||||
| cost->communication_with_partial_para_ = | cost->communication_with_partial_para_ = | ||||
| @@ -282,7 +282,7 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { | |||||
| for (const auto& cost1 : clist1) { | for (const auto& cost1 : clist1) { | ||||
| MS_EXCEPTION_IF_NULL(cost1); | MS_EXCEPTION_IF_NULL(cost1); | ||||
| auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1); | auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1); | ||||
| auto new_cost = std::make_shared<Cost>(cost1->memory_cost_, cost1->communication_cost_, decision); | |||||
| auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision); | |||||
| MS_EXCEPTION_IF_NULL(new_cost); | MS_EXCEPTION_IF_NULL(new_cost); | ||||
| new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; | new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| @@ -297,12 +297,12 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { | |||||
| } | } | ||||
| CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { | CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { | ||||
| if (cost_list.empty() || cost_list[0]->memory_cost_ >= memory) { | |||||
| if (cost_list.empty() || cost_list[0]->computation_cost_ >= memory) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { | std::function<CostPtr(CostPtr, const CostPtr&)> LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { | ||||
| MS_EXCEPTION_IF_NULL(cost_x); | MS_EXCEPTION_IF_NULL(cost_x); | ||||
| if (init == nullptr || cost_x->memory_cost_ < memory) { | |||||
| if (init == nullptr || cost_x->computation_cost_ < memory) { | |||||
| init = cost_x; | init = cost_x; | ||||
| } | } | ||||
| return init; | return init; | ||||
| @@ -313,36 +313,36 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, | |||||
| CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { | CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { | ||||
| // Select the cost with minimum training time. Currently, the training time is modeled as = | // Select the cost with minimum training time. Currently, the training time is modeled as = | ||||
| // costmodel_alpha_ * memory_cost + costmodel_beta_ * communication_with_partial_para_ | |||||
| // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ | |||||
| if (cost_list.empty()) { | if (cost_list.empty()) { | ||||
| MS_LOG(ERROR) << "Final cost list is null."; | MS_LOG(ERROR) << "Final cost list is null."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| CostPtr ret = cost_list[0]; | CostPtr ret = cost_list[0]; | ||||
| MS_EXCEPTION_IF_NULL(ret); | MS_EXCEPTION_IF_NULL(ret); | ||||
| if (ret->memory_cost_ >= memory) { | |||||
| MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->memory_cost_ | |||||
| if (ret->computation_cost_ >= memory) { | |||||
| MS_LOG(ERROR) << "No available cost; the minimum cost is " << ret->computation_cost_ | |||||
| << ", the memory capacity is: " << memory << "."; | << ", the memory capacity is: " << memory << "."; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| double minimum = costmodel_alpha_ * ret->memory_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; | |||||
| MS_LOG(INFO) << "minimum: " << minimum << ", memory_cost_: " << ret->memory_cost_ | |||||
| double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; | |||||
| MS_LOG(INFO) << "minimum: " << minimum << ", computation_cost_: " << ret->computation_cost_ | |||||
| << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ | ||||
| << ", communication_cost_: " << ret->communication_cost_ | << ", communication_cost_: " << ret->communication_cost_ | ||||
| << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; | ||||
| for (size_t i = 1; i < cost_list.size(); ++i) { | for (size_t i = 1; i < cost_list.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(cost_list[i]); | MS_EXCEPTION_IF_NULL(cost_list[i]); | ||||
| if (cost_list[i]->memory_cost_ >= memory) { | |||||
| MS_LOG(INFO) << "cost_list " << i << " memory_cost_: " << cost_list[i]->memory_cost_ | |||||
| if (cost_list[i]->computation_cost_ >= memory) { | |||||
| MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ | |||||
| << ", is larger than the memory capacity: " << memory << "."; | << ", is larger than the memory capacity: " << memory << "."; | ||||
| break; | break; | ||||
| } | } | ||||
| MS_LOG(INFO) << "cost_list " << i << " memory_cost_: " << cost_list[i]->memory_cost_ | |||||
| MS_LOG(INFO) << "cost_list " << i << " computation_cost_: " << cost_list[i]->computation_cost_ | |||||
| << ", communication_with_partial_para_: " << cost_list[i]->communication_with_partial_para_ | << ", communication_with_partial_para_: " << cost_list[i]->communication_with_partial_para_ | ||||
| << ", communication_cost_: " << cost_list[i]->communication_cost_ | << ", communication_cost_: " << cost_list[i]->communication_cost_ | ||||
| << ", communication_without_parameter_: " << cost_list[i]->communication_without_parameter_ << "."; | << ", communication_without_parameter_: " << cost_list[i]->communication_without_parameter_ << "."; | ||||
| auto tmp = | |||||
| costmodel_alpha_ * cost_list[i]->memory_cost_ + costmodel_beta_ * cost_list[i]->communication_with_partial_para_; | |||||
| auto tmp = costmodel_alpha_ * cost_list[i]->computation_cost_ + | |||||
| costmodel_beta_ * cost_list[i]->communication_with_partial_para_; | |||||
| MS_LOG(INFO) << "tmp: " << tmp; | MS_LOG(INFO) << "tmp: " << tmp; | ||||
| if (minimum > tmp) { | if (minimum > tmp) { | ||||
| minimum = tmp; | minimum = tmp; | ||||
| @@ -363,8 +363,8 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect | |||||
| MS_LOG(ERROR) << "The cost list " << i << " is empty."; | MS_LOG(ERROR) << "The cost list " << i << " is empty."; | ||||
| return ret; | return ret; | ||||
| } else { | } else { | ||||
| total_memory += all_cost_list[i][0]->memory_cost_; | |||||
| minimum += costmodel_alpha_ * all_cost_list[i][0]->memory_cost_ + | |||||
| total_memory += all_cost_list[i][0]->computation_cost_; | |||||
| minimum += costmodel_alpha_ * all_cost_list[i][0]->computation_cost_ + | |||||
| costmodel_beta_ * all_cost_list[i][0]->communication_with_partial_para_; | costmodel_beta_ * all_cost_list[i][0]->communication_with_partial_para_; | ||||
| ret[i] = all_cost_list[i][0]; | ret[i] = all_cost_list[i][0]; | ||||
| } | } | ||||
| @@ -381,8 +381,8 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect | |||||
| double tmp_memory = 0.0, tmp_minimum = 0.0; | double tmp_memory = 0.0, tmp_minimum = 0.0; | ||||
| for (size_t i = 0; i < selected_cost_list.size(); ++i) { | for (size_t i = 0; i < selected_cost_list.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(selected_cost_list[i]); | MS_EXCEPTION_IF_NULL(selected_cost_list[i]); | ||||
| tmp_memory += selected_cost_list[i]->memory_cost_; | |||||
| tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->memory_cost_ + | |||||
| tmp_memory += selected_cost_list[i]->computation_cost_; | |||||
| tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + | |||||
| costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; | costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; | ||||
| } | } | ||||
| MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum | MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum | ||||
| @@ -394,6 +394,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; | MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; | ||||
| for (auto& c : all_cost_list[k]) { | for (auto& c : all_cost_list[k]) { | ||||
| selected_cost_list[k] = c; | selected_cost_list[k] = c; | ||||
| @@ -814,7 +815,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||||
| for (size_t k = 0; k < tar_cost_list.size(); ++k) { | for (size_t k = 0; k < tar_cost_list.size(); ++k) { | ||||
| auto& tar_cost = tar_cost_list[k]; | auto& tar_cost = tar_cost_list[k]; | ||||
| MS_EXCEPTION_IF_NULL(tar_cost); | MS_EXCEPTION_IF_NULL(tar_cost); | ||||
| double memory = op_cost->memory_cost_ + edge_cost->memory_cost_ + tar_cost->memory_cost_; | |||||
| double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; | |||||
| double communication = | double communication = | ||||
| op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | ||||
| double communication_without_para = op_cost->communication_without_parameter_ + | double communication_without_para = op_cost->communication_without_parameter_ + | ||||
| @@ -823,7 +824,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const | |||||
| auto decision = | auto decision = | ||||
| std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); | std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); | ||||
| auto new_cost = std::make_shared<Cost>(memory, communication, decision); | |||||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | |||||
| MS_EXCEPTION_IF_NULL(new_cost); | MS_EXCEPTION_IF_NULL(new_cost); | ||||
| new_cost->communication_without_parameter_ = communication_without_para; | new_cost->communication_without_parameter_ = communication_without_para; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| @@ -891,7 +892,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||||
| for (size_t k = 0; k < tar_cost_list.size(); ++k) { | for (size_t k = 0; k < tar_cost_list.size(); ++k) { | ||||
| auto& tar_cost = tar_cost_list[k]; | auto& tar_cost = tar_cost_list[k]; | ||||
| MS_EXCEPTION_IF_NULL(tar_cost); | MS_EXCEPTION_IF_NULL(tar_cost); | ||||
| double memory = contract_op_cost->memory_cost_ + edge_cost->memory_cost_ + tar_cost->memory_cost_; | |||||
| double computation = | |||||
| contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; | |||||
| double communication = | double communication = | ||||
| contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; | ||||
| double communication_without_para = contract_op_cost->communication_without_parameter_ + | double communication_without_para = contract_op_cost->communication_without_parameter_ + | ||||
| @@ -900,7 +902,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str | |||||
| auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost, | auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost, | ||||
| target_op_stra, tar_cost); | target_op_stra, tar_cost); | ||||
| auto new_cost = std::make_shared<Cost>(memory, communication, decision); | |||||
| auto new_cost = std::make_shared<Cost>(computation, communication, decision); | |||||
| new_cost->communication_without_parameter_ = communication_without_para; | new_cost->communication_without_parameter_ = communication_without_para; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); | ||||
| @@ -963,9 +965,9 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||||
| MS_EXCEPTION_IF_NULL(left_edge_cost); | MS_EXCEPTION_IF_NULL(left_edge_cost); | ||||
| for (auto& left_node_cost : left_node_clist_origin) { | for (auto& left_node_cost : left_node_clist_origin) { | ||||
| MS_EXCEPTION_IF_NULL(left_node_cost); | MS_EXCEPTION_IF_NULL(left_node_cost); | ||||
| double new_memory_cost = elimi_op_cost->memory_cost_ + left_edge_cost->memory_cost_ + | |||||
| left_node_cost->memory_cost_ + right_edge_cost->memory_cost_ + | |||||
| right_op_cost->memory_cost_; | |||||
| double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + | |||||
| left_node_cost->computation_cost_ + right_edge_cost->computation_cost_ + | |||||
| right_op_cost->computation_cost_; | |||||
| double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + | double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + | ||||
| left_node_cost->communication_cost_ + right_edge_cost->communication_cost_ + | left_node_cost->communication_cost_ + right_edge_cost->communication_cost_ + | ||||
| right_op_cost->communication_cost_; | right_op_cost->communication_cost_; | ||||
| @@ -977,7 +979,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||||
| auto decision = | auto decision = | ||||
| std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, | std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, | ||||
| left_op_stra, left_node_cost, right_op_stra, right_op_cost); | left_op_stra, left_node_cost, right_op_stra, right_op_cost); | ||||
| auto new_cost = std::make_shared<Cost>(new_memory_cost, new_commu_cost, decision); | |||||
| auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision); | |||||
| new_cost->communication_without_parameter_ = new_commu_without; | new_cost->communication_without_parameter_ = new_commu_without; | ||||
| new_cost->communication_with_partial_para_ = | new_cost->communication_with_partial_para_ = | ||||
| new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); | ||||
| @@ -1082,11 +1084,12 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n | |||||
| succ_edges_costs[0] = first_succ_edge_cost; | succ_edges_costs[0] = first_succ_edge_cost; | ||||
| succ_nodes_costs[0] = first_succ_node_cost; | succ_nodes_costs[0] = first_succ_node_cost; | ||||
| double memory_cost = merged_node_cost->memory_cost_, commu_cost = merged_node_cost->communication_cost_, | |||||
| double computation_cost = merged_node_cost->computation_cost_, | |||||
| commu_cost = merged_node_cost->communication_cost_, | |||||
| commu_without = merged_node_cost->communication_without_parameter_; | commu_without = merged_node_cost->communication_without_parameter_; | ||||
| for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { | ||||
| MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); | ||||
| memory_cost += succ_edges_costs[i]->memory_cost_ + succ_nodes_costs[i]->memory_cost_; | |||||
| computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; | |||||
| commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; | commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; | ||||
| commu_without += succ_edges_costs[i]->communication_without_parameter_ + | commu_without += succ_edges_costs[i]->communication_without_parameter_ + | ||||
| succ_nodes_costs[i]->communication_without_parameter_; | succ_nodes_costs[i]->communication_without_parameter_; | ||||
| @@ -1094,7 +1097,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n | |||||
| auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs, | auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs, | ||||
| succ_nodes_stras, succ_nodes_costs); | succ_nodes_stras, succ_nodes_costs); | ||||
| auto new_cost = std::make_shared<Cost>(memory_cost, commu_cost, decision); | |||||
| auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision); | |||||
| new_cost->communication_without_parameter_ = commu_without; | new_cost->communication_without_parameter_ = commu_without; | ||||
| new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); | ||||
| first_succ_node_clist_new->emplace_back(std::move(new_cost)); | first_succ_node_clist_new->emplace_back(std::move(new_cost)); | ||||
| @@ -1210,36 +1213,6 @@ Status CostGraph::InitSelectedStrategy() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::CorrectOpsStrategyCostForMultiOutputUse() { | |||||
| for (auto& op : ops_) { | |||||
| MS_EXCEPTION_IF_NULL(op); | |||||
| if (op->GetAliveSuccEdges().size() > 1) { | |||||
| // Filter out the case of a output being used by multiple operators | |||||
| std::map<size_t, int> output_count; | |||||
| for (size_t i = 0; i < op->GetAliveSuccEdges().size(); ++i) { | |||||
| auto output_index = op->GetAliveSuccEdges()[i]->prev_op_output_index(); | |||||
| output_count[output_index]++; | |||||
| } | |||||
| for (size_t i = 0; i < op->GetAliveSuccEdges().size(); ++i) { | |||||
| auto output_index = op->GetAliveSuccEdges()[i]->prev_op_output_index(); | |||||
| if (output_count[output_index] <= 1) { | |||||
| continue; | |||||
| } | |||||
| auto next_op = op->GetAliveSuccEdges()[i]->next_operator(); | |||||
| MS_EXCEPTION_IF_NULL(next_op); | |||||
| auto input_index = op->GetAliveSuccEdges()[i]->next_op_input_index(); | |||||
| if (next_op->CorrectStrategyCostForMultiOutputUse(input_index) != SUCCESS) { | |||||
| MS_LOG(ERROR) << "The operator name: " << op->name() << ", the next operator name: " << next_op->name() | |||||
| << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; | |||||
| return FAILED; | |||||
| } | |||||
| output_count[output_index]--; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | ||||
| for (auto& op : ops_) { | for (auto& op : ops_) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| @@ -1252,23 +1225,23 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::CorrectOpsStrategyCostForMemoryReuse() { | |||||
| Status CostGraph::CalculateOpsMemoryCost() { | |||||
| for (auto& op : ops_) { | for (auto& op : ops_) { | ||||
| MS_EXCEPTION_IF_NULL(op); | MS_EXCEPTION_IF_NULL(op); | ||||
| if (op->CorrectStrategyCostForMemoryReuse() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Correcting Operator: " << op->name() << " cost for memory reuse failed."; | |||||
| if (op->CalculateMemoryCost() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CostGraph::CorrectEdgesStrategyCostForMemoryReuse() { | |||||
| Status CostGraph::CalculateEdgesMemoryCost() { | |||||
| for (auto& edge_pair : edges_) { | for (auto& edge_pair : edges_) { | ||||
| const auto& edges = edge_pair.second; | const auto& edges = edge_pair.second; | ||||
| for (auto& one_edge : edges) { | for (auto& one_edge : edges) { | ||||
| if (one_edge->CorrectStrategyCostForMemoryReuse() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Correcting Edge: " << one_edge->edge_name() << " cost for memory reuse failed."; | |||||
| if (one_edge->CalculateMemoryCost() != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| @@ -175,16 +175,12 @@ class CostGraph { | |||||
| void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, | void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, | ||||
| const CostPtrList&, std::vector<StrategyPtr>, CostPtrList&, CostPtrList&, | const CostPtrList&, std::vector<StrategyPtr>, CostPtrList&, CostPtrList&, | ||||
| CostPtrList*); | CostPtrList*); | ||||
| // When a output of a operator is being used by multiple operators, the memory cost of this part should be calculated | |||||
| // only once. This method is for correcting the 'strategy_cost_' for operators | |||||
| Status CorrectOpsStrategyCostForMultiOutputUse(); | |||||
| // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | ||||
| // the memory cost can be resused. | // the memory cost can be resused. | ||||
| Status CorrectOpsStrategyCostForMemoryReuse(); | |||||
| Status CalculateOpsMemoryCost(); | |||||
| // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then | ||||
| // the memory cost can be resused. | // the memory cost can be resused. | ||||
| Status CorrectEdgesStrategyCostForMemoryReuse(); | |||||
| Status CalculateEdgesMemoryCost(); | |||||
| Status ComputeOpsAndEdgesParameterInvolved(); | Status ComputeOpsAndEdgesParameterInvolved(); | ||||
| std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } | ||||
| @@ -74,8 +74,8 @@ double MatMulCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, co | |||||
| // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double MatMulCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t&) const { | |||||
| double MatMulCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, const int32_t&) const { | |||||
| // In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | // In forward phase, the memory cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) | ||||
| double result = 0.0; | double result = 0.0; | ||||
| TensorInfo output0 = outputs[0]; | TensorInfo output0 = outputs[0]; | ||||
| @@ -93,8 +93,8 @@ double MatMulCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, c | |||||
| // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per device memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double MatMulCost::GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| double MatMulCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) | // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) | ||||
| double result = 0.0; | double result = 0.0; | ||||
| if (is_parameter_[1]) { | if (is_parameter_[1]) { | ||||
| @@ -147,8 +147,8 @@ double ActivationCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double ActivationCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double ActivationCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| TensorInfo input0_info = inputs[0]; | TensorInfo input0_info = inputs[0]; | ||||
| Shape input0_slice_shape = input0_info.slice_shape(); | Shape input0_slice_shape = input0_info.slice_shape(); | ||||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | ||||
| @@ -156,8 +156,8 @@ double ActivationCost::GetForwardMemoryCost(const std::vector<TensorInfo>& input | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double ActivationCost::GetBackwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double ActivationCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| @@ -191,8 +191,8 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, c | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double SoftmaxCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double SoftmaxCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| // In the forward phase, the memory cost = slice(A) | // In the forward phase, the memory cost = slice(A) | ||||
| TensorInfo input0 = inputs[0]; | TensorInfo input0 = inputs[0]; | ||||
| Shape input0_slice_shape = input0.slice_shape(); | Shape input0_slice_shape = input0.slice_shape(); | ||||
| @@ -201,8 +201,9 @@ double SoftmaxCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double SoftmaxCost::GetBackwardMemoryCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, const int32_t&) const { | |||||
| double SoftmaxCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| @@ -222,9 +223,9 @@ double TmpIdentityCost::GetBackwardCommCost(const std::vector<mindspore::paralle | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double TmpIdentityCost::GetForwardMemoryCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double TmpIdentityCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| TensorInfo input0_info = inputs[0]; | TensorInfo input0_info = inputs[0]; | ||||
| Shape input0_slice_shape = input0_info.slice_shape(); | Shape input0_slice_shape = input0_info.slice_shape(); | ||||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | ||||
| @@ -232,15 +233,15 @@ double TmpIdentityCost::GetForwardMemoryCost(const std::vector<mindspore::parall | |||||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double TmpIdentityCost::GetBackwardMemoryCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double TmpIdentityCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double BatchParallelCost::GetForwardMemoryCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double BatchParallelCost::GetForwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double cost = 0.0; | double cost = 0.0; | ||||
| for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
| cost += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | cost += ListProduct(inputs[i].slice_shape()) * static_cast<double>(inputs_type_lengths_[i]); | ||||
| @@ -248,9 +249,9 @@ double BatchParallelCost::GetForwardMemoryCost(const std::vector<mindspore::para | |||||
| return cost; | return cost; | ||||
| } | } | ||||
| double BatchParallelCost::GetBackwardMemoryCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double BatchParallelCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| @@ -285,8 +286,8 @@ double PReLUCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, con | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double PReLUCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double PReLUCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| // In forward phase, the memory cost = slice(A) + slice(B) | // In forward phase, the memory cost = slice(A) + slice(B) | ||||
| Shape input0_slice_shape = inputs[0].slice_shape(); | Shape input0_slice_shape = inputs[0].slice_shape(); | ||||
| Shape input1_slice_shape = inputs[1].slice_shape(); | Shape input1_slice_shape = inputs[1].slice_shape(); | ||||
| @@ -297,9 +298,9 @@ double PReLUCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, co | |||||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double PReLUCost::GetBackwardMemoryCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| double PReLUCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>& inputs, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) | // In backward phase, the memory cost = (0 or 1) allreduce(slice(B)) | ||||
| double result = 0.0; | double result = 0.0; | ||||
| if (is_parameter_[1]) { | if (is_parameter_[1]) { | ||||
| @@ -338,8 +339,8 @@ double OneHotCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double OneHotCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double OneHotCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| // In onehot's forward phase, the memory cost = slice(A) | // In onehot's forward phase, the memory cost = slice(A) | ||||
| Shape input0_slice_shape = inputs[0].slice_shape(); | Shape input0_slice_shape = inputs[0].slice_shape(); | ||||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | ||||
| @@ -347,8 +348,8 @@ double OneHotCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, c | |||||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double OneHotCost::GetBackwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double OneHotCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| @@ -368,8 +369,9 @@ double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector< | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double SoftmaxCrossEntropyWithLogitsCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>&, const int32_t&) const { | |||||
| double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| // In forward phase, the memory cost = slice(A) + slice(B) | // In forward phase, the memory cost = slice(A) + slice(B) | ||||
| Shape input0_slice_shape = inputs[0].slice_shape(); | Shape input0_slice_shape = inputs[0].slice_shape(); | ||||
| Shape input1_slice_shape = inputs[1].slice_shape(); | Shape input1_slice_shape = inputs[1].slice_shape(); | ||||
| @@ -380,8 +382,9 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardMemoryCost(const std::vector | |||||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double SoftmaxCrossEntropyWithLogitsCost::GetBackwardMemoryCost(const std::vector<TensorInfo>&, | |||||
| const std::vector<TensorInfo>&, const int32_t&) const { | |||||
| double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector<TensorInfo>&, | |||||
| const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| @@ -409,8 +412,8 @@ double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const st | |||||
| // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | // Return the per memory cost in the forward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double ReshapeCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const { | |||||
| double ReshapeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const { | |||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); | ||||
| @@ -421,26 +424,27 @@ double ReshapeCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, | |||||
| if (tensor_redistribution.ComputeCost() == FAILED) { | if (tensor_redistribution.ComputeCost() == FAILED) { | ||||
| MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; | MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; | ||||
| } | } | ||||
| return (inputs_type_lengths_[0] * tensor_redistribution.mem_cost()); | |||||
| return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); | |||||
| } | } | ||||
| // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | // Return the per memory cost in the backward phase. The cost is calculated according to the bytes | ||||
| // this operator uses | // this operator uses | ||||
| double ReshapeCost::GetBackwardMemoryCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, const int32_t&) const { | |||||
| double ReshapeCost::GetBackwardComputationCost(const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const std::vector<mindspore::parallel::TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double ArithmeticCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double ArithmeticCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double result; | double result; | ||||
| result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) + | result = ListProduct(inputs[0].slice_shape()) * static_cast<double>(inputs_type_lengths_[0]) + | ||||
| ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]); | ListProduct(inputs[1].slice_shape()) * static_cast<double>(inputs_type_lengths_[1]); | ||||
| return result; | return result; | ||||
| } | } | ||||
| double ArithmeticCost::GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| double ArithmeticCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | MS_EXCEPTION_IF_NULL(g_device_manager); | ||||
| @@ -533,15 +537,15 @@ double L2NormalizeCost::GetBackwardCommCost(const std::vector<TensorInfo>& input | |||||
| return result; | return result; | ||||
| } | } | ||||
| double L2NormalizeCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| double L2NormalizeCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| TensorInfo input0_info = inputs[0]; | TensorInfo input0_info = inputs[0]; | ||||
| Shape input0_slice_shape = input0_info.slice_shape(); | Shape input0_slice_shape = input0_info.slice_shape(); | ||||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | ||||
| } | } | ||||
| double L2NormalizeCost::GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t& stage_id) const { | |||||
| double L2NormalizeCost::GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>&, const int32_t& stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| if (is_parameter_[0]) { | if (is_parameter_[0]) { | ||||
| @@ -618,8 +622,9 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector<TensorInfo>& inpu | |||||
| return result; | return result; | ||||
| } | } | ||||
| double ReduceMethodCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const { | |||||
| double ReduceMethodCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| TensorInfo input0 = inputs[0]; | TensorInfo input0 = inputs[0]; | ||||
| TensorInfo output0 = outputs[0]; | TensorInfo output0 = outputs[0]; | ||||
| @@ -640,8 +645,9 @@ double ReduceMethodCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inp | |||||
| return result; | return result; | ||||
| } | } | ||||
| double ReduceMeanCost::GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const { | |||||
| double ReduceMeanCost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const { | |||||
| double result = 0.0; | double result = 0.0; | ||||
| TensorInfo input0 = inputs[0]; | TensorInfo input0 = inputs[0]; | ||||
| TensorInfo output0 = outputs[0]; | TensorInfo output0 = outputs[0]; | ||||
| @@ -65,12 +65,12 @@ class OperatorCost { | |||||
| virtual double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | virtual double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | ||||
| const int32_t& stage_id) const = 0; | const int32_t& stage_id) const = 0; | ||||
| // per device computation cost | // per device computation cost | ||||
| virtual double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const = 0; | |||||
| virtual double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const = 0; | |||||
| virtual double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const = 0; | |||||
| virtual double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const = 0; | |||||
| virtual double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0; | |||||
| virtual double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, | |||||
| const std::vector<TensorInfo>& outputs, const int32_t& stage_id) const = 0; | |||||
| protected: | protected: | ||||
| // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter | ||||
| @@ -96,14 +96,14 @@ class MatMulCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| // per device computation cost | // per device computation cost | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using MatMulCostPtr = std::shared_ptr<MatMulCost>; | using MatMulCostPtr = std::shared_ptr<MatMulCost>; | ||||
| @@ -121,14 +121,14 @@ class ActivationCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | ||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using ActivationCostPtr = std::shared_ptr<ActivationCost>; | using ActivationCostPtr = std::shared_ptr<ActivationCost>; | ||||
| @@ -146,14 +146,14 @@ class SoftmaxCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | ||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t&) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t&) const override; | |||||
| }; | }; | ||||
| using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>; | using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>; | ||||
| @@ -171,14 +171,14 @@ class TmpIdentityCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | ||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>; | using TmpIdentityCostPtr = std::shared_ptr<TmpIdentityCost>; | ||||
| @@ -199,14 +199,14 @@ class BatchParallelCost : public OperatorCost { | |||||
| const int32_t&) const override { | const int32_t&) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | } | ||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>; | using BatchParallelCostPtr = std::shared_ptr<BatchParallelCost>; | ||||
| @@ -227,16 +227,16 @@ class VirtualDatasetCost : public OperatorCost { | |||||
| const int32_t&) const override { | const int32_t&) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | } | ||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -259,18 +259,18 @@ class GeneratorBaseCost : public OperatorCost { | |||||
| const int32_t&) const override { | const int32_t&) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | } | ||||
| // Inputs vector is empty for generator ops. | // Inputs vector is empty for generator ops. | ||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Generator ops don't have backward steps. | // Generator ops don't have backward steps. | ||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -292,14 +292,14 @@ class PReLUCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| // per device computation cost | // per device computation cost | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using PReLUCostPtr = std::shared_ptr<PReLUCost>; | using PReLUCostPtr = std::shared_ptr<PReLUCost>; | ||||
| @@ -319,14 +319,14 @@ class OneHotCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| // per device computation cost | // per device computation cost | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using OneHotCostPtr = std::shared_ptr<OneHotCost>; | using OneHotCostPtr = std::shared_ptr<OneHotCost>; | ||||
| @@ -346,14 +346,14 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| // per device computation cost | // per device computation cost | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropyWithLogitsCost>; | using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr<SoftmaxCrossEntropyWithLogitsCost>; | ||||
| @@ -376,16 +376,16 @@ class ReshapeCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| // per device computation cost | // per device computation cost | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | } | ||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using ReshapeCostPtr = std::shared_ptr<ReshapeCost>; | using ReshapeCostPtr = std::shared_ptr<ReshapeCost>; | ||||
| @@ -405,14 +405,14 @@ class ArithmeticCost : public OperatorCost { | |||||
| double GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | double GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | ||||
| const int32_t&) const override; | const int32_t&) const override; | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | } | ||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>; | using ArithmeticCostPtr = std::shared_ptr<ArithmeticCost>; | ||||
| @@ -431,14 +431,14 @@ class L2NormalizeCost : public OperatorCost { | |||||
| } | } | ||||
| double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | ||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using L2NormalizeCostPtr = std::shared_ptr<L2NormalizeCost>; | using L2NormalizeCostPtr = std::shared_ptr<L2NormalizeCost>; | ||||
| @@ -455,14 +455,14 @@ class ReduceMethodCost : public OperatorCost { | |||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | ||||
| const int32_t& stage_id) const override; | const int32_t& stage_id) const override; | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| void set_cross_batch(bool cb) { cross_batch_ = cb; } | void set_cross_batch(bool cb) { cross_batch_ = cb; } | ||||
| @@ -477,8 +477,8 @@ class ReduceMeanCost : public ReduceMethodCost { | |||||
| ReduceMeanCost() = default; | ReduceMeanCost() = default; | ||||
| ~ReduceMeanCost() override = default; | ~ReduceMeanCost() override = default; | ||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| }; | }; | ||||
| using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>; | using ReduceMeanCostPtr = std::shared_ptr<ReduceMeanCost>; | ||||
| @@ -499,18 +499,18 @@ class GetNextCost : public OperatorCost { | |||||
| const int32_t&) const override { | const int32_t&) const override { | ||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| double GetMemoryCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardMemoryCost(inputs, outputs, stage_id) + GetBackwardMemoryCost(inputs, outputs, stage_id); | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | } | ||||
| // Inputs vector is empty for generator ops. | // Inputs vector is empty for generator ops. | ||||
| double GetForwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| // Generator ops don't have backward steps. | // Generator ops don't have backward steps. | ||||
| double GetBackwardMemoryCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const override { | |||||
| return 0.0; | return 0.0; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -592,10 +592,10 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& | |||||
| int32_t stage_id = strategy->GetInputStage(); | int32_t stage_id = strategy->GetInputStage(); | ||||
| // Here, we use the origin outputs_, because we only use the slice size of the output tensor. | // Here, we use the origin outputs_, because we only use the slice size of the output tensor. | ||||
| // It does not matter whether the output tensor is transposed or not. | // It does not matter whether the output tensor is transposed or not. | ||||
| double memory_cost = | |||||
| matmulcost_ptr->GetForwardMemoryCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| double computation_cost = | |||||
| matmulcost_ptr->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | |||||
| double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | double communication_cost = matmulcost_ptr->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | ||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(memory_cost, communication_cost); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||||
| result->communication_without_parameter_ = | result->communication_without_parameter_ = | ||||
| matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | matmulcost_ptr->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); | ||||
| result->communication_with_partial_para_ = | result->communication_with_partial_para_ = | ||||
| @@ -604,7 +604,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& | |||||
| // Breaking ties for preferring data parallelization | // Breaking ties for preferring data parallelization | ||||
| BreakingTiesForPerferringDataParallel(strategy, result); | BreakingTiesForPerferringDataParallel(strategy, result); | ||||
| MS_LOG(DEBUG) << name_ << " : memory_cost: " << result->memory_cost_ | |||||
| MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_ | |||||
| << ", communication_cost: " << result->communication_cost_ | << ", communication_cost: " << result->communication_cost_ | ||||
| << ", communication_without_parameter_: " << result->communication_without_parameter_ | << ", communication_without_parameter_: " << result->communication_without_parameter_ | ||||
| << ", communication_with_partial_para_: " << result->communication_with_partial_para_; | << ", communication_with_partial_para_: " << result->communication_with_partial_para_; | ||||
| @@ -1034,9 +1034,10 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int32_t stage_id = strategy->GetInputStage(); | int32_t stage_id = strategy->GetInputStage(); | ||||
| double memory_cost = GetOperatorCost()->GetForwardMemoryCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| double computation_cost = | |||||
| GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | |||||
| double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | double communication_cost = GetOperatorCost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(memory_cost, communication_cost); | |||||
| std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost); | |||||
| result->communication_without_parameter_ = | result->communication_without_parameter_ = | ||||
| GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | GetOperatorCost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); | ||||
| result->communication_with_partial_para_ = | result->communication_with_partial_para_ = | ||||
| @@ -1056,22 +1057,6 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OperatorInfo::CorrectStrategyCostForMultiOutputUse(size_t input_index) { | |||||
| for (auto& swc : strategy_cost_) { | |||||
| double parameter_memory_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * | |||||
| static_cast<double>(GetOperatorCost()->inputs_type_lengths()[input_index]); | |||||
| // remove the parameter memory cost | |||||
| swc->cost_list[0]->memory_cost_ -= parameter_memory_cost; | |||||
| if (swc->cost_list[0]->memory_cost_ < -1) { | |||||
| MS_LOG(ERROR) << "The memory cost after correction is " << swc->cost_list[0]->memory_cost_ | |||||
| << ", the parameter_memory_cost is " << parameter_memory_cost; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| corrected_input_indices_.push_back(input_index); | |||||
| return SUCCESS; | |||||
| } | |||||
| int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { | int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { | ||||
| if (is_output_parameter_involve_ != -1) { | if (is_output_parameter_involve_ != -1) { | ||||
| return is_output_parameter_involve_; | return is_output_parameter_involve_; | ||||
| @@ -1217,7 +1202,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra | |||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); | auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); | ||||
| if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { | if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { | ||||
| cost->memory_cost_ -= 1.0; | |||||
| cost->computation_cost_ -= 1.0; | |||||
| cost->communication_cost_ -= 1.0; | cost->communication_cost_ -= 1.0; | ||||
| cost->communication_with_partial_para_ -= 1.0; | cost->communication_with_partial_para_ -= 1.0; | ||||
| cost->communication_without_parameter_ -= 1.0; | cost->communication_without_parameter_ -= 1.0; | ||||
| @@ -1226,7 +1211,7 @@ void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra | |||||
| } | } | ||||
| double OperatorInfo::GetForwardMemoryCostFromCNode() { | double OperatorInfo::GetForwardMemoryCostFromCNode() { | ||||
| return GetOperatorCost()->GetForwardMemoryCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||||
| return GetOperatorCost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); | |||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| @@ -87,13 +87,9 @@ class OperatorInfo { | |||||
| // is checked | // is checked | ||||
| Status SetCostUnderStrategyBase(const StrategyPtr& strategy); | Status SetCostUnderStrategyBase(const StrategyPtr& strategy); | ||||
| std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } | std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; } | ||||
| // In the case of a Parameter (or a output) being used by multiple operators, the memory cost induced by | |||||
| // the parameter (or a output) should be calculated only once. This method is used to | |||||
| // remove this part from the 'strategy_cost_'. | |||||
| Status CorrectStrategyCostForMultiOutputUse(size_t input_index); | |||||
| // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input | ||||
| // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. | ||||
| Status CorrectStrategyCostForMemoryReuse() const { return SUCCESS; } | |||||
| Status CalculateMemoryCost() const { return SUCCESS; } | |||||
| int ComputeOpAndPrevEdgeParameterInvolved(); | int ComputeOpAndPrevEdgeParameterInvolved(); | ||||
| ForwardOp forward_op() const { return forward_op_; } | ForwardOp forward_op() const { return forward_op_; } | ||||
| @@ -387,7 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| operator_info->set_outputs_dtype(cnode->Type()); | operator_info->set_outputs_dtype(cnode->Type()); | ||||
| operator_info->set_cnode(cnode); | operator_info->set_cnode(cnode); | ||||
| // If no strategy has been configured for this operator, then candidate strategies are generated for | // If no strategy has been configured for this operator, then candidate strategies are generated for | ||||
| // auto-strategy searchingm if this primitive is Cast, we ignore the user-specified strategy | |||||
| // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy | |||||
| if (!StrategyFound(attrs) || prim->name() == CAST) { | if (!StrategyFound(attrs) || prim->name() == CAST) { | ||||
| // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for | ||||
| // BatchParallelInfo operator | // BatchParallelInfo operator | ||||
| @@ -600,13 +600,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); | MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); | ||||
| } | } | ||||
| // For the case of a output being used by multiple subsequent operators, the output induced memory cost should be | |||||
| // calculated only once. This method is for correct the operators' memory cost calculation. | |||||
| if (entire_costgraph->CorrectOpsStrategyCostForMultiOutputUse() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Correcting strategy_cost_ for operators failed."; | |||||
| } else { | |||||
| MS_LOG(INFO) << "Correcting strategy_cost_ for operators succeeded."; | |||||
| } | |||||
| MS_LOG(INFO) << "Constructing edges for cost graph ends."; | MS_LOG(INFO) << "Constructing edges for cost graph ends."; | ||||
| } | } | ||||
| @@ -803,14 +797,6 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>( | std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>( | ||||
| edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); | edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); | ||||
| // Correct the memory calculation for a parameter being used by multiple operators. The parameter is calculated | |||||
| // only once | |||||
| if (target_cnode->operator_info()->CorrectStrategyCostForMultiOutputUse(IntToSize(input_index - 1)) != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Correcting strategy_cost_ failed : " << prim->name(); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Correcting strategy_cost_ succeeded. " << prim->name(); | |||||
| } | |||||
| if (edge_ptr->InitEdgeCost() != SUCCESS) { | if (edge_ptr->InitEdgeCost() != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | ||||
| } | } | ||||
| @@ -840,7 +826,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity | // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity | ||||
| // operator for this Parameter, and add an edge for the use of this Parameter by each | // operator for this Parameter, and add an edge for the use of this Parameter by each | ||||
| // subsequent operator; | // subsequent operator; | ||||
| // Step 3.1: Correct the memory calculation for memory reuse | |||||
| // Step 3.1: Calculate memory usage | |||||
| // Step 4: Run the Dynamic Programming algorithm: | // Step 4: Run the Dynamic Programming algorithm: | ||||
| // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge | // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge | ||||
| // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input | // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input | ||||
| @@ -867,14 +853,14 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() | MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() | ||||
| << " operators, and " << entire_costgraph->GetNumPairs() << " edges."; | << " operators, and " << entire_costgraph->GetNumPairs() << " edges."; | ||||
| // Step 3.1: Correcting calculation for memory reuse | |||||
| // Step 3.1: Calculate the memory usage | |||||
| if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { | ||||
| // Correcting operators' memory usage | |||||
| if (entire_costgraph->CorrectOpsStrategyCostForMemoryReuse() != SUCCESS) { | |||||
| // Calculate operators' memory usage | |||||
| if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed."; | MS_LOG(EXCEPTION) << "Correcting operators' cost for memory reuse failed."; | ||||
| } | } | ||||
| // Correcting edges' memory usage | |||||
| if (entire_costgraph->CorrectEdgesStrategyCostForMemoryReuse() != SUCCESS) { | |||||
| // Calculate edges' memory usage | |||||
| if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed."; | MS_LOG(EXCEPTION) << "Correcting edges' cost for memory reuse failed."; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -144,7 +144,7 @@ Status TensorRedistribution::ComputeCost() { | |||||
| MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; | MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; | ||||
| return Status::FAILED; | return Status::FAILED; | ||||
| } | } | ||||
| // Compute redistribution communication cost and memory cost | |||||
| // Compute redistribution communication cost and computation cost | |||||
| for (auto& op_cost : operator_list_) { | for (auto& op_cost : operator_list_) { | ||||
| OperatorR op = op_cost.first; | OperatorR op = op_cost.first; | ||||
| Shape slice_shape = op_cost.second; | Shape slice_shape = op_cost.second; | ||||
| @@ -154,14 +154,14 @@ Status TensorRedistribution::ComputeCost() { | |||||
| if (str == PERMUTE_BY_AXIS) { | if (str == PERMUTE_BY_AXIS) { | ||||
| // The shape does not change after PermuteByAxis operation. | // The shape does not change after PermuteByAxis operation. | ||||
| // communication cost = all_to_all + all_to_all = 2 * slice_shape | // communication cost = all_to_all + all_to_all = 2 * slice_shape | ||||
| // memory cost = slice_shape | |||||
| // computation cost = slice_shape | |||||
| forward_comm_cost_ += prod; | forward_comm_cost_ += prod; | ||||
| backward_comm_cost_ += prod; | backward_comm_cost_ += prod; | ||||
| comm_cost_ += 2.0 * prod; | comm_cost_ += 2.0 * prod; | ||||
| mem_cost_ += prod; | |||||
| computation_cost_ += prod; | |||||
| } else if (str == CONCAT_BY_AXIS) { | } else if (str == CONCAT_BY_AXIS) { | ||||
| // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape | ||||
| // memory cost = before_slice_shape | |||||
| // computation cost = before_slice_shape | |||||
| if (op.second.size() < 3) { | if (op.second.size() < 3) { | ||||
| MS_LOG(ERROR) << "op.second size should not be less than 3!"; | MS_LOG(ERROR) << "op.second size should not be less than 3!"; | ||||
| return Status::FAILED; | return Status::FAILED; | ||||
| @@ -173,22 +173,22 @@ Status TensorRedistribution::ComputeCost() { | |||||
| comm_cost_ += prod * (dev_num + 1.0); | comm_cost_ += prod * (dev_num + 1.0); | ||||
| int32_t concat_dim = op.second[0]; | int32_t concat_dim = op.second[0]; | ||||
| if (concat_dim == 0) { | if (concat_dim == 0) { | ||||
| // memory cost = all_gather | |||||
| mem_cost_ += prod; | |||||
| // computation cost = all_gather | |||||
| computation_cost_ += prod; | |||||
| } else { | } else { | ||||
| // memory cost = all_gather + split + concat | |||||
| mem_cost_ += (prod + prod * dev_num + prod * dev_num); | |||||
| // computation cost = all_gather + split + concat | |||||
| computation_cost_ += (prod + prod * dev_num + prod * dev_num); | |||||
| } | } | ||||
| } else { | } else { | ||||
| // There is only memory cost in SplitByAxis. | |||||
| // memory cost = before_slice_shape | |||||
| mem_cost_ += prod; | |||||
| // There is only computation cost in SplitByAxis. | |||||
| // computation cost = before_slice_shape | |||||
| computation_cost_ += prod; | |||||
| } | } | ||||
| } | } | ||||
| if (reshape_flag()) { | if (reshape_flag()) { | ||||
| Shape prev_slice_shape = from_.slice_shape().array(); | Shape prev_slice_shape = from_.slice_shape().array(); | ||||
| double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>()); | double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>()); | ||||
| mem_cost_ += 2.0 * prev_prod; | |||||
| computation_cost_ += 2.0 * prev_prod; | |||||
| } | } | ||||
| return Status::SUCCESS; | return Status::SUCCESS; | ||||
| } | } | ||||
| @@ -41,7 +41,7 @@ class TensorRedistribution { | |||||
| comm_cost_(0.0), | comm_cost_(0.0), | ||||
| forward_comm_cost_(0.0), | forward_comm_cost_(0.0), | ||||
| backward_comm_cost_(0.0), | backward_comm_cost_(0.0), | ||||
| mem_cost_(0.0), | |||||
| computation_cost_(0.0), | |||||
| construct_op_flag_(construct_op_flag), | construct_op_flag_(construct_op_flag), | ||||
| keep_reshape_(keep_reshape) {} | keep_reshape_(keep_reshape) {} | ||||
| Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); | Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); | ||||
| @@ -51,7 +51,7 @@ class TensorRedistribution { | |||||
| bool reshape_flag() const { return reshape_flag_; } | bool reshape_flag() const { return reshape_flag_; } | ||||
| Status ComputeCost(); | Status ComputeCost(); | ||||
| double comm_cost() const { return comm_cost_; } | double comm_cost() const { return comm_cost_; } | ||||
| double mem_cost() const { return mem_cost_; } | |||||
| double computation_cost() const { return computation_cost_; } | |||||
| double forward_comm_cost() const { return forward_comm_cost_; } | double forward_comm_cost() const { return forward_comm_cost_; } | ||||
| double backward_comm_cost() const { return backward_comm_cost_; } | double backward_comm_cost() const { return backward_comm_cost_; } | ||||
| @@ -66,10 +66,13 @@ class TensorRedistribution { | |||||
| RankList dev_list_; | RankList dev_list_; | ||||
| OperatorList operator_list_; | OperatorList operator_list_; | ||||
| bool reshape_flag_; | bool reshape_flag_; | ||||
| // communication cost | |||||
| double comm_cost_; | double comm_cost_; | ||||
| // forward communication cost | |||||
| double forward_comm_cost_; | double forward_comm_cost_; | ||||
| // backward communication cost | |||||
| double backward_comm_cost_; | double backward_comm_cost_; | ||||
| double mem_cost_; | |||||
| double computation_cost_; | |||||
| bool construct_op_flag_; | bool construct_op_flag_; | ||||
| bool keep_reshape_; | bool keep_reshape_; | ||||
| }; | }; | ||||
| @@ -322,8 +322,8 @@ TEST_F(TestCostGraph, test_SelectCostListWithMinTrainingTimeMultiple) { | |||||
| auto ret_list = entire_cost_graph.SelectCostListWithMinTrainingTimeMultiple(all_list, memory); | auto ret_list = entire_cost_graph.SelectCostListWithMinTrainingTimeMultiple(all_list, memory); | ||||
| ASSERT_EQ(ret_list.size(), 2); | ASSERT_EQ(ret_list.size(), 2); | ||||
| ASSERT_DOUBLE_EQ(ret_list[0]->memory_cost_, 10); | |||||
| ASSERT_DOUBLE_EQ(ret_list[1]->memory_cost_, 1010); | |||||
| ASSERT_DOUBLE_EQ(ret_list[0]->computation_cost_, 10); | |||||
| ASSERT_DOUBLE_EQ(ret_list[1]->computation_cost_, 1010); | |||||
| } | } | ||||
| TEST_F(TestCostGraph, test_CheckOpElimination) { | TEST_F(TestCostGraph, test_CheckOpElimination) { | ||||
| @@ -76,8 +76,8 @@ TEST_F(TestMatMulCost, test_CostGeneration) { | |||||
| mmcost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); | mmcost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); | ||||
| mmcost_.GetForwardCommCost(inputs, outputs, 0); | mmcost_.GetForwardCommCost(inputs, outputs, 0); | ||||
| mmcost_.GetBackwardCommCost(inputs, outputs, 0); | mmcost_.GetBackwardCommCost(inputs, outputs, 0); | ||||
| mmcost_.GetForwardMemoryCost(inputs, outputs, 0); | |||||
| mmcost_.GetBackwardMemoryCost(inputs, outputs, 0); | |||||
| mmcost_.GetForwardComputationCost(inputs, outputs, 0); | |||||
| mmcost_.GetForwardComputationCost(inputs, outputs, 0); | |||||
| } | } | ||||
| class TestActivationCost : public UT::Common { | class TestActivationCost : public UT::Common { | ||||
| @@ -128,8 +128,8 @@ TEST_F(TestActivationCost, test_CostGeneration) { | |||||
| std::vector<size_t> inputs_length = {4, 4}; | std::vector<size_t> inputs_length = {4, 4}; | ||||
| std::vector<size_t> outputs_length = {4}; | std::vector<size_t> outputs_length = {4}; | ||||
| ac_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); | ac_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); | ||||
| ac_cost_.GetForwardMemoryCost(inputs, outputs, 0); | |||||
| ac_cost_.GetBackwardMemoryCost(inputs, outputs, 0); | |||||
| ac_cost_.GetForwardComputationCost(inputs, outputs, 0); | |||||
| ac_cost_.GetBackwardComputationCost(inputs, outputs, 0); | |||||
| } | } | ||||
| class TestPReLUCost : public UT::Common { | class TestPReLUCost : public UT::Common { | ||||
| @@ -184,8 +184,8 @@ TEST_F(TestPReLUCost, test_CostGeneration) { | |||||
| prelu_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); | prelu_cost_.SetInputAndOutputTypeLength(inputs_length, outputs_length); | ||||
| double BCC, FMC, GMC; | double BCC, FMC, GMC; | ||||
| BCC = prelu_cost_.GetBackwardCommCost(inputs, outputs, 0); | BCC = prelu_cost_.GetBackwardCommCost(inputs, outputs, 0); | ||||
| FMC = prelu_cost_.GetForwardMemoryCost(inputs, outputs, 0); | |||||
| GMC = prelu_cost_.GetBackwardMemoryCost(inputs, outputs, 0); | |||||
| FMC = prelu_cost_.GetForwardComputationCost(inputs, outputs, 0); | |||||
| GMC = prelu_cost_.GetBackwardComputationCost(inputs, outputs, 0); | |||||
| ASSERT_EQ(BCC, 32 * 4); | ASSERT_EQ(BCC, 32 * 4); | ||||
| ASSERT_EQ(FMC, 8 * 32 * 8 * 8 * 4 + 32 * 4); | ASSERT_EQ(FMC, 8 * 32 * 8 * 8 * 4 + 32 * 4); | ||||
| ASSERT_EQ(GMC, 128); | ASSERT_EQ(GMC, 128); | ||||
| @@ -84,8 +84,8 @@ TEST_F(TestActivation, test_activation_strategies) { | |||||
| act_ptr_->InitForCostModel(sp); | act_ptr_->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.memory_cost_); | |||||
| ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | |||||
| ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | ASSERT_DOUBLE_EQ(act_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | ||||
| cost.communication_cost_); | cost.communication_cost_); | ||||
| } | } | ||||
| @@ -109,8 +109,8 @@ TEST_F(TestActivation, test_softmax_strategies) { | |||||
| soft_ptr_->InitForCostModel(sp); | soft_ptr_->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.memory_cost_); | |||||
| ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | |||||
| ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | ASSERT_DOUBLE_EQ(soft_ptr_->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | ||||
| cost.communication_cost_); | cost.communication_cost_); | ||||
| } | } | ||||
| @@ -569,8 +569,8 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies1) { | |||||
| matmul1->InitForCostModel(sp); | matmul1->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = matmul1->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = matmul1->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.memory_cost_); | |||||
| ASSERT_DOUBLE_EQ(matmul1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -599,8 +599,8 @@ TEST_F(TestMatmulInfo, test_GenerateStrategies2) { | |||||
| TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); | TensorInfo replica_input1_info(tly, input1_shape, input1_slice_shape); | ||||
| replica_inputs_info.push_back(replica_input1_info); | replica_inputs_info.push_back(replica_input1_info); | ||||
| ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetMemoryCost(replica_inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.memory_cost_); | |||||
| ASSERT_DOUBLE_EQ(matmul3->GetOperatorCost()->GetComputationCost(replica_inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| @@ -188,8 +188,8 @@ TEST_F(TestTensorAddInfo, GenerateStrategies) { | |||||
| tensor_add->InitForCostModel(sp); | tensor_add->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = tensor_add->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = tensor_add->outputs_tensor_info(); | ||||
| double memory_cost0 = tensor_add->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost1 = cost.memory_cost_; | |||||
| double memory_cost0 = tensor_add->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost1 = cost.computation_cost_; | |||||
| bool memory = memory_cost0 - memory_cost1 <= 1.0; | bool memory = memory_cost0 - memory_cost1 <= 1.0; | ||||
| double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | double comm_cost0 = tensor_add->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | ||||
| @@ -210,8 +210,8 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { | |||||
| tensor_add1->InitForCostModel(sp); | tensor_add1->InitForCostModel(sp); | ||||
| std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = tensor_add1->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = tensor_add1->outputs_tensor_info(); | ||||
| double memory_cost0 = tensor_add1->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost1 = cost.memory_cost_; | |||||
| double memory_cost0 = tensor_add1->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()); | |||||
| double memory_cost1 = cost.computation_cost_; | |||||
| bool memory = memory_cost0 - memory_cost1 <= 1.0; | bool memory = memory_cost0 - memory_cost1 <= 1.0; | ||||
| double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | double comm_cost0 = tensor_add1->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()); | ||||
| @@ -145,8 +145,8 @@ TEST_F(TestTmpIdentityInfo, test_generate_strategies) { | |||||
| identity_ptr->Init(sp); | identity_ptr->Init(sp); | ||||
| std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info(); | std::vector<TensorInfo> inputs_info = identity_ptr->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info(); | std::vector<TensorInfo> outputs_info = identity_ptr->outputs_tensor_info(); | ||||
| ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetMemoryCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.memory_cost_); | |||||
| ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()), | |||||
| cost.computation_cost_); | |||||
| ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | ASSERT_DOUBLE_EQ(identity_ptr->GetOperatorCost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()), | ||||
| cost.communication_cost_); | cost.communication_cost_); | ||||
| } | } | ||||