From 3af96e8d66af41784d0d7b0c70dfdb2687fdb084 Mon Sep 17 00:00:00 2001 From: changzherui Date: Fri, 7 May 2021 15:31:07 +0800 Subject: [PATCH] modify load_dic_ckpt for r1.2 --- mindspore/train/serialization.py | 139 +++++++++++++------------------ 1 file changed, 57 insertions(+), 82 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index ade65aef51..fc384f390b 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -19,6 +19,7 @@ import stat import math import shutil from threading import Thread, Lock +from collections import defaultdict import numpy as np import mindspore.nn as nn @@ -1062,38 +1063,27 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): return merged_parameter -def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): +def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None): """ Load checkpoint into net for distributed predication. Args: network (Cell): Network for distributed predication. - checkpoint_filenames (list(str)): The name of Checkpoint files - in order of rank id. - predict_strategy (Optional(dict)): Strategy of predication process, whose key - is parameter name, and value is a list or a tuple that the first four - elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, - it means that the predication process just uses single device. - Default: None. + checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. + predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or + a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, + it means that the predication process just uses single device. Default: None. Raises: TypeError: The type of inputs do not match the requirements. ValueError: Failed to load checkpoint into net. """ network = Validator.check_isinstance("network", network, nn.Cell) + _check_checkpoint_file(checkpoint_filenames) + _check_predict_strategy(predict_strategy) - for index, filename in enumerate(checkpoint_filenames): - if not isinstance(filename, str) or not os.path.exists(filename) \ - or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: - raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") - - if not _check_predict_strategy(predict_strategy): - raise ValueError(f"Please make sure that the key of predict_strategy is str, " - f"and the value is a list or a tuple that the first four elements are " - f"dev_matrix (list[int]), tensor_map (list[int]), " - f"param_split_shape (list[int]) and field_size (zero).") - - train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") + if train_strategy_filename is None: + train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") _train_strategy = build_searched_strategy(train_strategy_filename) train_strategy = _convert_to_list(_train_strategy) @@ -1107,6 +1097,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= rank_list = _infer_rank_list(train_strategy, predict_strategy) + param_total_dict = defaultdict(dict) + for file_index, file_name in enumerate(checkpoint_filenames): + ckpt_dict = load_checkpoint(file_name) + for param_name, param in ckpt_dict.items(): + param_total_dict[param_name][file_index] = param + param_dict = {} for _, param in network.parameters_and_names(): sliced_params = [] @@ -1114,8 +1110,31 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= continue param_rank = rank_list[param.name][0] skip_merge_split = rank_list[param.name][1] + shard_stride = train_strategy[param.name][4] + if train_strategy[param.name][5]: + shard_size = len(checkpoint_filenames) / shard_stride / train_strategy[param.name][5] + else: + shard_size = 0 for rank in param_rank: - sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) + param_total_list = list(range(0, len(checkpoint_filenames))) + if shard_size > 0: + shard_total_list = [param_total_list[i:i + shard_size] for i in + range(0, len(checkpoint_filenames), shard_size)] + param_total_list = shard_total_list[rank // shard_size] + if shard_stride > 0: + param_stride = [] + # merge pre parameter + param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride] + param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride]) + param_index = list(set(param_index)) + param_index.sort() + for rank_num in param_index: + param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy()) + + sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name) + else: + sliced_param = param_total_dict[param.name][rank] + sliced_params.append(sliced_param) if skip_merge_split: split_param = sliced_params[0] @@ -1139,19 +1158,31 @@ def _check_predict_strategy(predict_strategy): return True if predict_strategy is None: - return True - + return + flag = True predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) for key in predict_strategy.keys(): if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ or len(predict_strategy[key]) < 4: - return False + flag = False dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ not (_check_int_list(param_split_shape) or not param_split_shape) or \ not (isinstance(field_size, int) and field_size == 0): - return False - return True + flag = False + if not flag: + raise ValueError(f"Please make sure that the key of predict_strategy is str, " + f"and the value is a list or a tuple that the first four elements are " + f"dev_matrix (list[int]), tensor_map (list[int]), " + f"param_split_shape (list[int]) and field_size (zero).") + + +def _check_checkpoint_file(checkpoint_filenames): + """Check checkpoint file name.""" + for index, filename in enumerate(checkpoint_filenames): + if not isinstance(filename, str) or not os.path.exists(filename) \ + or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: + raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") def _convert_to_list(strategy): @@ -1207,59 +1238,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): layerwise_parallel = merged_param.layerwise_parallel split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) return split_param - - -def _load_single_param(ckpt_file_name, param_name): - """Load a parameter from checkpoint.""" - checkpoint_list = Checkpoint() - - try: - with open(ckpt_file_name, "rb") as f: - pb_content = f.read() - checkpoint_list.ParseFromString(pb_content) - except BaseException as e: - logger.error("Failed to read the checkpoint file `%s` during load single parameter," - " please check the correct of the file.", ckpt_file_name) - raise ValueError(e.__str__()) - - parameter = None - try: - param_data_list = [] - for element_id, element in enumerate(checkpoint_list.value): - if element.tag != param_name: - continue - data = element.tensor.tensor_content - data_type = element.tensor.tensor_type - np_type = tensor_to_np_type[data_type] - ms_type = tensor_to_ms_type[data_type] - element_data = np.frombuffer(data, np_type) - param_data_list.append(element_data) - if (element_id == len(checkpoint_list.value) - 1) or \ - (element.tag != checkpoint_list.value[element_id + 1].tag): - param_data = np.concatenate((param_data_list), axis=0) - param_data_list.clear() - dims = element.tensor.dims - if dims == [0]: - if 'Float' in data_type: - param_data = float(param_data[0]) - elif 'Int' in data_type: - param_data = int(param_data[0]) - parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) - elif dims == [1]: - parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) - else: - param_dim = [] - for dim in dims: - param_dim.append(dim) - param_value = param_data.reshape(param_dim) - parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) - break - - except BaseException as e: - logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) - raise RuntimeError(e.__str__()) - - if parameter is None: - raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " - f"please check parameter name or checkpoint file.") - return parameter