From: @ch-l Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.2.0-rc1
| @@ -32,7 +32,7 @@ namespace parallel { | |||||
| void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, | const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, | ||||
| const std::vector<std::vector<std::string>> &input_tensor_names, | const std::vector<std::vector<std::string>> &input_tensor_names, | ||||
| const std::shared_ptr<std::vector<size_t>> &index_list) { | |||||
| const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(eli_list); | MS_EXCEPTION_IF_NULL(eli_list); | ||||
| MS_EXCEPTION_IF_NULL(index_list); | MS_EXCEPTION_IF_NULL(index_list); | ||||
| @@ -46,11 +46,18 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std | |||||
| GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); | GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); | ||||
| for (auto &op : ops) { | for (auto &op : ops) { | ||||
| // Set user-defined strategy | |||||
| auto attrs = op->attrs(); | auto attrs = op->attrs(); | ||||
| if (StrategyFound(attrs)) { | if (StrategyFound(attrs)) { | ||||
| StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); | StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); | ||||
| op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); | op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); | ||||
| } | } | ||||
| // Set back to raw strategy for special node in predict/eval | |||||
| if (!is_training) { | |||||
| if ((op->is_last_node()) || (op->type() == "_VirtualDataset")) { | |||||
| SetBackToRawStrategy(op); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -486,6 +493,29 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, | |||||
| return strategies; | return strategies; | ||||
| } | } | ||||
| void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) { | |||||
| StrategyPtr origin_strategy = op->strategy(); | |||||
| Strategys strategies; | |||||
| for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) { | |||||
| Dimensions s; | |||||
| size_t strategy_size = origin_strategy->GetInputDim()[iter_strategy].size(); | |||||
| for (size_t dim = 0; dim < strategy_size; dim++) { | |||||
| if (strategy_size >= 1 && strategy_size <= 4) { | |||||
| s.push_back(1); | |||||
| } else if (strategy_size == 0) { | |||||
| s = {}; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched."; | |||||
| } | |||||
| } | |||||
| strategies.push_back(s); | |||||
| } | |||||
| StrategyPtr sp = std::make_shared<Strategy>(0, strategies); | |||||
| op->SetSelectedStrategyAndCost(sp, op->selected_cost()); | |||||
| } | |||||
| Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const size_t iter_graph, const size_t iter_ops) { | const size_t iter_graph, const size_t iter_ops) { | ||||
| if (ops.empty()) { | if (ops.empty()) { | ||||
| @@ -30,7 +30,7 @@ namespace parallel { | |||||
| void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, | const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, | ||||
| const std::vector<std::vector<std::string>> &input_tensor_names, | const std::vector<std::vector<std::string>> &input_tensor_names, | ||||
| const std::shared_ptr<std::vector<size_t>> &index_list); | |||||
| const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training); | |||||
| Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const size_t iter_graph, const size_t iter_ops); | const size_t iter_graph, const size_t iter_ops); | ||||
| Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s); | Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s); | ||||
| @@ -55,6 +55,7 @@ Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, | |||||
| Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, | Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph, | ||||
| const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, | const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, | ||||
| const size_t iter_ops); | const size_t iter_ops); | ||||
| void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op); | |||||
| Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||
| const size_t iter_graph, const size_t iter_ops); | const size_t iter_graph, const size_t iter_ops); | ||||
| void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph, | void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph, | ||||
| @@ -170,6 +170,8 @@ class OperatorInfo { | |||||
| // needed by rec_parser | // needed by rec_parser | ||||
| void set_type(const std::string &type) { type_ = type; } | void set_type(const std::string &type) { type_ = type; } | ||||
| const std::string &type() const { return type_; } | const std::string &type() const { return type_; } | ||||
| void set_last_node_flag(const bool &is_last_node) { is_last_node_ = is_last_node; } | |||||
| const bool &is_last_node() const { return is_last_node_; } | |||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | ||||
| void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } | void set_stage_id(int32_t stage_id) { stage_id_ = stage_id; } | ||||
| int32_t stage_id() const { return stage_id_; } | int32_t stage_id() const { return stage_id_; } | ||||
| @@ -181,6 +183,7 @@ class OperatorInfo { | |||||
| protected: | protected: | ||||
| // needed by rec_parser | // needed by rec_parser | ||||
| std::string type_; | std::string type_; | ||||
| bool is_last_node_; | |||||
| virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; | virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; | ||||
| virtual Status InferTensorMap() = 0; | virtual Status InferTensorMap() = 0; | ||||
| virtual Status InferForwardCommunication() = 0; | virtual Status InferForwardCommunication() = 0; | ||||
| @@ -421,6 +421,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| } | } | ||||
| // Needed by rec_parser | // Needed by rec_parser | ||||
| operator_info->set_type(prim->name()); | operator_info->set_type(prim->name()); | ||||
| operator_info->set_last_node_flag(is_last_nodes); | |||||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | ||||
| entire_costgraph->AddOperator(operator_info); | entire_costgraph->AddOperator(operator_info); | ||||
| @@ -523,6 +524,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| } | } | ||||
| // Needed by rec_parser | // Needed by rec_parser | ||||
| operator_info->set_type(prim->name()); | operator_info->set_type(prim->name()); | ||||
| operator_info->set_last_node_flag(is_last_nodes); | |||||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | ||||
| entire_costgraph->AddOperator(operator_info); | entire_costgraph->AddOperator(operator_info); | ||||
| @@ -1037,7 +1039,11 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); | |||||
| bool is_training = true; | |||||
| if (!root->has_flag(TRAINING)) { | |||||
| is_training = false; | |||||
| } | |||||
| GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list, is_training); | |||||
| if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { | if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { | ||||
| MS_LOG(INFO) << "Init selected strategy succeeded."; | MS_LOG(INFO) << "Init selected strategy succeeded."; | ||||