Browse Source

!859 [Auto parallel] Support searching strategy for inference phase

Merge pull request !859 from Xiaoda/support-inferring-phase-parallel-strategy-searching
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
6351f9c837
15 changed files with 255 additions and 45 deletions
  1. +14
    -4
      mindspore/ccsrc/parallel/auto_parallel/costmodel.cc
  2. +8
    -3
      mindspore/ccsrc/parallel/auto_parallel/costmodel.h
  3. +13
    -6
      mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
  4. +123
    -25
      mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
  5. +6
    -1
      mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
  6. +3
    -0
      mindspore/ccsrc/parallel/costmodel_context.cc
  7. +6
    -0
      mindspore/ccsrc/parallel/costmodel_context.h
  8. +1
    -0
      mindspore/ccsrc/parallel/ops_info/matmul_info.cc
  9. +1
    -0
      mindspore/ccsrc/parallel/ops_info/operator_info.cc
  10. +3
    -3
      mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h
  11. +2
    -0
      mindspore/ccsrc/pipeline/init.cc
  12. +32
    -2
      mindspore/parallel/_cost_model_context.py
  13. +1
    -1
      tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc
  14. +6
    -0
      tests/ut/python/parallel/__init__.py
  15. +36
    -0
      tests/ut/python/parallel/test_auto_parallel_inference.py

+ 14
- 4
mindspore/ccsrc/parallel/auto_parallel/costmodel.cc View File

@@ -23,8 +23,17 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
void Simplify(CostPtrList *clist_ptrs) { void Simplify(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_cost.
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
} else {
// inference phase
SimplifyForDecreasingCommunicationForward(clist_ptrs);
}
}
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
// excludes the cost with greater computation_cost_ and greater communication_forward.
// E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
if (!COST_MODEL_SIMPLIFY_CALCULATION) { if (!COST_MODEL_SIMPLIFY_CALCULATION) {
return; return;
@@ -37,14 +46,15 @@ void Simplify(CostPtrList *clist_ptrs) {
}); });
CostPtrList ret; CostPtrList ret;
for (size_t i = 0; i < clist_ptrs->size(); ++i) { for (size_t i = 0; i < clist_ptrs->size(); ++i) {
if ((ret.size() == size_t(0)) || (clist_ptrs->at(id[i])->communication_cost_ < ret.back()->communication_cost_)) {
if ((ret.size() == size_t(0)) ||
(clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
ret.emplace_back(std::move(clist_ptrs->at(id[i]))); ret.emplace_back(std::move(clist_ptrs->at(id[i])));
} }
} }
*clist_ptrs = std::move(ret); *clist_ptrs = std::move(ret);
} }


void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
// Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
// order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
if (!COST_MODEL_SIMPLIFY_CALCULATION) { if (!COST_MODEL_SIMPLIFY_CALCULATION) {


+ 8
- 3
mindspore/ccsrc/parallel/auto_parallel/costmodel.h View File

@@ -51,18 +51,22 @@ struct Cost {
communication_with_partial_para_ = 0.0; communication_with_partial_para_ = 0.0;
communication_redis_forward_ = 0.0; communication_redis_forward_ = 0.0;
communication_redis_backward_ = 0.0; communication_redis_backward_ = 0.0;
communication_forward_ = 0.0;
} }
// 'memory_with_reuse_' calculates the peak memory usage in a training phase // 'memory_with_reuse_' calculates the peak memory usage in a training phase
double memory_with_reuse_; double memory_with_reuse_;
// 'computation_cost_' models the training time of an iteration in a training phase
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
// by ONLY forward phase
double computation_cost_; double computation_cost_;
// 'communication_cost_' includes communications from operators (forward and backward) and edges
// 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution)
double communication_cost_; double communication_cost_;
// communication_without_parameter_ = communication_cost_ - (backward communication from operators) // communication_without_parameter_ = communication_cost_ - (backward communication from operators)
double communication_without_parameter_; double communication_without_parameter_;
// communication_with_partial_para_ = // communication_with_partial_para_ =
// communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ )
double communication_with_partial_para_; double communication_with_partial_para_;
// communication_forward_ = communication cost from operators (only forward phase) and forward redistribution.
double communication_forward_;
double communication_redis_forward_; double communication_redis_forward_;
double communication_redis_backward_; double communication_redis_backward_;
std::shared_ptr<Decision> decision_ptr_; std::shared_ptr<Decision> decision_ptr_;
@@ -296,7 +300,8 @@ using FinalDecisionPtr = std::shared_ptr<FinalDecision>;
using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>; using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>;


void Simplify(CostPtrList *clist); void Simplify(CostPtrList *clist);
void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist);
void SimplifyForDecreasingCommunicationForward(CostPtrList *clist);
void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist);
void RefineForPracticalCost(const CostPtr &, bool is_redistribution); void RefineForPracticalCost(const CostPtr &, bool is_redistribution);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore


