diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 0a1b9f2328..e602fa6a3a 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -259,7 +259,7 @@ def _infer_rank_list(train_map, predict_map=None): logger.warning("predict_map does not contain %s", param_name) continue predict_layout = predict_map[param_name] - dev_num = np.array(predict_layout[0].prod()) + dev_num = np.array(predict_layout[0]).prod() # optimization pass if _check_same_layout(train_layout, predict_layout): dev_rank = _get_global_rank() diff --git a/tests/ut/python/parallel/test_distribute_predict.py b/tests/ut/python/parallel/test_distribute_predict.py index acc7596f83..609b3e01cf 100644 --- a/tests/ut/python/parallel/test_distribute_predict.py +++ b/tests/ut/python/parallel/test_distribute_predict.py @@ -19,6 +19,7 @@ import mindspore.nn as nn from mindspore import Tensor, Model from mindspore.ops import operations as P from mindspore import context +from mindspore.parallel._utils import _infer_rank_list class Net(nn.Cell): @@ -71,3 +72,48 @@ def test_edge_case(): context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True) with pytest.raises(RuntimeError): model.predict(inputs) + + +# standalone predict +def test_infer_rank_list1(): + train_map = {'weight': [[4, 8], [-1, 0]]} + predict_map = None + rank_list = _infer_rank_list(train_map, predict_map)["weight"] + assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7] + assert rank_list[1] is False + + +# similar layout: gpt3 prediction mode +def test_infer_rank_list2(): + train_map = {'weight': [[4, 8], [-1, 0]]} + predict_map = {'weight': [[8], [-1, 0]]} + rank_list = _infer_rank_list(train_map, predict_map) + expect_map = {'weight': ([0], True)} + assert rank_list == expect_map + + +# same layout +def test_infer_rank_list3(): + train_map = {'weight': [[4, 8], [-1, 0]]} + predict_map = {'weight': [[4, 8], [-1, 0]]} + rank_list = _infer_rank_list(train_map, predict_map) + expect_map = {'weight': ([0], True)} + assert rank_list == expect_map + + +# totally different layout +def test_infer_rank_list4(): + train_map = {'weight': [[4, 8], [-1, 0]]} + predict_map = {'weight': [[2, 2], [1, 0]]} + rank_list = _infer_rank_list(train_map, predict_map)["weight"] + assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7] + assert rank_list[1] is False + + +# full shape ckpt +def test_infer_rank_list5(): + train_map = {'weight': [[8], [-1, -1]]} + predict_map = {'weight': [[2, 2], [1, 0]]} + rank_list = _infer_rank_list(train_map, predict_map) + expect_map = {'weight': ([0], False)} + assert rank_list == expect_map