# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """custom callbacks for ema and loss""" from copy import deepcopy import numpy as np from mindspore.train.callback import Callback from mindspore.common.parameter import Parameter from mindspore.train.serialization import save_checkpoint from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy from mindspore.train.model import Model from mindspore import Tensor def load_nparray_into_net(net, array_dict): """ Loads dictionary of numpy arrays into network. Args: net (Cell): Cell network. array_dict (dict): dictionary of numpy array format model weights. """ param_not_load = [] for _, param in net.parameters_and_names(): if param.name in array_dict: new_param = array_dict[param.name] param.set_data(Parameter(Tensor(deepcopy(new_param)), name=param.name)) else: param_not_load.append(param.name) return param_not_load class EmaEvalCallBack(Callback): """ Call back that will evaluate the model and save model checkpoint at the end of training epoch. Args: network: tinynet network instance. ema_network: step-wise exponential moving average of network. eval_dataset: the evaluation daatset. decay (float): ema decay. save_epoch (int): defines how often to save checkpoint. dataset_sink_mode (bool): whether to use data sink mode. start_epoch (int): which epoch to start/resume training. """ def __init__(self, network, ema_network, eval_dataset, loss_fn, decay=0.999, save_epoch=1, dataset_sink_mode=True, start_epoch=0): self.network = network self.ema_network = ema_network self.eval_dataset = eval_dataset self.loss_fn = loss_fn self.decay = decay self.save_epoch = save_epoch self.shadow = {} self.ema_accuracy = {} self.best_ema_accuracy = 0 self.best_accuracy = 0 self.best_ema_epoch = 0 self.best_epoch = 0 self._start_epoch = start_epoch self.eval_metrics = {'Validation-Loss': Loss(), 'Top1-Acc': Top1CategoricalAccuracy(), 'Top5-Acc': Top5CategoricalAccuracy()} self.dataset_sink_mode = dataset_sink_mode def begin(self, run_context): """Initialize the EMA parameters """ for _, param in self.network.parameters_and_names(): self.shadow[param.name] = deepcopy(param.data.asnumpy()) def step_end(self, run_context): """Update the EMA parameters""" for _, param in self.network.parameters_and_names(): new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \ self.decay * self.shadow[param.name] self.shadow[param.name] = new_average def epoch_end(self, run_context): """evaluate the model and ema-model at the end of each epoch""" cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1 save_ckpt = (cur_epoch % self.save_epoch == 0) load_nparray_into_net(self.ema_network, self.shadow) model = Model(self.network, loss_fn=self.loss_fn, metrics=self.eval_metrics) model_ema = Model(self.ema_network, loss_fn=self.loss_fn, metrics=self.eval_metrics) acc = model.eval( self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) ema_acc = model_ema.eval( self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode) print("Model Accuracy:", acc) print("EMA-Model Accuracy:", ema_acc) output = [{"name": k, "data": Tensor(v)} for k, v in self.shadow.items()] self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"] if self.best_ema_accuracy < ema_acc["Top1-Acc"]: self.best_ema_accuracy = ema_acc["Top1-Acc"] self.best_ema_epoch = cur_epoch save_checkpoint(output, "ema_best.ckpt") if self.best_accuracy < acc["Top1-Acc"]: self.best_accuracy = acc["Top1-Acc"] self.best_epoch = cur_epoch print("Best Model Accuracy: %s, at epoch %s" % (self.best_accuracy, self.best_epoch)) print("Best EMA-Model Accuracy: %s, at epoch %s" % (self.best_ema_accuracy, self.best_ema_epoch)) if save_ckpt: # Save the ema_model checkpoints ckpt = "{}-{}.ckpt".format("ema", cur_epoch) save_checkpoint(output, ckpt) save_checkpoint(output, "ema_last.ckpt") # Save the model checkpoints save_checkpoint(cb_params.train_network, "last.ckpt") print("Top 10 EMA-Model Accuracies: ") count = 0 for epoch in sorted(self.ema_accuracy, key=self.ema_accuracy.get, reverse=True): if count == 10: break print("epoch: %s, Top-1: %s)" % (epoch, self.ema_accuracy[epoch])) count += 1 class LossMonitor(Callback): """ Monitor the loss in training. If the loss is NAN or INF, it will terminate training. Note: If per_print_times is 0, do not print loss. Args: lr_array (numpy.array): scheduled learning rate. total_epochs (int): Total number of epochs for training. per_print_times (int): Print the loss every time. Default: 1. start_epoch (int): which epoch to start, used when resume from a certain epoch. Raises: ValueError: If print_step is not an integer or less than zero. """ def __init__(self, lr_array, total_epochs, per_print_times=1, start_epoch=0): super(LossMonitor, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: raise ValueError("print_step must be int and >= 0.") self._per_print_times = per_print_times self._lr_array = lr_array self._total_epochs = total_epochs self._start_epoch = start_epoch def step_end(self, run_context): """log epoch, step, loss and learning rate""" cb_params = run_context.original_args() loss = cb_params.net_outputs cur_epoch_num = cb_params.cur_epoch_num + self._start_epoch - 1 if isinstance(loss, (tuple, list)): if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): loss = loss[0] if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): loss = np.mean(loss.asnumpy()) global_step = cb_params.cur_step_num - 1 cur_step_in_epoch = global_step % cb_params.batch_num + 1 if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( cur_epoch_num, cur_step_in_epoch)) if self._per_print_times != 0 and cur_step_in_epoch % self._per_print_times == 0: print("epoch: %s/%s, step: %s/%s, loss is %s, learning rate: %s" % (cur_epoch_num, self._total_epochs, cur_step_in_epoch, cb_params.batch_num, loss, self._lr_array[global_step]), flush=True)