You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from collections.abc import Iterator
  2. from typing import Dict
  3. from fastNLP.core.callbacks import CallbackManager
  4. from .state import TrainerState
  5. class TrainerEventTrigger:
  6. """
  7. 为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer'
  8. 抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存
  9. 模型。
  10. """
  11. callback_manager: CallbackManager
  12. trainer_state: TrainerState
  13. def on_after_trainer_initialized(self, driver):
  14. self.callback_manager.on_after_trainer_initialized(self, driver)
  15. def on_sanity_check_begin(self):
  16. self.callback_manager.on_sanity_check_begin(self)
  17. def on_sanity_check_end(self, sanity_check_res):
  18. self.callback_manager.on_sanity_check_end(self, sanity_check_res)
  19. def on_train_begin(self):
  20. self.callback_manager.on_train_begin(self)
  21. def on_train_end(self):
  22. self.callback_manager.on_train_end(self)
  23. def on_train_epoch_begin(self):
  24. self.callback_manager.on_train_epoch_begin(self)
  25. def on_train_epoch_end(self):
  26. self.callback_manager.on_train_epoch_end(self)
  27. def on_fetch_data_begin(self):
  28. self.callback_manager.on_fetch_data_begin(self)
  29. def on_fetch_data_end(self):
  30. self.callback_manager.on_fetch_data_end(self)
  31. def on_train_batch_begin(self, batch, indices=None):
  32. self.callback_manager.on_train_batch_begin(self, batch, indices)
  33. def on_train_batch_end(self):
  34. self.callback_manager.on_train_batch_end(self)
  35. def on_exception(self, exception):
  36. self.callback_manager.on_exception(self, exception)
  37. def on_save_model(self):
  38. self.callback_manager.on_save_model(self)
  39. def on_load_model(self):
  40. self.callback_manager.on_load_model(self)
  41. def on_save_checkpoint(self) -> Dict:
  42. return self.callback_manager.on_save_checkpoint(self)
  43. def on_load_checkpoint(self, states):
  44. self.callback_manager.on_load_checkpoint(self, states)
  45. def on_before_backward(self, outputs):
  46. self.callback_manager.on_before_backward(self, outputs)
  47. def on_after_backward(self):
  48. self.callback_manager.on_after_backward(self)
  49. def on_before_optimizers_step(self, optimizers):
  50. self.callback_manager.on_before_optimizers_step(self, optimizers)
  51. def on_after_optimizers_step(self, optimizers):
  52. self.callback_manager.on_after_optimizers_step(self, optimizers)
  53. def on_before_zero_grad(self, optimizers):
  54. self.callback_manager.on_before_zero_grad(self, optimizers)
  55. def on_after_zero_grad(self, optimizers):
  56. self.callback_manager.on_after_zero_grad(self, optimizers)
  57. def on_validate_begin(self):
  58. self.callback_manager.on_validate_begin(self)
  59. def on_validate_end(self, results):
  60. self.trainer_state.save_on_this_step = True
  61. self.callback_manager.on_validate_end(self, results)
  62. class _TruncatedDataLoader:
  63. def __init__(self, dataloader, num_batches: int):
  64. """
  65. 限制
  66. :param dataloader: 可迭代的 dataloader 。
  67. :param num_batches: 迭代多少个 batch 就停止。
  68. """
  69. self.dataloader = dataloader
  70. self._num_batches = min(num_batches, len(dataloader))
  71. self._count = 0
  72. def __len__(self):
  73. r"""
  74. 为了在外部调用 `len` 方法时正确地返回当前会迭代的长度;
  75. """
  76. return self._num_batches
  77. def __iter__(self):
  78. # 将初试的 `dataloader` 转换成一个 `Iterator` 的逻辑应该放在这里,即只有当外界真正的调用 iter(dataloader) 的时候才需要返回一个 Iterator;
  79. # TODO 测试一下
  80. self._iterator = iter(self.dataloader)
  81. self._count = 0
  82. return self
  83. def __next__(self):
  84. if self._count >= self._num_batches:
  85. raise StopIteration
  86. self._count += 1
  87. # 注意 dataloader 数据不足时会自己本身触发 `StopIteration`;
  88. return next(self._iterator)
  89. def __getattr__(self, item):
  90. return getattr(self.dataloader, item)