|
|
|
@@ -19,6 +19,7 @@ import mindspore.nn as nn |
|
|
|
from mindspore import Tensor, Model |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore import context |
|
|
|
from mindspore.parallel._utils import _infer_rank_list |
|
|
|
|
|
|
|
|
|
|
|
class Net(nn.Cell): |
|
|
|
@@ -71,3 +72,48 @@ def test_edge_case(): |
|
|
|
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True) |
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
model.predict(inputs) |
|
|
|
|
|
|
|
|
|
|
|
# standalone predict |
|
|
|
def test_infer_rank_list1(): |
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]} |
|
|
|
predict_map = None |
|
|
|
rank_list = _infer_rank_list(train_map, predict_map)["weight"] |
|
|
|
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7] |
|
|
|
assert rank_list[1] is False |
|
|
|
|
|
|
|
|
|
|
|
# similar layout: gpt3 prediction mode |
|
|
|
def test_infer_rank_list2(): |
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]} |
|
|
|
predict_map = {'weight': [[8], [-1, 0]]} |
|
|
|
rank_list = _infer_rank_list(train_map, predict_map) |
|
|
|
expect_map = {'weight': ([0], True)} |
|
|
|
assert rank_list == expect_map |
|
|
|
|
|
|
|
|
|
|
|
# same layout |
|
|
|
def test_infer_rank_list3(): |
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]} |
|
|
|
predict_map = {'weight': [[4, 8], [-1, 0]]} |
|
|
|
rank_list = _infer_rank_list(train_map, predict_map) |
|
|
|
expect_map = {'weight': ([0], True)} |
|
|
|
assert rank_list == expect_map |
|
|
|
|
|
|
|
|
|
|
|
# totally different layout |
|
|
|
def test_infer_rank_list4(): |
|
|
|
train_map = {'weight': [[4, 8], [-1, 0]]} |
|
|
|
predict_map = {'weight': [[2, 2], [1, 0]]} |
|
|
|
rank_list = _infer_rank_list(train_map, predict_map)["weight"] |
|
|
|
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7] |
|
|
|
assert rank_list[1] is False |
|
|
|
|
|
|
|
|
|
|
|
# full shape ckpt |
|
|
|
def test_infer_rank_list5(): |
|
|
|
train_map = {'weight': [[8], [-1, -1]]} |
|
|
|
predict_map = {'weight': [[2, 2], [1, 0]]} |
|
|
|
rank_list = _infer_rank_list(train_map, predict_map) |
|
|
|
expect_map = {'weight': ([0], False)} |
|
|
|
assert rank_list == expect_map |