| @@ -18,12 +18,12 @@ import stat | |||||
| import math | import math | ||||
| from threading import Thread, Lock | from threading import Thread, Lock | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.context as context | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.train.checkpoint_pb2 import Checkpoint | from mindspore.train.checkpoint_pb2 import Checkpoint | ||||
| from mindspore.train.print_pb2 import Print | from mindspore.train.print_pb2 import Print | ||||
| from mindspore.train.node_strategy_pb2 import ParallelStrategyMap | |||||
| from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| @@ -31,7 +31,8 @@ from mindspore.common.api import _executor | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore._checkparam import check_input_data, Validator | from mindspore._checkparam import check_input_data, Validator | ||||
| from mindspore.compression.export import quant_export | from mindspore.compression.export import quant_export | ||||
| import mindspore.context as context | |||||
| from mindspore.parallel._tensor import _load_tensor | |||||
| from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices | |||||
| tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | ||||
| @@ -711,7 +712,7 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): | |||||
| param_split_shape = list(layout.param_split_shape[0].dim) | param_split_shape = list(layout.param_split_shape[0].dim) | ||||
| field_size = int(layout.field) | field_size = int(layout.field) | ||||
| except BaseException as e: | except BaseException as e: | ||||
| raise ValueError(f"{e.__str__()}. please make sure that strategy matches the node_strategy.proto.") | |||||
| raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") | |||||
| device_count = 1 | device_count = 1 | ||||
| for dim in dev_mat: | for dim in dev_mat: | ||||
| @@ -897,3 +898,202 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) | merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) | ||||
| return merged_parameter | return merged_parameter | ||||
| def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): | |||||
| """ | |||||
| Load checkpoint into net for distributed predication. | |||||
| Args: | |||||
| network (Cell): Network for distributed predication. | |||||
| checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. | |||||
| predict_strategy (dict): Strategy of predication process, whose key is parameter name, and | |||||
| value is a list that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. | |||||
| If None, it means that the predication process just uses single device. Default: None. | |||||
| Raises: | |||||
| TypeError: The type of inputs do not match the requirements. | |||||
| ValueError: Failed to load checkpoint into net. | |||||
| """ | |||||
| network = Validator.check_isinstance("network", network, nn.Cell) | |||||
| for index, filename in enumerate(checkpoint_filenames): | |||||
| if not isinstance(filename, str) or not os.path.exists(filename) \ | |||||
| or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: | |||||
| raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") | |||||
| if not _check_predict_strategy(predict_strategy): | |||||
| raise ValueError(f"Please make sure that the key of predict_strategy is str, " | |||||
| f"and the value is a list that the first four elements are " | |||||
| f"dev_matrix (list[int]), tensor_map (list[int]), " | |||||
| f"param_split_shape (list[int]) and field_size (zero).") | |||||
| train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") | |||||
| _train_strategy = build_searched_strategy(train_strategy_filename) | |||||
| train_strategy = _convert_to_list(_train_strategy) | |||||
| train_dev_count = 1 | |||||
| for dim in train_strategy[list(train_strategy.keys())[0]][0]: | |||||
| train_dev_count *= dim | |||||
| if train_dev_count != len(checkpoint_filenames): | |||||
| raise ValueError( | |||||
| f"The length of checkpoint_filenames should be equal to the device count of training process. " | |||||
| f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}.") | |||||
| rank_list = _infer_rank_list(train_strategy, predict_strategy) | |||||
| param_dict = {} | |||||
| for _, param in network.parameters_and_names(): | |||||
| sliced_params = [] | |||||
| if param.name not in rank_list.keys(): | |||||
| continue | |||||
| param_rank = rank_list[param.name] | |||||
| for rank in param_rank: | |||||
| sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) | |||||
| sliced_params.append(sliced_param) | |||||
| if len(sliced_params) == 1: | |||||
| split_param = sliced_params[0] | |||||
| else: | |||||
| param_unique_strategy = _remove_repeated_slices(train_strategy[param.name]) | |||||
| _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy) | |||||
| split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy) | |||||
| param_dict[param.name] = split_param | |||||
| load_param_into_net(network, param_dict) | |||||
| def _check_predict_strategy(predict_strategy): | |||||
| """Check predict strategy.""" | |||||
| def _check_int_list(arg): | |||||
| if not isinstance(arg, list): | |||||
| return False | |||||
| for item in arg: | |||||
| if not isinstance(item, int): | |||||
| return False | |||||
| return True | |||||
| if predict_strategy is None: | |||||
| return True | |||||
| predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) | |||||
| for key in predict_strategy.keys(): | |||||
| if not isinstance(key, str) or not isinstance(predict_strategy[key], list) or len(predict_strategy[key]) < 4: | |||||
| return False | |||||
| dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] | |||||
| if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ | |||||
| not (_check_int_list(param_split_shape) or not param_split_shape) or \ | |||||
| not (isinstance(field_size, int) and field_size == 0): | |||||
| return False | |||||
| return True | |||||
| def _convert_to_list(strategy): | |||||
| """Convert ParallelLayouts object to specified list.""" | |||||
| train_map = {} | |||||
| for param_name in strategy.keys(): | |||||
| try: | |||||
| layout = strategy.get(param_name) | |||||
| dev_mat = list(layout.dev_matrix[0].dim) | |||||
| tensor_map = list(layout.tensor_map[0].dim) | |||||
| param_split_shape = list(layout.param_split_shape[0].dim) | |||||
| field_size = int(layout.field) | |||||
| train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size] | |||||
| except BaseException as e: | |||||
| raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") | |||||
| return train_map | |||||
| def _convert_to_layout(param_name, tensor_layout): | |||||
| """Convert list to ParallelLayouts object.""" | |||||
| strategy = {} | |||||
| try: | |||||
| layout = ParallelLayouts() | |||||
| layout.field = tensor_layout[3] | |||||
| dev_matrix = layout.dev_matrix.add() | |||||
| for item in tensor_layout[0]: | |||||
| dev_matrix.dim.append(item) | |||||
| tensor_map = layout.tensor_map.add() | |||||
| for item in tensor_layout[1]: | |||||
| tensor_map.dim.append(item) | |||||
| param_split_shape = layout.param_split_shape.add() | |||||
| for item in tensor_layout[2]: | |||||
| param_split_shape.dim.append(item) | |||||
| except BaseException as e: | |||||
| raise ValueError("Convert failed. " + e.__str__()) | |||||
| strategy[param_name] = layout | |||||
| return strategy | |||||
| def _merge_and_split(sliced_params, train_strategy, predict_strategy): | |||||
| """Merge sliced parameter and split it according to the predict strategy.""" | |||||
| merged_param = merge_sliced_parameter(sliced_params, train_strategy) | |||||
| if predict_strategy is None: | |||||
| return merged_param | |||||
| param_name = merged_param.name | |||||
| tensor_layout = predict_strategy[param_name] | |||||
| split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1]) | |||||
| requires_grad = merged_param.requires_grad | |||||
| layerwise_parallel = merged_param.layerwise_parallel | |||||
| split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) | |||||
| return split_param | |||||
| def _load_single_param(ckpt_file_name, param_name): | |||||
| """Load a parameter from checkpoint.""" | |||||
| logger.info("Execute the process of loading checkpoint files.") | |||||
| checkpoint_list = Checkpoint() | |||||
| try: | |||||
| with open(ckpt_file_name, "rb") as f: | |||||
| pb_content = f.read() | |||||
| checkpoint_list.ParseFromString(pb_content) | |||||
| except BaseException as e: | |||||
| logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) | |||||
| raise ValueError(e.__str__()) | |||||
| parameter = None | |||||
| try: | |||||
| param_data_list = [] | |||||
| for element_id, element in enumerate(checkpoint_list.value): | |||||
| if element.tag != param_name: | |||||
| continue | |||||
| data = element.tensor.tensor_content | |||||
| data_type = element.tensor.tensor_type | |||||
| np_type = tensor_to_np_type[data_type] | |||||
| ms_type = tensor_to_ms_type[data_type] | |||||
| element_data = np.frombuffer(data, np_type) | |||||
| param_data_list.append(element_data) | |||||
| if (element_id == len(checkpoint_list.value) - 1) or \ | |||||
| (element.tag != checkpoint_list.value[element_id + 1].tag): | |||||
| param_data = np.concatenate((param_data_list), axis=0) | |||||
| param_data_list.clear() | |||||
| dims = element.tensor.dims | |||||
| if dims == [0]: | |||||
| if 'Float' in data_type: | |||||
| param_data = float(param_data[0]) | |||||
| elif 'Int' in data_type: | |||||
| param_data = int(param_data[0]) | |||||
| parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) | |||||
| elif dims == [1]: | |||||
| parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) | |||||
| else: | |||||
| param_dim = [] | |||||
| for dim in dims: | |||||
| param_dim.append(dim) | |||||
| param_value = param_data.reshape(param_dim) | |||||
| parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) | |||||
| break | |||||
| logger.info("Loading checkpoint files process is finished.") | |||||
| except BaseException as e: | |||||
| logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) | |||||
| raise RuntimeError(e.__str__()) | |||||
| if parameter is None: | |||||
| raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " | |||||
| f"please check parameter name or checkpoint file.") | |||||
| return parameter | |||||