+ 13
- 6
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc View File

@@ -76,6 +76,7 @@ Status Edge::InitEdgeCost() {
<< ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
// refine communication cost calculation for practice // refine communication cost calculation for practice
RefineForPracticalCost(cost, true); RefineForPracticalCost(cost, true);
cost->communication_forward_ = cost->communication_redis_forward_;
CostPtrKey ck = {target_output_str, target_input_str}; CostPtrKey ck = {target_output_str, target_input_str};
CostPtrList cl; CostPtrList cl;
cl.push_back(cost); cl.push_back(cost);
@@ -160,8 +161,9 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
(void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);


CostPtrList selected_cost_list(all_cost_list.size(), nullptr); CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
std::function<void(size_t, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para) {
std::function<void(size_t, double, double, double, double, double)> recursive =
[&](size_t k, double computation, double memory, double communication, double communication_without_para,
double communication_forward) {
if (k == edges.size()) { if (k == edges.size()) {
auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list); auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
CostPtr new_cost = std::make_shared<Cost>(computation, communication); CostPtr new_cost = std::make_shared<Cost>(computation, communication);
@@ -170,6 +172,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
new_cost->communication_with_partial_para_ = new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory; new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
new_cost->decision_ptr_ = decision; new_cost->decision_ptr_ = decision;
result.push_back(new_cost); result.push_back(new_cost);
return; return;
@@ -179,11 +182,12 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr
selected_cost_list[k] = c; selected_cost_list[k] = c;
recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
communication + c->communication_cost_, communication + c->communication_cost_,
communication_without_para + c->communication_without_parameter_);
communication_without_para + c->communication_without_parameter_,
communication_forward + c->communication_forward_);
} }
}; };
recursive(0, 0.0, 0.0, 0.0, 0.0);
SimplifyForDreasingCommunicationWithPartialPara(&result);
recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
Simplify(&result);
return result; return result;
} }


@@ -219,6 +223,8 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
double communication = double communication =
left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
double communication_forward =
left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
double communication_without_para = left_cost->communication_without_parameter_ + double communication_without_para = left_cost->communication_without_parameter_ +
middle_cost->communication_without_parameter_ + middle_cost->communication_without_parameter_ +
right_cost->communication_without_parameter_; right_cost->communication_without_parameter_;
@@ -232,6 +238,7 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr
cost->communication_with_partial_para_ = cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
cost->memory_with_reuse_ = memory_cost; cost->memory_with_reuse_ = memory_cost;
cost->communication_forward_ = communication_forward;
ret_cost_list->emplace_back(std::move(cost)); ret_cost_list->emplace_back(std::move(cost));
} }
} }
@@ -251,7 +258,7 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyP
CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
} }
SimplifyForDreasingCommunicationWithPartialPara(&result);
Simplify(&result);
return result; return result;
} }




+ 123
- 25
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc View File

@@ -38,6 +38,8 @@ bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE;


void CostGraph::SetDeviceMemoryAndCostParameter() { void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
@@ -142,6 +144,23 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
} else { } else {
MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; MS_LOG(INFO) << "elementwise_op_strategy_follow: false.";
} }

// MULTI_SUBGRAPHS
auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs();
MULTI_SUBGRAPHS = multi_subgraphs;
if (MULTI_SUBGRAPHS) {
MS_LOG(INFO) << "multi_subgraphs: true.";
} else {
MS_LOG(INFO) << "multi_subgraphs: false.";
}

// RUN_PHASE
auto phase = CostModelContext::GetInstance()->run_phase();
if (phase != 0 && phase != 1) {
MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}";
}
RUN_PHASE = phase;
MS_LOG(INFO) << "run_phase: " << RUN_PHASE << ".";
} }


