You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 25 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Callback related classes and functions."""
  16. import os
  17. import stat
  18. import shutil
  19. import time
  20. import numpy as np
  21. import mindspore.context as context
  22. from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
  23. from mindspore.train._utils import _make_directory
  24. from mindspore import log as logger
  25. from mindspore._checkparam import check_int_non_negative, check_bool
  26. from mindspore.common.tensor import Tensor
  27. from .summary.summary_record import _cache_summary_tensor_data
  28. __all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
  29. _cur_dir = os.getcwd()
  30. _cur_net = None
  31. _save_dir = _cur_dir
  32. class _CheckpointManager:
  33. """Manage checkpoint files according to train_config of checkpoint."""
  34. def __init__(self):
  35. self._ckpoint_filelist = []
  36. @property
  37. def ckpoint_filelist(self):
  38. """Get all the related checkpoint files managed here."""
  39. return self._ckpoint_filelist
  40. @property
  41. def ckpoint_num(self):
  42. """Get the number of the related checkpoint files managed here."""
  43. return len(self._ckpoint_filelist)
  44. def update_ckpoint_filelist(self, directory, prefix):
  45. """Update the checkpoint file list."""
  46. self._ckpoint_filelist = []
  47. files = os.listdir(directory)
  48. for filename in files:
  49. if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
  50. mid_name = filename[len(prefix):-5]
  51. flag = True
  52. for char in mid_name:
  53. if char.isalpha():
  54. flag = False
  55. if flag:
  56. self._ckpoint_filelist.append(directory + '/' + filename)
  57. def remove_ckpoint_file(self, file_name):
  58. """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
  59. try:
  60. os.chmod(file_name, stat.S_IWRITE)
  61. os.remove(file_name)
  62. self._ckpoint_filelist.remove(file_name)
  63. except OSError:
  64. logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
  65. except ValueError:
  66. logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
  67. def remove_oldest_ckpoint_file(self):
  68. """Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
  69. ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
  70. self.remove_ckpoint_file(ckpoint_files[0])
  71. def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
  72. """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
  73. movs = []
  74. oldest_file = ''
  75. oldest_time = cur_time
  76. for ck_file in self._ckpoint_filelist:
  77. modify_time = os.path.getmtime(ck_file)
  78. if cur_time - modify_time < 60 * minutes:
  79. movs.append(ck_file)
  80. if modify_time < oldest_time:
  81. oldest_time = modify_time
  82. oldest_file = ck_file
  83. for mv_file in movs:
  84. if mv_file == oldest_file:
  85. continue
  86. self.remove_ckpoint_file(mv_file)
  87. def _check_file_name_prefix(file_name_prefix):
  88. """
  89. Check file name valid or not.
  90. File name can't include '/'. This file name naming convention only apply to Linux.
  91. """
  92. if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
  93. return False
  94. return True
  95. def _chg_ckpt_file_name_if_same_exist(directory, prefix):
  96. """Check if there is a file with the same name."""
  97. files = os.listdir(directory)
  98. suffix_num = 0
  99. pre_len = len(prefix)
  100. for filename in files:
  101. name_ext = os.path.splitext(filename)
  102. if name_ext[-1] != ".ckpt":
  103. continue
  104. # find same prefix file
  105. if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
  106. # add the max suffix + 1
  107. index = filename[pre_len:].find("-")
  108. if index == 0:
  109. suffix_num = max(suffix_num, 1)
  110. elif index != -1:
  111. num = filename[pre_len+1:pre_len+index]
  112. if num.isdigit():
  113. suffix_num = max(suffix_num, int(num)+1)
  114. if suffix_num != 0:
  115. prefix = prefix + "_" + str(suffix_num)
  116. return prefix
  117. class CheckpointConfig:
  118. """
  119. The config for model checkpoint.
  120. Note:
  121. During the training process, if dataset is transmitted through the data channel,
  122. suggest set save_checkpoint_steps be an integer multiple of loop_size.
  123. Otherwise there may be deviation in the timing of saving checkpoint.
  124. Args:
  125. save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
  126. save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0.
  127. Can't be used with save_checkpoint_steps at the same time.
  128. keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
  129. keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
  130. Can't be used with keep_checkpoint_max at the same time.
  131. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
  132. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
  133. Raises:
  134. ValueError: If the input_param is None or 0.
  135. Examples:
  136. >>> config = CheckpointConfig()
  137. >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
  138. >>> model.train(10, dataset, callbacks=ckpoint_cb)
  139. """
  140. def __init__(self,
  141. save_checkpoint_steps=1,
  142. save_checkpoint_seconds=0,
  143. keep_checkpoint_max=5,
  144. keep_checkpoint_per_n_minutes=0,
  145. integrated_save=True):
  146. if not save_checkpoint_steps and not save_checkpoint_seconds and \
  147. not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
  148. raise ValueError("The input_param can't be all None or 0")
  149. if save_checkpoint_steps:
  150. save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
  151. if save_checkpoint_seconds:
  152. save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
  153. if keep_checkpoint_max:
  154. keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
  155. if keep_checkpoint_per_n_minutes:
  156. keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
  157. self._save_checkpoint_steps = save_checkpoint_steps
  158. self._save_checkpoint_seconds = save_checkpoint_seconds
  159. if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
  160. self._save_checkpoint_seconds = None
  161. self._keep_checkpoint_max = keep_checkpoint_max
  162. self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
  163. if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
  164. self._keep_checkpoint_per_n_minutes = None
  165. else:
  166. if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
  167. self._keep_checkpoint_max = 1
  168. self._integrated_save = check_bool(integrated_save)
  169. @property
  170. def save_checkpoint_steps(self):
  171. """Get the value of _save_checkpoint_steps."""
  172. return self._save_checkpoint_steps
  173. @property
  174. def save_checkpoint_seconds(self):
  175. """Get the value of _save_checkpoint_seconds."""
  176. return self._save_checkpoint_seconds
  177. @property
  178. def keep_checkpoint_max(self):
  179. """Get the value of _keep_checkpoint_max."""
  180. return self._keep_checkpoint_max
  181. @property
  182. def keep_checkpoint_per_n_minutes(self):
  183. """Get the value of _keep_checkpoint_per_n_minutes."""
  184. return self._keep_checkpoint_per_n_minutes
  185. @property
  186. def integrated_save(self):
  187. """Get the value of _integrated_save."""
  188. return self._integrated_save
  189. def get_checkpoint_policy(self):
  190. """Get the policy of checkpoint."""
  191. checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
  192. 'save_checkpoint_seconds': self._save_checkpoint_seconds,
  193. 'keep_checkpoint_max': self._keep_checkpoint_max,
  194. 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
  195. return checkpoint_policy
  196. def _set_cur_net(net):
  197. """
  198. Set current net for which we are using to save checkpoint.
  199. Args:
  200. net (Cell): train network
  201. """
  202. global _cur_net
  203. _cur_net = net
  204. def _checkpoint_cb_for_save_op(parameter_list):
  205. """
  206. The checkpoint callback function for MindSpore.
  207. Will be executed by checkpoint save op.
  208. Args:
  209. parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor.
  210. Returns:
  211. bool, true: means save checkpoint success.
  212. """
  213. if _cur_net is None:
  214. logger.warning("_cur_net is None. parameters are not updated.")
  215. return False
  216. logger.info("update parameters in the net.")
  217. _fill_param_into_net(_cur_net, parameter_list)
  218. _set_cur_net(None)
  219. return True
  220. def _summary_cb_for_save_op(summary_list):
  221. """
  222. The summary callback function for MindSpore.
  223. Will be executed by summary op.
  224. Args:
  225. summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor.
  226. Returns:
  227. bool, true: means save summary success.
  228. """
  229. ret = _cache_summary_tensor_data(summary_list)
  230. return ret
  231. def _build_callbacks(callbacks):
  232. """
  233. Contain a list of callback.
  234. Args:
  235. callbacks (list): Callback functions list, Support None, a single Callback object, or a list.
  236. Returns:
  237. List, a list of callback functions.
  238. """
  239. if callbacks:
  240. if isinstance(callbacks, tuple):
  241. raise TypeError("Callbacks cannot be a tuple. Please check it.")
  242. if not isinstance(callbacks, list):
  243. callbacks = [callbacks]
  244. else:
  245. callbacks = []
  246. excute_callbacks = []
  247. for cb in callbacks:
  248. if cb is None or not isinstance(cb, Callback):
  249. raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.")
  250. excute_callbacks.append(cb)
  251. return _ListCallback(excute_callbacks)
  252. class _ListCallback:
  253. """
  254. Sequential execution of callback functions.
  255. Execute Callback functions at certain points.
  256. Args:
  257. callbacks (list): Callback functions list.
  258. """
  259. def __init__(self, callbacks):
  260. super(_ListCallback, self).__init__()
  261. self._callbacks = callbacks
  262. def begin(self, run_context):
  263. """Called once before network training."""
  264. for cb in self._callbacks:
  265. cb.begin(run_context)
  266. def epoch_begin(self, run_context):
  267. """Called before each epoch begin."""
  268. for cb in self._callbacks:
  269. cb.epoch_begin(run_context)
  270. def epoch_end(self, run_context):
  271. """Called after each epoch finished."""
  272. for cb in self._callbacks:
  273. cb.epoch_end(run_context)
  274. def step_begin(self, run_context):
  275. """Called before each epoch begin."""
  276. for cb in self._callbacks:
  277. cb.step_begin(run_context)
  278. def step_end(self, run_context):
  279. """Called after each step finished."""
  280. for cb in self._callbacks:
  281. cb.step_end(run_context)
  282. def end(self, run_context):
  283. """Called once after network training."""
  284. for cb in self._callbacks:
  285. cb.end(run_context)
  286. class Callback:
  287. """
  288. Abstract base class used to build a callback function.
  289. Callback function will execution some operating to the current step or epoch.
  290. Examples:
  291. >>> class Print_info(Callback):
  292. >>> def step_end(self, run_context):
  293. >>> cb_params = run_context.original_args()
  294. >>> print(cb_params.cur_epoch_num)
  295. >>> print(cb_params.cur_step_num)
  296. >>>
  297. >>> print_cb = Print_info()
  298. >>> model.train(epoch, dataset, callbacks=print_cb)
  299. """
  300. def __init__(self):
  301. pass
  302. def begin(self, run_context):
  303. """
  304. Called once before the network executing.
  305. Args:
  306. run_context (RunContext): Include some information of the model.
  307. """
  308. def epoch_begin(self, run_context):
  309. """
  310. Called before each epoch beginning.
  311. Args:
  312. run_context (RunContext): Include some information of the model.
  313. """
  314. def epoch_end(self, run_context):
  315. """
  316. Called after each epoch finished.
  317. Args:
  318. run_context (RunContext): Include some information of the model.
  319. """
  320. def step_begin(self, run_context):
  321. """
  322. Called before each epoch beginning.
  323. Args:
  324. run_context (RunContext): Include some information of the model.
  325. """
  326. def step_end(self, run_context):
  327. """
  328. Called after each step finished.
  329. Args:
  330. run_context (RunContext): Include some information of the model.
  331. """
  332. def end(self, run_context):
  333. """
  334. Called once after network training.
  335. Args:
  336. run_context (RunContext): Include some information of the model.
  337. """
  338. class SummaryStep(Callback):
  339. """
  340. The summary callback class.
  341. Args:
  342. summary (Object): Summary recode object.
  343. flush_step (int): Number of interval steps to execute. Default: 10.
  344. """
  345. def __init__(self, summary, flush_step=10):
  346. super(SummaryStep, self).__init__()
  347. if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
  348. raise ValueError("`flush_step` should be int and greater than 0")
  349. self._summary = summary
  350. self._flush_step = flush_step
  351. def step_end(self, run_context):
  352. """
  353. Save summary.
  354. Args:
  355. run_context (RunContext): Context of the train running.
  356. """
  357. cb_params = run_context.original_args()
  358. if cb_params.cur_step_num % self._flush_step == 0:
  359. self._summary.record(cb_params.cur_step_num, cb_params.train_network)
  360. @property
  361. def summary_file_name(self):
  362. return self._summary.full_file_name
  363. class _InternalCallbackParam(dict):
  364. """Internal callback object's parameters."""
  365. def __getattr__(self, key):
  366. return self[key]
  367. def __setattr__(self, key, value):
  368. self[key] = value
  369. class RunContext:
  370. """
  371. Provides information about the model.
  372. Run call being made. Provides information about original request to model function.
  373. callback objects can stop the loop by calling request_stop() of run_context.
  374. Args:
  375. original_args (dict): Holding the related information of model etc.
  376. """
  377. def __init__(self, original_args):
  378. if not isinstance(original_args, dict):
  379. raise TypeError("The arg of RunContext should be dict type.")
  380. self._original_args = original_args
  381. self._stop_requested = False
  382. def original_args(self):
  383. """
  384. Get the _original_args object.
  385. Returns:
  386. Dict, a object holding the original arguments of model.
  387. """
  388. return self._original_args
  389. def request_stop(self):
  390. """
  391. Sets stop requested during training.
  392. Callbacks can use this function to request stop of iterations.
  393. model.train() checks whether this is called or not.
  394. """
  395. self._stop_requested = True
  396. def get_stop_requested(self):
  397. """
  398. Returns whether a stop is requested or not.
  399. Returns:
  400. bool, if true, model.train() stops iterations.
  401. """
  402. return self._stop_requested
  403. class ModelCheckpoint(Callback):
  404. """
  405. The checkpoint callback class.
  406. It is called to combine with train process and save the model and network parameters after traning.
  407. Args:
  408. prefix (str): Checkpoint files names prefix. Default: "CKP".
  409. directory (str): Folder path into which checkpoint files will be saved. Default: None.
  410. config (CheckpointConfig): Checkpoint strategy config. Default: None.
  411. Raises:
  412. ValueError: If the prefix is invalid.
  413. TypeError: If the config is not CheckpointConfig type.
  414. """
  415. def __init__(self, prefix='CKP', directory=None, config=None):
  416. super(ModelCheckpoint, self).__init__()
  417. self._latest_ckpt_file_name = ""
  418. self._init_time = time.time()
  419. self._last_time = time.time()
  420. self._last_time_for_keep = time.time()
  421. self._last_triggered_step = 0
  422. if _check_file_name_prefix(prefix):
  423. self._prefix = prefix
  424. else:
  425. raise ValueError("Prefix {} for checkpoint file name invalid, "
  426. "please check and correct it and then continue.".format(prefix))
  427. if directory:
  428. self._directory = _make_directory(directory)
  429. else:
  430. self._directory = _cur_dir
  431. if config is None:
  432. self._config = CheckpointConfig()
  433. else:
  434. if not isinstance(config, CheckpointConfig):
  435. raise TypeError("config should be CheckpointConfig type.")
  436. self._config = config
  437. # get existing checkpoint files
  438. self._manager = _CheckpointManager()
  439. self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
  440. self._graph_saved = False
  441. def step_end(self, run_context):
  442. """
  443. Save the checkpoint at the end of step.
  444. Args:
  445. run_context (RunContext): Context of the train running.
  446. """
  447. cb_params = run_context.original_args()
  448. # save graph (only once)
  449. if not self._graph_saved:
  450. graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
  451. _save_graph(cb_params.train_network, graph_file_name)
  452. self._graph_saved = True
  453. self._save_ckpt(cb_params)
  454. def end(self, run_context):
  455. """
  456. Save the last checkpoint after training finished.
  457. Args:
  458. run_context (RunContext): Context of the train running.
  459. """
  460. cb_params = run_context.original_args()
  461. _to_save_last_ckpt = True
  462. self._save_ckpt(cb_params, _to_save_last_ckpt)
  463. from mindspore.parallel._cell_wrapper import destroy_allgather_cell
  464. destroy_allgather_cell()
  465. def _check_save_ckpt(self, cb_params, force_to_save):
  466. """Check whether save checkpoint files or not."""
  467. if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
  468. if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
  469. or force_to_save is True:
  470. return True
  471. elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
  472. self._cur_time = time.time()
  473. if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True:
  474. self._last_time = self._cur_time
  475. return True
  476. return False
  477. def _save_ckpt(self, cb_params, force_to_save=False):
  478. """Save checkpoint files."""
  479. if cb_params.cur_step_num == self._last_triggered_step:
  480. return
  481. save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
  482. step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
  483. if save_ckpt:
  484. cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
  485. + str(step_num_in_epoch) + ".ckpt"
  486. # update checkpoint file list.
  487. self._manager.update_ckpoint_filelist(self._directory, self._prefix)
  488. # keep checkpoint files number equal max number.
  489. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
  490. self._manager.remove_oldest_ckpoint_file()
  491. elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
  492. self._cur_time_for_keep = time.time()
  493. if (self._cur_time_for_keep - self._last_time_for_keep) \
  494. < self._config.keep_checkpoint_per_n_minutes * 60:
  495. self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
  496. self._cur_time_for_keep)
  497. # generate the new checkpoint file and rename it.
  498. global _save_dir
  499. _save_dir = self._directory
  500. cur_file = os.path.join(self._directory, cur_ckpoint_file)
  501. tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
  502. gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
  503. self._last_time_for_keep = time.time()
  504. self._last_triggered_step = cb_params.cur_step_num
  505. if context.get_context("enable_ge"):
  506. _set_cur_net(cb_params.train_network)
  507. cb_params.train_network.exec_checkpoint_graph()
  508. _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
  509. if os.path.exists(gen_file):
  510. shutil.move(gen_file, cur_file)
  511. self._latest_ckpt_file_name = cur_file
  512. @property
  513. def latest_ckpt_file_name(self):
  514. """Return the latest checkpoint path and file name."""
  515. return self._latest_ckpt_file_name
  516. class LossMonitor(Callback):
  517. """
  518. Monitor the loss in training.
  519. If the loss is NAN or INF, it will terminate training.
  520. Note:
  521. If per_print_times is 0 do not print loss.
  522. Args:
  523. per_print_times (int): Print loss every times. Default: 1.
  524. Raises:
  525. ValueError: If print_step is not int or less than zero.
  526. """
  527. def __init__(self, per_print_times=1):
  528. super(LossMonitor, self).__init__()
  529. if not isinstance(per_print_times, int) or per_print_times < 0:
  530. raise ValueError("print_step must be int and >= 0.")
  531. self._per_print_times = per_print_times
  532. def step_end(self, run_context):
  533. cb_params = run_context.original_args()
  534. loss = cb_params.net_outputs
  535. if isinstance(loss, (tuple, list)):
  536. if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
  537. loss = loss[0]
  538. if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
  539. loss = np.mean(loss.asnumpy())
  540. cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
  541. if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
  542. raise ValueError("epoch: {} step: {}. Invalid loss, terminating training."
  543. .format(cb_params.cur_epoch_num, cur_step_in_epoch))
  544. if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
  545. print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
  546. class TimeMonitor(Callback):
  547. """Time Monitor."""
  548. def __init__(self, data_size):
  549. super(TimeMonitor, self).__init__()
  550. self.data_size = data_size
  551. def epoch_begin(self, run_context):
  552. self.epoch_time = time.time()
  553. def epoch_end(self, run_context):
  554. epoch_mseconds = (time.time() - self.epoch_time) * 1000
  555. per_step_mseconds = epoch_mseconds / self.data_size
  556. print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)