|
|
|
@@ -48,74 +48,6 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
// splittable_op_ will continuously be updated |
|
|
|
std::vector<std::string> splittable_op_ = {MATMUL, |
|
|
|
GELU, |
|
|
|
TANH, |
|
|
|
SOFTMAX, |
|
|
|
LOG_SOFTMAX, |
|
|
|
ACTIVATION, |
|
|
|
PRELU, |
|
|
|
FLOORDIV, |
|
|
|
L2_NORMALIZE, |
|
|
|
TRANSPOSE, |
|
|
|
RESHAPE, |
|
|
|
TENSOR_ADD, |
|
|
|
SUB, |
|
|
|
MUL, |
|
|
|
DIV, |
|
|
|
GREATER, |
|
|
|
MAXPOOL, |
|
|
|
MAXPOOLV2, |
|
|
|
VIRTUAL_DATA_SET, |
|
|
|
SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, |
|
|
|
RELU, |
|
|
|
ONEHOT, |
|
|
|
DROPOUT_DO_MASK, |
|
|
|
REDUCE_MAX, |
|
|
|
REDUCE_MIN, |
|
|
|
ARGMAXWITHVALUE, |
|
|
|
ARGMINWITHVALUE, |
|
|
|
REDUCE_SUM, |
|
|
|
CONV2D, |
|
|
|
FUSE_BATCH_NORM, |
|
|
|
POOLING, |
|
|
|
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, |
|
|
|
SIGMOID_CROSS_ENTROPY_WITH_LOGITS, |
|
|
|
MAX_POOL_WITH_ARGMAX, |
|
|
|
SIMPLE_MEAN, |
|
|
|
FLATTEN, |
|
|
|
BATCH_NORM, |
|
|
|
LAYER_NORM, |
|
|
|
BIAS_ADD, |
|
|
|
ASSIGN_SUB, |
|
|
|
COS, |
|
|
|
ACOS, |
|
|
|
EXP, |
|
|
|
LOG, |
|
|
|
REDUCE_MEAN, |
|
|
|
REAL_DIV, |
|
|
|
SIGMOID, |
|
|
|
POW, |
|
|
|
MAXIMUM, |
|
|
|
MINIMUM, |
|
|
|
EQUAL, |
|
|
|
NOT_EQUAL, |
|
|
|
LOGICALNOT, |
|
|
|
GATHERV2, |
|
|
|
STRIDEDSLICE, |
|
|
|
SQRT, |
|
|
|
GET_NEXT, |
|
|
|
CAST, |
|
|
|
NEG, |
|
|
|
SQUARE, |
|
|
|
BATCH_MATMUL, |
|
|
|
EXPAND_DIMS, |
|
|
|
SQUEEZE}; |
|
|
|
|
|
|
|
std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, CAST, |
|
|
|
POW, EXP, LOG, COS, ACOS, LOGICALNOT, NEG, SQUARE}; |
|
|
|
|
|
|
|
bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { |
|
|
|
MS_EXCEPTION_IF_NULL(root); |
|
|
|
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); |
|
|
|
@@ -314,14 +246,27 @@ std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
bool IsElementWiseOperator(const std::string &op_name) { |
|
|
|
auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name); |
|
|
|
return (iter != elementwise_op_.end()); |
|
|
|
static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, |
|
|
|
SQRT, CAST, POW, EXP, LOG, COS, |
|
|
|
ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID}; |
|
|
|
auto iter = elementwise_op.find(op_name); |
|
|
|
return (iter != elementwise_op.end()); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsSplittableOperator(const std::string &op_name) { |
|
|
|
std::vector<std::string>::iterator iter; |
|
|
|
iter = std::find(splittable_op_.begin(), splittable_op_.end(), op_name); |
|
|
|
return (iter != splittable_op_.end()); |
|
|
|
// clang-format off |
|
|
|
static const std::set<std::string> splittable_op = |
|
|
|
{MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, |
|
|
|
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, |
|
|
|
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, |
|
|
|
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, |
|
|
|
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, |
|
|
|
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, |
|
|
|
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; |
|
|
|
// clang-format on |
|
|
|
|
|
|
|
auto iter = splittable_op.find(op_name); |
|
|
|
return (iter != splittable_op.end()); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsAutoParallelCareNode(const CNodePtr &cnode) { |
|
|
|
|