void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
@@ -249,19 +268,21 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
MS_EXCEPTION_IF_NULL(cost3); MS_EXCEPTION_IF_NULL(cost3);
double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
double commmunication =
cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
double communication_forward =
cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
double communication_without_para = cost1->communication_without_parameter_ + double communication_without_para = cost1->communication_without_parameter_ +
cost2->communication_without_parameter_ + cost2->communication_without_parameter_ +
cost3->communication_without_parameter_; cost3->communication_without_parameter_;
auto decision = auto decision =
std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
auto cost = std::make_shared<Cost>(computation, commmunication, decision);
auto cost = std::make_shared<Cost>(computation, communication, decision);
MS_EXCEPTION_IF_NULL(cost); MS_EXCEPTION_IF_NULL(cost);
cost->communication_without_parameter_ = communication_without_para; cost->communication_without_parameter_ = communication_without_para;
cost->communication_with_partial_para_ = cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (commmunication - communication_without_para);
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
cost->memory_with_reuse_ = memory; cost->memory_with_reuse_ = memory;
cost->communication_forward_ = communication_forward;
ret.push_back(cost); ret.push_back(cost);
} }
} }
@@ -269,7 +290,7 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::
} }
} }


SimplifyForDreasingCommunicationWithPartialPara(&ret);
Simplify(&ret);
return ret; return ret;
} }


@@ -291,32 +312,67 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
cost1->communication_without_parameter_ + cost1->communication_without_parameter_ +
COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_);
new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
new_cost->communication_forward_ = cost1->communication_forward_;
ret.push_back(new_cost); ret.push_back(new_cost);
} }
} }


SimplifyForDreasingCommunicationWithPartialPara(&ret);
Simplify(&ret);
return ret; return ret;
} }


CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) {
CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) {
// Select the cost with minimum inference time. Currently, the inference time is modeled as =
// costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
if (cost_list.empty()) {
MS_LOG(ERROR) << "Final cost list is null.";
return nullptr;
}
CostPtrList after_mem_filter; CostPtrList after_mem_filter;
// Filter out the valid costs
double minimum_memory = DBL_MAX;
// Filter out the valid costs.
for (auto &a_cost : cost_list) { for (auto &a_cost : cost_list) {
if (a_cost->memory_with_reuse_ <= memory) { if (a_cost->memory_with_reuse_ <= memory) {
after_mem_filter.emplace_back(std::move(a_cost)); after_mem_filter.emplace_back(std::move(a_cost));
} else if (a_cost->memory_with_reuse_ < minimum_memory) {
minimum_memory = a_cost->memory_with_reuse_;
} }
} }
if (after_mem_filter.empty()) {
MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
<< ", the memory capacity is: " << memory << ".";
return nullptr;
}
// Init the returned value with first cost.
CostPtr ret = after_mem_filter[0];


std::function<CostPtr(CostPtr, const CostPtr &)> LocalCompare = [&](CostPtr init, const CostPtr &cost_x) {
MS_EXCEPTION_IF_NULL(cost_x);
if (init == nullptr || cost_x->computation_cost_ < memory) {
init = cost_x;
double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_;
MS_LOG(INFO) << "Cost 0: "
<< "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
<< ", communication_forward_: " << ret->communication_forward_
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
<< ", communication_cost_: " << ret->communication_cost_
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
<< ", computation_cost_: " << after_mem_filter[i]->computation_cost_
<< ", communication_forward_: " << after_mem_filter[i]->communication_forward_
<< ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
<< ", communication_cost_: " << after_mem_filter[i]->communication_cost_
<< ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
<< ".";
auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ +
costmodel_beta_ * after_mem_filter[i]->communication_forward_;
MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
if (minimum > tmp) {
minimum = tmp;
ret = after_mem_filter[i];
MS_LOG(INFO) << "Selected: " << i;
} }
return init;
};
CostPtr ret = nullptr;
return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare);
}
return ret;
} }


CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
@@ -524,12 +580,26 @@ Status CostGraph::SearchStrategy() {
}); });


