Browse Source

!11215 [AutoParallel] adjust strategy of the last op of network according to inference or training

From: @ch-l
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
294ce9ddb8
4 changed files with 43 additions and 3 deletions
  1. +31
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +2
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h
  3. +3
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  4. +7
    -1
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc

+ 31
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -32,7 +32,7 @@ namespace parallel {
void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> &index_list) {
const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(eli_list); MS_EXCEPTION_IF_NULL(eli_list);
MS_EXCEPTION_IF_NULL(index_list); MS_EXCEPTION_IF_NULL(index_list);
@@ -46,11 +46,18 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);


for (auto &op : ops) { for (auto &op : ops) {
// Set user-defined strategy
auto attrs = op->attrs(); auto attrs = op->attrs();
if (StrategyFound(attrs)) { if (StrategyFound(attrs)) {
StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs);
op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost());
} }
// Set back to raw strategy for special node in predict/eval
if (!is_training) {
if ((op->is_last_node()) || (op->type() == "_VirtualDataset")) {
SetBackToRawStrategy(op);
}
}
} }
} }


@@ -486,6 +493,29 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
return strategies; return strategies;
} }


void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
StrategyPtr origin_strategy = op->strategy();
Strategys strategies;

for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) {
Dimensions s;
size_t strategy_size = origin_strategy->GetInputDim()[iter_strategy].size();
for (size_t dim = 0; dim < strategy_size; dim++) {
if (strategy_size >= 1 && strategy_size <= 4) {
s.push_back(1);
} else if (strategy_size == 0) {
s = {};
} else {
MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched.";
}
}
strategies.push_back(s);
}

StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
op->SetSelectedStrategyAndCost(sp, op->selected_cost());
}

Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) { const size_t iter_graph, const size_t iter_ops) {
if (ops.empty()) { if (ops.empty()) {


+ 2
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h View File

@@ -30,7 +30,7 @@ namespace parallel {
void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> &index_list);
const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training);
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s); Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
@@ -55,6 +55,7 @@ Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops); const size_t iter_ops);
void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op);
Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph, void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,


+ 3
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -170,6 +170,8 @@ class OperatorInfo {
// needed by rec_parser // needed by rec_parser
void set_type(const std::string &type) { type_ = type; } void set_type(const std::string &type) { type_ = type; }
const std::string &type() const { return type_; } const std::string &type() const { return type_; }
void set_last_node_flag(const bool &is_last_node) { is_last_node_ = is_last_node; }
const bool &is_last_node() const { return is_last_node_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; }
int32_t stage_id() const { return stage_id_; } int32_t stage_id() const { return stage_id_; }
@@ -181,6 +183,7 @@ class OperatorInfo {
protected: protected:
// needed by rec_parser // needed by rec_parser
std::string type_; std::string type_;
bool is_last_node_;
virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
virtual Status InferTensorMap() = 0; virtual Status InferTensorMap() = 0;
virtual Status InferForwardCommunication() = 0; virtual Status InferForwardCommunication() = 0;


+ 7
- 1
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -421,6 +421,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
} }
// Needed by rec_parser // Needed by rec_parser
operator_info->set_type(prim->name()); operator_info->set_type(prim->name());
operator_info->set_last_node_flag(is_last_nodes);
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);


entire_costgraph->AddOperator(operator_info); entire_costgraph->AddOperator(operator_info);
@@ -523,6 +524,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
} }
// Needed by rec_parser // Needed by rec_parser
operator_info->set_type(prim->name()); operator_info->set_type(prim->name());
operator_info->set_last_node_flag(is_last_nodes);
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);


entire_costgraph->AddOperator(operator_info); entire_costgraph->AddOperator(operator_info);
@@ -1037,7 +1039,11 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
return FAILED; return FAILED;
} }


GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list);
bool is_training = true;
if (!root->has_flag(TRAINING)) {
is_training = false;
}
GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training);


if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
MS_LOG(INFO) << "Init selected strategy succeeded."; MS_LOG(INFO) << "Init selected strategy succeeded.";


Loading…
Cancel
Save