You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Callback.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. import os
  18. import sys
  19. import time
  20. import numpy as np
  21. import torch
  22. from fastNLP.core.const import Const
  23. from fastNLP.io.model_io import ModelSaver
  24. from fastNLP.core.callback import Callback, EarlyStopError
  25. from fastNLP.core._logger import logger
  26. class TrainCallback(Callback):
  27. def __init__(self, hps, patience=3, quit_all=True):
  28. super().__init__()
  29. self._hps = hps
  30. self.patience = patience
  31. self.wait = 0
  32. self.train_loss = 0.0
  33. self.prev_train_avg_loss = 1000.0
  34. self.train_dir = os.path.join(self._hps.save_root, "train")
  35. if type(quit_all) != bool:
  36. raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.")
  37. self.quit_all = quit_all
  38. def on_epoch_begin(self):
  39. self.epoch_start_time = time.time()
  40. self.model.Train = True
  41. def on_backward_begin(self, loss):
  42. """
  43. :param loss: []
  44. :return:
  45. """
  46. if not (np.isfinite(loss.data)).numpy():
  47. logger.error("train Loss is not finite. Stopping.")
  48. logger.info(loss)
  49. for name, param in self.model.named_parameters():
  50. if param.requires_grad:
  51. logger.info(name)
  52. logger.info(param.grad.data.sum())
  53. raise Exception("train Loss is not finite. Stopping.")
  54. self.train_loss += loss.data
  55. def on_backward_end(self):
  56. if self._hps.grad_clip:
  57. torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._hps.max_grad_norm)
  58. torch.cuda.empty_cache()
  59. def on_epoch_end(self):
  60. epoch_avg_loss = self.train_loss / self.n_steps
  61. logger.info(' | end of epoch {:3d} | time: {:5.2f}s | train loss: {:5.6f}'
  62. .format(self.epoch, (time.time() - self.epoch_start_time), epoch_avg_loss))
  63. if self.prev_train_avg_loss < epoch_avg_loss:
  64. save_file = os.path.join(self.train_dir, "earlystop.pkl")
  65. self.save_model(save_file)
  66. else:
  67. self.prev_train_avg_loss = epoch_avg_loss
  68. self.train_loss = 0.0
  69. # save epoch
  70. save_file = os.path.join(self.train_dir, "epoch_%d.pkl" % self.epoch)
  71. self.save_model(save_file)
  72. def on_valid_begin(self):
  73. self.valid_start_time = time.time()
  74. self.model.Train = False
  75. def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
  76. logger.info(' | end of valid {:3d} | time: {:5.2f}s | '
  77. .format(self.epoch, (time.time() - self.valid_start_time)))
  78. # early stop
  79. if not is_better_eval:
  80. if self.wait == self.patience:
  81. train_dir = os.path.join(self._hps.save_root, "train")
  82. save_file = os.path.join(train_dir, "earlystop.pkl")
  83. self.save_model(save_file)
  84. raise EarlyStopError("Early stopping raised.")
  85. else:
  86. self.wait += 1
  87. else:
  88. self.wait = 0
  89. # lr descent
  90. if self._hps.lr_descent:
  91. new_lr = max(5e-6, self._hps.lr / (self.epoch + 1))
  92. for param_group in list(optimizer.param_groups):
  93. param_group['lr'] = new_lr
  94. logger.info("[INFO] The learning rate now is %f", new_lr)
  95. def on_exception(self, exception):
  96. if isinstance(exception, KeyboardInterrupt):
  97. logger.error("[Error] Caught keyboard interrupt on worker. Stopping supervisor...")
  98. save_file = os.path.join(self.train_dir, "earlystop.pkl")
  99. self.save_model(save_file)
  100. if self.quit_all is True:
  101. sys.exit(0) # 直接退出程序
  102. else:
  103. pass
  104. else:
  105. raise exception # 抛出陌生Error
  106. def save_model(self, save_file):
  107. saver = ModelSaver(save_file)
  108. saver.save_pytorch(self.model)
  109. logger.info('[INFO] Saving model to %s', save_file)