# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # less required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import time import numpy as np from mindspore.train.callback import Callback from mindspore.common.tensor import Tensor class StepLossTimeMonitor(Callback): def __init__(self, batch_size, per_print_times=1): super(StepLossTimeMonitor, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: raise ValueError("print_step must be int and >= 0.") self._per_print_times = per_print_times self.batch_size = batch_size def step_begin(self, run_context): self.step_time = time.time() def step_end(self, run_context): step_seconds = time.time() - self.step_time step_fps = self.batch_size*1.0/step_seconds cb_params = run_context.original_args() loss = cb_params.net_outputs if isinstance(loss, (tuple, list)): if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): loss = loss[0] if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): loss = np.mean(loss.asnumpy()) cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( cb_params.cur_epoch_num, cur_step_in_epoch)) if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: # TEST print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)