You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import os
  2. import torch
  3. import sys
  4. from torch import nn
  5. from fastNLP.core.callback import Callback
  6. from fastNLP.core.utils import _get_model_device
  7. class MyCallback(Callback):
  8. def __init__(self, args):
  9. super(MyCallback, self).__init__()
  10. self.args = args
  11. self.real_step = 0
  12. def on_step_end(self):
  13. if self.step % self.update_every == 0 and self.step > 0:
  14. self.real_step += 1
  15. cur_lr = self.args.max_lr * 100 * min(self.real_step ** (-0.5), self.real_step * self.args.warmup_steps**(-1.5))
  16. for param_group in self.optimizer.param_groups:
  17. param_group['lr'] = cur_lr
  18. if self.real_step % 1000 == 0:
  19. self.pbar.write('Current learning rate is {:.8f}, real_step: {}'.format(cur_lr, self.real_step))
  20. def on_epoch_end(self):
  21. self.pbar.write('Epoch {} is done !!!'.format(self.epoch))
  22. def _save_model(model, model_name, save_dir, only_param=False):
  23. """ 存储不含有显卡信息的 state_dict 或 model
  24. :param model:
  25. :param model_name:
  26. :param save_dir: 保存的 directory
  27. :param only_param:
  28. :return:
  29. """
  30. model_path = os.path.join(save_dir, model_name)
  31. if not os.path.isdir(save_dir):
  32. os.makedirs(save_dir, exist_ok=True)
  33. if isinstance(model, nn.DataParallel):
  34. model = model.module
  35. if only_param:
  36. state_dict = model.state_dict()
  37. for key in state_dict:
  38. state_dict[key] = state_dict[key].cpu()
  39. torch.save(state_dict, model_path)
  40. else:
  41. _model_device = _get_model_device(model)
  42. model.cpu()
  43. torch.save(model, model_path)
  44. model.to(_model_device)
  45. class SaveModelCallback(Callback):
  46. """
  47. 由于Trainer在训练过程中只会保存最佳的模型, 该 callback 可实现多种方式的结果存储。
  48. 会根据训练开始的时间戳在 save_dir 下建立文件夹,在再文件夹下存放多个模型
  49. -save_dir
  50. -2019-07-03-15-06-36
  51. -epoch0step20{metric_key}{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
  52. -epoch1step40
  53. -2019-07-03-15-10-00
  54. -epoch:0step:20{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能
  55. :param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型
  56. :param int top: 保存dev表现top多少模型。-1为保存所有模型
  57. :param bool only_param: 是否只保存模型权重
  58. :param save_on_exception: 发生exception时,是否保存一份当时的模型
  59. """
  60. def __init__(self, save_dir, top=5, only_param=False, save_on_exception=False):
  61. super().__init__()
  62. if not os.path.isdir(save_dir):
  63. raise IsADirectoryError("{} is not a directory.".format(save_dir))
  64. self.save_dir = save_dir
  65. if top < 0:
  66. self.top = sys.maxsize
  67. else:
  68. self.top = top
  69. self._ordered_save_models = [] # List[Tuple], Tuple[0]是metric, Tuple[1]是path。metric是依次变好的,所以从头删
  70. self.only_param = only_param
  71. self.save_on_exception = save_on_exception
  72. def on_train_begin(self):
  73. self.save_dir = os.path.join(self.save_dir, self.trainer.start_time)
  74. def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
  75. metric_value = list(eval_result.values())[0][metric_key]
  76. self._save_this_model(metric_value)
  77. def _insert_into_ordered_save_models(self, pair):
  78. # pair:(metric_value, model_name)
  79. # 返回save的模型pair与删除的模型pair. pair中第一个元素是metric的值,第二个元素是模型的名称
  80. index = -1
  81. for _pair in self._ordered_save_models:
  82. if _pair[0]>=pair[0] and self.trainer.increase_better:
  83. break
  84. if not self.trainer.increase_better and _pair[0]<=pair[0]:
  85. break
  86. index += 1
  87. save_pair = None
  88. if len(self._ordered_save_models)<self.top or (len(self._ordered_save_models)>=self.top and index!=-1):
  89. save_pair = pair
  90. self._ordered_save_models.insert(index+1, pair)
  91. delete_pair = None
  92. if len(self._ordered_save_models)>self.top:
  93. delete_pair = self._ordered_save_models.pop(0)
  94. return save_pair, delete_pair
  95. def _save_this_model(self, metric_value):
  96. name = "epoch:{}_step:{}_{}:{:.6f}.pt".format(self.epoch, self.step, self.trainer.metric_key, metric_value)
  97. save_pair, delete_pair = self._insert_into_ordered_save_models((metric_value, name))
  98. if save_pair:
  99. try:
  100. _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
  101. except Exception as e:
  102. print(f"The following exception:{e} happens when saves model to {self.save_dir}.")
  103. if delete_pair:
  104. try:
  105. delete_model_path = os.path.join(self.save_dir, delete_pair[1])
  106. if os.path.exists(delete_model_path):
  107. os.remove(delete_model_path)
  108. except Exception as e:
  109. print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
  110. def on_exception(self, exception):
  111. if self.save_on_exception:
  112. name = "epoch:{}_step:{}_Exception:{}.pt".format(self.epoch, self.step, exception.__class__.__name__)
  113. _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)