You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """custom callbacks for ema and loss"""
  16. from copy import deepcopy
  17. import numpy as np
  18. from mindspore.train.callback import Callback
  19. from mindspore.common.parameter import Parameter
  20. from mindspore.train.serialization import save_checkpoint
  21. from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
  22. from mindspore.train.model import Model
  23. from mindspore import Tensor
  24. def load_nparray_into_net(net, array_dict):
  25. """
  26. Loads dictionary of numpy arrays into network.
  27. Args:
  28. net (Cell): Cell network.
  29. array_dict (dict): dictionary of numpy array format model weights.
  30. """
  31. param_not_load = []
  32. for _, param in net.parameters_and_names():
  33. if param.name in array_dict:
  34. new_param = array_dict[param.name]
  35. param.set_data(Parameter(Tensor(deepcopy(new_param)), name=param.name))
  36. else:
  37. param_not_load.append(param.name)
  38. return param_not_load
  39. class EmaEvalCallBack(Callback):
  40. """
  41. Call back that will evaluate the model and save model checkpoint at
  42. the end of training epoch.
  43. Args:
  44. network: tinynet network instance.
  45. ema_network: step-wise exponential moving average of network.
  46. eval_dataset: the evaluation daatset.
  47. decay (float): ema decay.
  48. save_epoch (int): defines how often to save checkpoint.
  49. dataset_sink_mode (bool): whether to use data sink mode.
  50. start_epoch (int): which epoch to start/resume training.
  51. """
  52. def __init__(self, network, ema_network, eval_dataset, loss_fn, decay=0.999,
  53. save_epoch=1, dataset_sink_mode=True, start_epoch=0):
  54. self.network = network
  55. self.ema_network = ema_network
  56. self.eval_dataset = eval_dataset
  57. self.loss_fn = loss_fn
  58. self.decay = decay
  59. self.save_epoch = save_epoch
  60. self.shadow = {}
  61. self.ema_accuracy = {}
  62. self.best_ema_accuracy = 0
  63. self.best_accuracy = 0
  64. self.best_ema_epoch = 0
  65. self.best_epoch = 0
  66. self._start_epoch = start_epoch
  67. self.eval_metrics = {'Validation-Loss': Loss(),
  68. 'Top1-Acc': Top1CategoricalAccuracy(),
  69. 'Top5-Acc': Top5CategoricalAccuracy()}
  70. self.dataset_sink_mode = dataset_sink_mode
  71. def begin(self, run_context):
  72. """Initialize the EMA parameters """
  73. for _, param in self.network.parameters_and_names():
  74. self.shadow[param.name] = deepcopy(param.data.asnumpy())
  75. def step_end(self, run_context):
  76. """Update the EMA parameters"""
  77. for _, param in self.network.parameters_and_names():
  78. new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \
  79. self.decay * self.shadow[param.name]
  80. self.shadow[param.name] = new_average
  81. def epoch_end(self, run_context):
  82. """evaluate the model and ema-model at the end of each epoch"""
  83. cb_params = run_context.original_args()
  84. cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1
  85. save_ckpt = (cur_epoch % self.save_epoch == 0)
  86. load_nparray_into_net(self.ema_network, self.shadow)
  87. model = Model(self.network, loss_fn=self.loss_fn, metrics=self.eval_metrics)
  88. model_ema = Model(self.ema_network, loss_fn=self.loss_fn,
  89. metrics=self.eval_metrics)
  90. acc = model.eval(
  91. self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
  92. ema_acc = model_ema.eval(
  93. self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
  94. print("Model Accuracy:", acc)
  95. print("EMA-Model Accuracy:", ema_acc)
  96. output = [{"name": k, "data": Tensor(v)}
  97. for k, v in self.shadow.items()]
  98. self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
  99. if self.best_ema_accuracy < ema_acc["Top1-Acc"]:
  100. self.best_ema_accuracy = ema_acc["Top1-Acc"]
  101. self.best_ema_epoch = cur_epoch
  102. save_checkpoint(output, "ema_best.ckpt")
  103. if self.best_accuracy < acc["Top1-Acc"]:
  104. self.best_accuracy = acc["Top1-Acc"]
  105. self.best_epoch = cur_epoch
  106. print("Best Model Accuracy: %s, at epoch %s" %
  107. (self.best_accuracy, self.best_epoch))
  108. print("Best EMA-Model Accuracy: %s, at epoch %s" %
  109. (self.best_ema_accuracy, self.best_ema_epoch))
  110. if save_ckpt:
  111. # Save the ema_model checkpoints
  112. ckpt = "{}-{}.ckpt".format("ema", cur_epoch)
  113. save_checkpoint(output, ckpt)
  114. save_checkpoint(output, "ema_last.ckpt")
  115. # Save the model checkpoints
  116. save_checkpoint(cb_params.train_network, "last.ckpt")
  117. print("Top 10 EMA-Model Accuracies: ")
  118. count = 0
  119. for epoch in sorted(self.ema_accuracy, key=self.ema_accuracy.get,
  120. reverse=True):
  121. if count == 10:
  122. break
  123. print("epoch: %s, Top-1: %s)" % (epoch, self.ema_accuracy[epoch]))
  124. count += 1
  125. class LossMonitor(Callback):
  126. """
  127. Monitor the loss in training.
  128. If the loss is NAN or INF, it will terminate training.
  129. Note:
  130. If per_print_times is 0, do not print loss.
  131. Args:
  132. lr_array (numpy.array): scheduled learning rate.
  133. total_epochs (int): Total number of epochs for training.
  134. per_print_times (int): Print the loss every time. Default: 1.
  135. start_epoch (int): which epoch to start, used when resume from a
  136. certain epoch.
  137. Raises:
  138. ValueError: If print_step is not an integer or less than zero.
  139. """
  140. def __init__(self, lr_array, total_epochs, per_print_times=1, start_epoch=0):
  141. super(LossMonitor, self).__init__()
  142. if not isinstance(per_print_times, int) or per_print_times < 0:
  143. raise ValueError("print_step must be int and >= 0.")
  144. self._per_print_times = per_print_times
  145. self._lr_array = lr_array
  146. self._total_epochs = total_epochs
  147. self._start_epoch = start_epoch
  148. def step_end(self, run_context):
  149. """log epoch, step, loss and learning rate"""
  150. cb_params = run_context.original_args()
  151. loss = cb_params.net_outputs
  152. cur_epoch_num = cb_params.cur_epoch_num + self._start_epoch - 1
  153. if isinstance(loss, (tuple, list)):
  154. if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
  155. loss = loss[0]
  156. if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
  157. loss = np.mean(loss.asnumpy())
  158. global_step = cb_params.cur_step_num - 1
  159. cur_step_in_epoch = global_step % cb_params.batch_num + 1
  160. if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
  161. raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
  162. cur_epoch_num, cur_step_in_epoch))
  163. if self._per_print_times != 0 and cur_step_in_epoch % self._per_print_times == 0:
  164. print("epoch: %s/%s, step: %s/%s, loss is %s, learning rate: %s"
  165. % (cur_epoch_num, self._total_epochs, cur_step_in_epoch,
  166. cb_params.batch_num, loss, self._lr_array[global_step]),
  167. flush=True)