diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index c1874ed559..1ab9da3c6b 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -32,7 +32,7 @@ namespace parallel { void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::shared_ptr>> &eli_list, const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list) { + const std::shared_ptr> &index_list, bool is_training) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(eli_list); MS_EXCEPTION_IF_NULL(index_list); @@ -46,11 +46,18 @@ void GenerateStrategy(const std::shared_ptr &graph, const std::vectorattrs(); if (StrategyFound(attrs)) { StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); } + // Set back to raw strategy for special node in predict/eval + if (!is_training) { + if ((op->is_last_node()) || (op->type() == "_VirtualDataset")) { + SetBackToRawStrategy(op); + } + } } } @@ -486,6 +493,29 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr &graph, return strategies; } +void SetBackToRawStrategy(const std::shared_ptr &op) { + StrategyPtr origin_strategy = op->strategy(); + Strategys strategies; + + for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) { + Dimensions s; + size_t strategy_size = origin_strategy->GetInputDim()[iter_strategy].size(); + for (size_t dim = 0; dim < strategy_size; dim++) { + if (strategy_size >= 1 && strategy_size <= 4) { + s.push_back(1); + } else if (strategy_size == 0) { + s = {}; + } else { + MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched."; + } + } + strategies.push_back(s); + } + + StrategyPtr sp = std::make_shared(0, strategies); + op->SetSelectedStrategyAndCost(sp, op->selected_cost()); +} + Strategys PrepareStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops) { if (ops.empty()) { diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h index ee3dd0463f..cee86413c2 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -30,7 +30,7 @@ namespace parallel { void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, const std::shared_ptr>> &eli_list, const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list); + const std::shared_ptr> &index_list, bool is_training); Strategys PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); Strategys PrepareBiasAdd(const std::shared_ptr &s); @@ -55,6 +55,7 @@ Strategys CheckDivisible(const std::vector> &ops, Strategys MakeDataParallelStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); +void SetBackToRawStrategy(const std::shared_ptr &op); Strategys PrepareStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 813b64d13d..c382707a88 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -170,6 +170,8 @@ class OperatorInfo { // needed by rec_parser void set_type(const std::string &type) { type_ = type; } const std::string &type() const { return type_; } + void set_last_node_flag(const bool &is_last_node) { is_last_node_ = is_last_node; } + const bool &is_last_node() const { return is_last_node_; } const std::unordered_map &attrs() const { return attrs_; } void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } int32_t stage_id() const { return stage_id_; } @@ -181,6 +183,7 @@ class OperatorInfo { protected: // needed by rec_parser std::string type_; + bool is_last_node_; virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status InferTensorMap() = 0; virtual Status InferForwardCommunication() = 0; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 7baa7f0300..2035feea57 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -421,6 +421,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node } // Needed by rec_parser operator_info->set_type(prim->name()); + operator_info->set_last_node_flag(is_last_nodes); std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); @@ -523,6 +524,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no } // Needed by rec_parser operator_info->set_type(prim->name()); + operator_info->set_last_node_flag(is_last_nodes); std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); @@ -1037,7 +1039,11 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const return FAILED; } - GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); + bool is_training = true; + if (!root->has_flag(TRAINING)) { + is_training = false; + } + GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training); if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { MS_LOG(INFO) << "Init selected strategy succeeded.";