From 90cee010721274d2ec8bf7f6ebea798491bd369c Mon Sep 17 00:00:00 2001 From: caozhou Date: Thu, 12 Nov 2020 17:10:39 +0800 Subject: [PATCH] load distributed ckpt for predict --- mindspore/train/serialization.py | 208 ++++++++++++++++++++++++++++++- 1 file changed, 204 insertions(+), 4 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 0d872e9dbb..855b449f0b 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -18,12 +18,12 @@ import stat import math from threading import Thread, Lock import numpy as np - import mindspore.nn as nn +import mindspore.context as context from mindspore import log as logger from mindspore.train.checkpoint_pb2 import Checkpoint from mindspore.train.print_pb2 import Print -from mindspore.train.node_strategy_pb2 import ParallelStrategyMap +from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter @@ -31,7 +31,8 @@ from mindspore.common.api import _executor from mindspore.common import dtype as mstype from mindspore._checkparam import check_input_data, Validator from mindspore.compression.export import quant_export -import mindspore.context as context +from mindspore.parallel._tensor import _load_tensor +from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, @@ -711,7 +712,7 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): param_split_shape = list(layout.param_split_shape[0].dim) field_size = int(layout.field) except BaseException as e: - raise ValueError(f"{e.__str__()}. please make sure that strategy matches the node_strategy.proto.") + raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") device_count = 1 for dim in dev_mat: @@ -897,3 +898,202 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) return merged_parameter + + +def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None): + """ + Load checkpoint into net for distributed predication. + + Args: + network (Cell): Network for distributed predication. + checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. + predict_strategy (dict): Strategy of predication process, whose key is parameter name, and + value is a list that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. + If None, it means that the predication process just uses single device. Default: None. + + Raises: + TypeError: The type of inputs do not match the requirements. + ValueError: Failed to load checkpoint into net. + """ + network = Validator.check_isinstance("network", network, nn.Cell) + + for index, filename in enumerate(checkpoint_filenames): + if not isinstance(filename, str) or not os.path.exists(filename) \ + or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0: + raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.") + + if not _check_predict_strategy(predict_strategy): + raise ValueError(f"Please make sure that the key of predict_strategy is str, " + f"and the value is a list that the first four elements are " + f"dev_matrix (list[int]), tensor_map (list[int]), " + f"param_split_shape (list[int]) and field_size (zero).") + + train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file") + _train_strategy = build_searched_strategy(train_strategy_filename) + train_strategy = _convert_to_list(_train_strategy) + + train_dev_count = 1 + for dim in train_strategy[list(train_strategy.keys())[0]][0]: + train_dev_count *= dim + if train_dev_count != len(checkpoint_filenames): + raise ValueError( + f"The length of checkpoint_filenames should be equal to the device count of training process. " + f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}.") + + rank_list = _infer_rank_list(train_strategy, predict_strategy) + + param_dict = {} + for _, param in network.parameters_and_names(): + sliced_params = [] + if param.name not in rank_list.keys(): + continue + param_rank = rank_list[param.name] + for rank in param_rank: + sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) + sliced_params.append(sliced_param) + if len(sliced_params) == 1: + split_param = sliced_params[0] + else: + param_unique_strategy = _remove_repeated_slices(train_strategy[param.name]) + _param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy) + split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy) + param_dict[param.name] = split_param + + load_param_into_net(network, param_dict) + + +def _check_predict_strategy(predict_strategy): + """Check predict strategy.""" + def _check_int_list(arg): + if not isinstance(arg, list): + return False + for item in arg: + if not isinstance(item, int): + return False + return True + + if predict_strategy is None: + return True + + predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) + for key in predict_strategy.keys(): + if not isinstance(key, str) or not isinstance(predict_strategy[key], list) or len(predict_strategy[key]) < 4: + return False + dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] + if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ + not (_check_int_list(param_split_shape) or not param_split_shape) or \ + not (isinstance(field_size, int) and field_size == 0): + return False + return True + + +def _convert_to_list(strategy): + """Convert ParallelLayouts object to specified list.""" + train_map = {} + for param_name in strategy.keys(): + try: + layout = strategy.get(param_name) + dev_mat = list(layout.dev_matrix[0].dim) + tensor_map = list(layout.tensor_map[0].dim) + param_split_shape = list(layout.param_split_shape[0].dim) + field_size = int(layout.field) + train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size] + except BaseException as e: + raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.") + return train_map + + +def _convert_to_layout(param_name, tensor_layout): + """Convert list to ParallelLayouts object.""" + strategy = {} + try: + layout = ParallelLayouts() + layout.field = tensor_layout[3] + + dev_matrix = layout.dev_matrix.add() + for item in tensor_layout[0]: + dev_matrix.dim.append(item) + + tensor_map = layout.tensor_map.add() + for item in tensor_layout[1]: + tensor_map.dim.append(item) + + param_split_shape = layout.param_split_shape.add() + for item in tensor_layout[2]: + param_split_shape.dim.append(item) + except BaseException as e: + raise ValueError("Convert failed. " + e.__str__()) + + strategy[param_name] = layout + return strategy + + +def _merge_and_split(sliced_params, train_strategy, predict_strategy): + """Merge sliced parameter and split it according to the predict strategy.""" + merged_param = merge_sliced_parameter(sliced_params, train_strategy) + if predict_strategy is None: + return merged_param + param_name = merged_param.name + tensor_layout = predict_strategy[param_name] + split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1]) + requires_grad = merged_param.requires_grad + layerwise_parallel = merged_param.layerwise_parallel + split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) + return split_param + + +def _load_single_param(ckpt_file_name, param_name): + """Load a parameter from checkpoint.""" + logger.info("Execute the process of loading checkpoint files.") + checkpoint_list = Checkpoint() + + try: + with open(ckpt_file_name, "rb") as f: + pb_content = f.read() + checkpoint_list.ParseFromString(pb_content) + except BaseException as e: + logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) + raise ValueError(e.__str__()) + + parameter = None + try: + param_data_list = [] + for element_id, element in enumerate(checkpoint_list.value): + if element.tag != param_name: + continue + data = element.tensor.tensor_content + data_type = element.tensor.tensor_type + np_type = tensor_to_np_type[data_type] + ms_type = tensor_to_ms_type[data_type] + element_data = np.frombuffer(data, np_type) + param_data_list.append(element_data) + if (element_id == len(checkpoint_list.value) - 1) or \ + (element.tag != checkpoint_list.value[element_id + 1].tag): + param_data = np.concatenate((param_data_list), axis=0) + param_data_list.clear() + dims = element.tensor.dims + if dims == [0]: + if 'Float' in data_type: + param_data = float(param_data[0]) + elif 'Int' in data_type: + param_data = int(param_data[0]) + parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) + elif dims == [1]: + parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) + else: + param_dim = [] + for dim in dims: + param_dim.append(dim) + param_value = param_data.reshape(param_dim) + parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) + break + logger.info("Loading checkpoint files process is finished.") + + except BaseException as e: + logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) + raise RuntimeError(e.__str__()) + + if parameter is None: + raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " + f"please check parameter name or checkpoint file.") + return parameter