You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serialization.py 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Model and parameters serialization."""
  16. import os
  17. import stat
  18. import numpy as np
  19. import mindspore.nn as nn
  20. import mindspore.context as context
  21. from mindspore import log as logger
  22. from mindspore.train.checkpoint_pb2 import Checkpoint
  23. from mindspore.train.print_pb2 import Print
  24. from mindspore.common.tensor import Tensor
  25. from mindspore.common.initializer import initializer
  26. from mindspore.common.parameter import Parameter
  27. from mindspore.common.api import _executor
  28. from mindspore.common import dtype as mstype
  29. from mindspore._checkparam import check_input_data
  30. __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"]
  31. tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
  32. "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
  33. "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
  34. "Bool": mstype.bool_}
  35. tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uint16": np.uint16,
  36. "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
  37. "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
  38. ModelType = ["normal", "fusion", "quant"]
  39. def _special_process_par(par, new_par):
  40. """
  41. Processes the special condition.
  42. Like (12,2048,1,1)->(12,2048), this case is caused by GE 4 dimensions tensor.
  43. """
  44. par_shape_len = len(par.data.shape)
  45. new_par_shape_len = len(new_par.data.shape)
  46. delta_len = new_par_shape_len - par_shape_len
  47. delta_i = 0
  48. for delta_i in range(delta_len):
  49. if new_par.data.shape[par_shape_len + delta_i] != 1:
  50. break
  51. if delta_i == delta_len - 1:
  52. new_val = new_par.data.asnumpy()
  53. new_val = new_val.reshape(par.data.shape)
  54. par.set_parameter_data(Tensor(new_val, par.data.dtype))
  55. return True
  56. return False
  57. def _update_param(param, new_param):
  58. """Updates param's data from new_param's data."""
  59. if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor):
  60. if param.data.dtype != new_param.data.dtype:
  61. logger.error("Failed to combine the net and the parameters for param %s.", param.name)
  62. msg = ("Net parameters {} type({}) different from parameter_dict's({})"
  63. .format(param.name, param.data.dtype, new_param.data.dtype))
  64. raise RuntimeError(msg)
  65. if param.data.shape != new_param.data.shape:
  66. if not _special_process_par(param, new_param):
  67. logger.error("Failed to combine the net and the parameters for param %s.", param.name)
  68. msg = ("Net parameters {} shape({}) different from parameter_dict's({})"
  69. .format(param.name, param.data.shape, new_param.data.shape))
  70. raise RuntimeError(msg)
  71. return
  72. param.set_parameter_data(new_param.data)
  73. return
  74. if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
  75. if param.data.shape != (1,) and param.data.shape != ():
  76. logger.error("Failed to combine the net and the parameters for param %s.", param.name)
  77. msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)."
  78. .format(param.name, param.data.shape))
  79. raise RuntimeError(msg)
  80. param.set_parameter_data(initializer(new_param.data, param.data.shape, param.data.dtype))
  81. elif isinstance(new_param.data, Tensor) and not isinstance(param.data, Tensor):
  82. logger.error("Failed to combine the net and the parameters for param %s.", param.name)
  83. msg = ("Net parameters {} type({}) different from parameter_dict's({})"
  84. .format(param.name, type(param.data), type(new_param.data)))
  85. raise RuntimeError(msg)
  86. else:
  87. param.set_parameter_data(type(param.data)(new_param.data))
  88. def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
  89. """
  90. Saves checkpoint info to a specified file.
  91. Args:
  92. parameter_list (list): Parameters list, each element is a dict
  93. like {"name":xx, "type":xx, "shape":xx, "data":xx}.
  94. ckpt_file_name (str): Checkpoint file name.
  95. model_type (str): The name of model type. Default: "normal".
  96. Raises:
  97. RuntimeError: Failed to save the Checkpoint file.
  98. """
  99. logger.info("Execute save checkpoint process.")
  100. checkpoint_list = Checkpoint()
  101. checkpoint_list.model_type = model_type
  102. try:
  103. for param in parameter_list:
  104. param_value = checkpoint_list.value.add()
  105. param_value.tag = param["name"]
  106. param_tensor = param_value.tensor
  107. if isinstance(param["data"], Parameter):
  108. param["data"].init_data()
  109. param_data = param["data"].asnumpy().reshape(-1)
  110. param_tensor.tensor_content = param_data.tostring()
  111. param_tensor.tensor_type = str(param["data"].dtype)
  112. if param['data'].shape == ():
  113. param_tensor.dims.append(0)
  114. else:
  115. for dim in param['data'].shape:
  116. param_tensor.dims.append(dim)
  117. with open(ckpt_file_name, "wb") as f:
  118. f.write(checkpoint_list.SerializeToString())
  119. os.chmod(ckpt_file_name, stat.S_IRUSR)
  120. except BaseException as e:
  121. logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
  122. raise RuntimeError(e.__str__())
  123. logger.info("Save checkpoint process finish.")
  124. def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
  125. """
  126. Loads checkpoint info from a specified file.
  127. Args:
  128. ckpt_file_name (str): Checkpoint file name.
  129. model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
  130. net (Cell): Cell network. Default: None
  131. Returns:
  132. Dict, key is parameter name, value is a Parameter.
  133. Raises:
  134. ValueError: Checkpoint file is incorrect.
  135. """
  136. if not isinstance(ckpt_file_name, str):
  137. raise ValueError("The ckpt_file_name must be string.")
  138. if model_type not in ModelType:
  139. raise ValueError(f"The model_type is not in {ModelType}.")
  140. if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
  141. raise ValueError("Please input the correct checkpoint file name.")
  142. if os.path.getsize(ckpt_file_name) == 0:
  143. raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
  144. logger.info("Execute load checkpoint process.")
  145. checkpoint_list = Checkpoint()
  146. try:
  147. with open(ckpt_file_name, "rb") as f:
  148. pb_content = f.read()
  149. checkpoint_list.ParseFromString(pb_content)
  150. except BaseException as e:
  151. logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
  152. raise ValueError(e.__str__())
  153. parameter_dict = {}
  154. if checkpoint_list.model_type:
  155. if model_type != checkpoint_list.model_type:
  156. raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
  157. checkpoint_list.model_type, model_type))
  158. try:
  159. for element in checkpoint_list.value:
  160. data = element.tensor.tensor_content
  161. data_type = element.tensor.tensor_type
  162. np_type = tensor_to_np_type[data_type]
  163. ms_type = tensor_to_ms_type[data_type]
  164. param_data = np.fromstring(data, np_type)
  165. dims = element.tensor.dims
  166. if dims == [0]:
  167. if 'Float' in data_type:
  168. param_data = float(param_data[0])
  169. elif 'Int' in data_type:
  170. param_data = int(param_data[0])
  171. parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
  172. elif dims == [1]:
  173. parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
  174. else:
  175. param_dim = []
  176. for dim in dims:
  177. param_dim.append(dim)
  178. param_value = param_data.reshape(param_dim)
  179. parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
  180. logger.info("Load checkpoint process finish.")
  181. except BaseException as e:
  182. logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
  183. raise RuntimeError(e.__str__())
  184. if net:
  185. load_param_into_net(net, parameter_dict)
  186. return parameter_dict
  187. def load_param_into_net(net, parameter_dict):
  188. """
  189. Loads parameters into network.
  190. Args:
  191. net (Cell): Cell network.
  192. parameter_dict (dict): Parameter dict.
  193. Raises:
  194. TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dict.
  195. """
  196. if not isinstance(net, nn.Cell):
  197. logger.error("Failed to combine the net and the parameters.")
  198. msg = ("Argument net should be a Cell, but got {}.".format(type(net)))
  199. raise TypeError(msg)
  200. if not isinstance(parameter_dict, dict):
  201. logger.error("Failed to combine the net and the parameters.")
  202. msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
  203. raise TypeError(msg)
  204. logger.info("Execute load parameter into net process.")
  205. net.init_parameters_data()
  206. param_not_load = []
  207. for _, param in net.parameters_and_names():
  208. if param.name in parameter_dict:
  209. new_param = parameter_dict[param.name]
  210. if not isinstance(new_param, Parameter):
  211. logger.error("Failed to combine the net and the parameters.")
  212. msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
  213. raise TypeError(msg)
  214. param.init_data()
  215. _update_param(param, new_param)
  216. else:
  217. param_not_load.append(param.name)
  218. if param_not_load:
  219. _load_dismatch_prefix_params(net, parameter_dict, param_not_load)
  220. logger.debug("Params not matched(in net but not in parameter_dict):")
  221. for param_name in param_not_load:
  222. logger.debug("%s", param_name)
  223. logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load)))
  224. def _load_dismatch_prefix_params(net, parameter_dict, param_not_load):
  225. """When some net parameter did not load, try to continue load."""
  226. prefix_name = ""
  227. longest_name = param_not_load[0]
  228. while prefix_name != longest_name and param_not_load:
  229. logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load)))
  230. prefix_name = longest_name
  231. for net_param_name in param_not_load:
  232. for dict_name in parameter_dict:
  233. if dict_name.endswith(net_param_name):
  234. prefix_name = dict_name[:-len(net_param_name)]
  235. break
  236. if prefix_name != longest_name:
  237. break
  238. if prefix_name != longest_name:
  239. logger.warning("Remove parameter prefix name: {}, continue to load.".format(prefix_name))
  240. for _, param in net.parameters_and_names():
  241. new_param_name = prefix_name + param.name
  242. if param.name in param_not_load and new_param_name in parameter_dict:
  243. new_param = parameter_dict[new_param_name]
  244. _update_param(param, new_param)
  245. param_not_load.remove(param.name)
  246. def _save_graph(network, file_name):
  247. """
  248. Saves the graph of network to a file.
  249. Args:
  250. network (Cell): Obtain a pipeline through network for saving graph.
  251. file_name (str): Graph file name into which the graph will be saved.
  252. """
  253. logger.info("Execute save the graph process.")
  254. graph_proto = network.get_func_graph_proto()
  255. if graph_proto:
  256. with open(file_name, "wb") as f:
  257. f.write(graph_proto)
  258. os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
  259. def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
  260. """
  261. Saves checkpoint for 'ms' backend.
  262. Args:
  263. train_network (Network): The train network for training.
  264. ckpt_file_name (str): The name of checkpoint file.
  265. model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
  266. integrated_save (bool): Whether to integrated save in automatic model parallel scene.
  267. """
  268. param_dict = {}
  269. for _, param in train_network.parameters_and_names():
  270. param_dict[param.name] = param
  271. param_list = []
  272. for (key, value) in param_dict.items():
  273. each_param = {"name": key}
  274. value.init_data()
  275. if isinstance(value.data, Tensor):
  276. param_data = value.data
  277. else:
  278. param_data = Tensor(value.data)
  279. # in automatic model parallel scenario, some parameters were spliteds to all the devices,
  280. # which should be combined before saving
  281. if integrated_save and key in train_network.parameter_layout_dict:
  282. param_data = _get_merged_param_data(train_network, key, param_data)
  283. each_param["data"] = param_data
  284. param_list.append(each_param)
  285. save_checkpoint(param_list, ckpt_file_name, model_type)
  286. def _get_merged_param_data(net, param_name, param_data):
  287. """
  288. Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
  289. Args:
  290. net (Cell): MindSpore network.
  291. param_name(str): The parameter name, which to be combined.
  292. param_data(Tensor):The parameter data on the local device,
  293. It was a slice of the whole parameter data.
  294. Returns:
  295. Tensor, the combined tensor which with the whole data value.
  296. """
  297. layout = []
  298. layout = net.parameter_layout_dict[param_name]
  299. if len(layout) < 2:
  300. logger.info("layout dict does not contain the key %s", param_name)
  301. return param_data
  302. dev_mat = layout[0]
  303. tensor_map = layout[1]
  304. from mindspore.parallel._cell_wrapper import get_allgather_cell
  305. from mindspore.parallel._tensor import _reshape_param_data
  306. # while any dim is not equal to -1, means param is splited and needs to be merged
  307. for dim in tensor_map:
  308. if dim != -1:
  309. allgather_net = get_allgather_cell()
  310. param_data = allgather_net(param_data)
  311. return _reshape_param_data(param_data, dev_mat, tensor_map)
  312. return param_data
  313. def _fill_param_into_net(net, parameter_list):
  314. """
  315. Fills parameter_list into net.
  316. Args:
  317. net (Cell): train network.
  318. parameter_list (list): parameters list from ge callback.
  319. """
  320. parameter_dict = {}
  321. for each_param in parameter_list:
  322. param_name = each_param["name"]
  323. if isinstance(each_param["data"], Parameter):
  324. each_param["data"].init_data()
  325. np_val = each_param["data"].asnumpy()
  326. if np_val.shape == (1,):
  327. parameter_dict[param_name] = Parameter(np_val, name=param_name)
  328. elif np_val.shape == ():
  329. parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
  330. name=param_name)
  331. else:
  332. parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
  333. load_param_into_net(net, parameter_dict)
  334. def export(net, *inputs, file_name, file_format='GEIR'):
  335. """
  336. Exports MindSpore predict model to file in specified format.
  337. Args:
  338. net (Cell): MindSpore network.
  339. inputs (Tensor): Inputs of the `net`.
  340. file_name (str): File name of model to export.
  341. file_format (str): MindSpore currently supports 'GEIR', 'ONNX' 'LITE' and 'BINARY' format for exported model.
  342. - GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
  343. Ascend model.
  344. - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
  345. - LITE: Huawei model format for mobile. A lite model only for the MindSpore Lite
  346. - BINARY: Binary format for model. An intermidiate representation format for models.
  347. """
  348. logger.info("exporting model file:%s format:%s.", file_name, file_format)
  349. check_input_data(*inputs, data_class=Tensor)
  350. supported_formats = ['GEIR', 'ONNX', 'LITE', 'BINARY']
  351. if file_format not in supported_formats:
  352. raise ValueError(f'Illegal file format {file_format}, it must be one of {supported_formats}')
  353. # switch network mode to infer when it is training
  354. is_training = net.training
  355. if is_training:
  356. net.set_train(mode=False)
  357. # export model
  358. if file_format == 'GEIR':
  359. _executor.compile(net, *inputs, phase='export')
  360. _executor.export(net, file_name, file_format)
  361. elif file_format == 'ONNX': # file_format is 'ONNX'
  362. # NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline,
  363. # do not change it to other values.
  364. phase_name = 'export_onnx'
  365. graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
  366. onnx_stream = _executor._get_func_graph_proto(graph_id)
  367. with open(file_name, 'wb') as f:
  368. os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
  369. f.write(onnx_stream)
  370. elif file_format == 'BINARY': # file_format is 'BINARY'
  371. phase_name = 'export_binary'
  372. graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
  373. onnx_stream = _executor._get_func_graph_proto(graph_id, 'binary_ir')
  374. with open(file_name, 'wb') as f:
  375. os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
  376. f.write(onnx_stream)
  377. elif file_format == 'LITE': # file_format is 'LITE'
  378. context.set_context(save_ms_model=True, save_ms_model_path=file_name)
  379. net(*inputs)
  380. # restore network training mode
  381. if is_training:
  382. net.set_train(mode=True)
  383. def parse_print(print_file_name):
  384. """
  385. Loads Print data from a specified file.
  386. Args:
  387. print_file_name (str): The file name of save print data.
  388. Returns:
  389. List, element of list is Tensor.
  390. Raises:
  391. ValueError: Print file is incorrect.
  392. """
  393. if not os.path.realpath(print_file_name):
  394. raise ValueError("Please input the correct print file name.")
  395. if os.path.getsize(print_file_name) == 0:
  396. raise ValueError("The print file may be empty, please make sure enter the correct file name.")
  397. logger.info("Execute load print process.")
  398. print_list = Print()
  399. try:
  400. with open(print_file_name, "rb") as f:
  401. pb_content = f.read()
  402. print_list.ParseFromString(pb_content)
  403. except BaseException as e:
  404. logger.error("Failed to read the print file %s, please check the correct of the file.", print_file_name)
  405. raise ValueError(e.__str__())
  406. tensor_list = []
  407. try:
  408. for print_ in print_list.value:
  409. # String type
  410. if print_.HasField("desc"):
  411. tensor_list.append(print_.desc)
  412. elif print_.HasField("tensor"):
  413. dims = print_.tensor.dims
  414. data_type = print_.tensor.tensor_type
  415. data = print_.tensor.tensor_content
  416. np_type = tensor_to_np_type[data_type]
  417. param_data = np.fromstring(data, np_type)
  418. ms_type = tensor_to_ms_type[data_type]
  419. param_dim = []
  420. for dim in dims:
  421. param_dim.append(dim)
  422. if param_dim:
  423. param_value = param_data.reshape(param_dim)
  424. tensor_list.append(Tensor(param_value, ms_type))
  425. # Scale type
  426. else:
  427. data_type_ = data_type.lower()
  428. if 'float' in data_type_:
  429. param_data = float(param_data[0])
  430. elif 'int' in data_type_:
  431. param_data = int(param_data[0])
  432. elif 'bool' in data_type_:
  433. param_data = bool(param_data[0])
  434. tensor_list.append(Tensor(param_data, ms_type))
  435. except BaseException as e:
  436. logger.error("Failed to load the print file %s.", print_list)
  437. raise RuntimeError(e.__str__())
  438. return tensor_list