if (alive_ops.size() > 2) { if (alive_ops.size() > 2) {
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
return SearchStrategyForMultiNodeFinalGraph(alive_ops);
} else {
// inference phase
MS_LOG(EXCEPTION)
<< "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
}
} else if (alive_ops.size() == 1) { } else if (alive_ops.size() == 1) {
MS_LOG(INFO) << "There are 1 single node in the final graph."; MS_LOG(INFO) << "There are 1 single node in the final graph.";
OperatorInfoPtr u = alive_ops[0]; OperatorInfoPtr u = alive_ops[0];
auto cost_list = CreateFinalSingleCostList(u); auto cost_list = CreateFinalSingleCostList(u);
auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
CostPtr cost = nullptr;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
} else {
// inference phase
cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_);
}
if (cost == nullptr) { if (cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
return FAILED; return FAILED;
@@ -575,7 +645,15 @@ Status CostGraph::SearchStrategy() {
auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
all_list.push_back(cost_list); all_list.push_back(cost_list);
} }
auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
CostPtrList selected_cost_list;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_);
} else {
// inference phase
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
"phase is not supported.";
}
for (size_t k = 0; k < selected_cost_list.size(); ++k) { for (size_t k = 0; k < selected_cost_list.size(); ++k) {
auto selected_cost = selected_cost_list[k]; auto selected_cost = selected_cost_list[k];
if (selected_cost == nullptr) { if (selected_cost == nullptr) {
@@ -601,7 +679,14 @@ Status CostGraph::SearchStrategy() {
auto e = u->GetAliveSuccEdges()[0]; auto e = u->GetAliveSuccEdges()[0];
MS_EXCEPTION_IF_NULL(e); MS_EXCEPTION_IF_NULL(e);
auto cost_list = CreateFinalCostList(u, e, v); auto cost_list = CreateFinalCostList(u, e, v);
auto cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
CostPtr cost = nullptr;
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_);
} else {
MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
"phase is not supported.";
}
if (cost == nullptr) { if (cost == nullptr) {
MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << ".";
return FAILED; return FAILED;
@@ -841,6 +926,8 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
double communication = double communication =
op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
double communication_forward =
op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
double communication_without_para = op_cost->communication_without_parameter_ + double communication_without_para = op_cost->communication_without_parameter_ +
edge_cost->communication_without_parameter_ + edge_cost->communication_without_parameter_ +
tar_cost->communication_without_parameter_; tar_cost->communication_without_parameter_;
@@ -853,6 +940,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const
new_cost->communication_with_partial_para_ = new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory; new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
MS_EXCEPTION_IF_NULL(tar_cost_list_new); MS_EXCEPTION_IF_NULL(tar_cost_list_new);
tar_cost_list_new->emplace_back(std::move(new_cost)); tar_cost_list_new->emplace_back(std::move(new_cost));
} }
@@ -885,7 +973,7 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {


CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
} }
SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
Simplify(&tar_clist_new);
// Set the new costlist w.r.t the strategy // Set the new costlist w.r.t the strategy
tar_stra_cost->cost_list = tar_clist_new; tar_stra_cost->cost_list = tar_clist_new;
if ((!valid) && (!tar_clist_new.empty())) { if ((!valid) && (!tar_clist_new.empty())) {
@@ -922,6 +1010,8 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
double communication = double communication =
contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
tar_cost->communication_forward_;
double communication_without_para = contract_op_cost->communication_without_parameter_ + double communication_without_para = contract_op_cost->communication_without_parameter_ +
edge_cost->communication_without_parameter_ + edge_cost->communication_without_parameter_ +
tar_cost->communication_without_parameter_; tar_cost->communication_without_parameter_;
@@ -933,6 +1023,7 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str
new_cost->communication_with_partial_para_ = new_cost->communication_with_partial_para_ =
communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para);
new_cost->memory_with_reuse_ = memory; new_cost->memory_with_reuse_ = memory;
new_cost->communication_forward_ = communication_forward;
tar_cost_list_new->emplace_back(std::move(new_cost)); tar_cost_list_new->emplace_back(std::move(new_cost));
} }
} }
@@ -962,7 +1053,7 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {


CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
} }
SimplifyForDreasingCommunicationWithPartialPara(&tar_clist_new);
Simplify(&tar_clist_new);
// Set the new costlist w.r.t the strategy // Set the new costlist w.r.t the strategy
tar_stra_cost->cost_list = tar_clist_new; tar_stra_cost->cost_list = tar_clist_new;
if ((!valid) && (!tar_clist_new.empty())) { if ((!valid) && (!tar_clist_new.empty())) {
@@ -998,6 +1089,8 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
double new_commu_without = double new_commu_without =
elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
@@ -1009,6 +1102,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,
new_cost->communication_with_partial_para_ = new_cost->communication_with_partial_para_ =
new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without);
new_cost->memory_with_reuse_ = new_memory; new_cost->memory_with_reuse_ = new_memory;
new_cost->communication_forward_ = new_commu_forward;
left_node_clist_new->emplace_back(std::move(new_cost)); left_node_clist_new->emplace_back(std::move(new_cost));
} }
} }
@@ -1079,7 +1173,7 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
&left_node_clist_new); &left_node_clist_new);
} }
} }
SimplifyForDreasingCommunicationWithPartialPara(&left_node_clist_new);
Simplify(&left_node_clist_new);
// Set the new costlist w.r.t the strategy // Set the new costlist w.r.t the strategy
left_node_stra_cost->cost_list = left_node_clist_new; left_node_stra_cost->cost_list = left_node_clist_new;
if ((!valid) && (!left_node_clist_new.empty())) { if ((!valid) && (!left_node_clist_new.empty())) {
@@ -1112,19 +1206,22 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n


double computation_cost = merged_node_cost->computation_cost_, double computation_cost = merged_node_cost->computation_cost_,
memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
commu_without = merged_node_cost->communication_without_parameter_;
commu_without = merged_node_cost->communication_without_parameter_,
commu_forward = merged_node_cost->communication_forward_;
for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
if (i == 0) { if (i == 0) {
computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
commu_without += succ_edges_costs[i]->communication_without_parameter_ + commu_without += succ_edges_costs[i]->communication_without_parameter_ +
succ_nodes_costs[i]->communication_without_parameter_; succ_nodes_costs[i]->communication_without_parameter_;
} else { } else {
computation_cost += succ_edges_costs[i]->computation_cost_; computation_cost += succ_edges_costs[i]->computation_cost_;
memory_cost += succ_edges_costs[i]->memory_with_reuse_; memory_cost += succ_edges_costs[i]->memory_with_reuse_;
commu_cost += succ_edges_costs[i]->communication_cost_; commu_cost += succ_edges_costs[i]->communication_cost_;
commu_forward += succ_edges_costs[i]->communication_forward_;
commu_without += succ_edges_costs[i]->communication_without_parameter_; commu_without += succ_edges_costs[i]->communication_without_parameter_;
} }
} }
@@ -1135,6 +1232,7 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n
new_cost->communication_without_parameter_ = commu_without; new_cost->communication_without_parameter_ = commu_without;
new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without);
new_cost->memory_with_reuse_ = memory_cost; new_cost->memory_with_reuse_ = memory_cost;
new_cost->communication_forward_ = commu_forward;
first_succ_node_clist_new->emplace_back(std::move(new_cost)); first_succ_node_clist_new->emplace_back(std::move(new_cost));
} }
} }
@@ -1220,7 +1318,7 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
merged_op_stra, merged_op_clist, &first_succ_node_clist_new); merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
} }
SimplifyForDreasingCommunicationWithPartialPara(&first_succ_node_clist_new);
Simplify(&first_succ_node_clist_new);
// Set the new costlist w.r.t the strategy // Set the new costlist w.r.t the strategy
first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
if ((!valid) && (!first_succ_node_clist_new.empty())) { if ((!valid) && (!first_succ_node_clist_new.empty())) {


+ 6
- 1
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h View File

@@ -45,6 +45,9 @@ namespace parallel {
#define DEFAULT_FULLY_USE_DEVICES true #define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false
#define DEFAULT_IS_MULTI_SUBGRAPHS false #define DEFAULT_IS_MULTI_SUBGRAPHS false
#define DEFAULT_RUN_PHASE 0
#define TRAINING_PHASE 0
#define INFERENCE_PHASE 1


class CostGraph; class CostGraph;
using CostGraphPtr = std::shared_ptr<CostGraph>; using CostGraphPtr = std::shared_ptr<CostGraph>;
@@ -60,6 +63,8 @@ extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
extern bool FULLY_USE_DEVICES; extern bool FULLY_USE_DEVICES;
extern bool ELEMENTWISE_OP_STRA_FOLLOW; extern bool ELEMENTWISE_OP_STRA_FOLLOW;
extern bool MULTI_SUBGRAPHS;
extern int32_t RUN_PHASE;


class CostGraph { class CostGraph {
// 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
@@ -98,7 +103,7 @@ class CostGraph {


CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v);
CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u);
CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory);
CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory);
CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory); CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, double memory);
Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);


