You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 3.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ===========================================================================
  15. """Callback."""
  16. import time
  17. from mindspore.train.callback import ModelCheckpoint
  18. from mindspore.train.callback import CheckpointConfig, Callback
  19. class ProgressMonitor(Callback):
  20. '''Progress Monitor.'''
  21. def __init__(self, args):
  22. super(ProgressMonitor, self).__init__()
  23. self.args = args
  24. self.epoch_start_time = 0
  25. self.step_start_time = 0
  26. self.globe_step_cnt = 0
  27. self.local_step_cnt = 0
  28. self.ckpt_history = []
  29. def begin(self, run_context):
  30. if not self.args.epoch_cnt:
  31. self.args.logger.info('start network train...')
  32. if run_context is None:
  33. pass
  34. def step_begin(self, run_context):
  35. if self.local_step_cnt == 0:
  36. self.step_start_time = time.time()
  37. if run_context is None:
  38. pass
  39. def step_end(self, run_context):
  40. '''Callback when step end.'''
  41. if self.local_step_cnt % self.args.log_interval == 0 and self.local_step_cnt > 0:
  42. cb_params = run_context.original_args()
  43. time_used = time.time() - self.step_start_time
  44. fps_mean = self.args.per_batch_size * self.args.log_interval / time_used
  45. self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt,
  46. self.globe_step_cnt +
  47. self.local_step_cnt,
  48. cb_params.net_outputs,
  49. fps_mean))
  50. self.step_start_time = time.time()
  51. self.local_step_cnt += 1
  52. def epoch_begin(self, run_context):
  53. self.epoch_start_time = time.time()
  54. if run_context is None:
  55. pass
  56. def epoch_end(self, run_context):
  57. '''Callback when epoch end.'''
  58. cb_params = run_context.original_args()
  59. self.globe_step_cnt = self.args.steps_per_epoch * (self.args.epoch_cnt + 1) - 1
  60. time_used = time.time() - self.epoch_start_time
  61. fps_mean = self.args.per_batch_size * self.args.steps_per_epoch / time_used
  62. self.args.logger.info(
  63. 'epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt, self.globe_step_cnt,
  64. cb_params.net_outputs, fps_mean))
  65. self.args.epoch_cnt += 1
  66. self.local_step_cnt = 0
  67. def end(self, run_context):
  68. pass
  69. def callback_func(args, cb, prefix):
  70. callbacks = [cb]
  71. if args.rank_save_ckpt_flag:
  72. ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
  73. ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
  74. ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix=prefix)
  75. callbacks.append(ckpt_cb)
  76. return callbacks