Browse Source

!10209 [AutoParallell] take manually configured strategies in consideration

From: @ch-l
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
84f9a2bdde
1 changed files with 9 additions and 0 deletions
  1. +9
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc

+ 9
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -25,6 +25,7 @@
#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
#include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h" #include "frontend/parallel/strategy.h"
#include "frontend/parallel/step_parallel.h"


namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
@@ -43,6 +44,14 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);

for (auto &op : ops) {
auto attrs = op->attrs();
if (StrategyFound(attrs)) {
StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs);
op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost());
}
}
} }


Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,


Loading…
Cancel
Save