+ 3
- 0
mindspore/ccsrc/parallel/costmodel_context.cc View File

@@ -47,6 +47,7 @@ void CostModelContext::ResetCostModel() {
costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST;
costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS;
is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS;
run_phase_ = DEFAULT_RUN_PHASE;
costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM;
costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES;
costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT;
@@ -125,5 +126,7 @@ void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_
void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
elementwise_stra_follow_ = elementwise_follow; elementwise_stra_follow_ = elementwise_follow;
} }

void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

+ 6
- 0
mindspore/ccsrc/parallel/costmodel_context.h View File

@@ -113,6 +113,9 @@ class CostModelContext {
void set_elementwise_stra_follow(bool); void set_elementwise_stra_follow(bool);
bool elementwise_stra_follow() const { return elementwise_stra_follow_; } bool elementwise_stra_follow() const { return elementwise_stra_follow_; }


void set_run_phase(int32_t);
int32_t run_phase() const { return run_phase_; }

private: private:
CostModelContext(); CostModelContext();
static std::shared_ptr<CostModelContext> cm_context_inst_; static std::shared_ptr<CostModelContext> cm_context_inst_;
@@ -141,8 +144,11 @@ class CostModelContext {
// COST_MODEL_COMMUNI_BIAS // COST_MODEL_COMMUNI_BIAS
double costmodel_communi_bias_; double costmodel_communi_bias_;


// MULTI_SUBGRAPHS
bool is_multi_subgraphs_; bool is_multi_subgraphs_;


int32_t run_phase_; // 0: 'training', 1: 'inference'

int32_t costmodel_allreduce_fusion_algorithm_; int32_t costmodel_allreduce_fusion_algorithm_;


int32_t costmodel_allreduce_fusion_times_; int32_t costmodel_allreduce_fusion_times_;


+ 1
- 0
mindspore/ccsrc/parallel/ops_info/matmul_info.cc View File

@@ -610,6 +610,7 @@ Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &
<< ", communication_with_partial_para_: " << result->communication_with_partial_para_; << ", communication_with_partial_para_: " << result->communication_with_partial_para_;
// refine communication cost calculation for practice // refine communication cost calculation for practice
RefineForPracticalCost(result, false); RefineForPracticalCost(result, false);
result->communication_forward_ = result->communication_without_parameter_;


std::shared_ptr<StrategyWithCost> swc = std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);


+ 1
- 0
mindspore/ccsrc/parallel/ops_info/operator_info.cc View File

@@ -1049,6 +1049,7 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
BreakingTiesForPerferringDataParallel(strategy, result); BreakingTiesForPerferringDataParallel(strategy, result);
// refine communication cost calculation for practice // refine communication cost calculation for practice
RefineForPracticalCost(result, false); RefineForPracticalCost(result, false);
result->communication_forward_ = result->communication_without_parameter_;


std::shared_ptr<StrategyWithCost> swc = std::shared_ptr<StrategyWithCost> swc =
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_); std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);


