You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_checkpoint.py 16 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Checkpoint related classes and functions."""
  16. import os
  17. import shutil
  18. import stat
  19. import time
  20. import mindspore.context as context
  21. from mindspore import log as logger
  22. from mindspore._checkparam import check_bool, check_string, check_int_non_negative
  23. from mindspore.train._utils import _make_directory
  24. from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
  25. from ._callback import Callback, set_cur_net
  26. _cur_dir = os.getcwd()
  27. _save_dir = _cur_dir
  28. def _check_file_name_prefix(file_name_prefix):
  29. """
  30. Check file name valid or not.
  31. File name can't include '/'. This file name naming convention only apply to Linux.
  32. """
  33. if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
  34. return False
  35. return True
  36. def _chg_ckpt_file_name_if_same_exist(directory, prefix):
  37. """Check if there is a file with the same name."""
  38. files = os.listdir(directory)
  39. suffix_num = 0
  40. pre_len = len(prefix)
  41. for filename in files:
  42. name_ext = os.path.splitext(filename)
  43. if name_ext[-1] != ".ckpt":
  44. continue
  45. # find same prefix file
  46. if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
  47. # add the max suffix + 1
  48. index = filename[pre_len:].find("-")
  49. if index == 0:
  50. suffix_num = max(suffix_num, 1)
  51. elif index != -1:
  52. num = filename[pre_len+1:pre_len+index]
  53. if num.isdigit():
  54. suffix_num = max(suffix_num, int(num)+1)
  55. if suffix_num != 0:
  56. prefix = prefix + "_" + str(suffix_num)
  57. return prefix
  58. class CheckpointConfig:
  59. """
  60. The config for model checkpoint.
  61. Note:
  62. During the training process, if dataset is transmitted through the data channel,
  63. suggest set save_checkpoint_steps be an integer multiple of loop_size.
  64. Otherwise there may be deviation in the timing of saving checkpoint.
  65. Args:
  66. save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
  67. save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.
  68. Can't be used with save_checkpoint_steps at the same time.
  69. keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
  70. keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
  71. Can't be used with keep_checkpoint_max at the same time.
  72. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
  73. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
  74. model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
  75. Raises:
  76. ValueError: If the input_param is None or 0.
  77. Examples:
  78. >>> config = CheckpointConfig()
  79. >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
  80. >>> model.train(10, dataset, callbacks=ckpoint_cb)
  81. """
  82. def __init__(self,
  83. save_checkpoint_steps=1,
  84. save_checkpoint_seconds=0,
  85. keep_checkpoint_max=5,
  86. keep_checkpoint_per_n_minutes=0,
  87. integrated_save=True,
  88. model_type="normal"):
  89. if not save_checkpoint_steps and not save_checkpoint_seconds and \
  90. not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
  91. raise ValueError("The input_param can't be all None or 0")
  92. if save_checkpoint_steps:
  93. save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
  94. if save_checkpoint_seconds:
  95. save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
  96. if keep_checkpoint_max:
  97. keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
  98. if keep_checkpoint_per_n_minutes:
  99. keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
  100. if model_type:
  101. model_type = check_string(model_type, ["normal", "fusion", "quant"])
  102. self._save_checkpoint_steps = save_checkpoint_steps
  103. self._save_checkpoint_seconds = save_checkpoint_seconds
  104. if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
  105. self._save_checkpoint_seconds = None
  106. self._keep_checkpoint_max = keep_checkpoint_max
  107. self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
  108. if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
  109. self._keep_checkpoint_per_n_minutes = None
  110. else:
  111. if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
  112. self._keep_checkpoint_max = 1
  113. self._model_type = model_type
  114. self._integrated_save = check_bool(integrated_save)
  115. @property
  116. def save_checkpoint_steps(self):
  117. """Get the value of _save_checkpoint_steps."""
  118. return self._save_checkpoint_steps
  119. @property
  120. def save_checkpoint_seconds(self):
  121. """Get the value of _save_checkpoint_seconds."""
  122. return self._save_checkpoint_seconds
  123. @property
  124. def keep_checkpoint_max(self):
  125. """Get the value of _keep_checkpoint_max."""
  126. return self._keep_checkpoint_max
  127. @property
  128. def keep_checkpoint_per_n_minutes(self):
  129. """Get the value of _keep_checkpoint_per_n_minutes."""
  130. return self._keep_checkpoint_per_n_minutes
  131. @property
  132. def integrated_save(self):
  133. """Get the value of _integrated_save."""
  134. return self._integrated_save
  135. @property
  136. def model_type(self):
  137. """Get the value of model_type."""
  138. return self._model_type
  139. def get_checkpoint_policy(self):
  140. """Get the policy of checkpoint."""
  141. checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
  142. 'save_checkpoint_seconds': self._save_checkpoint_seconds,
  143. 'keep_checkpoint_max': self._keep_checkpoint_max,
  144. 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
  145. 'model_type': self._model_type}
  146. return checkpoint_policy
  147. class ModelCheckpoint(Callback):
  148. """
  149. The checkpoint callback class.
  150. It is called to combine with train process and save the model and network parameters after traning.
  151. Args:
  152. prefix (str): Checkpoint files names prefix. Default: "CKP".
  153. directory (str): Folder path into which checkpoint files will be saved. Default: None.
  154. config (CheckpointConfig): Checkpoint strategy config. Default: None.
  155. Raises:
  156. ValueError: If the prefix is invalid.
  157. TypeError: If the config is not CheckpointConfig type.
  158. """
  159. def __init__(self, prefix='CKP', directory=None, config=None):
  160. super(ModelCheckpoint, self).__init__()
  161. self._latest_ckpt_file_name = ""
  162. self._init_time = time.time()
  163. self._last_time = time.time()
  164. self._last_time_for_keep = time.time()
  165. self._last_triggered_step = 0
  166. if _check_file_name_prefix(prefix):
  167. self._prefix = prefix
  168. else:
  169. raise ValueError("Prefix {} for checkpoint file name invalid, "
  170. "please check and correct it and then continue.".format(prefix))
  171. if directory:
  172. self._directory = _make_directory(directory)
  173. else:
  174. self._directory = _cur_dir
  175. if config is None:
  176. self._config = CheckpointConfig()
  177. else:
  178. if not isinstance(config, CheckpointConfig):
  179. raise TypeError("config should be CheckpointConfig type.")
  180. self._config = config
  181. # get existing checkpoint files
  182. self._manager = CheckpointManager()
  183. self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
  184. self._graph_saved = False
  185. def step_end(self, run_context):
  186. """
  187. Save the checkpoint at the end of step.
  188. Args:
  189. run_context (RunContext): Context of the train running.
  190. """
  191. cb_params = run_context.original_args()
  192. # save graph (only once)
  193. if not self._graph_saved:
  194. graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
  195. _save_graph(cb_params.train_network, graph_file_name)
  196. self._graph_saved = True
  197. self._save_ckpt(cb_params, self._config.model_type)
  198. def end(self, run_context):
  199. """
  200. Save the last checkpoint after training finished.
  201. Args:
  202. run_context (RunContext): Context of the train running.
  203. """
  204. cb_params = run_context.original_args()
  205. _to_save_last_ckpt = True
  206. self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)
  207. from mindspore.parallel._cell_wrapper import destroy_allgather_cell
  208. destroy_allgather_cell()
  209. def _check_save_ckpt(self, cb_params, force_to_save):
  210. """Check whether save checkpoint files or not."""
  211. if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
  212. if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
  213. or force_to_save is True:
  214. return True
  215. elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
  216. self._cur_time = time.time()
  217. if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
  218. self._last_time = self._cur_time
  219. return True
  220. return False
  221. def _save_ckpt(self, cb_params, model_type, force_to_save=False):
  222. """Save checkpoint files."""
  223. if cb_params.cur_step_num == self._last_triggered_step:
  224. return
  225. save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
  226. step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
  227. if save_ckpt:
  228. cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
  229. + str(step_num_in_epoch) + ".ckpt"
  230. # update checkpoint file list.
  231. self._manager.update_ckpoint_filelist(self._directory, self._prefix)
  232. # keep checkpoint files number equal max number.
  233. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
  234. self._manager.remove_oldest_ckpoint_file()
  235. elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
  236. self._cur_time_for_keep = time.time()
  237. if (self._cur_time_for_keep - self._last_time_for_keep) \
  238. < self._config.keep_checkpoint_per_n_minutes * 60:
  239. self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
  240. self._cur_time_for_keep)
  241. # generate the new checkpoint file and rename it.
  242. global _save_dir
  243. _save_dir = self._directory
  244. cur_file = os.path.join(self._directory, cur_ckpoint_file)
  245. tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
  246. gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
  247. self._last_time_for_keep = time.time()
  248. self._last_triggered_step = cb_params.cur_step_num
  249. if context.get_context("enable_ge"):
  250. set_cur_net(cb_params.train_network)
  251. cb_params.train_network.exec_checkpoint_graph()
  252. _exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)
  253. if os.path.exists(gen_file):
  254. shutil.move(gen_file, cur_file)
  255. self._latest_ckpt_file_name = cur_file
  256. @property
  257. def latest_ckpt_file_name(self):
  258. """Return the latest checkpoint path and file name."""
  259. return self._latest_ckpt_file_name
  260. class CheckpointManager:
  261. """Manage checkpoint files according to train_config of checkpoint."""
  262. def __init__(self):
  263. self._ckpoint_filelist = []
  264. @property
  265. def ckpoint_filelist(self):
  266. """Get all the related checkpoint files managed here."""
  267. return self._ckpoint_filelist
  268. @property
  269. def ckpoint_num(self):
  270. """Get the number of the related checkpoint files managed here."""
  271. return len(self._ckpoint_filelist)
  272. def update_ckpoint_filelist(self, directory, prefix):
  273. """Update the checkpoint file list."""
  274. self._ckpoint_filelist = []
  275. files = os.listdir(directory)
  276. for filename in files:
  277. if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
  278. mid_name = filename[len(prefix):-5]
  279. flag = True
  280. for char in mid_name:
  281. if char.isalpha():
  282. flag = False
  283. if flag:
  284. self._ckpoint_filelist.append(directory + '/' + filename)
  285. def remove_ckpoint_file(self, file_name):
  286. """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
  287. try:
  288. os.chmod(file_name, stat.S_IWRITE)
  289. os.remove(file_name)
  290. self._ckpoint_filelist.remove(file_name)
  291. except OSError:
  292. logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
  293. except ValueError:
  294. logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
  295. def remove_oldest_ckpoint_file(self):
  296. """Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
  297. ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
  298. self.remove_ckpoint_file(ckpoint_files[0])
  299. def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
  300. """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
  301. movs = []
  302. oldest_file = ''
  303. oldest_time = cur_time
  304. for ck_file in self._ckpoint_filelist:
  305. modify_time = os.path.getmtime(ck_file)
  306. if cur_time - modify_time < 60 * minutes:
  307. movs.append(ck_file)
  308. if modify_time < oldest_time:
  309. oldest_time = modify_time
  310. oldest_file = ck_file
  311. for mv_file in movs:
  312. if mv_file == oldest_file:
  313. continue
  314. self.remove_ckpoint_file(mv_file)