diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 85a5813b5b..0af3e1c464 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -300,6 +300,38 @@ Strategys PrepareL2Normalize(const std::vector> &o return strategies; } +Strategys PrepareAxisRelatedStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_graph, + const size_t iter_ops) { + Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); + if (strategies.size() < 1) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy."; + } + + int64_t axis = -1; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t."; + } + } + + if (axis < 0) { + int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); + axis = input_dim + axis; + } + + if (strategies[0][axis] != 1) { + strategies[0][axis] = 1; + MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis; + } + + return strategies; +} + Strategys MakeRecSearchStrategy(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops) { @@ -437,6 +469,8 @@ Strategys PrepareStrategy(const std::shared_ptr &graph, const std::vector return PrepareMatMul(graph, ops, iter_graph, iter_ops); } else if (type == ONEHOT) { return PrepareOneHot(graph, ops, iter_graph, iter_ops); + } else if (type == SOFTMAX) { + return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops); } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || (type == "FusedBatchNormEx") || (type == "Dropout")) { return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 97bb83dec8..ee3dd0463f 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -36,6 +36,9 @@ Strategys PrepareMatMul(const std::shared_ptr &graph, const std::vector &s); Strategys PrepareOneHot(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); +Strategys PrepareAxisRelatedStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_graph, + const size_t iter_ops); Strategys PrepareGatherV2(const std::vector> &ops, const size_t iter_ops, Dimensions s); Strategys PrepareGatherV2P(const std::vector> &ops, const size_t iter_ops, Dimensions s); Dimensions PrepareGatherV2POutputStrategy(const std::vector> &ops, diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h index fd66258fdb..7e361a21cc 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -73,9 +73,9 @@ const std::map DictOpType{ {PRELU, OperatorType::kRecPReLU}, // Elm-wise OP {TRANSPOSE, OperatorType::kRecElmWiseOp}, - {TRANSPOSE, OperatorType::kRecElmWiseOp}, {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp}, + {TENSOR_DOT, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp}, {MUL, OperatorType::kRecElmWiseOp}, {DIV, OperatorType::kRecElmWiseOp},