Browse Source

!9722 add tuple check

From: @caozhou_huawei
Reviewed-by: @zhunaipan,@zhoufeng54
Signed-off-by: @zhoufeng54
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7c50c8052a
1 changed files with 6 additions and 5 deletions
  1. +6
    -5
      mindspore/train/serialization.py

+ 6
- 5
mindspore/train/serialization.py View File

@@ -910,9 +910,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
Args:
network (Cell): Network for distributed predication.
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and
value is a list that the first four elements are [dev_matrix, tensor_map, param_split_shape, field].
If None, it means that the predication process just uses single device. Default: None.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device. Default: None.

Raises:
TypeError: The type of inputs do not match the requirements.
@@ -927,7 +927,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=

if not _check_predict_strategy(predict_strategy):
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list that the first four elements are "
f"and the value is a list or a tuple that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")

@@ -980,7 +980,8 @@ def _check_predict_strategy(predict_strategy):

predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
for key in predict_strategy.keys():
if not isinstance(key, str) or not isinstance(predict_strategy[key], list) or len(predict_strategy[key]) < 4:
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
or len(predict_strategy[key]) < 4:
return False
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \


Loading…
Cancel
Save