Browse Source

fix distribtued predict

tags/v1.1.0
Ziyan 5 years ago
parent
commit
e7a24611f4
3 changed files with 27 additions and 10 deletions
  1. +20
    -6
      mindspore/parallel/_utils.py
  2. +4
    -2
      mindspore/train/model.py
  3. +3
    -2
      mindspore/train/serialization.py

+ 20
- 6
mindspore/parallel/_utils.py View File

@@ -218,6 +218,11 @@ def _check_similar_layout(tensor_layout1, tensor_layout2):
return True return True




def _check_same_layout(tensor_layout1, tensor_layout2):
"""check if two tensor layouts are same"""
return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1]


def _remove_repeated_slices(tensor_layout): def _remove_repeated_slices(tensor_layout):
"""generate unrepeated tensor layout""" """generate unrepeated tensor layout"""
import copy import copy
@@ -236,9 +241,14 @@ def _infer_rank_list(train_map, predict_map=None):
ret = {} ret = {}
for param_name in train_map: for param_name in train_map:
train_layout = train_map[param_name] train_layout = train_map[param_name]
new_train_layout = _remove_repeated_slices(train_layout)
predict_layout = predict_map[param_name]
train_dev_mat = train_layout[0] train_dev_mat = train_layout[0]
dev_num = np.array(train_dev_mat).prod() dev_num = np.array(train_dev_mat).prod()
if _check_same_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = ([dev_rank], True)
continue
new_train_layout = _remove_repeated_slices(train_layout)
array = np.arange(dev_num).reshape(train_dev_mat) array = np.arange(dev_num).reshape(train_dev_mat)
index = () index = ()
for i in new_train_layout[0]: for i in new_train_layout[0]:
@@ -248,16 +258,20 @@ def _infer_rank_list(train_map, predict_map=None):
index = index + (slice(None),) index = index + (slice(None),)
rank_list = array[index].flatten() rank_list = array[index].flatten()
if not predict_map: if not predict_map:
ret[param_name] = rank_list
ret[param_name] = (rank_list, False)
continue continue
if param_name not in predict_map: if param_name not in predict_map:
logger.warning("predict_map does not contain %s", param_name) logger.warning("predict_map does not contain %s", param_name)
continue continue
predict_layout = predict_map[param_name]
# optimization pass # optimization pass
if _check_similar_layout(train_layout, predict_layout): if _check_similar_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = [rank_list[dev_rank]]
if len(rank_list) == 1:
ret[param_name] = (rank_list, True)
elif len(rank_list) == dev_num:
dev_rank = _get_global_rank()
ret[param_name] = ([rank_list[dev_rank]], True)
else:
ret[param_name] = (rank_list, False)
else: else:
ret[param_name] = rank_list
ret[param_name] = (rank_list, False)
return ret return ret

+ 4
- 2
mindspore/train/model.py View File

@@ -746,18 +746,20 @@ class Model:
""" """
Generate parameter layout for the predict network in auto or semi auto parallel mode. Generate parameter layout for the predict network in auto or semi auto parallel mode.


Data could be a single tensor, a list of tensor, or a tuple of tensor.
Data could be a single tensor or multiple tensors.


Note: Note:
Batch data should be put together in one tensor. Batch data should be put together in one tensor.


Args: Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
predict_data (Tensor): One tensor or multiple tensors of predict data.


Returns: Returns:
parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint


Examples: Examples:
>>> context.set_context(mode=context.GRAPH_MODE)
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net()) >>> model = Model(Net())
>>> model.infer_predict_layout(input_data) >>> model.infer_predict_layout(input_data)


+ 3
- 2
mindspore/train/serialization.py View File

@@ -950,11 +950,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
sliced_params = [] sliced_params = []
if param.name not in rank_list.keys(): if param.name not in rank_list.keys():
continue continue
param_rank = rank_list[param.name]
param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
for rank in param_rank: for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
sliced_params.append(sliced_param) sliced_params.append(sliced_param)
if len(sliced_params) == 1:
if skip_merge_split:
split_param = sliced_params[0] split_param = sliced_params[0]
else: else:
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name]) param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])


Loading…
Cancel
Save