+ 3
- 3
mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h View File

@@ -69,16 +69,16 @@ class TensorRedistribution {
RankList dev_list_; RankList dev_list_;
OperatorList operator_list_; OperatorList operator_list_;
bool reshape_flag_; bool reshape_flag_;
// communication cost
// communication cost, which is the sum of forward communication cost and backward communication cost
double comm_cost_; double comm_cost_;
// forward communication cost // forward communication cost
double forward_comm_cost_; double forward_comm_cost_;
// backward communication cost // backward communication cost
double backward_comm_cost_; double backward_comm_cost_;
// computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the
// inputs.
// inputs. This is calculated ONLY for forward phase.
double computation_cost_; double computation_cost_;
// memory_cost models the PEAK memory cost in a traning iteration contributed by this tensor redistribution, which is
// memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is
// calculated by the outputs. // calculated by the outputs.
double memory_cost_; double memory_cost_;
bool construct_op_flag_; bool construct_op_flag_;


+ 2
- 0
mindspore/ccsrc/pipeline/init.cc View File

@@ -228,6 +228,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get the parameter cost_model_communi_bias of the DP algorithm.") "Get the parameter cost_model_communi_bias of the DP algorithm.")
.def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
.def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
.def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.")
.def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.")
.def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
"Set the parameter gradient AllReduce fusion algorithm.") "Set the parameter gradient AllReduce fusion algorithm.")
.def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,


+ 32
- 2
mindspore/parallel/_cost_model_context.py View File

