diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index f00da76532..6716acd3b3 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -25,6 +25,7 @@ #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/strategy.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { @@ -43,6 +44,14 @@ void GenerateStrategy(const std::shared_ptr &graph, const std::vectorattrs(); + if (StrategyFound(attrs)) { + StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); + op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); + } + } } Strategys PrepareMatMul(const std::shared_ptr &graph, const std::vector> &ops,