You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Callback.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. import os
  18. import sys
  19. import time
  20. import numpy as np
  21. import torch
  22. from fastNLP.core.const import Const
  23. from fastNLP.io.model_io import ModelSaver
  24. from fastNLP.core.callback import Callback, EarlyStopError
  25. from tools.logger import *
  26. class TrainCallback(Callback):
  27. def __init__(self, hps, patience=3, quit_all=True):
  28. super().__init__()
  29. self._hps = hps
  30. self.patience = patience
  31. self.wait = 0
  32. if type(quit_all) != bool:
  33. raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.")
  34. self.quit_all = quit_all
  35. def on_epoch_begin(self):
  36. self.epoch_start_time = time.time()
  37. # def on_loss_begin(self, batch_y, predict_y):
  38. # """
  39. #
  40. # :param batch_y: dict
  41. # input_len: [batch, N]
  42. # :param predict_y: dict
  43. # p_sent: [batch, N, 2]
  44. # :return:
  45. # """
  46. # input_len = batch_y[Const.INPUT_LEN]
  47. # batch_y[Const.TARGET] = batch_y[Const.TARGET] * ((1 - input_len) * -100)
  48. # # predict_y["p_sent"] = predict_y["p_sent"] * input_len.unsqueeze(-1)
  49. # # logger.debug(predict_y["p_sent"][0:5,:,:])
  50. def on_backward_begin(self, loss):
  51. """
  52. :param loss: []
  53. :return:
  54. """
  55. if not (np.isfinite(loss.data)).numpy():
  56. logger.error("train Loss is not finite. Stopping.")
  57. logger.info(loss)
  58. for name, param in self.model.named_parameters():
  59. if param.requires_grad:
  60. logger.info(name)
  61. logger.info(param.grad.data.sum())
  62. raise Exception("train Loss is not finite. Stopping.")
  63. def on_backward_end(self):
  64. if self._hps.grad_clip:
  65. torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm)
  66. def on_epoch_end(self):
  67. logger.info(' | end of epoch {:3d} | time: {:5.2f}s | '
  68. .format(self.epoch, (time.time() - self.epoch_start_time)))
  69. def on_valid_begin(self):
  70. self.valid_start_time = time.time()
  71. def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
  72. logger.info(' | end of valid {:3d} | time: {:5.2f}s | '
  73. .format(self.epoch, (time.time() - self.valid_start_time)))
  74. # early stop
  75. if not is_better_eval:
  76. if self.wait == self.patience:
  77. train_dir = os.path.join(self._hps.save_root, "train")
  78. save_file = os.path.join(train_dir, "earlystop.pkl")
  79. saver = ModelSaver(save_file)
  80. saver.save_pytorch(self.model)
  81. logger.info('[INFO] Saving early stop model to %s', save_file)
  82. raise EarlyStopError("Early stopping raised.")
  83. else:
  84. self.wait += 1
  85. else:
  86. self.wait = 0
  87. # lr descent
  88. if self._hps.lr_descent:
  89. new_lr = max(5e-6, self._hps.lr / (self.epoch + 1))
  90. for param_group in list(optimizer.param_groups):
  91. param_group['lr'] = new_lr
  92. logger.info("[INFO] The learning rate now is %f", new_lr)
  93. def on_exception(self, exception):
  94. if isinstance(exception, KeyboardInterrupt):
  95. logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...")
  96. train_dir = os.path.join(self._hps.save_root, "train")
  97. save_file = os.path.join(train_dir, "earlystop.pkl")
  98. saver = ModelSaver(save_file)
  99. saver.save_pytorch(self.model)
  100. logger.info('[INFO] Saving early stop model to %s', save_file)
  101. if self.quit_all is True:
  102. sys.exit(0) # 直接退出程序
  103. else:
  104. pass
  105. else:
  106. raise exception # 抛出陌生Error