diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 08ef6fad83..cbd285c214 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -264,7 +264,8 @@ bool IsSplittableOperator(const std::string &op_name) { MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, - SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; + SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, + EMBEDDING_LOOKUP}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/tests/ut/python/parallel/test_manual_embedding_lookup.py b/tests/ut/python/parallel/test_manual_embedding_lookup.py index 0c8c038e0e..945296dcec 100644 --- a/tests/ut/python/parallel/test_manual_embedding_lookup.py +++ b/tests/ut/python/parallel/test_manual_embedding_lookup.py @@ -115,6 +115,13 @@ def test_auto_parallel_error(): compile_net(net) +def test_auto_parallel(): + context.set_context(save_graphs=True) + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0) + net = Net(split_string="fake") + compile_net(net) + + def test_axis_error(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) strategy1 = ((2, 1), (1, 2))