You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callbacks.py 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from fastNLP.core.callback import Callback
  2. import torch
  3. from torch import nn
  4. class OptimizerCallback(Callback):
  5. def __init__(self, optimizer, scheduler, update_every=4):
  6. super().__init__()
  7. self._optimizer = optimizer
  8. self.scheduler = scheduler
  9. self._update_every = update_every
  10. def on_backward_end(self):
  11. if self.step % self._update_every==0:
  12. # nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5)
  13. # self._optimizer.step()
  14. self.scheduler.step()
  15. # self.model.zero_grad()
  16. class DevCallback(Callback):
  17. def __init__(self, tester, metric_key='u_f1'):
  18. super().__init__()
  19. self.tester = tester
  20. setattr(tester, 'verbose', 0)
  21. self.metric_key = metric_key
  22. self.record_best = False
  23. self.best_eval_value = 0
  24. self.best_eval_res = None
  25. self.best_dev_res = None # 存取dev的表现
  26. def on_valid_begin(self):
  27. eval_res = self.tester.test()
  28. metric_name = self.tester.metrics[0].__class__.__name__
  29. metric_value = eval_res[metric_name][self.metric_key]
  30. if metric_value>self.best_eval_value:
  31. self.best_eval_value = metric_value
  32. self.best_epoch = self.trainer.epoch
  33. self.record_best = True
  34. self.best_eval_res = eval_res
  35. self.test_eval_res = eval_res
  36. eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \
  37. self.tester._format_eval_results(eval_res)
  38. self.pbar.write(eval_str)
  39. def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
  40. if self.record_best:
  41. self.best_dev_res = eval_result
  42. self.record_best = False
  43. if is_better_eval:
  44. self.best_dev_res_on_dev = eval_result
  45. self.best_test_res_on_dev = self.test_eval_res
  46. self.dev_epoch = self.epoch
  47. def on_train_end(self):
  48. print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch,
  49. self.tester._format_eval_results(self.best_eval_res),
  50. self.tester._format_eval_results(self.best_dev_res)))
  51. print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch,
  52. self.tester._format_eval_results(self.best_test_res_on_dev),
  53. self.tester._format_eval_results(self.best_dev_res_on_dev)))