Browse Source

embedding lookup auto parallel

tags/v0.7.0-beta
yangzhenzhang 5 years ago
parent
commit
6f6a8ae9f0
2 changed files with 9 additions and 1 deletions
  1. +2
    -1
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  2. +7
    -0
      tests/ut/python/parallel/test_manual_embedding_lookup.py

+ 2
- 1
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -264,7 +264,8 @@ bool IsSplittableOperator(const std::string &op_name) {
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
EMBEDDING_LOOKUP};
// clang-format on // clang-format on


auto iter = splittable_op.find(op_name); auto iter = splittable_op.find(op_name);


+ 7
- 0
tests/ut/python/parallel/test_manual_embedding_lookup.py View File

@@ -115,6 +115,13 @@ def test_auto_parallel_error():
compile_net(net) compile_net(net)




def test_auto_parallel():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2, global_rank=0)
net = Net(split_string="fake")
compile_net(net)


def test_axis_error(): def test_axis_error():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2)) strategy1 = ((2, 1), (1, 2))


Loading…
Cancel
Save