From: @ch-l Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.2.0-rc1
| @@ -32,7 +32,7 @@ namespace parallel { | |||
| void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, | |||
| const std::vector<std::vector<std::string>> &input_tensor_names, | |||
| const std::shared_ptr<std::vector<size_t>> &index_list) { | |||
| const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(eli_list); | |||
| MS_EXCEPTION_IF_NULL(index_list); | |||
| @@ -46,11 +46,18 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std | |||
| GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); | |||
| for (auto &op : ops) { | |||
| // Set user-defined strategy | |||
| auto attrs = op->attrs(); | |||
| if (StrategyFound(attrs)) { | |||
| StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); | |||
| op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); | |||
| } | |||
| // Set back to raw strategy for special node in predict/eval | |||
| if (!is_training) { | |||
| if ((op->is_last_node()) || (op->type() == "_VirtualDataset")) { | |||
| SetBackToRawStrategy(op); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -486,6 +493,29 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, | |||
| return strategies; | |||
| } | |||
| void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) { | |||
| StrategyPtr origin_strategy = op->strategy(); | |||
| Strategys strategies; | |||
| for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) { | |||
| Dimensions s; | |||
| size_t strategy_size = origin_strategy->GetInputDim()[iter_strategy].size(); | |||
| for (size_t dim = 0; dim < strategy_size; dim++) { | |||
| if (strategy_size >= 1 && strategy_size <= 4) { | |||
| s.push_back(1); | |||
| } else if (strategy_size == 0) { | |||
| s = {}; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched."; | |||
| } | |||
| } | |||
| strategies.push_back(s); | |||
| } | |||
| StrategyPtr sp = std::make_shared<Strategy>(0, strategies); | |||
| op->SetSelectedStrategyAndCost(sp, op->selected_cost()); | |||
| } | |||
| Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops) { | |||
| if (ops.empty()) { | |||
| @@ -30,7 +30,7 @@ namespace parallel { | |||
| void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, | |||
| const std::vector<std::vector<std::string>> &input_tensor_names, | |||
| const std::shared_ptr<std::vector<size_t>> &index_list); | |||
| const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training); | |||
| Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops); | |||
| Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s); | |||
| @@ -55,6 +55,7 @@ Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, | |||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, | |||
| const size_t iter_ops); | |||
| void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op); | |||
| Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||
| const size_t iter_graph, const size_t iter_ops); | |||
| void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph, | |||
| @@ -170,6 +170,8 @@ class OperatorInfo { | |||
| // needed by rec_parser | |||
| void set_type(const std::string &type) { type_ = type; } | |||
| const std::string &type() const { return type_; } | |||
| void set_last_node_flag(const bool &is_last_node) { is_last_node_ = is_last_node; } | |||
| const bool &is_last_node() const { return is_last_node_; } | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } | |||
| int32_t stage_id() const { return stage_id_; } | |||
| @@ -181,6 +183,7 @@ class OperatorInfo { | |||
| protected: | |||
| // needed by rec_parser | |||
| std::string type_; | |||
| bool is_last_node_; | |||
| virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; | |||
| virtual Status InferTensorMap() = 0; | |||
| virtual Status InferForwardCommunication() = 0; | |||
| @@ -421,6 +421,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| } | |||
| // Needed by rec_parser | |||
| operator_info->set_type(prim->name()); | |||
| operator_info->set_last_node_flag(is_last_nodes); | |||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||
| entire_costgraph->AddOperator(operator_info); | |||
| @@ -523,6 +524,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| } | |||
| // Needed by rec_parser | |||
| operator_info->set_type(prim->name()); | |||
| operator_info->set_last_node_flag(is_last_nodes); | |||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||
| entire_costgraph->AddOperator(operator_info); | |||
| @@ -1037,7 +1039,11 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const | |||
| return FAILED; | |||
| } | |||
| GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); | |||
| bool is_training = true; | |||
| if (!root->has_flag(TRAINING)) { | |||
| is_training = false; | |||
| } | |||
| GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training); | |||
| if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { | |||
| MS_LOG(INFO) << "Init selected strategy succeeded."; | |||