|
|
|
@@ -25,6 +25,7 @@ |
|
|
|
#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" |
|
|
|
#include "frontend/parallel/ops_info/operator_info.h" |
|
|
|
#include "frontend/parallel/strategy.h" |
|
|
|
#include "frontend/parallel/step_parallel.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
@@ -43,6 +44,14 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std |
|
|
|
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); |
|
|
|
GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); |
|
|
|
GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); |
|
|
|
|
|
|
|
for (auto &op : ops) { |
|
|
|
auto attrs = op->attrs(); |
|
|
|
if (StrategyFound(attrs)) { |
|
|
|
StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs); |
|
|
|
op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, |
|
|
|
|