You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 3.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the License);
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # httpwww.apache.orglicensesLICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an AS IS BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Defined callback for DeepSpeech.
  17. """
  18. import time
  19. from mindspore.train.callback import Callback
  20. from mindspore import Tensor
  21. import numpy as np
  22. class TimeMonitor(Callback):
  23. """
  24. Time monitor for calculating cost of each epoch.
  25. Args
  26. data_size (int) step size of an epoch.
  27. """
  28. def __init__(self, data_size):
  29. super(TimeMonitor, self).__init__()
  30. self.data_size = data_size
  31. def epoch_begin(self, run_context):
  32. self.epoch_time = time.time()
  33. def epoch_end(self, run_context):
  34. epoch_mseconds = (time.time() - self.epoch_time) * 1000
  35. per_step_mseconds = epoch_mseconds / self.data_size
  36. print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
  37. def step_begin(self, run_context):
  38. self.step_time = time.time()
  39. def step_end(self, run_context):
  40. step_mseconds = (time.time() - self.step_time) * 1000
  41. print(f"step time {step_mseconds}", flush=True)
  42. class Monitor(Callback):
  43. """
  44. Monitor loss and time.
  45. Args:
  46. lr_init (numpy array): train lr
  47. Returns:
  48. None
  49. """
  50. def __init__(self, lr_init=None):
  51. super(Monitor, self).__init__()
  52. self.lr_init = lr_init
  53. self.lr_init_len = len(lr_init)
  54. def epoch_begin(self, run_context):
  55. self.losses = []
  56. self.epoch_time = time.time()
  57. def epoch_end(self, run_context):
  58. cb_params = run_context.original_args()
  59. epoch_mseconds = (time.time() - self.epoch_time)
  60. per_step_mseconds = epoch_mseconds / cb_params.batch_num
  61. print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
  62. per_step_mseconds,
  63. np.mean(self.losses)))
  64. def step_begin(self, run_context):
  65. self.step_time = time.time()
  66. def step_end(self, run_context):
  67. """
  68. Args:
  69. run_context:
  70. Returns:
  71. """
  72. cb_params = run_context.original_args()
  73. step_mseconds = (time.time() - self.step_time)
  74. step_loss = cb_params.net_outputs
  75. if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
  76. step_loss = step_loss[0]
  77. if isinstance(step_loss, Tensor):
  78. step_loss = np.mean(step_loss.asnumpy())
  79. self.losses.append(step_loss)
  80. cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
  81. print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:.9f}]".format(
  82. cb_params.cur_epoch_num -
  83. 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
  84. np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1].asnumpy()))