You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 25 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Callback related classes and functions."""
  16. import os
  17. import stat
  18. import shutil
  19. import time
  20. import numpy as np
  21. import mindspore.context as context
  22. from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
  23. from mindspore.train._utils import _make_directory
  24. from mindspore import log as logger
  25. from mindspore._checkparam import check_int_non_negative, check_bool
  26. from mindspore.common.tensor import Tensor
  27. from .summary.summary_record import _cache_summary_tensor_data
  28. __all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
  29. _cur_dir = os.getcwd()
  30. _cur_net = None
  31. _save_dir = _cur_dir
  32. class _CheckpointManager:
  33. """Manage checkpoint files according to train_config of checkpoint."""
  34. def __init__(self):
  35. self._ckpoint_filelist = []
  36. @property
  37. def ckpoint_filelist(self):
  38. """Get all the related checkpoint files managed here."""
  39. return self._ckpoint_filelist
  40. @property
  41. def ckpoint_num(self):
  42. """Get the number of the related checkpoint files managed here."""
  43. return len(self._ckpoint_filelist)
  44. def update_ckpoint_filelist(self, directory, prefix):
  45. """Update the checkpoint file list."""
  46. self._ckpoint_filelist = []
  47. files = os.listdir(directory)
  48. for filename in files:
  49. if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
  50. mid_name = filename[len(prefix):-5]
  51. flag = True
  52. for char in mid_name:
  53. if char.isalpha():
  54. flag = False
  55. if flag:
  56. self._ckpoint_filelist.append(directory + '/' + filename)
  57. def remove_ckpoint_file(self, file_name):
  58. """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
  59. try:
  60. os.chmod(file_name, stat.S_IWRITE)
  61. os.remove(file_name)
  62. self._ckpoint_filelist.remove(file_name)
  63. except OSError:
  64. logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
  65. except ValueError:
  66. logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
  67. def remove_oldest_ckpoint_file(self):
  68. """Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
  69. ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
  70. self.remove_ckpoint_file(ckpoint_files[0])
  71. def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
  72. """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
  73. movs = []
  74. oldest_file = ''
  75. oldest_time = cur_time
  76. for ck_file in self._ckpoint_filelist:
  77. modify_time = os.path.getmtime(ck_file)
  78. if cur_time - modify_time < 60 * minutes:
  79. movs.append(ck_file)
  80. if modify_time < oldest_time:
  81. oldest_time = modify_time
  82. oldest_file = ck_file
  83. for mv_file in movs:
  84. if mv_file == oldest_file:
  85. continue
  86. self.remove_ckpoint_file(mv_file)
  87. def _check_file_name_prefix(file_name_prefix):
  88. """
  89. Check file name valid or not.
  90. File name can't include '/'. This file name naming convention only apply to Linux.
  91. """
  92. if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
  93. return False
  94. return True
  95. def _chg_ckpt_file_name_if_same_exist(directory, prefix):
  96. """Check if there is a file with the same name."""
  97. files = os.listdir(directory)
  98. suffix_num = 0
  99. pre_len = len(prefix)
  100. for filename in files:
  101. name_ext = os.path.splitext(filename)
  102. if name_ext[-1] != ".ckpt":
  103. continue
  104. # find same prefix file
  105. if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
  106. # add the max suffix + 1
  107. index = filename[pre_len:].find("-")
  108. if index == 0:
  109. suffix_num = max(suffix_num, 1)
  110. elif index != -1:
  111. num = filename[pre_len+1:pre_len+index]
  112. if num.isdigit():
  113. suffix_num = max(suffix_num, int(num)+1)
  114. if suffix_num != 0:
  115. prefix = prefix + "_" + str(suffix_num)
  116. return prefix
  117. class CheckpointConfig:
  118. """
  119. The config for model checkpoint.
  120. Args:
  121. save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
  122. save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.
  123. Can't be used with save_checkpoint_steps at the same time.
  124. keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
  125. keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
  126. Can't be used with keep_checkpoint_max at the same time.
  127. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
  128. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
  129. Raises:
  130. ValueError: If the input_param is None or 0.
  131. Examples:
  132. >>> config = CheckpointConfig()
  133. >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
  134. >>> model.train(10, dataset, callbacks=ckpoint_cb)
  135. """
  136. def __init__(self,
  137. save_checkpoint_steps=1,
  138. save_checkpoint_seconds=0,
  139. keep_checkpoint_max=5,
  140. keep_checkpoint_per_n_minutes=0,
  141. integrated_save=True):
  142. if not save_checkpoint_steps and not save_checkpoint_seconds and \
  143. not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
  144. raise ValueError("The input_param can't be all None or 0")
  145. if save_checkpoint_steps:
  146. save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
  147. if save_checkpoint_seconds:
  148. save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
  149. if keep_checkpoint_max:
  150. keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
  151. if keep_checkpoint_per_n_minutes:
  152. keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
  153. self._save_checkpoint_steps = save_checkpoint_steps
  154. self._save_checkpoint_seconds = save_checkpoint_seconds
  155. if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
  156. self._save_checkpoint_seconds = None
  157. self._keep_checkpoint_max = keep_checkpoint_max
  158. self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
  159. if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
  160. self._keep_checkpoint_per_n_minutes = None
  161. else:
  162. if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
  163. self._keep_checkpoint_max = 1
  164. self._integrated_save = check_bool(integrated_save)
  165. @property
  166. def save_checkpoint_steps(self):
  167. """Get the value of _save_checkpoint_steps."""
  168. return self._save_checkpoint_steps
  169. @property
  170. def save_checkpoint_seconds(self):
  171. """Get the value of _save_checkpoint_seconds."""
  172. return self._save_checkpoint_seconds
  173. @property
  174. def keep_checkpoint_max(self):
  175. """Get the value of _keep_checkpoint_max."""
  176. return self._keep_checkpoint_max
  177. @property
  178. def keep_checkpoint_per_n_minutes(self):
  179. """Get the value of _keep_checkpoint_per_n_minutes."""
  180. return self._keep_checkpoint_per_n_minutes
  181. @property
  182. def integrated_save(self):
  183. """Get the value of _integrated_save."""
  184. return self._integrated_save
  185. def get_checkpoint_policy(self):
  186. """Get the policy of checkpoint."""
  187. checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
  188. 'save_checkpoint_seconds': self._save_checkpoint_seconds,
  189. 'keep_checkpoint_max': self._keep_checkpoint_max,
  190. 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
  191. return checkpoint_policy
  192. def _set_cur_net(net):
  193. """
  194. Set current net for which we are using to save checkpoint.
  195. Args:
  196. net (Cell): train network
  197. """
  198. global _cur_net
  199. _cur_net = net
  200. def _checkpoint_cb_for_save_op(parameter_list):
  201. """
  202. The checkpoint callback function for MindSpore.
  203. Will be executed by checkpoint save op.
  204. Args:
  205. parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
  206. Returns:
  207. bool, true: means save checkpoint success.
  208. """
  209. if _cur_net is None:
  210. logger.warning("_cur_net is None. parameters are not updated.")
  211. return False
  212. logger.info("update parameters in the net.")
  213. _fill_param_into_net(_cur_net, parameter_list)
  214. _set_cur_net(None)
  215. return True
  216. def _summary_cb_for_save_op(summary_list):
  217. """
  218. The summary callback function for MindSpore.
  219. Will be executed by summary op.
  220. Args:
  221. summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
  222. Returns:
  223. bool, true: means save summary success.
  224. """
  225. ret = _cache_summary_tensor_data(summary_list)
  226. return ret
  227. def _build_callbacks(callbacks):
  228. """
  229. Contain a list of callback.
  230. Args:
  231. callbacks (list): Callback functions list, Support None, a single Callback object, or a list.
  232. Returns:
  233. List, a list of callback functions.
  234. """
  235. if callbacks:
  236. if isinstance(callbacks, tuple):
  237. raise TypeError("Callbacks cannot be a tuple. Please check it.")
  238. if not isinstance(callbacks, list):
  239. callbacks = [callbacks]
  240. else:
  241. callbacks = []
  242. excute_callbacks = []
  243. for cb in callbacks:
  244. if cb is None or not isinstance(cb, Callback):
  245. raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.")
  246. excute_callbacks.append(cb)
  247. return _ListCallback(excute_callbacks)
  248. class _ListCallback:
  249. """
  250. Sequential execution of callback functions.
  251. Execute Callback functions at certain points.
  252. Args:
  253. callbacks (list): Callback functions list.
  254. """
  255. def __init__(self, callbacks):
  256. super(_ListCallback, self).__init__()
  257. self._callbacks = callbacks
  258. def begin(self, run_context):
  259. """Called once before network training."""
  260. for cb in self._callbacks:
  261. cb.begin(run_context)
  262. def epoch_begin(self, run_context):
  263. """Called before each epoch begin."""
  264. for cb in self._callbacks:
  265. cb.epoch_begin(run_context)
  266. def epoch_end(self, run_context):
  267. """Called after each epoch finished."""
  268. for cb in self._callbacks:
  269. cb.epoch_end(run_context)
  270. def step_begin(self, run_context):
  271. """Called before each epoch begin."""
  272. for cb in self._callbacks:
  273. cb.step_begin(run_context)
  274. def step_end(self, run_context):
  275. """Called after each step finished."""
  276. for cb in self._callbacks:
  277. cb.step_end(run_context)
  278. def end(self, run_context):
  279. """Called once after network training."""
  280. for cb in self._callbacks:
  281. cb.end(run_context)
  282. class Callback:
  283. """
  284. Abstract base class used to build a callback function.
  285. Callback function will execution some operating to the current step or epoch.
  286. Examples:
  287. >>> class Print_info(Callback):
  288. >>> def step_end(self, run_context):
  289. >>> cb_params = run_context.original_args()
  290. >>> print(cb_params.cur_epoch_num)
  291. >>> print(cb_params.cur_step_num)
  292. >>>
  293. >>> print_cb = Print_info()
  294. >>> model.train(epoch, dataset, callback=print_cb)
  295. """
  296. def __init__(self):
  297. pass
  298. def begin(self, run_context):
  299. """
  300. Called once before the network executing.
  301. Args:
  302. run_context (RunContext): Include some information of the model.
  303. """
  304. def epoch_begin(self, run_context):
  305. """
  306. Called before each epoch beginning.
  307. Args:
  308. run_context (RunContext): Include some information of the model.
  309. """
  310. def epoch_end(self, run_context):
  311. """
  312. Called after each epoch finished.
  313. Args:
  314. run_context (RunContext): Include some information of the model.
  315. """
  316. def step_begin(self, run_context):
  317. """
  318. Called before each epoch beginning.
  319. Args:
  320. run_context (RunContext): Include some information of the model.
  321. """
  322. def step_end(self, run_context):
  323. """
  324. Called after each step finished.
  325. Args:
  326. run_context (RunContext): Include some information of the model.
  327. """
  328. def end(self, run_context):
  329. """
  330. Called once after network training.
  331. Args:
  332. run_context (RunContext): Include some information of the model.
  333. """
  334. class SummaryStep(Callback):
  335. """
  336. The summary callback class.
  337. Args:
  338. summary (Object): Summary recode object.
  339. flush_step (int): Number of interval steps to execute. Default: 10.
  340. """
  341. def __init__(self, summary, flush_step=10):
  342. super(SummaryStep, self).__init__()
  343. if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
  344. raise ValueError("`flush_step` should be int and greater than 0")
  345. self._summary = summary
  346. self._flush_step = flush_step
  347. def step_end(self, run_context):
  348. """
  349. Save summary.
  350. Args:
  351. run_context (RunContext): Context of the train running.
  352. """
  353. cb_params = run_context.original_args()
  354. if cb_params.cur_step_num % self._flush_step == 0:
  355. self._summary.record(cb_params.cur_step_num, cb_params.train_network)
  356. @property
  357. def summary_file_name(self):
  358. return self._summary.full_file_name
  359. class _InternalCallbackParam(dict):
  360. """Internal callback object's parameters."""
  361. def __getattr__(self, key):
  362. return self[key]
  363. def __setattr__(self, key, value):
  364. self[key] = value
  365. class RunContext:
  366. """
  367. Provides information about the model.
  368. Run call being made. Provides information about original request to model function.
  369. callback objects can stop the loop by calling request_stop() of run_context.
  370. Args:
  371. original_args (dict): Holding the related information of model etc.
  372. """
  373. def __init__(self, original_args):
  374. if not isinstance(original_args, dict):
  375. raise TypeError("The arg of RunContext should be dict type.")
  376. self._original_args = original_args
  377. self._stop_requested = False
  378. def original_args(self):
  379. """
  380. Get the _original_args object.
  381. Returns:
  382. Dict, a object holding the original arguments of model.
  383. """
  384. return self._original_args
  385. def request_stop(self):
  386. """
  387. Sets stop requested during training.
  388. Callbacks can use this function to request stop of iterations.
  389. model.train() checks whether this is called or not.
  390. """
  391. self._stop_requested = True
  392. def get_stop_requested(self):
  393. """
  394. Returns whether a stop is requested or not.
  395. Returns:
  396. bool, if true, model.train() stops iterations.
  397. """
  398. return self._stop_requested
  399. class ModelCheckpoint(Callback):
  400. """
  401. The checkpoint callback class.
  402. It is called to combine with train process and save the model and network parameters after traning.
  403. Args:
  404. prefix (str): Checkpoint files names prefix. Default: "CKP".
  405. directory (str): Lolder path into which checkpoint files will be saved. Default: None.
  406. config (CheckpointConfig): Checkpoint strategy config. Default: None.
  407. Raises:
  408. ValueError: If the prefix is invalid.
  409. TypeError: If the config is not CheckpointConfig type.
  410. """
  411. def __init__(self, prefix='CKP', directory=None, config=None):
  412. super(ModelCheckpoint, self).__init__()
  413. self._latest_ckpt_file_name = ""
  414. self._init_time = time.time()
  415. self._last_time = time.time()
  416. self._last_time_for_keep = time.time()
  417. self._last_triggered_step = 0
  418. if _check_file_name_prefix(prefix):
  419. self._prefix = prefix
  420. else:
  421. raise ValueError("Prefix {} for checkpoint file name invalid, "
  422. "please check and correct it and then continue.".format(prefix))
  423. if directory:
  424. self._directory = _make_directory(directory)
  425. else:
  426. self._directory = _cur_dir
  427. if config is None:
  428. self._config = CheckpointConfig()
  429. else:
  430. if not isinstance(config, CheckpointConfig):
  431. raise TypeError("config should be CheckpointConfig type.")
  432. self._config = config
  433. # get existing checkpoint files
  434. self._manager = _CheckpointManager()
  435. self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
  436. self._graph_saved = False
  437. def step_end(self, run_context):
  438. """
  439. Save the checkpoint at the end of step.
  440. Args:
  441. run_context (RunContext): Context of the train running.
  442. """
  443. cb_params = run_context.original_args()
  444. # save graph (only once)
  445. if not self._graph_saved:
  446. graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
  447. _save_graph(cb_params.train_network, graph_file_name)
  448. self._graph_saved = True
  449. self._save_ckpt(cb_params)
  450. def end(self, run_context):
  451. """
  452. Save the last checkpoint after training finished.
  453. Args:
  454. run_context (RunContext): Context of the train running.
  455. """
  456. cb_params = run_context.original_args()
  457. _to_save_last_ckpt = True
  458. self._save_ckpt(cb_params, _to_save_last_ckpt)
  459. from mindspore.parallel._cell_wrapper import destroy_allgather_cell
  460. destroy_allgather_cell()
  461. def _check_save_ckpt(self, cb_params, force_to_save):
  462. """Check whether save checkpoint files or not."""
  463. if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
  464. if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
  465. or force_to_save is True:
  466. return True
  467. elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
  468. self._cur_time = time.time()
  469. if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
  470. self._last_time = self._cur_time
  471. return True
  472. return False
  473. def _save_ckpt(self, cb_params, force_to_save=False):
  474. """Save checkpoint files."""
  475. if cb_params.cur_step_num == self._last_triggered_step:
  476. return
  477. save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
  478. step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
  479. if save_ckpt:
  480. cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
  481. + str(step_num_in_epoch) + ".ckpt"
  482. # update checkpoint file list.
  483. self._manager.update_ckpoint_filelist(self._directory, self._prefix)
  484. # keep checkpoint files number equal max number.
  485. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
  486. self._manager.remove_oldest_ckpoint_file()
  487. elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
  488. self._cur_time_for_keep = time.time()
  489. if (self._cur_time_for_keep - self._last_time_for_keep) \
  490. < self._config.keep_checkpoint_per_n_minutes * 60:
  491. self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
  492. self._cur_time_for_keep)
  493. # generate the new checkpoint file and rename it.
  494. global _save_dir
  495. _save_dir = self._directory
  496. cur_file = os.path.join(self._directory, cur_ckpoint_file)
  497. tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
  498. gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
  499. self._last_time_for_keep = time.time()
  500. self._last_triggered_step = cb_params.cur_step_num
  501. if context.get_context("enable_ge"):
  502. _set_cur_net(cb_params.train_network)
  503. cb_params.train_network.exec_checkpoint_graph()
  504. _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
  505. if os.path.exists(gen_file):
  506. shutil.move(gen_file, cur_file)
  507. self._latest_ckpt_file_name = cur_file
  508. @property
  509. def latest_ckpt_file_name(self):
  510. """Return the latest checkpoint path and file name."""
  511. return self._latest_ckpt_file_name
  512. class LossMonitor(Callback):
  513. """
  514. Monitor the loss in training.
  515. If the loss is NAN or INF, it will terminate training.
  516. Note:
  517. If per_print_times is 0 do not print loss.
  518. Args:
  519. per_print_times (int): Print loss every times. Default: 1.
  520. Raises:
  521. ValueError: If print_step is not int or less than zero.
  522. """
  523. def __init__(self, per_print_times=1):
  524. super(LossMonitor, self).__init__()
  525. if not isinstance(per_print_times, int) or per_print_times < 0:
  526. raise ValueError("print_step must be int and >= 0.")
  527. self._per_print_times = per_print_times
  528. def step_end(self, run_context):
  529. cb_params = run_context.original_args()
  530. loss = cb_params.net_outputs
  531. if isinstance(loss, (tuple, list)):
  532. if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
  533. loss = loss[0]
  534. if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
  535. loss = np.mean(loss.asnumpy())
  536. cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
  537. if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
  538. raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
  539. .format(cb_params.cur_epoch_num, cur_step_in_epoch))
  540. if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
  541. print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
  542. class TimeMonitor(Callback):
  543. """Time Monitor."""
  544. def __init__(self, data_size):
  545. super(TimeMonitor, self).__init__()
  546. self.data_size = data_size
  547. def epoch_begin(self, run_context):
  548. self.epoch_time = time.time()
  549. def epoch_end(self, run_context):
  550. epoch_mseconds = (time.time() - self.epoch_time) * 1000
  551. per_step_mseconds = epoch_mseconds / self.data_size
  552. print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
  553. def step_begin(self, run_context):
  554. self.step_time = time.time()
  555. def step_end(self, run_context):
  556. step_mseconds = (time.time() - self.step_time) * 1000
  557. print('step time', step_mseconds, flush=True)