| @@ -19,6 +19,7 @@ import stat | |||||
| import math | import math | ||||
| import shutil | import shutil | ||||
| from threading import Thread, Lock | from threading import Thread, Lock | ||||
| from collections import defaultdict | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -1062,38 +1063,27 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| return merged_parameter | return merged_parameter | ||||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): | |||||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None): | |||||
| """ | """ | ||||
| Load checkpoint into net for distributed predication. | Load checkpoint into net for distributed predication. | ||||
| Args: | Args: | ||||
| network (Cell): Network for distributed predication. | network (Cell): Network for distributed predication. | ||||
| checkpoint_filenames (list(str)): The name of Checkpoint files | |||||
| in order of rank id. | |||||
| predict_strategy (Optional(dict)): Strategy of predication process, whose key | |||||
| is parameter name, and value is a list or a tuple that the first four | |||||
| elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | |||||
| it means that the predication process just uses single device. | |||||
| Default: None. | |||||
| checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. | |||||
| predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or | |||||
| a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | |||||
| it means that the predication process just uses single device. Default: None. | |||||
| Raises: | Raises: | ||||
| TypeError: The type of inputs do not match the requirements. | TypeError: The type of inputs do not match the requirements. | ||||
| ValueError: Failed to load checkpoint into net. | ValueError: Failed to load checkpoint into net. | ||||
| """ | """ | ||||
| network = Validator.check_isinstance("network", network, nn.Cell) | network = Validator.check_isinstance("network", network, nn.Cell) | ||||
| _check_checkpoint_file(checkpoint_filenames) | |||||
| _check_predict_strategy(predict_strategy) | |||||
| for index, filename in enumerate(checkpoint_filenames): | |||||
| if not isinstance(filename, str) or not os.path.exists(filename) \ | |||||
| or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: | |||||
| raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") | |||||
| if not _check_predict_strategy(predict_strategy): | |||||
| raise ValueError(f"Please make sure that the key of predict_strategy is str, " | |||||
| f"and the value is a list or a tuple that the first four elements are " | |||||
| f"dev_matrix (list[int]), tensor_map (list[int]), " | |||||
| f"param_split_shape (list[int]) and field_size (zero).") | |||||
| train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | |||||
| if train_strategy_filename is None: | |||||
| train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | |||||
| _train_strategy = build_searched_strategy(train_strategy_filename) | _train_strategy = build_searched_strategy(train_strategy_filename) | ||||
| train_strategy = _convert_to_list(_train_strategy) | train_strategy = _convert_to_list(_train_strategy) | ||||
| @@ -1107,6 +1097,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||||
| rank_list = _infer_rank_list(train_strategy, predict_strategy) | rank_list = _infer_rank_list(train_strategy, predict_strategy) | ||||
| param_total_dict = defaultdict(dict) | |||||
| for file_index, file_name in enumerate(checkpoint_filenames): | |||||
| ckpt_dict = load_checkpoint(file_name) | |||||
| for param_name, param in ckpt_dict.items(): | |||||
| param_total_dict[param_name][file_index] = param | |||||
| param_dict = {} | param_dict = {} | ||||
| for _, param in network.parameters_and_names(): | for _, param in network.parameters_and_names(): | ||||
| sliced_params = [] | sliced_params = [] | ||||
| @@ -1114,8 +1110,31 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||||
| continue | continue | ||||
| param_rank = rank_list[param.name][0] | param_rank = rank_list[param.name][0] | ||||
| skip_merge_split = rank_list[param.name][1] | skip_merge_split = rank_list[param.name][1] | ||||
| shard_stride = train_strategy[param.name][4] | |||||
| if train_strategy[param.name][5]: | |||||
| shard_size = len(checkpoint_filenames) / shard_stride / train_strategy[param.name][5] | |||||
| else: | |||||
| shard_size = 0 | |||||
| for rank in param_rank: | for rank in param_rank: | ||||
| sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) | |||||
| param_total_list = list(range(0, len(checkpoint_filenames))) | |||||
| if shard_size > 0: | |||||
| shard_total_list = [param_total_list[i:i + shard_size] for i in | |||||
| range(0, len(checkpoint_filenames), shard_size)] | |||||
| param_total_list = shard_total_list[rank // shard_size] | |||||
| if shard_stride > 0: | |||||
| param_stride = [] | |||||
| # merge pre parameter | |||||
| param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride] | |||||
| param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride]) | |||||
| param_index = list(set(param_index)) | |||||
| param_index.sort() | |||||
| for rank_num in param_index: | |||||
| param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy()) | |||||
| sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name) | |||||
| else: | |||||
| sliced_param = param_total_dict[param.name][rank] | |||||
| sliced_params.append(sliced_param) | sliced_params.append(sliced_param) | ||||
| if skip_merge_split: | if skip_merge_split: | ||||
| split_param = sliced_params[0] | split_param = sliced_params[0] | ||||
| @@ -1139,19 +1158,31 @@ def _check_predict_strategy(predict_strategy): | |||||
| return True | return True | ||||
| if predict_strategy is None: | if predict_strategy is None: | ||||
| return True | |||||
| return | |||||
| flag = True | |||||
| predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) | predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) | ||||
| for key in predict_strategy.keys(): | for key in predict_strategy.keys(): | ||||
| if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ | if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ | ||||
| or len(predict_strategy[key]) < 4: | or len(predict_strategy[key]) < 4: | ||||
| return False | |||||
| flag = False | |||||
| dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] | dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] | ||||
| if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ | if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ | ||||
| not (_check_int_list(param_split_shape) or not param_split_shape) or \ | not (_check_int_list(param_split_shape) or not param_split_shape) or \ | ||||
| not (isinstance(field_size, int) and field_size == 0): | not (isinstance(field_size, int) and field_size == 0): | ||||
| return False | |||||
| return True | |||||
| flag = False | |||||
| if not flag: | |||||
| raise ValueError(f"Please make sure that the key of predict_strategy is str, " | |||||
| f"and the value is a list or a tuple that the first four elements are " | |||||
| f"dev_matrix (list[int]), tensor_map (list[int]), " | |||||
| f"param_split_shape (list[int]) and field_size (zero).") | |||||
| def _check_checkpoint_file(checkpoint_filenames): | |||||
| """Check checkpoint file name.""" | |||||
| for index, filename in enumerate(checkpoint_filenames): | |||||
| if not isinstance(filename, str) or not os.path.exists(filename) \ | |||||
| or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: | |||||
| raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") | |||||
| def _convert_to_list(strategy): | def _convert_to_list(strategy): | ||||
| @@ -1207,59 +1238,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): | |||||
| layerwise_parallel = merged_param.layerwise_parallel | layerwise_parallel = merged_param.layerwise_parallel | ||||
| split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) | split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) | ||||
| return split_param | return split_param | ||||
| def _load_single_param(ckpt_file_name, param_name): | |||||
| """Load a parameter from checkpoint.""" | |||||
| checkpoint_list = Checkpoint() | |||||
| try: | |||||
| with open(ckpt_file_name, "rb") as f: | |||||
| pb_content = f.read() | |||||
| checkpoint_list.ParseFromString(pb_content) | |||||
| except BaseException as e: | |||||
| logger.error("Failed to read the checkpoint file `%s` during load single parameter," | |||||
| " please check the correct of the file.", ckpt_file_name) | |||||
| raise ValueError(e.__str__()) | |||||
| parameter = None | |||||
| try: | |||||
| param_data_list = [] | |||||
| for element_id, element in enumerate(checkpoint_list.value): | |||||
| if element.tag != param_name: | |||||
| continue | |||||
| data = element.tensor.tensor_content | |||||
| data_type = element.tensor.tensor_type | |||||
| np_type = tensor_to_np_type[data_type] | |||||
| ms_type = tensor_to_ms_type[data_type] | |||||
| element_data = np.frombuffer(data, np_type) | |||||
| param_data_list.append(element_data) | |||||
| if (element_id == len(checkpoint_list.value) - 1) or \ | |||||
| (element.tag != checkpoint_list.value[element_id + 1].tag): | |||||
| param_data = np.concatenate((param_data_list), axis=0) | |||||
| param_data_list.clear() | |||||
| dims = element.tensor.dims | |||||
| if dims == [0]: | |||||
| if 'Float' in data_type: | |||||
| param_data = float(param_data[0]) | |||||
| elif 'Int' in data_type: | |||||
| param_data = int(param_data[0]) | |||||
| parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) | |||||
| elif dims == [1]: | |||||
| parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) | |||||
| else: | |||||
| param_dim = [] | |||||
| for dim in dims: | |||||
| param_dim.append(dim) | |||||
| param_value = param_data.reshape(param_dim) | |||||
| parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) | |||||
| break | |||||
| except BaseException as e: | |||||
| logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) | |||||
| raise RuntimeError(e.__str__()) | |||||
| if parameter is None: | |||||
| raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " | |||||
| f"please check parameter name or checkpoint file.") | |||||
| return parameter | |||||