You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tutorial_9_callback.rst 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. ==============================================================================
  2. Callback 教程
  3. ==============================================================================
  4. 在训练时,我们常常要使用trick来提高模型的性能(如调节学习率),或者要打印训练中的信息。
  5. 这里我们提供Callback类,在Trainer中插入代码,完成一些自定义的操作。
  6. 我们使用和 :doc:`/user/quickstart` 中一样的任务来进行详细的介绍。
  7. 给出一段评价性文字,预测其情感倾向是积极(label=1)、消极(label=0)还是中性(label=2),使用 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester` 来进行快速训练和测试。
  8. 关于数据处理,Loss和Optimizer的选择可以看其他教程,这里仅在训练时加入学习率衰减。
  9. ---------------------
  10. Callback的构建和使用
  11. ---------------------
  12. 创建Callback
  13. 我们可以继承fastNLP :class:`~fastNLP.Callback` 类来定义自己的Callback。
  14. 这里我们实现一个让学习率线性衰减的Callback。
  15. .. code-block:: python
  16. import fastNLP
  17. class LRDecay(fastNLP.Callback):
  18. def __init__(self):
  19. super(MyCallback, self).__init__()
  20. self.base_lrs = []
  21. self.delta = []
  22. def on_train_begin(self):
  23. # 初始化,仅训练开始时调用
  24. self.base_lrs = [pg['lr'] for pg in self.optimizer.param_groups]
  25. self.delta = [float(lr) / self.n_epochs for lr in self.base_lrs]
  26. def on_epoch_end(self):
  27. # 每个epoch结束时,更新学习率
  28. ep = self.epoch
  29. lrs = [lr - d * ep for lr, d in zip(self.base_lrs, self.delta)]
  30. self.change_lr(lrs)
  31. def change_lr(self, lrs):
  32. for pg, lr in zip(self.optimizer.param_groups, lrs):
  33. pg['lr'] = lr
  34. 这里,:class:`~fastNLP.Callback` 中所有以 ``on_`` 开头的类方法会在 :class:`~fastNLP.Trainer` 的训练中在特定时间调用。
  35. 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。
  36. 具体有哪些类方法,参见文档。
  37. 另外,为了使用方便,可以在 :class:`~fastNLP.Callback` 内部访问 :class:`~fastNLP.Trainer` 中的属性,如 optimizer, epoch, step,分别对应训练时的优化器,当前epoch数,和当前的总step数。
  38. 具体可访问的属性,参见文档。
  39. 使用Callback
  40. 在定义好 :class:`~fastNLP.Callback` 之后,就能将它传入Trainer的 ``callbacks`` 参数,在实际训练时使用。
  41. .. code-block:: python
  42. """
  43. 数据预处理,模型定义等等
  44. """
  45. trainer = fastNLP.Trainer(
  46. model=model, train_data=train_data, dev_data=dev_data,
  47. optimizer=optimizer, metrics=metrics,
  48. batch_size=10, n_epochs=100,
  49. callbacks=[LRDecay()])
  50. trainer.train()