|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ===========================================================================
- """Callback."""
- import time
-
- from mindspore.train.callback import ModelCheckpoint
- from mindspore.train.callback import CheckpointConfig, Callback
-
-
- class ProgressMonitor(Callback):
- '''Progress Monitor.'''
- def __init__(self, args):
- super(ProgressMonitor, self).__init__()
- self.args = args
- self.epoch_start_time = 0
- self.step_start_time = 0
- self.globe_step_cnt = 0
- self.local_step_cnt = 0
- self.ckpt_history = []
-
- def begin(self, run_context):
- if not self.args.epoch_cnt:
- self.args.logger.info('start network train...')
- if run_context is None:
- pass
-
- def step_begin(self, run_context):
- if self.local_step_cnt == 0:
- self.step_start_time = time.time()
- if run_context is None:
- pass
-
- def step_end(self, run_context):
- '''Callback when step end.'''
- if self.local_step_cnt % self.args.log_interval == 0 and self.local_step_cnt > 0:
- cb_params = run_context.original_args()
- time_used = time.time() - self.step_start_time
- fps_mean = self.args.per_batch_size * self.args.log_interval / time_used
- self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt,
- self.globe_step_cnt +
- self.local_step_cnt,
- cb_params.net_outputs,
- fps_mean))
- self.step_start_time = time.time()
- self.local_step_cnt += 1
-
- def epoch_begin(self, run_context):
- self.epoch_start_time = time.time()
- if run_context is None:
- pass
-
- def epoch_end(self, run_context):
- '''Callback when epoch end.'''
- cb_params = run_context.original_args()
- self.globe_step_cnt = self.args.steps_per_epoch * (self.args.epoch_cnt + 1) - 1
-
- time_used = time.time() - self.epoch_start_time
- fps_mean = self.args.per_batch_size * self.args.steps_per_epoch / time_used
- self.args.logger.info(
- 'epoch[{}], iter[{}], loss:{}, mean_wps:{:.2f} wavs/sec'.format(self.args.epoch_cnt, self.globe_step_cnt,
- cb_params.net_outputs, fps_mean))
- self.args.epoch_cnt += 1
- self.local_step_cnt = 0
-
- def end(self, run_context):
- pass
-
-
- def callback_func(args, cb, prefix):
- callbacks = [cb]
- if args.rank_save_ckpt_flag:
- ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval
- ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
- ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix=prefix)
- callbacks.append(ckpt_cb)
- return callbacks
|