You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 52 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895
  1. from typing import Union, Optional, List, Callable, Dict, Sequence, BinaryIO, IO
  2. from functools import partial
  3. from collections import defaultdict
  4. import copy
  5. from contextlib import contextmanager
  6. from dataclasses import is_dataclass
  7. import os
  8. from pathlib import Path
  9. import io
  10. __all__ = [
  11. 'Trainer',
  12. ]
  13. from .loops import Loop, TrainBatchLoop
  14. from .utils import State, TrainerState
  15. from .utils.utils import check_evaluate_every
  16. from .evaluator import Evaluator
  17. from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
  18. from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
  19. from fastNLP.core.callbacks.callback import _CallbackWrapper
  20. from fastNLP.core.callbacks.callback_events import _SingleEventState
  21. from fastNLP.core.callbacks.progress_callback import choose_progress_callback
  22. from fastNLP.core.drivers import Driver
  23. from fastNLP.core.drivers.utils import choose_driver
  24. from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
  25. from fastNLP.core.utils.utils import _check_valid_parameters_number
  26. from fastNLP.envs import rank_zero_call
  27. from fastNLP.core.log import logger
  28. from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
  29. from fastNLP.core.utils.exceptions import EarlyStopException
  30. class Trainer(TrainerEventTrigger):
  31. _custom_callbacks: dict = defaultdict(list)
  32. def __init__(
  33. self,
  34. model,
  35. driver,
  36. train_dataloader,
  37. optimizers,
  38. device: Optional[Union[int, List[int], str]] = "cpu",
  39. n_epochs: int = 20,
  40. evaluate_dataloaders=None,
  41. batch_step_fn: Optional[Callable] = None,
  42. evaluate_batch_step_fn: Optional[Callable] = None,
  43. train_fn: Optional[str] = None,
  44. evaluate_fn: Optional[str] = None,
  45. callbacks: Union[List[Callback], Callback, None] = None,
  46. metrics: Optional[dict] = None,
  47. evaluate_every: Optional[Union[int, Callable]] = -1,
  48. input_mapping: Optional[Union[Callable, Dict]] = None,
  49. output_mapping: Optional[Union[Callable, Dict]] = None,
  50. model_wo_auto_param_call: bool = False,
  51. accumulation_steps: int = 1,
  52. fp16: bool = False,
  53. monitor: Union[str, Callable] = None,
  54. larger_better: bool = True,
  55. marker: Optional[str] = None,
  56. **kwargs
  57. ):
  58. r"""
  59. `Trainer` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产
  60. 的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需
  61. 要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP;
  62. :param model: 训练所需要的模型,目前支持 pytorch;
  63. :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle
  64. 等国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练
  65. :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict;
  66. :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List;
  67. :param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你
  68. 可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也
  69. 可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前
  70. 自己构造 DDP 的多进程场景);
  71. device 的可选输入如下所示:
  72. 1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中;
  73. 2. torch.device:将模型装载到torch.device上;
  74. 3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`;
  75. 4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`;
  76. 5. None: 为None则不对模型进行任何处理;
  77. :param n_epochs: 训练总共的 epoch 的数量,默认为 20;
  78. :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
  79. 为 None;
  80. :param batch_step_fn: 定制每次 train batch 执行的函数。该函数应接受两个参数为 `trainer` 和`batch`,不需要要返回值;可以
  81. 参考 fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop中的batch_step_fn函数。
  82. :param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`,
  83. 不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。
  84. :param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`;
  85. 默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,
  86. 则使用模型默认的前向传播函数。
  87. :param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似;
  88. 注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None);如果该值为 None ,将首先尝试寻找模型中是否有
  89. evaluate_step 这个函数,如果没有则使用 forward 函数。
  90. :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类;
  91. :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()};
  92. :param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次;
  93. 为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
  94. 返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。
  95. :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
  96. 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的
  97. value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
  98. 类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;
  99. 注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时);
  100. 如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。
  101. :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个
  102. 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型,
  103. 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
  104. 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
  105. 如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。
  106. :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
  107. 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
  108. 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`;
  109. :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
  110. :param fp16: 是否开启混合精度训练;默认为 False;
  111. :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
  112. 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
  113. 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
  114. :param larger_better: monitor 的值是否是越大越好。
  115. :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
  116. :param kwargs: 一些其它的可能需要的参数;
  117. torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
  118. data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
  119. 注意如果 model_device 为 None,那么 data_device 不会起作用;
  120. torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入
  121. {'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。
  122. set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
  123. use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
  124. 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
  125. evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
  126. output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
  127. ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
  128. log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
  129. progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
  130. 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果
  131. 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
  132. train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
  133. train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
  134. evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
  135. evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。
  136. """
  137. self.model = model
  138. self.marker = marker
  139. if isinstance(driver, str):
  140. self.driver_name = driver
  141. else:
  142. self.driver_name = driver.__class__.__name__
  143. self.device = device
  144. if train_dataloader is None:
  145. raise ValueError("Parameter `train_dataloader` can not be None.")
  146. self.train_dataloader = train_dataloader
  147. self.evaluate_dataloaders = evaluate_dataloaders
  148. self.optimizers = optimizers
  149. self.fp16 = fp16
  150. train_input_mapping = kwargs.get('train_input_mapping', None)
  151. train_output_mapping = kwargs.get('train_output_mapping', None)
  152. evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None)
  153. evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None)
  154. train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \
  155. _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
  156. evaluate_input_mapping, evaluate_output_mapping)
  157. self.input_mapping = train_input_mapping
  158. self.output_mapping = train_output_mapping
  159. self.evaluate_fn = evaluate_fn
  160. self.batch_step_fn = batch_step_fn
  161. if batch_step_fn is not None:
  162. _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
  163. self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True)
  164. else:
  165. self.check_batch_step_fn = lambda *args, **kwargs: ...
  166. # 该变量表示是否检测过 `train_batch_loop`,主要用于当用户通过属性替换的方式使用自己定制的 `train_batch_loop` 时,我们需要检测
  167. # 用户是否正确地调用了 callback 函数以及是否正确地更新了 `trainer_state` 的状态;
  168. # 我们将其默认值置为 True,这表示默认的 `train_batch_loop` 已经检测过,不需要再进行检测;
  169. # 我们只会在第一个 epoch 运行完后进行检测,之后的 epoch 不会再进行检测;
  170. self.has_checked_train_batch_loop = True
  171. self._train_batch_loop = TrainBatchLoop(batch_step_fn=batch_step_fn)
  172. if not isinstance(accumulation_steps, int):
  173. raise ValueError("Parameter `accumulation_steps` can only be `int` type.")
  174. elif accumulation_steps < 0:
  175. raise ValueError("Parameter `accumulation_steps` can only be bigger than 0.")
  176. self.accumulation_steps = accumulation_steps
  177. # todo 思路大概是,每个driver提供一下自己的参数是啥(需要对应回初始化的那个),然后trainer/evalutor在初始化的时候,就检测一下自己手上的参数和driver的是不是一致的,不一致的地方需要warn用户说这些值driver不太一样。感觉可以留到后面做吧
  178. self.driver = choose_driver(
  179. model=model,
  180. driver=driver,
  181. train_dataloader=train_dataloader,
  182. optimizers=optimizers,
  183. device=device,
  184. n_epochs=n_epochs,
  185. evaluate_dataloaders=evaluate_dataloaders,
  186. batch_step_fn=batch_step_fn,
  187. evaluate_batch_step_fn=evaluate_batch_step_fn,
  188. evaluate_fn=evaluate_fn,
  189. callbacks=callbacks,
  190. metrics=metrics,
  191. evaluate_every=evaluate_every,
  192. input_mapping=evaluate_input_mapping,
  193. output_mapping=evaluate_output_mapping,
  194. model_wo_auto_param_call=model_wo_auto_param_call,
  195. accumulation_steps=accumulation_steps,
  196. fp16=fp16,
  197. marker=marker,
  198. **kwargs
  199. )
  200. self.driver.set_optimizers(optimizers=optimizers)
  201. # 根据 progress_bar 参数选择 ProgressBarCallback
  202. progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto'))
  203. if progress_bar_callback is not None:
  204. if callbacks is None:
  205. callbacks = []
  206. elif not isinstance(callbacks, Sequence):
  207. callbacks = [callbacks]
  208. callbacks = list(callbacks) + [progress_bar_callback]
  209. else:
  210. rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
  211. "during training.")
  212. # 初始化 callback manager;
  213. self.callback_manager = CallbackManager(callbacks)
  214. # 添加所有的函数式 callbacks;
  215. self._fetch_matched_fn_callbacks()
  216. # 添加所有的类 callbacks;
  217. self.callback_manager.initialize_class_callbacks()
  218. # 初始化 state,包括提供给用户的接口和我们自己使用的接口;
  219. self.state = State()
  220. self.trainer_state = TrainerState(
  221. n_epochs=n_epochs,
  222. cur_epoch_idx=0,
  223. global_forward_batches=0,
  224. batch_idx_in_epoch=0,
  225. num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化;
  226. total_batches=None
  227. )
  228. """ 设置内部的 Evaluator """
  229. if metrics is None and evaluate_dataloaders is not None:
  230. raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")
  231. if metrics is not None and evaluate_dataloaders is None:
  232. raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.")
  233. self.metrics = metrics
  234. self.evaluate_every = evaluate_every
  235. self.driver.setup()
  236. self.driver.barrier()
  237. use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed())
  238. if use_dist_sampler:
  239. _dist_sampler = "dist"
  240. else:
  241. _dist_sampler = None
  242. self.evaluator = None
  243. self.monitor = monitor
  244. self.larger_better = larger_better
  245. if metrics is not None and evaluate_dataloaders is not None:
  246. check_evaluate_every(evaluate_every)
  247. progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
  248. if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
  249. progress_bar = progress_bar.name
  250. self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
  251. driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn,
  252. evaluate_fn=evaluate_fn, input_mapping=input_mapping,
  253. output_mapping=output_mapping, fp16=fp16, verbose=0,
  254. use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
  255. progress_bar=progress_bar)
  256. if train_fn is not None and not isinstance(train_fn, str):
  257. raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
  258. self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
  259. self.train_fn = train_fn
  260. self.dataloader = self.train_dataloader
  261. self.driver.set_deterministic_dataloader(self.dataloader)
  262. self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
  263. reproducible=self.callback_manager._need_reproducible_sampler)
  264. self.set_grad_to_none = kwargs.get("set_grad_to_none", True)
  265. self.evaluate_batch_step_fn = evaluate_batch_step_fn
  266. self.kwargs = kwargs
  267. self.on_after_trainer_initialized(self.driver)
  268. self.driver.barrier()
  269. def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
  270. num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
  271. catch_KeyboardInterrupt=None):
  272. """
  273. 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint
  274. 去保存断点重训的文件;
  275. :param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。
  276. :param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止,-1 为根据 dataloader 有多少个 batch 决定。
  277. :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误。为 0 表示不检测。
  278. :param resume_from: 从哪个路径下恢复 trainer 的状态
  279. :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。
  280. :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运
  281. 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch )
  282. :return:
  283. """
  284. if catch_KeyboardInterrupt is None:
  285. catch_KeyboardInterrupt = not self.driver.is_distributed()
  286. else:
  287. if self.driver.is_distributed():
  288. if catch_KeyboardInterrupt:
  289. logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device "
  290. "driver. And we are gonna to set it to False.")
  291. catch_KeyboardInterrupt = False
  292. self._set_num_eval_batch_per_dl(num_eval_batch_per_dl)
  293. if resume_from is not None:
  294. if os.path.exists(resume_from):
  295. self.load(resume_from, resume_training=resume_training)
  296. else:
  297. raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")
  298. if self.evaluator is not None and num_eval_sanity_batch > 0:
  299. logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.")
  300. self.on_sanity_check_begin()
  301. sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch)
  302. self.on_sanity_check_end(sanity_check_res)
  303. if num_train_batch_per_epoch != -1:
  304. self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch)
  305. self.num_batches_per_epoch = len(self.dataloader)
  306. self.total_batches = self.num_batches_per_epoch * self.n_epochs
  307. self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch
  308. try:
  309. self.on_train_begin()
  310. self.driver.barrier()
  311. self.driver.zero_grad(self.set_grad_to_none)
  312. while self.cur_epoch_idx < self.n_epochs:
  313. # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save
  314. self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
  315. self.driver.set_model_mode("train")
  316. self.on_train_epoch_begin()
  317. self.driver.set_sampler_epoch(self.dataloader, self.cur_epoch_idx)
  318. self.train_batch_loop.run(self, self.dataloader)
  319. if not self.has_checked_train_batch_loop:
  320. self._check_train_batch_loop_legality()
  321. self.cur_epoch_idx += 1
  322. self.on_train_epoch_end()
  323. self.driver.barrier()
  324. self.epoch_evaluate()
  325. self.driver.barrier()
  326. except EarlyStopException as e:
  327. logger.info(f"Catch early stop exception: {e.msg}.")
  328. self.on_exception(e)
  329. except KeyboardInterrupt as e:
  330. self.driver.on_exception()
  331. self.on_exception(e)
  332. if not catch_KeyboardInterrupt:
  333. raise e
  334. except BaseException as e:
  335. self.driver.on_exception()
  336. self.on_exception(e)
  337. raise e
  338. finally:
  339. self.on_train_end()
  340. self.driver.barrier()
  341. def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
  342. def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
  343. trainer.on_evaluate_begin()
  344. _evaluate_res: dict = evaluate_fn()
  345. trainer.on_evaluate_end(_evaluate_res)
  346. if self.evaluator is not None:
  347. self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))
  348. def step_evaluate(self):
  349. """
  350. 在每个 batch 结束后调用,根据设置执行 evaluate 。
  351. :return:
  352. """
  353. if self.evaluator is not None:
  354. if callable(self.evaluate_every):
  355. if self.evaluate_every(self):
  356. self.run_evaluate()
  357. elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0:
  358. self.run_evaluate()
  359. def epoch_evaluate(self):
  360. """
  361. 在每个 epoch 结束后调用,根据设置执行 evaluate 。
  362. :return:
  363. """
  364. if self.evaluator is not None:
  365. if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
  366. evaluate_every = -self.evaluate_every
  367. if self.cur_epoch_idx % evaluate_every == 0:
  368. self.run_evaluate()
  369. def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
  370. r"""
  371. 在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数;
  372. 这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数;
  373. :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;
  374. :param fn: 具体的 callback 函数;
  375. """
  376. if not isinstance(event, (_SingleEventState, EventsList)):
  377. raise ValueError("parameter event should only be `Events` or `EventsList` type.")
  378. _custom_callback = _CallbackWrapper(event, fn)
  379. self.callback_manager.dissect_one_callback(_custom_callback)
  380. @classmethod
  381. def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None):
  382. r"""
  383. 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制;
  384. 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前;
  385. :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;
  386. :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时,
  387. 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer
  388. 实例使用;
  389. :return: 返回原函数;
  390. """
  391. def wrapper(fn: Callable) -> Callable:
  392. cls._custom_callbacks[marker].append((event, fn))
  393. callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
  394. _check_valid_parameters_number(fn, callback_fn_args)
  395. return fn
  396. return wrapper
  397. def _fetch_matched_fn_callbacks(self):
  398. """
  399. 因为对于使用装饰器加入的函数 callback,我们是加在类属性中,因此在初始化一个具体的 trainer 实例后,我们需要从 Trainer 的
  400. callback 类属性中将属于其的 callback 函数拿到,然后加入到 callback_manager 中;
  401. """
  402. _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
  403. _own_callbacks.extend(self._custom_callbacks[None])
  404. self._custom_callbacks[None] = []
  405. if self.marker is not None:
  406. if len(self._custom_callbacks[self.marker]) == 0:
  407. logger.info(f"You have set `trainer.marker = {self.marker}`, but there are no callback function matched "
  408. f"`{self.marker}` that is added through function `Trainer.on`")
  409. _own_callbacks += self._custom_callbacks[self.marker]
  410. for each_callback in _own_callbacks:
  411. self.add_callback_fn(*each_callback)
  412. def _check_callback_called_legality(self, check_mode: bool = True):
  413. """
  414. 1. 函数的调用时机:
  415. 当检测 'batch_step_fn' 时,这个函数应当在 'train_batch_loop.run' 的 while 循环的最后进行调用;
  416. 当检测 'TrainBatchLoop' 时,这个函数应当在每一个 epoch 的最后进行调用;
  417. 2. 函数作用
  418. 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际
  419. 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
  420. "on_after_zero_grad") /
  421. ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
  422. "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
  423. "on_after_zero_grad")
  424. 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中
  425. 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为;
  426. 注意,这一函数只会在 batch_step_fn 不为 None 时或者 TrainBatchLoop 没有被替换时才会被调用;
  427. :param check_mode: 用来判断该函数是用来检测 'batch_step_fn' 还是用来检测 'TrainBatchLoop' 的参数,为 True 时表示检测
  428. 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop';
  429. """
  430. if check_mode:
  431. callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
  432. "on_before_zero_grad", "on_after_zero_grad")
  433. else:
  434. callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
  435. "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
  436. "on_before_zero_grad", "on_after_zero_grad")
  437. _not_called_callback_fns = []
  438. for each_callback_fn in callbacks:
  439. if each_callback_fn in self.callback_manager.callback_fns:
  440. if self.callback_manager.callback_counter[each_callback_fn] == 0:
  441. _not_called_callback_fns.append(each_callback_fn)
  442. if check_mode:
  443. logger.warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
  444. f"callback_fns: {_not_called_callback_fns}, but it seems that"
  445. "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.")
  446. # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass
  447. # 函数;
  448. self.check_batch_step_fn = lambda *args, **kwargs: ...
  449. else:
  450. logger.warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: "
  451. f"{_not_called_callback_fns}, but it seems that"
  452. "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.")
  453. def _check_train_batch_loop_legality(self):
  454. r"""
  455. 该函数用于检测用户定制的 `train_batch_loop` 是否正确地调用了 callback 函数以及是否正确地更新了 `trainer_state` 的状态;
  456. 该函数仅当用户通过属性更换用自己的定制的 `train_batch_loop` 替换了默认的 `TrainBatchLoop` 对象后才会被调用;
  457. 当被调用时,该函数仅当第一次被调用时被调用;
  458. """
  459. # 1. 检测用户定制的 `train_batch_loop` 是否正确地调用了 callback 函数;
  460. self._check_callback_called_legality(check_mode=False)
  461. # 2. 检测用户定制的 `train_batch_loop` 是否正确地更新了 `trainer_state` 的状态;
  462. # 因为该检测函数只会在第一个 epoch 运行完后调用,因此我们只需要检测这些 `trainer_state` 的值是否正确即可;
  463. if self.batch_idx_in_epoch == 0:
  464. logger.warning("You have customized your `train_batch_loop`, but it seemed that you forget to update the "
  465. "`trainer_state.batch_idx_in_epoch` in your process of training. Look the origin class "
  466. "`TrainBatchLoop`.")
  467. if self.global_forward_batches == 0:
  468. logger.warning("You have customized your `train_batch_loop`, but it seemed that you forget to update the "
  469. "`trainer_state.global_forward_batches` in your process of training. Look the origin class "
  470. "`TrainBatchLoop`.")
  471. self.has_checked_train_batch_loop = True
  472. """ Trainer 需要的一些 property """
  473. @property
  474. def driver(self):
  475. return self._driver
  476. @driver.setter
  477. def driver(self, driver: Driver):
  478. self._driver = driver
  479. @property
  480. def train_batch_loop(self):
  481. return self._train_batch_loop
  482. @train_batch_loop.setter
  483. def train_batch_loop(self, loop: Loop):
  484. self.has_checked_train_batch_loop = False
  485. if self.batch_step_fn is not None:
  486. logger.warning("`batch_step_fn` was customized in the Trainer initialization, it will be ignored "
  487. "when the `train_batch_loop` is also customized.")
  488. # 如果用户定制了 TrainBatchLoop,那么我们不需要再专门去检测 batch_step_fn,因为该函数一定会被忽略;
  489. self.check_batch_step_fn = lambda *args, **kwargs: ...
  490. self._train_batch_loop = loop
  491. def save_model(self, folder: Union[str, os.PathLike, BinaryIO, io.BytesIO], only_state_dict: bool = False,
  492. model_save_fn: Optional[Callable] = None, **kwargs):
  493. r"""
  494. 用于帮助用户保存模型的辅助函数,具体实际的保存模型的操作由具体的 driver 实现;
  495. :param folder: 保存模型的文件夹。如果没有传入 model_save_fn 参数,则在这个文件夹下创建 fastnlp_model.pkl.tar 文件。
  496. :param only_state_dict: 仅在 model_save_fn 为空时,有效。是否只保存模型的 `state_dict`;
  497. :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数;
  498. :param kwargs:
  499. """
  500. self.on_save_model()
  501. self.driver.barrier()
  502. if not isinstance(folder, (io.BytesIO, BinaryIO)):
  503. if model_save_fn is not None:
  504. if not callable(model_save_fn):
  505. raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
  506. rank_zero_call(model_save_fn)(folder)
  507. else:
  508. if isinstance(folder, str):
  509. folder = Path(folder)
  510. self.driver.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
  511. else:
  512. if model_save_fn is not None:
  513. raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being "
  514. "`io.BytesIO` type.")
  515. self.driver.save_model(folder, only_state_dict, **kwargs)
  516. self.driver.barrier()
  517. def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False,
  518. model_load_fn: Optional[Callable] = None, **kwargs):
  519. """
  520. 加载模型
  521. :param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时,
  522. 直接将该 folder 传递到 model_load_fn 中。
  523. :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。
  524. :param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。
  525. :param kwargs:
  526. :return:
  527. """
  528. self.on_load_model()
  529. self.driver.barrier()
  530. if not isinstance(folder, (io.BytesIO, BinaryIO)):
  531. try:
  532. if model_load_fn is not None:
  533. if not callable(model_load_fn):
  534. raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
  535. model_load_fn(folder)
  536. else:
  537. if isinstance(folder, str):
  538. folder = Path(folder)
  539. self.driver.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
  540. except FileNotFoundError as e:
  541. if FASTNLP_MODEL_FILENAME not in os.listdir(folder):
  542. logger.error(f"fastNLP model checkpoint file:{FASTNLP_MODEL_FILENAME} is not found in {folder}.")
  543. raise e
  544. else:
  545. if model_load_fn is not None:
  546. raise RuntimeError("It is not allowed to specify a `model_save_fn` parameter with `folder` being "
  547. "`io.BytesIO` type.")
  548. self.driver.load_model(folder, only_state_dict, **kwargs)
  549. self.driver.barrier()
  550. def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs):
  551. r"""
  552. 用于断点重训 Trainer 的保存函数。
  553. :param folder: 保存在哪个文件夹下,会在该文件下声称两个文件:fastnlp_checkpoint.pkl.tar 与 fastnlp_model.pkl.tar 。
  554. 如果 model_save_fn 不为空,则没有 fastnlp_model.pkl.tar 文件。
  555. :param only_state_dict: 当 model_save_fn 为空时有效,表明是否仅保存模型的权重。
  556. :param model_save_fn: 如果模型保存比较特殊,可以传入该函数自定义保存过程,输入应该接受一个文件夹(实际上就是接受上面的 folder
  557. 参数),不必返回任何东西。
  558. :param kwargs:
  559. :return:
  560. """
  561. self.driver.barrier()
  562. # 1. callback states 和 每一个callback的具体 callback 函数的 filter 的状态;
  563. # 2. trainer_state;
  564. states = {"callback_states": self.on_save_checkpoint(),
  565. "trainer_state": self.trainer_state.state_dict(),
  566. 'num_consumed_batches': self.batch_idx_in_epoch - getattr(self, 'start_batch_idx_in_epoch', 0)
  567. }
  568. if isinstance(folder, str):
  569. folder = Path(folder)
  570. if model_save_fn is not None:
  571. if not callable(model_save_fn):
  572. raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.")
  573. rank_zero_call(model_save_fn)(folder)
  574. self.driver.save(folder=folder, dataloader=self.dataloader, states=states, should_save_model=False, **kwargs)
  575. else:
  576. self.driver.save(folder=folder, dataloader=self.dataloader, states=states,
  577. only_state_dict=only_state_dict, should_save_model=True, **kwargs)
  578. self.driver.barrier()
  579. def load(self, folder: str, resume_training: bool = True, only_state_dict: bool = True,
  580. model_load_fn: Optional[Callable] = None, **kwargs):
  581. r"""
  582. 用于断点重训的加载函数;
  583. 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的
  584. 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler;
  585. 注意我们目前不支持单卡到多卡的断点重训;
  586. :param folder: 保存断点重训 states 的文件地址;
  587. :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们
  588. 只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置;
  589. :param only_state_dict: 保存的 model 是否只包含了权重。
  590. :param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。
  591. """
  592. self.driver.barrier()
  593. if isinstance(folder, str):
  594. folder = Path(folder)
  595. dataloader = self.dataloader
  596. if not resume_training:
  597. dataloader = None
  598. try:
  599. if model_load_fn is not None:
  600. if not callable(model_load_fn):
  601. raise ValueError("Parameter `model_save_fn` should be `Callable`.")
  602. model_load_fn(folder)
  603. states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs)
  604. else:
  605. states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs)
  606. except FileNotFoundError as e:
  607. if FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder) and FASTNLP_MODEL_FILENAME in os.listdir(folder):
  608. logger.error("It seems that you are trying to load the trainer checkpoint from a model checkpoint folder.")
  609. elif FASTNLP_CHECKPOINT_FILENAME not in os.listdir(folder):
  610. logger.error(f"fastNLP Trainer checkpoint file:{FASTNLP_CHECKPOINT_FILENAME} is not found in {folder}.")
  611. raise e
  612. if not resume_training:
  613. return
  614. self.dataloader = states.pop('dataloader')
  615. # 1. 恢复 trainer_state 的状态;
  616. self.trainer_state.load_state_dict(states["trainer_state"])
  617. # 2. 修改 trainer_state.batch_idx_in_epoch
  618. # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler;
  619. # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于
  620. # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立
  621. self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')
  622. # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save
  623. self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch
  624. # 5. 恢复所有 callback 的状态;
  625. self.on_load_checkpoint(states["callback_states"])
  626. self.driver.barrier()
  627. """ 这四个函数是用来方便用户定制自己的 batch_step_fn(用于替换 train_batch_loop 当中的 batch_step_fn 函数) 的 """
  628. def train_step(self, batch):
  629. with self.driver.auto_cast():
  630. outputs = self.driver.model_call(batch, self._train_step, self._train_step_signature_fn)
  631. outputs = match_and_substitute_params(self.output_mapping, outputs)
  632. return outputs
  633. def backward(self, outputs):
  634. self.on_before_backward(outputs)
  635. loss = self.extract_loss_from_outputs(outputs)
  636. loss = loss / self.accumulation_steps
  637. # with self.get_no_sync_context():
  638. # self.driver.backward(loss)
  639. self.driver.backward(loss)
  640. self.on_after_backward()
  641. def zero_grad(self):
  642. if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
  643. self.on_before_zero_grad(self.optimizers)
  644. self.driver.zero_grad(self.set_grad_to_none)
  645. self.on_after_zero_grad(self.optimizers)
  646. def step(self):
  647. if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
  648. self.on_before_optimizers_step(self.optimizers)
  649. self.driver.step()
  650. self.on_after_optimizers_step(self.optimizers)
  651. def move_data_to_device(self, batch):
  652. return self.driver.move_data_to_device(batch)
  653. @staticmethod
  654. def extract_loss_from_outputs(outputs):
  655. r"""
  656. 用来从用户模型的输出对象中抽取 `loss` 对象;
  657. 目前支持 `outputs` 对象为 'Dict' 或者 'dataclass';
  658. :return: 返回被抽取出来的 `loss` 对象,如果当前运行的是 'pytorch' 的 `Driver`,那么返回的就是一个 tensor;
  659. """
  660. if isinstance(outputs, Dict):
  661. try:
  662. loss = outputs["loss"]
  663. except:
  664. raise KeyError(f"We cannot find `loss` from your model output(with keys:{outputs.keys()}). Please either "
  665. f"directly return it from your model or use `output_mapping` to prepare it.")
  666. elif is_dataclass(outputs):
  667. try:
  668. loss = outputs.loss
  669. except:
  670. raise AttributeError("We cannot find `loss` from your model output. Please either directly return it from"
  671. " your model or use `output_mapping` to prepare it.")
  672. else:
  673. raise ValueError("The `outputs` from your model could only be of `dataclass` or `Dict` type. Or you can use "
  674. "the parameter `output_mapping` to prepare loss.")
  675. return loss
  676. @contextmanager
  677. def get_no_sync_context(self):
  678. r"""
  679. 用于在梯度累积并且使用 DDP 时,由于在前 `accumulation_steps` - 1 的时间内不需要进行梯度的同步,因此通过使用该 context 上下文
  680. 环境来避免梯度的同步;
  681. :return: 一个 no_sync 的 context;
  682. """
  683. if (self.global_forward_batches + 1) % self.accumulation_steps != 0:
  684. _no_sync_context = self.driver.get_model_no_sync_context()
  685. else:
  686. _no_sync_context = nullcontext
  687. with _no_sync_context():
  688. yield
  689. """ trainer state property """
  690. @property
  691. def n_epochs(self) -> int:
  692. return self.trainer_state.n_epochs
  693. @n_epochs.setter
  694. def n_epochs(self, n_epochs: int):
  695. self.trainer_state.n_epochs = n_epochs
  696. @property
  697. def cur_epoch_idx(self) -> int:
  698. return self.trainer_state.cur_epoch_idx
  699. @cur_epoch_idx.setter
  700. def cur_epoch_idx(self, cur_epoch_idx: int):
  701. self.trainer_state.cur_epoch_idx = cur_epoch_idx
  702. @property
  703. def global_forward_batches(self) -> int:
  704. return self.trainer_state.global_forward_batches
  705. @global_forward_batches.setter
  706. def global_forward_batches(self, global_forward_batches: int):
  707. self.trainer_state.global_forward_batches = global_forward_batches
  708. @property
  709. def batch_idx_in_epoch(self) -> int:
  710. return self.trainer_state.batch_idx_in_epoch
  711. @batch_idx_in_epoch.setter
  712. def batch_idx_in_epoch(self, batch_idx_in_epoch: int):
  713. self.trainer_state.batch_idx_in_epoch = batch_idx_in_epoch
  714. @property
  715. def num_batches_per_epoch(self) -> int:
  716. return self.trainer_state.num_batches_per_epoch
  717. @num_batches_per_epoch.setter
  718. def num_batches_per_epoch(self, num_batches_per_epoch: int):
  719. self.trainer_state.num_batches_per_epoch = num_batches_per_epoch
  720. @property
  721. def total_batches(self) -> int:
  722. return self.trainer_state.total_batches
  723. @total_batches.setter
  724. def total_batches(self, total_batches: int):
  725. self.trainer_state.total_batches = total_batches
  726. """ driver property """
  727. @property
  728. def model_device(self):
  729. return self.driver.model_device
  730. @property
  731. def data_device(self):
  732. return self.driver.data_device
  733. """ dataloader property """
  734. @property
  735. def train_dataloader(self):
  736. return self._train_dataloader
  737. @train_dataloader.setter
  738. def train_dataloader(self, train_dataloader):
  739. self._train_dataloader = train_dataloader
  740. @property
  741. def evaluate_dataloaders(self):
  742. return self._evaluate_dataloaders
  743. @evaluate_dataloaders.setter
  744. def evaluate_dataloaders(self, evaluate_dataloaders):
  745. self._evaluate_dataloaders = evaluate_dataloaders
  746. def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping,
  747. evaluate_input_mapping, evaluate_output_mapping):
  748. if train_input_mapping is not None and input_mapping is not None:
  749. raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.")
  750. if evaluate_input_mapping is not None and input_mapping is not None:
  751. raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.")
  752. if train_output_mapping is not None and output_mapping is not None:
  753. raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.")
  754. if evaluate_output_mapping is not None and output_mapping is not None:
  755. raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.")
  756. if train_input_mapping is None:
  757. train_input_mapping = input_mapping
  758. if evaluate_input_mapping is None:
  759. evaluate_input_mapping = input_mapping
  760. if train_output_mapping is None:
  761. train_output_mapping = output_mapping
  762. if evaluate_output_mapping is None:
  763. evaluate_output_mapping = output_mapping
  764. return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping