| @@ -25,6 +25,7 @@ | |||||
| #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" | #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" | ||||
| #include "frontend/parallel/ops_info/operator_info.h" | #include "frontend/parallel/ops_info/operator_info.h" | ||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| #include "frontend/parallel/step_parallel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -43,6 +44,14 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std | |||||
| GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); | GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); | ||||
| GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); | GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); | ||||
| GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); | GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); | ||||
| for (auto &op : ops) { | |||||
| auto attrs = op->attrs(); | |||||
| if (StrategyFound(attrs)) { | |||||
| StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); | |||||
| op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, | ||||