|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Model and parameters serialization."""
- import os
- import stat
- import numpy as np
-
- import mindspore.nn as nn
- import mindspore.context as context
- from mindspore import log as logger
- from mindspore.train.checkpoint_pb2 import Checkpoint
- from mindspore.common.tensor import Tensor
- from mindspore.common.initializer import initializer
- from mindspore.common.parameter import Parameter
- from mindspore.common.api import _executor
- from mindspore.common import dtype as mstype
- from mindspore._checkparam import check_input_data
-
- __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
-
- tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype.int32, "Int64": mstype.int64,
- "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64}
-
- tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64,
- "Float16": np.float16, "Float32": np.float32, "Float64": np.float64}
-
- def _special_process_par(par, new_par):
- """
- Processes the special condition.
-
- Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
- """
- par_shape_len = len(par.data.shape())
- new_par_shape_len = len(new_par.data.shape())
- delta_len = new_par_shape_len - par_shape_len
- delta_i = 0
- for delta_i in range(delta_len):
- if new_par.data.shape()[par_shape_len + delta_i] != 1:
- break
- if delta_i == delta_len - 1:
- new_val = new_par.data.asnumpy()
- new_val = new_val.reshape(par.data.shape())
- par.set_parameter_data(Tensor(new_val, par.data.dtype()))
- return True
- return False
-
-
- def _update_param(param, new_param):
- """Updates param's data from new_param's data."""
-
- if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
- if param.data.dtype() != new_param.data.dtype():
- logger.error("Failed to combine the net and the parameters for param %s.", param.name)
- msg = ("Net parameters {} type({}) different from parameter_dict's({})"
- .format(param.name, param.data.dtype(), new_param.data.dtype()))
- raise RuntimeError(msg)
-
- if param.data.shape() != new_param.data.shape():
- if not _special_process_par(param, new_param):
- logger.error("Failed to combine the net and the parameters for param %s.", param.name)
- msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
- .format(param.name, param.data.shape(), new_param.data.shape()))
- raise RuntimeError(msg)
- return
-
- param.set_parameter_data(new_param.data)
- return
-
- if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
- if param.data.shape() != (1,) and param.data.shape() != ():
- logger.error("Failed to combine the net and the parameters for param %s.", param.name)
- msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)."
- .format(param.name, param.data.shape()))
- raise RuntimeError(msg)
- param.set_parameter_data(initializer(new_param.data, param.data.shape(), param.data.dtype()))
-
- elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
- logger.error("Failed to combine the net and the parameters for param %s.", param.name)
- msg = ("Net parameters {} type({}) different from parameter_dict's({})"
- .format(param.name, type(param.data), type(new_param.data)))
- raise RuntimeError(msg)
-
- else:
- param.set_parameter_data(type(param.data)(new_param.data))
-
-
- def save_checkpoint(parameter_list, ckpoint_file_name):
- """
- Saves checkpoint info to a specified file.
-
- Args:
- parameter_list (list): Parameters list, each element is a dict
- like {"name":xx, "type":xx, "shape":xx, "data":xx}.
- ckpoint_file_name (str): Checkpoint file name.
-
- Raises:
- RuntimeError: Failed to save the Checkpoint file.
- """
- logger.info("Execute save checkpoint process.")
- checkpoint_list = Checkpoint()
-
- try:
- for param in parameter_list:
- param_value = checkpoint_list.value.add()
- param_value.tag = param["name"]
- param_tensor = param_value.tensor
- param_data = param["data"].asnumpy().reshape(-1)
- param_tensor.tensor_content = param_data.tostring()
- param_tensor.tensor_type = str(param["data"].dtype())
-
- if param['data'].shape() == ():
- param_tensor.dims.append(0)
- else:
- for dim in param['data'].shape():
- param_tensor.dims.append(dim)
-
- with open(ckpoint_file_name, "wb") as f:
- f.write(checkpoint_list.SerializeToString())
- os.chmod(ckpoint_file_name, stat.S_IRUSR)
-
- except BaseException as e:
- logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
- raise RuntimeError(e.__str__())
- logger.info("Save checkpoint process finish.")
-
-
- def load_checkpoint(ckpoint_file_name, net=None):
- """
- Loads checkpoint info from a specified file.
-
- Args:
- ckpoint_file_name (str): Checkpoint file name.
- net (Cell): Cell network. Default: None
-
- Returns:
- Dict, key is parameter name, value is a Parameter.
-
- Raises:
- ValueError: Checkpoint file is incorrect.
- """
- if not isinstance(ckpoint_file_name, str):
- raise ValueError("The ckpoint_file_name must be String.")
-
- if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
- raise ValueError("Please input the correct checkpoint file name.")
-
- if os.path.getsize(ckpoint_file_name) == 0:
- raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
-
- logger.info("Execute load checkpoint process.")
- checkpoint_list = Checkpoint()
-
- try:
- with open(ckpoint_file_name, "rb") as f:
- pb_content = f.read()
- checkpoint_list.ParseFromString(pb_content)
- except BaseException as e:
- logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name)
- raise ValueError(e.__str__())
-
- parameter_dict = {}
-
- try:
- for element in checkpoint_list.value:
- data = element.tensor.tensor_content
- data_type = element.tensor.tensor_type
- np_type = tensor_to_np_type[data_type]
- ms_type = tensor_to_ms_type[data_type]
- param_data = np.fromstring(data, np_type)
- dims = element.tensor.dims
-
- if dims == [0]:
- if 'Float' in data_type:
- param_data = float(param_data[0])
- elif 'Int' in data_type:
- param_data = int(param_data[0])
- parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
- elif dims == [1]:
- parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
- else:
- param_dim = []
- for dim in dims:
- param_dim.append(dim)
- param_value = param_data.reshape(param_dim)
- parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
-
- logger.info("Load checkpoint process finish.")
-
- except BaseException as e:
- logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name)
- raise RuntimeError(e.__str__())
-
- if net:
- load_param_into_net(net, parameter_dict)
-
- return parameter_dict
-
-
- def load_param_into_net(net, parameter_dict):
- """
- Loads parameters into network.
-
- Args:
- net (Cell): Cell network.
- parameter_dict (dict): Parameter dict.
-
- Raises:
- TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.
- """
- if not isinstance(net, nn.Cell):
- logger.error("Failed to combine the net and the parameters.")
- msg = ("Argument net should be a Cell, but got {}.".format(type(net)))
- raise TypeError(msg)
-
- if not isinstance(parameter_dict, dict):
- logger.error("Failed to combine the net and the parameters.")
- msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
- raise TypeError(msg)
-
- logger.info("Execute load parameter into net process.")
- for name in parameter_dict:
- for _, param in net.parameters_and_names():
- if name == param.name and param.layerwise_parallel:
- # layerwise parallel parameter data loaded from checkpoint file,
- # was a complete(merged) data, need to be splited
- new_param = parameter_dict[param.name]
- _load_tensor_for_layerwise(new_param, param)
- break
-
- param_not_load = []
- for _, param in net.parameters_and_names():
- if param.name in parameter_dict:
- new_param = parameter_dict[param.name]
- if not isinstance(new_param, Parameter):
- logger.error("Failed to combine the net and the parameters.")
- msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
- raise TypeError(msg)
- _update_param(param, new_param)
- else:
- param_not_load.append(param.name)
-
- if param_not_load:
- _load_dismatch_prefix_params(net, parameter_dict, param_not_load)
-
- logger.debug("Params not matched(in net but not in parameter_dict):")
- for param_name in param_not_load:
- logger.debug("%s", param_name)
-
- logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load)))
-
-
- def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
- """When some net parameter did not load, try to continue load."""
- prefix_name = ""
- longest_name = param_not_load[0]
- while prefix_name != longest_name and param_not_load:
- logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load)))
- longest_name = sorted(param_not_load, key=len, reverse=True)[0]
- prefix_name = longest_name
- for net_param_name in param_not_load:
- for dict_name in parameter_dict:
- if dict_name.endswith(net_param_name):
- tmp_name = dict_name[:-len(net_param_name)]
- prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name
-
- if prefix_name != longest_name:
- logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name))
- for _, param in net.parameters_and_names():
- new_param_name = prefix_name + param.name
- if param.name in param_not_load and new_param_name in parameter_dict:
- new_param = parameter_dict[new_param_name]
- _update_param(param, new_param)
- param_not_load.remove(param.name)
-
-
- def _save_graph(network, file_name):
- """
- Saves the graph of network to a file.
-
- Args:
- network (Cell): Obtain a pipeline through network for saving graph.
- file_name (str): Graph file name into which the graph will be saved.
- """
- logger.info("Execute save the graph process.")
-
- graph_proto = network.get_func_graph_proto()
- if graph_proto:
- with open(file_name, "wb") as f:
- f.write(graph_proto)
- os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
-
-
- def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
- """
- Saves checkpoint for 'ms' backend.
-
- Args:
- train_network (Network): The train network for training.
- ckpoint_file_name (str): The name of checkpoint file.
- integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
- """
-
- param_dict = {}
- for _, param in train_network.parameters_and_names():
- param_dict[param.name] = param
-
- param_list = []
- for (key, value) in param_dict.items():
- each_param = {"name": key}
- if isinstance(value.data, Tensor):
- param_data = value.data
- else:
- param_data = Tensor(value.data)
-
- # in automatic model parallel scenario, some parameters were spliteds to all the devices,
- # which should be combined before saving
- if integrated_save and key in train_network.parameter_layout_dict:
- param_data = _get_merged_param_data(train_network, key, param_data)
-
- each_param["data"] = param_data
- param_list.append(each_param)
-
- save_checkpoint(param_list, ckpoint_file_name)
-
-
- def _get_merged_param_data(net, param_name, param_data):
- """
- Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
-
- Args:
- net (Cell): MindSpore network.
- param_name(str): The parameter name, which to be combined.
- param_data(Tensor):The parameter data on the local device,
- It was a slice of the whole parameter data.
- Returns:
- Tensor, the combined tensor which with the whole data value.
- """
- layout = []
- layout = net.parameter_layout_dict[param_name]
- if len(layout) < 2:
- logger.info("layout dict does not contain the key %s", param_name)
- return param_data
-
- dev_mat = layout[0]
- tensor_map = layout[1]
-
- from mindspore.parallel._cell_wrapper import get_allgather_cell
- from mindspore.parallel._tensor import _reshape_param_data
- # while any dim is not equal to -1, means param is splited and needs to be merged
- for dim in tensor_map:
- if dim != -1:
- allgather_net = get_allgather_cell()
- param_data = allgather_net(param_data)
- return _reshape_param_data(param_data, dev_mat, tensor_map)
-
- return param_data
-
-
- def _load_tensor_for_layerwise(new_param, old_param):
- """
- Replaces parameters with sliced tensors by layerwise parallel strategies.
-
- Args:
- new_param (Parameter): The new layerwise parallel parameter, will be loaded into net.
- old_param(Parameter): The current parameter in the net.
- """
- if not isinstance(new_param.data, Tensor) or not isinstance(old_param.data, Tensor):
- logger.error("Failed to combine the net and the parameters.")
- msg = ("layerwise parallel parameter should be a Tensor, but got {}.".format(type(new_param.data)))
- raise TypeError(msg)
-
- if old_param.data.shape() == new_param.data.shape():
- return
-
- from mindspore.parallel._tensor import _load_tensor
- from mindspore.communication.management import get_group_size
- dev_mat = [get_group_size()]
- shape = new_param.data.shape()
- for x in range(len(shape)): # dim 0 set 0, others set -1
- if x:
- tensor_map.append(-1)
-
- new_tensor = _load_tensor(new_param.data, dev_mat, tensor_map)
- new_param.set_parameter_data(new_tensor)
-
-
- def _fill_param_into_net(net, parameter_list):
- """
- Fills parameter_list into net.
-
- Args:
- net (Cell): train network.
- parameter_list (list): parameters list from ge callback.
- """
- parameter_dict = {}
- for each_param in parameter_list:
- param_name = each_param["name"]
- np_val = each_param["data"].asnumpy()
- if np_val.shape == (1,):
- parameter_dict[param_name] = Parameter(np_val, name=param_name)
- elif np_val.shape == ():
- parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
- name=param_name)
- else:
- parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
-
- load_param_into_net(net, parameter_dict)
-
-
- def export(net, *inputs, file_name, file_format='GEIR'):
- """
- Exports MindSpore predict model to file in specified format.
-
- Args:
- net (Cell): MindSpore network.
- inputs (Tensor): Inputs of the `net`.
- file_name (str): File name of model to export.
- file_format (str): MindSpore currently supports 'GEIR', 'ONNX' and 'LITE' format for exported model.
-
- - GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
- Ascend model.
- - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
- - LITE: Huawei model format for mobile. A lite model only for the MindSpore Lite
- """
- logger.info("exporting model file:%s format:%s.", file_name, file_format)
- check_input_data(*inputs, data_class=Tensor)
-
- supported_formats = ['GEIR', 'ONNX', 'LITE']
- if file_format not in supported_formats:
- raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
- # switch network mode to infer when it is training
- is_training = net.training
- if is_training:
- net.set_train(mode=False)
- # export model
- if file_format == 'GEIR':
- _executor.compile(net, *inputs, phase='export')
- _executor.export(net, file_name, file_format)
- elif file_format == 'ONNX': # file_format is 'ONNX'
- phase_name = 'export_onnx'
- graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
- onnx_stream = _executor._get_func_graph_proto(graph_id)
- with open(file_name, 'wb') as f:
- os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
- f.write(onnx_stream)
- elif file_format == 'LITE': # file_format is 'LITE'
- context.set_context(save_ms_model=True, save_ms_model_path=file_name)
- net(*inputs)
- # restore network training mode
- if is_training:
- net.set_train(mode=True)
|