| @@ -19,6 +19,7 @@ import stat | |||
| import math | |||
| import shutil | |||
| from threading import Thread, Lock | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| @@ -1062,38 +1063,27 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||
| return merged_parameter | |||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): | |||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None): | |||
| """ | |||
| Load checkpoint into net for distributed predication. | |||
| Args: | |||
| network (Cell): Network for distributed predication. | |||
| checkpoint_filenames (list(str)): The name of Checkpoint files | |||
| in order of rank id. | |||
| predict_strategy (Optional(dict)): Strategy of predication process, whose key | |||
| is parameter name, and value is a list or a tuple that the first four | |||
| elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | |||
| it means that the predication process just uses single device. | |||
| Default: None. | |||
| checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. | |||
| predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or | |||
| a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | |||
| it means that the predication process just uses single device. Default: None. | |||
| Raises: | |||
| TypeError: The type of inputs do not match the requirements. | |||
| ValueError: Failed to load checkpoint into net. | |||
| """ | |||
| network = Validator.check_isinstance("network", network, nn.Cell) | |||
| _check_checkpoint_file(checkpoint_filenames) | |||
| _check_predict_strategy(predict_strategy) | |||
| for index, filename in enumerate(checkpoint_filenames): | |||
| if not isinstance(filename, str) or not os.path.exists(filename) \ | |||
| or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: | |||
| raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") | |||
| if not _check_predict_strategy(predict_strategy): | |||
| raise ValueError(f"Please make sure that the key of predict_strategy is str, " | |||
| f"and the value is a list or a tuple that the first four elements are " | |||
| f"dev_matrix (list[int]), tensor_map (list[int]), " | |||
| f"param_split_shape (list[int]) and field_size (zero).") | |||
| train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | |||
| if train_strategy_filename is None: | |||
| train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | |||
| _train_strategy = build_searched_strategy(train_strategy_filename) | |||
| train_strategy = _convert_to_list(_train_strategy) | |||
| @@ -1107,6 +1097,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||
| rank_list = _infer_rank_list(train_strategy, predict_strategy) | |||
| param_total_dict = defaultdict(dict) | |||
| for file_index, file_name in enumerate(checkpoint_filenames): | |||
| ckpt_dict = load_checkpoint(file_name) | |||
| for param_name, param in ckpt_dict.items(): | |||
| param_total_dict[param_name][file_index] = param | |||
| param_dict = {} | |||
| for _, param in network.parameters_and_names(): | |||
| sliced_params = [] | |||
| @@ -1114,8 +1110,31 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||
| continue | |||
| param_rank = rank_list[param.name][0] | |||
| skip_merge_split = rank_list[param.name][1] | |||
| shard_stride = train_strategy[param.name][4] | |||
| if train_strategy[param.name][5]: | |||
| shard_size = len(checkpoint_filenames) / shard_stride / train_strategy[param.name][5] | |||
| else: | |||
| shard_size = 0 | |||
| for rank in param_rank: | |||
| sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) | |||
| param_total_list = list(range(0, len(checkpoint_filenames))) | |||
| if shard_size > 0: | |||
| shard_total_list = [param_total_list[i:i + shard_size] for i in | |||
| range(0, len(checkpoint_filenames), shard_size)] | |||
| param_total_list = shard_total_list[rank // shard_size] | |||
| if shard_stride > 0: | |||
| param_stride = [] | |||
| # merge pre parameter | |||
| param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride] | |||
| param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride]) | |||
| param_index = list(set(param_index)) | |||
| param_index.sort() | |||
| for rank_num in param_index: | |||
| param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy()) | |||
| sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name) | |||
| else: | |||
| sliced_param = param_total_dict[param.name][rank] | |||
| sliced_params.append(sliced_param) | |||
| if skip_merge_split: | |||
| split_param = sliced_params[0] | |||
| @@ -1139,19 +1158,31 @@ def _check_predict_strategy(predict_strategy): | |||
| return True | |||
| if predict_strategy is None: | |||
| return True | |||
| return | |||
| flag = True | |||
| predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) | |||
| for key in predict_strategy.keys(): | |||
| if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ | |||
| or len(predict_strategy[key]) < 4: | |||
| return False | |||
| flag = False | |||
| dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] | |||
| if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ | |||
| not (_check_int_list(param_split_shape) or not param_split_shape) or \ | |||
| not (isinstance(field_size, int) and field_size == 0): | |||
| return False | |||
| return True | |||
| flag = False | |||
| if not flag: | |||
| raise ValueError(f"Please make sure that the key of predict_strategy is str, " | |||
| f"and the value is a list or a tuple that the first four elements are " | |||
| f"dev_matrix (list[int]), tensor_map (list[int]), " | |||
| f"param_split_shape (list[int]) and field_size (zero).") | |||
| def _check_checkpoint_file(checkpoint_filenames): | |||
| """Check checkpoint file name.""" | |||
| for index, filename in enumerate(checkpoint_filenames): | |||
| if not isinstance(filename, str) or not os.path.exists(filename) \ | |||
| or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: | |||
| raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") | |||
| def _convert_to_list(strategy): | |||
| @@ -1207,59 +1238,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): | |||
| layerwise_parallel = merged_param.layerwise_parallel | |||
| split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) | |||
| return split_param | |||
| def _load_single_param(ckpt_file_name, param_name): | |||
| """Load a parameter from checkpoint.""" | |||
| checkpoint_list = Checkpoint() | |||
| try: | |||
| with open(ckpt_file_name, "rb") as f: | |||
| pb_content = f.read() | |||
| checkpoint_list.ParseFromString(pb_content) | |||
| except BaseException as e: | |||
| logger.error("Failed to read the checkpoint file `%s` during load single parameter," | |||
| " please check the correct of the file.", ckpt_file_name) | |||
| raise ValueError(e.__str__()) | |||
| parameter = None | |||
| try: | |||
| param_data_list = [] | |||
| for element_id, element in enumerate(checkpoint_list.value): | |||
| if element.tag != param_name: | |||
| continue | |||
| data = element.tensor.tensor_content | |||
| data_type = element.tensor.tensor_type | |||
| np_type = tensor_to_np_type[data_type] | |||
| ms_type = tensor_to_ms_type[data_type] | |||
| element_data = np.frombuffer(data, np_type) | |||
| param_data_list.append(element_data) | |||
| if (element_id == len(checkpoint_list.value) - 1) or \ | |||
| (element.tag != checkpoint_list.value[element_id + 1].tag): | |||
| param_data = np.concatenate((param_data_list), axis=0) | |||
| param_data_list.clear() | |||
| dims = element.tensor.dims | |||
| if dims == [0]: | |||
| if 'Float' in data_type: | |||
| param_data = float(param_data[0]) | |||
| elif 'Int' in data_type: | |||
| param_data = int(param_data[0]) | |||
| parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) | |||
| elif dims == [1]: | |||
| parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) | |||
| else: | |||
| param_dim = [] | |||
| for dim in dims: | |||
| param_dim.append(dim) | |||
| param_value = param_data.reshape(param_dim) | |||
| parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) | |||
| break | |||
| except BaseException as e: | |||
| logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) | |||
| raise RuntimeError(e.__str__()) | |||
| if parameter is None: | |||
| raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " | |||
| f"please check parameter name or checkpoint file.") | |||
| return parameter | |||