You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 4.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the License);
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # httpwww.apache.orglicensesLICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an AS IS BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Defined callback for DeepFM.
  17. """
  18. import time
  19. from mindspore.train.callback import Callback
  20. def add_write(file_path, out_str):
  21. with open(file_path, 'a+', encoding='utf-8') as file_out:
  22. file_out.write(out_str + '\n')
  23. class EvalCallBack(Callback):
  24. """
  25. Monitor the loss in training.
  26. If the loss is NAN or INF terminating training.
  27. Note
  28. If per_print_times is 0 do not print loss.
  29. """
  30. def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
  31. super(EvalCallBack, self).__init__()
  32. self.model = model
  33. self.eval_dataset = eval_dataset
  34. self.aucMetric = auc_metric
  35. self.aucMetric.clear()
  36. self.eval_file_path = eval_file_path
  37. def epoch_end(self, run_context):
  38. start_time = time.time()
  39. out = self.model.eval(self.eval_dataset)
  40. eval_time = int(time.time() - start_time)
  41. time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  42. out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
  43. time_str, out.values(), eval_time)
  44. print(out_str)
  45. add_write(self.eval_file_path, out_str)
  46. class LossCallBack(Callback):
  47. """
  48. Monitor the loss in training.
  49. If the loss is NAN or INF terminating training.
  50. Note
  51. If per_print_times is 0 do not print loss.
  52. Args
  53. loss_file_path (str) The file absolute path, to save as loss_file;
  54. per_print_times (int) Print loss every times. Default 1.
  55. """
  56. def __init__(self, loss_file_path, per_print_times=1):
  57. super(LossCallBack, self).__init__()
  58. if not isinstance(per_print_times, int) or per_print_times < 0:
  59. raise ValueError("print_step must be int and >= 0.")
  60. self.loss_file_path = loss_file_path
  61. self._per_print_times = per_print_times
  62. def step_end(self, run_context):
  63. cb_params = run_context.original_args()
  64. loss = cb_params.net_outputs.asnumpy()
  65. cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
  66. cur_num = cb_params.cur_step_num
  67. if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
  68. with open(self.loss_file_path, "a+") as loss_file:
  69. time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  70. loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
  71. time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
  72. print("epoch: {} step: {}, loss is {}\n".format(
  73. cb_params.cur_epoch_num, cur_step_in_epoch, loss))
  74. class TimeMonitor(Callback):
  75. """
  76. Time monitor for calculating cost of each epoch.
  77. Args
  78. data_size (int) step size of an epoch.
  79. """
  80. def __init__(self, data_size):
  81. super(TimeMonitor, self).__init__()
  82. self.data_size = data_size
  83. def epoch_begin(self, run_context):
  84. self.epoch_time = time.time()
  85. def epoch_end(self, run_context):
  86. epoch_mseconds = (time.time() - self.epoch_time) * 1000
  87. per_step_mseconds = epoch_mseconds / self.data_size
  88. print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
  89. def step_begin(self, run_context):
  90. self.step_time = time.time()
  91. def step_end(self, run_context):
  92. step_mseconds = (time.time() - self.step_time) * 1000
  93. print(f"step time {step_mseconds}", flush=True)