You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_loss_monitor.py 2.4 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """LossMonitor Callback class."""
  16. import numpy as np
  17. from mindspore.common.tensor import Tensor
  18. from ._callback import Callback
  19. class LossMonitor(Callback):
  20. """
  21. Monitor the loss in training.
  22. If the loss is NAN or INF, it will terminate training.
  23. Note:
  24. If per_print_times is 0, do not print loss.
  25. Args:
  26. per_print_times (int): Print the loss each every time. Default: 1.
  27. Raises:
  28. ValueError: If print_step is not an integer or less than zero.
  29. """
  30. def __init__(self, per_print_times=1):
  31. super(LossMonitor, self).__init__()
  32. if not isinstance(per_print_times, int) or per_print_times < 0:
  33. raise ValueError("print_step must be int and >= 0.")
  34. self._per_print_times = per_print_times
  35. def step_end(self, run_context):
  36. cb_params = run_context.original_args()
  37. loss = cb_params.net_outputs
  38. if isinstance(loss, (tuple, list)):
  39. if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
  40. loss = loss[0]
  41. if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
  42. loss = np.mean(loss.asnumpy())
  43. cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
  44. if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
  45. raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
  46. cb_params.cur_epoch_num, cur_step_in_epoch))
  47. if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
  48. print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)