|
|
|
@@ -910,9 +910,9 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= |
|
|
|
Args: |
|
|
|
network (Cell): Network for distributed predication. |
|
|
|
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. |
|
|
|
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and |
|
|
|
value is a list that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. |
|
|
|
If None, it means that the predication process just uses single device. Default: None. |
|
|
|
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or |
|
|
|
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, |
|
|
|
it means that the predication process just uses single device. Default: None. |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: The type of inputs do not match the requirements. |
|
|
|
@@ -927,7 +927,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= |
|
|
|
|
|
|
|
if not _check_predict_strategy(predict_strategy): |
|
|
|
raise ValueError(f"Please make sure that the key of predict_strategy is str, " |
|
|
|
f"and the value is a list that the first four elements are " |
|
|
|
f"and the value is a list or a tuple that the first four elements are " |
|
|
|
f"dev_matrix (list[int]), tensor_map (list[int]), " |
|
|
|
f"param_split_shape (list[int]) and field_size (zero).") |
|
|
|
|
|
|
|
@@ -980,7 +980,8 @@ def _check_predict_strategy(predict_strategy): |
|
|
|
|
|
|
|
predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) |
|
|
|
for key in predict_strategy.keys(): |
|
|
|
if not isinstance(key, str) or not isinstance(predict_strategy[key], list) or len(predict_strategy[key]) < 4: |
|
|
|
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ |
|
|
|
or len(predict_strategy[key]) < 4: |
|
|
|
return False |
|
|
|
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] |
|
|
|
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ |
|
|
|
|