@@ -239,6 +239,33 @@ class _CostModelContext:
raise ValueError("Context handle is none in context!!!") raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_multi_subgraphs() return self._context_handle.get_multi_subgraphs()


def set_run_phase(self, phase):
"""
Set the flag of running phase: training (0) or inference (1)

Args:
phase (int): A parameter indicating which phase is running.

Raises:
ValueError: If context handle is none, or phase is not in {0, 1}.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
if phase not in (0, 1):
raise ValueError("The argument of set_run_phase() must be '0' or '1', but got {}".format(phase))
self._context_handle.set_run_phase(phase)

def get_run_phase(self):
"""
Get the flag of running phase.

Raises:
ValueError: If context handle is none.
"""
if self._context_handle is None:
raise ValueError("Context handle is none in context!!!")
return self._context_handle.get_run_phase()

def set_costmodel_allreduce_fusion_algorithm(self, algorithm): def set_costmodel_allreduce_fusion_algorithm(self, algorithm):
""" """
Set costmodel allreduce fusion algorithm. Set costmodel allreduce fusion algorithm.
@@ -453,6 +480,7 @@ set_cost_model_context_func_map = {
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const, "costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().set_multi_subgraphs, "multi_subgraphs": cost_model_context().set_multi_subgraphs,
"run_phase": cost_model_context().set_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent,
@@ -473,7 +501,8 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const, "costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().get_multi_subgraphs(),
"multi_subgraphs": cost_model_context().get_multi_subgraphs,
"run_phase": cost_model_context().get_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
"costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent,
@@ -488,7 +517,7 @@ get_cost_model_context_func_map = {


@args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float,
costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float,
multi_subgraphs=bool,
multi_subgraphs=bool, run_phase=int,
costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int,
costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float,
costmodel_allreduce_fusion_allreduce_inherent_time=float, costmodel_allreduce_fusion_allreduce_inherent_time=float,
@@ -510,6 +539,7 @@ def set_cost_model_context(**kwargs):
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion; 0: bypass allreduce fusion;
1: only use backward computation time to group allreduce; 1: only use backward computation time to group allreduce;


+ 1
- 1
tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc View File

@@ -371,7 +371,7 @@ TEST_F(TestCostGraph, test_CreateFinalCostList_AND_Select) {
ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS); ASSERT_EQ(edge_m1_m2->InitEdgeCost(), SUCCESS);
cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2); cost_graph.AddEdge(matmul1, matmul2, edge_m1_m2);
auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2); auto cost_list = cost_graph.CreateFinalCostList(matmul1, edge_m1_m2, matmul2);
cost_graph.SelectCostWithMemoryConstraint(cost_list, cost_graph.GetDeviceMemory());
cost_graph.SelectCostWithMinInferenceTime(cost_list, cost_graph.GetDeviceMemory());
} }


TEST_F(TestCostGraph, test_EliminationOp) { TEST_F(TestCostGraph, test_EliminationOp) {


+ 6
- 0
tests/ut/python/parallel/__init__.py View File

@@ -14,15 +14,21 @@


import mindspore.context as context import mindspore.context as context
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._cost_model_context import reset_cost_model_context
from mindspore.parallel.algo_parameter_config import reset_algo_parameters
from mindspore.parallel._utils import _reset_op_id from mindspore.parallel._utils import _reset_op_id




def setup_module(module): def setup_module(module):
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True) auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
reset_cost_model_context()
reset_algo_parameters()
_reset_op_id() _reset_op_id()




def teardown_module(): def teardown_module():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
reset_cost_model_context()
reset_algo_parameters()
_reset_op_id() _reset_op_id()

+ 36
- 0
tests/ut/python/parallel/test_auto_parallel_inference.py View File

@@ -0,0 +1,36 @@
import numpy as np

import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn import Momentum
from mindspore.parallel._cost_model_context import set_cost_model_context

class Net(nn.Cell):
def __init__(self, input_ch, out_ch):
super(Net, self).__init__()
self.dense = nn.Dense(input_ch, out_ch)
self.relu = P.ReLU()

def construct(self, x):
x = self.dense(x)
x = self.relu(x)
return x

def test_inference_phase():
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
set_cost_model_context(run_phase=1)

net = Net(512, 128)
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.001)
label = Tensor(np.ones([64, 128]).astype(np.float32))

loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()

output = train_network(predict, label)

Loading…
Cancel
Save