Browse Source

fix infer rank list typo

tags/v1.2.0-rc1
Ziyan 4 years ago
parent
commit
2c3b99ce91
2 changed files with 47 additions and 1 deletions
  1. +1
    -1
      mindspore/parallel/_utils.py
  2. +46
    -0
      tests/ut/python/parallel/test_distribute_predict.py

+ 1
- 1
mindspore/parallel/_utils.py View File

@@ -259,7 +259,7 @@ def _infer_rank_list(train_map, predict_map=None):
logger.warning("predict_map does not contain %s", param_name) logger.warning("predict_map does not contain %s", param_name)
continue continue
predict_layout = predict_map[param_name] predict_layout = predict_map[param_name]
dev_num = np.array(predict_layout[0].prod())
dev_num = np.array(predict_layout[0]).prod()
# optimization pass # optimization pass
if _check_same_layout(train_layout, predict_layout): if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank() dev_rank = _get_global_rank()


+ 46
- 0
tests/ut/python/parallel/test_distribute_predict.py View File

@@ -19,6 +19,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Model from mindspore import Tensor, Model
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import context from mindspore import context
from mindspore.parallel._utils import _infer_rank_list




class Net(nn.Cell): class Net(nn.Cell):
@@ -71,3 +72,48 @@ def test_edge_case():
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True) context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
model.predict(inputs) model.predict(inputs)


# standalone predict
def test_infer_rank_list1():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = None
rank_list = _infer_rank_list(train_map, predict_map)["weight"]
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
assert rank_list[1] is False


# similar layout: gpt3 prediction mode
def test_infer_rank_list2():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = {'weight': [[8], [-1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)
expect_map = {'weight': ([0], True)}
assert rank_list == expect_map


# same layout
def test_infer_rank_list3():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = {'weight': [[4, 8], [-1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)
expect_map = {'weight': ([0], True)}
assert rank_list == expect_map


# totally different layout
def test_infer_rank_list4():
train_map = {'weight': [[4, 8], [-1, 0]]}
predict_map = {'weight': [[2, 2], [1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)["weight"]
assert list(rank_list[0]) == [0, 1, 2, 3, 4, 5, 6, 7]
assert rank_list[1] is False


# full shape ckpt
def test_infer_rank_list5():
train_map = {'weight': [[8], [-1, -1]]}
predict_map = {'weight': [[2, 2], [1, 0]]}
rank_list = _infer_rank_list(train_map, predict_map)
expect_map = {'weight': ([0], False)}
assert rank_list == expect_map

Loading…
Cancel
Save