diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 124b64fb94..c63bb64f59 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -421,6 +421,20 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect if (ops[iter_ops]->type() == ONEHOT) { return PrepareOneHot(s); } + + auto dev_num = g_device_manager->DeviceNum(); + size_t cut_num = 1; + for (size_t i = 0; i < s.size(); i++) { + cut_num *= s[i]; + } + if (cut_num < dev_num) { + size_t diff = dev_num / cut_num; + if (s[0] * diff > dev_num) { + MS_LOG(EXCEPTION) << "Failure: Can not continue to partition in the N-dimension of the element-wise operator."; + } + s[0] = s[0] * diff; + } + for (size_t i = 0; i < (size_t)ops[iter_ops]->inputs_tensor_info().size(); i++) { if (ops[iter_ops]->inputs_tensor_info()[i].shape().size() == 0) { stra.push_back(s_empty); @@ -537,6 +551,11 @@ void GenerateEliminatedOperatorStrategyBackward(const std::vectorat(iter_list); std::vector> stra; std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); + if (s.size() == 0) { + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) { + s.push_back(1); + } + } if (ops[iter_ops]->type() == SQUEEZE) { s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 6af1deea9c..f3b0fbe247 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -34,26 +34,72 @@ const std::map DictOpType{ {MAXPOOL, OperatorType::kRecPooling}, {MAXPOOLV2, OperatorType::kRecPooling}, {SIMPLE_MEAN, OperatorType::kRecPooling}, - {TENSOR_ADD, OperatorType::kRecElmWiseOp}, {RESHAPE, OperatorType::kRecReshape}, {BIAS_ADD, OperatorType::kRecBiasAdd}, - {RELU, OperatorType::kRecReLU}, {BATCH_NORM, OperatorType::kRecBatchNorm}, {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, - {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, {ONEHOT, OperatorType::kRecOneHot}, - {LOG, OperatorType::kRecLog}, - {EXP, OperatorType::kRecExp}, - {SUB, OperatorType::kRecElmWiseOp}, - {MUL, OperatorType::kRecElmWiseOp}, - {DIV, OperatorType::kRecElmWiseOp}, {SQUEEZE, OperatorType::kRecSqueeze}, {CAST, OperatorType::kRecCast}, {REDUCE_SUM, OperatorType::kRecReduce}, {REDUCE_MAX, OperatorType::kRecReduce}, {REDUCE_MIN, OperatorType::kRecReduce}, - {REDUCE_MEAN, OperatorType::kRecReduce}}; + {REDUCE_MEAN, OperatorType::kRecReduce}, + + {RELU, OperatorType::kRecReLU}, + {"ReLU6", OperatorType::kRecReLU}, + {"ReLUV2", OperatorType::kRecReLU}, + {SIGMOID, OperatorType::kRecReLU}, + {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU}, + {"HSigmoid", OperatorType::kRecReLU}, + {GELU, OperatorType::kRecReLU}, + {TANH, OperatorType::kRecReLU}, + + {TENSOR_ADD, OperatorType::kRecElmWiseOp}, + {SUB, OperatorType::kRecElmWiseOp}, + {MUL, OperatorType::kRecElmWiseOp}, + {DIV, OperatorType::kRecElmWiseOp}, + {REAL_DIV, OperatorType::kRecElmWiseOp}, + {SOFTMAX, OperatorType::kRecElmWiseOp}, + {LOG_SOFTMAX, OperatorType::kRecElmWiseOp}, + {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecElmWiseOp}, + {SQRT, OperatorType::kRecElmWiseOp}, + {NEG, OperatorType::kRecElmWiseOp}, + {POW, OperatorType::kRecElmWiseOp}, + {EXP, OperatorType::kRecElmWiseOp}, + {LOG, OperatorType::kRecElmWiseOp}, + {COS, OperatorType::kRecElmWiseOp}, + {ACOS, OperatorType::kRecElmWiseOp}, + {LOGICALNOT, OperatorType::kRecElmWiseOp}, + {"LogicalAnd", OperatorType::kRecElmWiseOp}, + {"LogicalOr", OperatorType::kRecElmWiseOp}, + {SQUARE, OperatorType::kRecElmWiseOp}, + {"Abs", OperatorType::kRecElmWiseOp}, + {"Acosh", OperatorType::kRecElmWiseOp}, + {"AddN", OperatorType::kRecElmWiseOp}, + {"Atan2", OperatorType::kRecElmWiseOp}, + {"Erf", OperatorType::kRecElmWiseOp}, + {"Floor", OperatorType::kRecElmWiseOp}, + {FLOORDIV, OperatorType::kRecElmWiseOp}, + {"FloorMod", OperatorType::kRecElmWiseOp}, + {GREATER, OperatorType::kRecElmWiseOp}, + {"GreaterEqual", OperatorType::kRecElmWiseOp}, + {"HSwish", OperatorType::kRecElmWiseOp}, + {"Less", OperatorType::kRecElmWiseOp}, + {"LessEqual", OperatorType::kRecElmWiseOp}, + {MAXIMUM, OperatorType::kRecElmWiseOp}, + {MINIMUM, OperatorType::kRecElmWiseOp}, + {EQUAL, OperatorType::kRecElmWiseOp}, + {NOT_EQUAL, OperatorType::kRecElmWiseOp}, + {"Reciprocal", OperatorType::kRecElmWiseOp}, + {"Round", OperatorType::kRecElmWiseOp}, + {"Rsqrt", OperatorType::kRecElmWiseOp}, + {"Sign", OperatorType::kRecElmWiseOp}, + {"Sin", OperatorType::kRecElmWiseOp}, + {ASSIGN, OperatorType::kRecElmWiseOp}, + {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, + {"AssignAdd", OperatorType::kRecElmWiseOp}}; const TensorParam MakeTensor(int n, int c, int h, int w);