| @@ -11,37 +11,39 @@ from .callback_event import Event, Filter | |||
| class Callback: | |||
| r""" | |||
| 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | |||
| callback 调用时机顺序大概如下 | |||
| Trainer.__init__(): | |||
| on_after_trainer_initialized(trainer, driver) | |||
| Trainer.run(): | |||
| if num_eval_sanity_batch>0: | |||
| on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||
| on_sanity_check_end(trainer, sanity_check_res) | |||
| try: | |||
| on_train_begin(trainer) | |||
| while cur_epoch_idx < n_epochs: | |||
| on_train_epoch_begin(trainer) | |||
| while batch_idx_in_epoch<=num_batches_per_epoch: | |||
| on_fetch_data_begin(trainer) | |||
| batch = next(dataloader) | |||
| on_fetch_data_end(trainer) | |||
| on_train_batch_begin(trainer, batch, indices) | |||
| on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||
| on_after_backward(trainer) | |||
| on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_train_batch_end(trainer) | |||
| on_train_epoch_end(trainer) | |||
| except BaseException: | |||
| self.on_exception(trainer, exception) | |||
| finally: | |||
| on_train_end(trainer) | |||
| callback 调用时机顺序大概如下:: | |||
| Trainer.__init__(): | |||
| on_after_trainer_initialized(trainer, driver) | |||
| Trainer.run(): | |||
| if num_eval_sanity_batch>0: | |||
| on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||
| on_sanity_check_end(trainer, sanity_check_res) | |||
| try: | |||
| on_train_begin(trainer) | |||
| while cur_epoch_idx < n_epochs: | |||
| on_train_epoch_begin(trainer) | |||
| while batch_idx_in_epoch<=num_batches_per_epoch: | |||
| on_fetch_data_begin(trainer) | |||
| batch = next(dataloader) | |||
| on_fetch_data_end(trainer) | |||
| on_train_batch_begin(trainer, batch, indices) | |||
| on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||
| on_after_backward(trainer) | |||
| on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||
| on_train_batch_end(trainer) | |||
| on_train_epoch_end(trainer) | |||
| except BaseException: | |||
| self.on_exception(trainer, exception) | |||
| finally: | |||
| on_train_end(trainer) | |||
| 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | |||
| on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 | |||
| 的时间调用。 | |||
| on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 | |||
| 的时间调用。 | |||
| """ | |||
| def on_after_trainer_initialized(self, trainer, driver): | |||
| @@ -123,8 +125,8 @@ class Callback: | |||
| def on_train_batch_begin(self, trainer, batch, indices): | |||
| r""" | |||
| 在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||
| 其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||
| 如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||
| 其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||
| 如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||
| :param trainer: `fastNLP.Trainer` | |||
| :param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||
| @@ -136,8 +138,8 @@ class Callback: | |||
| def on_train_batch_end(self, trainer): | |||
| """ | |||
| 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||
| global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||
| 执行。 | |||
| global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||
| 执行。 | |||
| :param trainer: | |||
| :return: | |||
| @@ -184,7 +186,7 @@ class Callback: | |||
| def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
| r""" | |||
| 当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||
| 的返回值。 | |||
| 的返回值。 | |||
| :param trainer: | |||
| :param states: | |||
| @@ -205,7 +207,7 @@ class Callback: | |||
| def on_after_backward(self, trainer): | |||
| """ | |||
| 在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||
| 因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||
| 因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||
| :param trainer: | |||
| :return: | |||
| @@ -255,7 +257,7 @@ class Callback: | |||
| def on_evaluate_begin(self, trainer): | |||
| """ | |||
| 在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 | |||
| 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||
| 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||
| :param trainer: | |||
| :return: | |||
| @@ -294,7 +296,7 @@ class Callback: | |||
| class _CallbackWrapper(Callback): | |||
| """ | |||
| 对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 | |||
| 这一个 callback 函数; | |||
| 这一个 callback 函数; | |||
| """ | |||
| def __init__(self, event: Event, fn: Callable): | |||
| r""" | |||
| @@ -42,7 +42,7 @@ class Event: | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||
| filter.num_executed 两个变量分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||
| """ | |||
| self.every = every | |||
| self.once = once | |||
| @@ -59,6 +59,7 @@ class Event: | |||
| 当 Trainer 运行到 on_after_trainer_initialized 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。默认为 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -74,6 +75,7 @@ class Event: | |||
| 当 Trainer 运行到 on_sanity_check_begin 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -89,6 +91,7 @@ class Event: | |||
| 当 Trainer 运行到 on_sanity_check_end 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -104,6 +107,7 @@ class Event: | |||
| 当 Trainer 运行到 on_train_begin 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -119,6 +123,7 @@ class Event: | |||
| 当 Trainer 运行到 on_train_end 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -134,6 +139,7 @@ class Event: | |||
| 当 Trainer 运行到 on_train_epoch_begin 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -149,6 +155,7 @@ class Event: | |||
| 当 Trainer 运行到 on_train_epoch_end 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -164,6 +171,7 @@ class Event: | |||
| 当 Trainer 运行到 on_fetch_data_begin 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -179,6 +187,7 @@ class Event: | |||
| 当 Trainer 运行到 on_fetch_data_end 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -194,6 +203,7 @@ class Event: | |||
| 当 Trainer 运行到 on_train_batch_begin 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -209,6 +219,7 @@ class Event: | |||
| 当 Trainer 运行到 on_train_batch_end 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -224,6 +235,7 @@ class Event: | |||
| 当 Trainer 运行到 on_exception 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -239,6 +251,7 @@ class Event: | |||
| 当 Trainer 运行到 on_save_model 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -254,6 +267,7 @@ class Event: | |||
| 当 Trainer 运行到 on_load_model 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -269,6 +283,7 @@ class Event: | |||
| 当 Trainer 运行到 on_save_checkpoint 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -284,6 +299,7 @@ class Event: | |||
| 当 Trainer 运行到 on_load_checkpoint 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -299,6 +315,7 @@ class Event: | |||
| 当 Trainer 运行到 on_load_checkpoint 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -314,6 +331,7 @@ class Event: | |||
| 当 Trainer 运行到 on_before_backward 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -329,6 +347,7 @@ class Event: | |||
| 当 Trainer 运行到 on_after_backward 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -344,6 +363,7 @@ class Event: | |||
| 当 Trainer 运行到 on_before_optimizers_step 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -359,6 +379,7 @@ class Event: | |||
| 当 Trainer 运行到 on_after_optimizers_step 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -374,6 +395,7 @@ class Event: | |||
| 当 Trainer 运行到 on_before_zero_grad 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -389,6 +411,7 @@ class Event: | |||
| 当 Trainer 运行到 on_after_zero_grad 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -404,6 +427,7 @@ class Event: | |||
| 当 Trainer 运行到 on_evaluate_begin 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -419,6 +443,7 @@ class Event: | |||
| 当 Trainer 运行到 on_evaluate_end 时 | |||
| 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||
| :param int every: 触发了多少次,才真正运行一次。 | |||
| :param bool once: 是否只在第一次运行后就不再执行了。 | |||
| :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||
| @@ -110,7 +110,7 @@ class CallbackManager: | |||
| def initialize_class_callbacks(self): | |||
| r""" | |||
| 在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | |||
| 一个个 callback 时机,也就是 `Event` 的类别; | |||
| 一个个 callback 时机,也就是 `Event` 的类别; | |||
| 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | |||
| :param callbacks: | |||
| @@ -150,7 +150,8 @@ class CallbackManager: | |||
| 断点重训应当保存的状态; | |||
| 2. 每一个具体的 callback 函数的 filter 的状态; | |||
| :return: 一个包含上述内容的字典; | |||
| :return: 一个包含上述内容的字典:: | |||
| { | |||
| "callback_name_1": { | |||
| "states": {...}, | |||
| @@ -19,15 +19,15 @@ class CheckpointCallback(Callback): | |||
| only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | |||
| save_evaluate_results=True, **kwargs): | |||
| """ | |||
| 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 | |||
| - {save_object}-last/ # 最后一个 epoch 的保存 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下:: | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - {save_object}-epoch_{epoch_idx}/ # 满足 every_n_epochs 条件保存的模型 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 every_n_batches 保存的模型 | |||
| - {save_object}-last/ # 最后一个 epoch 的保存 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | |||
| 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 | |||
| @@ -78,11 +78,11 @@ class MonitorUtility: | |||
| return monitor_value | |||
| # 第一次运行 | |||
| if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor: | |||
| logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), " | |||
| f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.") | |||
| logger.rank_zero_warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as " | |||
| f"{list(results.keys())}), we use the `{use_monitor}` as the monitor.", once=True) | |||
| # 检测到此次和上次不同。 | |||
| elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor: | |||
| logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
| logger.rank_zero_warning(f"Change of monitor detected for `{self.__class__.__name__}`. " | |||
| f"The expected monitor is:`{self.monitor}`, last used monitor is:" | |||
| f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a " | |||
| f"customized monitor function when the evaluation results are varying between validation.") | |||
| @@ -20,14 +20,15 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||
| **kwargs): | |||
| """ | |||
| 当评测时需要调用不同的 evaluate_fn (例如在大部分生成任务中,一般使用训练 loss 作为训练过程中的 evaluate ;但同时在训练到 | |||
| 一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||
| 无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||
| topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||
| 如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存 | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| 一定 epoch 数量之后,会让 model 生成的完整的数据评测 bleu 等。此刻就可能需要两种不同的 evaluate_fn ),只使用 Trainer | |||
| 无法满足需求,可以通过调用本 callback 进行。如果需要根据本 callback 中的评测结果进行模型保存,请传入 topk 以及 | |||
| topk_monitor 等相关参数。可以通过 evaluate_every 或 watch_monitor 控制触发进行 evaluate 的条件。 | |||
| 如果设置了 evaluate 结果更好就保存的话,将按如下文件结构进行保存:: | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| :param dataloaders: 需要评估的数据 | |||
| :param metrics: 使用的 metrics 。 | |||
| @@ -19,10 +19,11 @@ class Saver: | |||
| def __init__(self, folder:str=None, save_object:str='model', only_state_dict:bool=True, | |||
| model_save_fn:Callable=None, **kwargs): | |||
| """ | |||
| 执行保存的对象。保存的文件组织结构为 | |||
| - folder # 当前初始化的参数 | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - folder_name # 由 save() 调用时传入。 | |||
| 执行保存的对象。保存的文件组织结构为:: | |||
| - folder # 当前初始化的参数 | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - folder_name # 由 save() 调用时传入。 | |||
| :param folder: 保存在哪个文件夹下,默认为当前 folder 下。 | |||
| :param save_object: 可选 ['trainer', 'model'],表示在保存时的保存对象为 trainer+model 还是 只是model 。 | |||
| @@ -32,7 +33,7 @@ class Saver: | |||
| :param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | |||
| """ | |||
| if folder is None: | |||
| logger.warning( | |||
| logger.rank_zero_warning( | |||
| "Parameter `folder` is None, and we will use the current work directory to find and load your model.") | |||
| folder = Path.cwd() | |||
| folder = Path(folder) | |||
| @@ -53,10 +54,11 @@ class Saver: | |||
| @rank_zero_call | |||
| def save(self, trainer, folder_name): | |||
| """ | |||
| 执行保存的函数,将数据保存在 | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - folder_name # 当前函数参数 | |||
| 执行保存的函数,将数据保存在:: | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - folder_name # 当前函数参数 | |||
| :param trainer: Trainer 对象 | |||
| :param folder_name: 保存的 folder 名称,将被创建。 | |||
| @@ -129,8 +131,8 @@ class TopkQueue: | |||
| def push(self, key, value) -> Optional[Tuple[Union[str, None], Union[float, None]]]: | |||
| """ | |||
| 将 key/value 推入 topk 的 queue 中,以 value 为标准,如果满足 topk 则保留此次推入的信息,同时如果新推入的数据将之前的数据给 | |||
| 挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 | |||
| 推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。 | |||
| 挤出了 topk ,则会返回被挤出的 (key, value);如果返回为 (None, None),说明满足 topk 且没有数据被挤出。如果不满足 topk ,则返回 | |||
| 推入的 (key, value) 本身。这里排序只根据 value 是否更大了判断,因此如果有的情况是越小越好,请在输入前取负号。 | |||
| :param str key: | |||
| :param float value: 如果为 None, 则不做任何操作。 | |||
| @@ -173,10 +175,11 @@ class TopkSaver(MonitorUtility, Saver): | |||
| only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | |||
| **kwargs): | |||
| """ | |||
| 用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为 | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| 用来识别 topk 模型并保存,也可以仅当一个保存 Saver 使用。保存路径为:: | |||
| - folder/ | |||
| - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 | |||
| - {save_object}-epoch_{epoch_idx}-batch_{global_batch_idx}-{topk_monitor}_{monitor_value}/ # 满足topk条件存储文件名 | |||
| :param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 | |||
| :param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 | |||
| @@ -208,7 +211,7 @@ class TopkSaver(MonitorUtility, Saver): | |||
| def save_topk(self, trainer, results: Dict) -> Optional[str]: | |||
| """ | |||
| 根据 results 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 None ,则说明此次没有满足 | |||
| topk 要求,没有发生保存。 | |||
| topk 要求,没有发生保存。 | |||
| :param trainer: | |||
| :param results: evaluate 的结果。 | |||
| @@ -10,8 +10,7 @@ class TorchGradClipCallback(Callback): | |||
| 在每次 optimizer update 之前将 parameter 进行 clip | |||
| :param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||
| :param str clip_type: 支持'norm', 'value' | |||
| 两种:: | |||
| :param str clip_type: 支持'norm', 'value'两种:: | |||
| 1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Optional, Union | |||
| from typing import Optional, Union, Tuple | |||
| import os | |||
| from fastNLP.core.log.logger import logger | |||
| @@ -6,10 +6,10 @@ from difflib import SequenceMatcher | |||
| from fastNLP.core.utils.utils import _get_fun_msg | |||
| def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float): | |||
| def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->Tuple[str, float]: | |||
| """ | |||
| 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | |||
| 匹配。 | |||
| 匹配。 | |||
| :param monitor: | |||
| :param real_monitor: | |||
| @@ -12,8 +12,8 @@ from .padders.get_padder import get_padder | |||
| import re | |||
| from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||
| pack_batch_sequence | |||
| from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \ | |||
| NestedMappingPackerUnpacker | |||
| sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||
| SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||
| @@ -84,8 +84,8 @@ class Collator: | |||
| def __init__(self, backend='auto'): | |||
| """ | |||
| 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||
| 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||
| 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||
| 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||
| 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||
| 若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对不能进行 pad 的数据没用影响,不能 pad | |||
| @@ -101,8 +101,7 @@ class Collator: | |||
| def __call__(self, batch)->Union[List, Dict]: | |||
| """ | |||
| batch可能存在三种可能性 | |||
| List[Dict], List[List], List[Sample] | |||
| batch可能存在三种可能性:List[Dict], List[List], List[Sample] | |||
| 第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 | |||
| 第二步:使用每个 field 各自的 padder 进行 pad 。 | |||
| @@ -126,46 +125,36 @@ class Collator: | |||
| logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||
| f"is `{self.batch_data_type}`.") | |||
| if self.batch_data_type == 's': | |||
| self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 | |||
| self.pack_batch_func = lambda x: x['_single'] | |||
| self.packer_unpacker = SinglePackerUnpacker() # 不需要做任何调整 | |||
| elif self.batch_data_type == 'l': | |||
| self.unpack_batch_func = unpack_batch_sequence | |||
| self.pack_batch_func = pack_batch_sequence | |||
| self.packer_unpacker = SequencePackerUnpacker() | |||
| elif self.batch_data_type == 'd': | |||
| if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} | |||
| self.unpack_batch_func = unpack_batch_nested_mapping | |||
| self.pack_batch_func = pack_batch_nested_mapping | |||
| self.packer_unpacker = NestedMappingPackerUnpacker() | |||
| else: | |||
| self.unpack_batch_func = unpack_batch_mapping | |||
| self.pack_batch_func = lambda x:x | |||
| self.packer_unpacker = MappingPackerUnpacker() | |||
| if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | |||
| unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | |||
| else: | |||
| unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 | |||
| # 将 batch 中各个 field 组成自己的 batch;同时忽略处于 ignore_fields 中的数据。 | |||
| unpack_batch = self.packer_unpacker.unpack_batch(batch, self.ignore_fields, self.input_fields) | |||
| pad_batch = {} | |||
| if len(self.padders)==0: # 第一次运行,准备 padder | |||
| if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 | |||
| self.backend = _get_backend() | |||
| for key in unpack_batch.keys(): | |||
| if key not in self.input_fields and key not in self.ignore_fields: | |||
| self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | |||
| elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': | |||
| self.input_fields[key]['backend'] = self.backend | |||
| for field_name, setting in self.input_fields.items(): | |||
| pad_fn = setting.get('pad_fn', None) | |||
| for field_name, batch_field in unpack_batch.items(): | |||
| setting = self.input_fields.get(field_name, {'backend': self.backend, 'pad_val': 0 , | |||
| 'dtype': None, 'pad_fn': None}) | |||
| pad_fn = setting['pad_fn'] | |||
| if callable(pad_fn): | |||
| padder = pad_fn | |||
| else: | |||
| backend = self.backend if setting['backend'] == 'auto' else setting['backend'] | |||
| batch_field = unpack_batch.get(field_name) | |||
| padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | |||
| dtype=setting['dtype'], backend=backend, | |||
| field_name=field_name) | |||
| self.padders[field_name] = padder | |||
| if self.batch_data_type == 'l': | |||
| self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | |||
| @@ -173,7 +162,7 @@ class Collator: | |||
| batch = unpack_batch.get(key) | |||
| pad_batch[key] = padder(batch) | |||
| return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||
| return self.packer_unpacker.pack_batch(pad_batch) # 根据情况恢复成与输入一致的类型 | |||
| def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', | |||
| pad_fn:Callable=None) -> "Collator": | |||
| @@ -195,16 +184,17 @@ class Collator: | |||
| 形式,输出将被直接作为结果输出。 | |||
| :return: 返回 Collator 自身 | |||
| """ | |||
| self.padders.clear() # 重新生成 | |||
| self._renew() | |||
| if self.batch_data_type is not None: | |||
| if self.batch_data_type == 's': | |||
| logger.debug("Set as single field mode.") | |||
| self.input_fields.clear() | |||
| elif self.batch_data_type == 'd': | |||
| if self.batch_data_type == 's': | |||
| logger.debug("Set as single field mode.") | |||
| self.input_fields.clear() | |||
| elif self.batch_data_type == 'd': | |||
| if isinstance(field_name, str): | |||
| assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ | |||
| f"index, but other field is set as dict mode." | |||
| elif self.batch_data_type == 'l': | |||
| elif self.batch_data_type == 'l': | |||
| if isinstance(field_name, str): | |||
| assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||
| f"field name is {field_name}." | |||
| @@ -215,8 +205,40 @@ class Collator: | |||
| else: | |||
| self.batch_data_type = 'd' | |||
| if field_name in self.ignore_fields: | |||
| logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") | |||
| # 检测是否已经设置了,主要需要考虑它的父亲节点的情况 | |||
| ignore_fields = [(field, field) if isinstance(field, tuple) else ((field,), field) | |||
| for field in self.ignore_fields] | |||
| input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) | |||
| for field in self.input_fields.keys()] | |||
| if isinstance(field_name, tuple): | |||
| _field_name = field_name | |||
| else: | |||
| _field_name = (field_name,) | |||
| for field, o_field in ignore_fields: | |||
| d = _compare_tuple(field, _field_name) | |||
| if d is None: | |||
| continue | |||
| if d == 0: | |||
| logger.rank_zero_warning(f"Field:`{field_name}` has been set as ignored before. It will not be " | |||
| f"ignored afterwards.") | |||
| self.ignore_fields.remove(o_field) | |||
| if d > 0: | |||
| raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set " | |||
| f"as ignore field.") | |||
| if d < 0: | |||
| raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set " | |||
| f"as ignore field.") | |||
| for field, o_field in input_field_names: | |||
| d = _compare_tuple(field, _field_name) | |||
| if d is None: | |||
| continue | |||
| if d > 0: | |||
| raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set " | |||
| f"pad.") | |||
| if d < 0: | |||
| raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set " | |||
| f"pad.") | |||
| if backend is None: | |||
| backend = self.backend | |||
| else: | |||
| @@ -235,13 +257,14 @@ class Collator: | |||
| :return: | |||
| """ | |||
| assert backend in SUPPORTED_BACKENDS | |||
| self.padders.clear() | |||
| self._renew() | |||
| self.backend = backend | |||
| def set_ignore(self, *field_names) -> "Collator": | |||
| """ | |||
| 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
| Ex:: | |||
| Example:: | |||
| collator.set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| @@ -249,400 +272,56 @@ class Collator: | |||
| __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
| :return: 返回 Collator 自身 | |||
| """ | |||
| for field_name in field_names: | |||
| if field_name in self.input_fields: | |||
| self.input_fields.pop(field_name) | |||
| logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | |||
| self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 | |||
| self.ignore_fields.add(field_name) | |||
| self._renew() | |||
| input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) | |||
| for field in self.input_fields.keys()] | |||
| # 需要考虑父节点之类的情况 | |||
| for field in field_names: | |||
| if not isinstance(field, tuple): | |||
| _field = (field,) | |||
| else: | |||
| _field = field | |||
| for _field_name, o_field_name in input_field_names: | |||
| d = _compare_tuple(_field, _field_name) | |||
| if d is None: | |||
| continue | |||
| if d == 0: | |||
| self.input_fields.pop(o_field_name) | |||
| logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.") | |||
| if d < 0: | |||
| self.input_fields.pop(o_field_name) | |||
| logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.") | |||
| if d > 0: | |||
| raise KeyError(f"Cannot ignore {field} since its parent key {o_field_name} has been set as pad.") | |||
| self.ignore_fields.add(field) | |||
| return self | |||
| def _renew(self): | |||
| self.packer_unpacker = None | |||
| self.padders.clear() | |||
| # | |||
| # from abc import ABCMeta, abstractmethod | |||
| # from typing import Any, Dict, List, Callable, Union, Tuple | |||
| # from numbers import Number | |||
| # import warnings | |||
| # | |||
| # import numpy as np | |||
| # | |||
| # from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
| # | |||
| # if _NEED_IMPORT_PADDLE: | |||
| # import paddle | |||
| # | |||
| # if _NEED_IMPORT_TORCH: | |||
| # import torch | |||
| # | |||
| # | |||
| # class ApplyResultException(Exception): | |||
| # def __init__(self, msg, index=None): | |||
| # super().__init__(msg) | |||
| # self.msg = msg | |||
| # self.index = index # 标示在哪个数据遭遇到问题了 | |||
| # | |||
| # | |||
| # class SetInputOrTargetException(Exception): | |||
| # def __init__(self, msg, index=None, field_name=None): | |||
| # super().__init__(msg) | |||
| # self.msg = msg | |||
| # self.index = index # 标示在哪个数据遭遇到问题了 | |||
| # self.field_name = field_name # 标示当前 field 的名称 | |||
| # | |||
| # | |||
| # def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: | |||
| # r""" | |||
| # 识别cell的类别与dimension的数量 | |||
| # | |||
| # numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||
| # :param cell: | |||
| # :param dim: | |||
| # :return: | |||
| # """ | |||
| # if isinstance(cell, (str, Number, np.bool_)): | |||
| # if hasattr(cell, 'dtype'): | |||
| # return cell.dtype.type, dim | |||
| # return type(cell), dim | |||
| # | |||
| # elif isinstance(cell, list): | |||
| # dim += 1 | |||
| # res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
| # types = set([i for i, j in res]) | |||
| # dims = set([j for i, j in res]) | |||
| # if len(types) > 1: | |||
| # raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
| # elif len(types) == 0: | |||
| # raise SetInputOrTargetException("Empty value encountered.") | |||
| # if len(dims) > 1: | |||
| # raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
| # return types.pop(), dims.pop() | |||
| # | |||
| # elif isinstance(cell, torch.Tensor): | |||
| # return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 | |||
| # | |||
| # elif isinstance(cell, paddle.Tensor): | |||
| # return cell.dtype, cell.dim() + dim | |||
| # | |||
| # elif isinstance(cell, np.ndarray): | |||
| # if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 | |||
| # return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 | |||
| # # 否则需要继续往下 iterate | |||
| # dim += 1 | |||
| # res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
| # types = set([i for i, j in res]) | |||
| # dims = set([j for i, j in res]) | |||
| # if len(types) > 1: | |||
| # raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
| # elif len(types) == 0: | |||
| # raise SetInputOrTargetException("Empty value encountered.") | |||
| # if len(dims) > 1: | |||
| # raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
| # return types.pop(), dims.pop() | |||
| # | |||
| # else: # 包含 tuple, set, dict 以及其它的类型 | |||
| # raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||
| # | |||
| # | |||
| # def _get_ds_type_dim(ds: dict): | |||
| # # 获取数据集第一行的 field 内部函数的类型和维度 | |||
| # field_dtype, field_dim = {}, {} | |||
| # for field_name, field_content in ds.items(): | |||
| # type_0, dim_0 = _get_ele_type_and_dim(field_content) | |||
| # field_dtype[field_name], field_dim[field_name] = type_0, dim_0 | |||
| # return field_dtype, field_dim | |||
| # | |||
| # | |||
| # class Collator(metaclass=ABCMeta): | |||
| # r""" | |||
| # 辅助DataLoader管理collate_fn的类 | |||
| # | |||
| # """ | |||
| # | |||
| # def __init__(self): | |||
| # super(Collator, self).__init__() | |||
| # self.collate_fn = [] | |||
| # | |||
| # @abstractmethod | |||
| # def __call__(self, ins_lst: List) -> Any: | |||
| # raise NotImplementedError | |||
| # | |||
| # @abstractmethod | |||
| # def set_pad_val(self, *field_names: str, value=0): | |||
| # raise NotImplementedError | |||
| # | |||
| # | |||
| # class _MultiCollator: | |||
| # """ | |||
| # 管理所有collator的容器, | |||
| # 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 | |||
| # """ | |||
| # | |||
| # def __init__(self, collate_fns: Union[Callable, List[Callable], None]): | |||
| # | |||
| # if collate_fns is None: | |||
| # collate_fns = [] | |||
| # | |||
| # if isinstance(collate_fns, Callable): | |||
| # collate_fns = [collate_fns] | |||
| # | |||
| # self._collators: list = collate_fns | |||
| # | |||
| # def __call__(self, ins_lst) -> Dict: | |||
| # out, list_out = {}, [] | |||
| # for idx, _collate_fn in enumerate(self._collators): | |||
| # res = _collate_fn(ins_lst) | |||
| # if isinstance(res, Dict): | |||
| # out.update(res) | |||
| # else: | |||
| # list_out.append(res) | |||
| # # else: | |||
| # # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") | |||
| # if len(out) > 0 and len(list_out) > 0: | |||
| # raise ValueError("the return of collate_fns is not the same, must be dict or list") | |||
| # if len(list_out) == 1: | |||
| # list_out = list_out[-1] | |||
| # # print(list_out) | |||
| # return out if len(out) > 0 else list_out | |||
| # | |||
| # def get_collators(self): | |||
| # return self._collators | |||
| # | |||
| # def add_collator(self, collator: Callable): | |||
| # self._collators.append(collator) | |||
| # | |||
| # def set_as_numpy(self, as_numpy: bool): | |||
| # """ | |||
| # 存在AutoCollator时,as_numpy控制其返回值的类型 | |||
| # | |||
| # :param as_numpy: | |||
| # :return: | |||
| # """ | |||
| # for collator in self._collators: | |||
| # if isinstance(collator, AutoCollator): | |||
| # collator.set_as_numpy(as_numpy) | |||
| # return self | |||
| # | |||
| # def set_pad_val(self, *field_names, val=0): | |||
| # """ | |||
| # 存在AutoCollator时,设置field_name的padding值 | |||
| # | |||
| # :param field_names: 数据集的field名 | |||
| # :param val: padding的值 | |||
| # :return: | |||
| # """ | |||
| # flag = True | |||
| # for collator in self._collators: | |||
| # if isinstance(collator, AutoCollator): | |||
| # collator.set_pad_val(*field_names, val=val) | |||
| # flag = False | |||
| # if flag: | |||
| # warnings.warn("AutoCollator is remove, set_padding is unavailable!!") | |||
| # return self | |||
| # | |||
| # def set_input(self, *field_names): | |||
| # """ | |||
| # 设置AutoCollator需要的field_names,未被设置默认过滤掉 | |||
| # | |||
| # :param field_names: | |||
| # :return: | |||
| # """ | |||
| # flag = True | |||
| # for collator in self._collators: | |||
| # if isinstance(collator, AutoCollator): | |||
| # collator.set_input(*field_names) | |||
| # flag = False | |||
| # if flag: | |||
| # warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||
| # return self | |||
| # | |||
| # | |||
| # class AutoCollator(Collator): | |||
| # | |||
| # def __init__(self, as_numpy: bool): | |||
| # super(AutoCollator, self).__init__() | |||
| # self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||
| # self.need_inputs = set() # 需要的 field name | |||
| # self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||
| # self.field_dims = None # 每列数据单元维度 | |||
| # self.as_numpy = as_numpy | |||
| # | |||
| # def __call__(self, ins_lst: List[Dict]) -> dict: | |||
| # if len(self.need_inputs) == 0: | |||
| # raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||
| # # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 | |||
| # | |||
| # # 第一种情况,设置了 set_input 的值 | |||
| # # 第二种情况, 根据数据的类型的判断是否 padding | |||
| # if self.field_dtypes is None and self.field_dims is None: | |||
| # field_dtypes, field_dims = {}, {} | |||
| # for key, value in ins_lst[0].items(): | |||
| # if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: | |||
| # field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) | |||
| # self.field_dtypes = field_dtypes | |||
| # self.field_dims = field_dims | |||
| # | |||
| # pack_ins_lst, pad_ins_lst = {field_name: [] | |||
| # for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||
| # # 将 list 列表内数据按列名打包 | |||
| # for per_ins in ins_lst: | |||
| # for field_name, _field_content in per_ins.items(): | |||
| # if field_name in self.need_inputs: | |||
| # pack_ins_lst[field_name].append(_field_content) | |||
| # | |||
| # pad_field_kv = {field_name: 0 for field_name in self.need_inputs} | |||
| # pad_field_kv.update(self.pad_field_value) | |||
| # self.pad_field_value = pad_field_kv | |||
| # | |||
| # if len(self.pad_field_value.keys()) > 0: | |||
| # # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||
| # non_pad_field_names = [] | |||
| # for k, v in self.pad_field_value.items(): | |||
| # if v is None: | |||
| # non_pad_field_names.append(k) | |||
| # | |||
| # # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||
| # for field_name in non_pad_field_names: | |||
| # field_array = pack_ins_lst.pop(field_name) | |||
| # pad_ins_lst[field_name] = np.array(field_array) | |||
| # | |||
| # for field_name, field_array in pack_ins_lst.items(): | |||
| # content = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
| # self.field_dims[field_name], | |||
| # self.pad_field_value[field_name], | |||
| # as_numpy=self.as_numpy) | |||
| # pad_ins_lst[field_name] = content | |||
| # | |||
| # # else: | |||
| # # # 取出每列的数据,根据类型判断是否能 pad | |||
| # # for field_name, field_array in pack_ins_lst.items(): | |||
| # # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
| # # self.field_dims[field_name], | |||
| # # pad_val=0, as_numpy=self.as_numpy) | |||
| # # pad_ins_lst[field_name] = pad_field_array | |||
| # | |||
| # return pad_ins_lst | |||
| # | |||
| # def set_pad_val(self, *field_names, val=0): | |||
| # for field_name in field_names: | |||
| # self.pad_field_value[field_name] = val | |||
| # | |||
| # def set_as_numpy(self, as_numpy: bool): | |||
| # self.as_numpy = as_numpy | |||
| # | |||
| # def set_input(self, *field_names): | |||
| # for field_name in field_names: | |||
| # self.need_inputs.add(field_name) | |||
| # | |||
| # | |||
| # def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||
| # | |||
| # if field_type: | |||
| # # 不处理, 返回 np.array 类型 | |||
| # if field_dim > 3: | |||
| # return np.array(content) | |||
| # # 元素类型为数值类型 np.int64, np.float64, int, float 等 | |||
| # if isinstance(field_type, type) and \ | |||
| # (issubclass(field_type, np.number) or issubclass(field_type, Number)): | |||
| # if field_dim == 0: | |||
| # array = np.array(content, dtype=field_type) | |||
| # elif field_dim == 1: | |||
| # max_len = max(map(len, content)) | |||
| # array = np.full((len(content), max_len), pad_val, dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # array[i, :len(content_i)] = content_i | |||
| # elif field_dim == 2: | |||
| # max_len = max(map(len, content)) | |||
| # max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
| # content_i in content]) | |||
| # array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # for j, content_ii in enumerate(content_i): | |||
| # array[i, j, :len(content_ii)] = content_ii | |||
| # else: | |||
| # shape = np.shape(content) | |||
| # if len(shape) == 4: # 说明各 dimension 是相同的大小 | |||
| # array = np.array(content, dtype=field_type) | |||
| # else: | |||
| # raise RuntimeError( | |||
| # f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
| # if as_numpy is False: | |||
| # array = torch.tensor(array) | |||
| # return array | |||
| # # 元素类型为数值类型 torch.float 等 | |||
| # elif str(field_type).startswith('torch'): | |||
| # if field_dim == 0: | |||
| # tensor = torch.tensor(content).to(field_type) | |||
| # elif field_dim == 1: | |||
| # max_len = max(map(len, content)) | |||
| # tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # tensor[i, :len(content_i)] = content_i.clone().detach() | |||
| # elif field_dim == 2: | |||
| # max_len = max(map(len, content)) | |||
| # max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
| # content_i in content]) | |||
| # tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
| # dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # for j, content_ii in enumerate(content_i): | |||
| # tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
| # else: | |||
| # shapes = set([np.shape(content_i) for content_i in content]) | |||
| # if len(shapes) > 1: | |||
| # raise RuntimeError( | |||
| # f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
| # shape = shapes.pop() | |||
| # if len(shape) == 3: | |||
| # tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, | |||
| # dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # tensor[i] = content_i.clone().detach().to(field_type) | |||
| # else: | |||
| # raise RuntimeError( | |||
| # f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
| # return tensor | |||
| # # TODO 增加jittor/paddle? | |||
| # elif str(field_type).startswith('paddle'): | |||
| # if field_dim == 0: | |||
| # tensor = paddle.Tensor(content).to(field_type) | |||
| # elif field_dim == 1: | |||
| # max_len = max(map(len, content)) | |||
| # tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # tensor[i, :len(content_i)] = content_i.clone().detach() | |||
| # elif field_dim == 2: | |||
| # max_len = max(map(len, content)) | |||
| # max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
| # content_i in content]) | |||
| # tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
| # dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # for j, content_ii in enumerate(content_i): | |||
| # tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
| # else: | |||
| # shapes = set([np.shape(content_i) for content_i in content]) | |||
| # if len(shapes) > 1: | |||
| # raise RuntimeError( | |||
| # f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
| # shape = shapes.pop() | |||
| # if len(shape) == 3: | |||
| # tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, | |||
| # dtype=field_type) | |||
| # for i, content_i in enumerate(content): | |||
| # tensor[i] = content_i.clone().detach().to(field_type) | |||
| # else: | |||
| # raise RuntimeError( | |||
| # f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
| # return tensor | |||
| # | |||
| # else: | |||
| # return np.array(content) # 不进行任何操作 | |||
| # else: | |||
| # return np.array(content) | |||
| def _compare_tuple(t1, t2): | |||
| """ | |||
| 检测 t1 和 t2 的关系。 | |||
| 例如 (1, ) 和 (1, ) 关系为 0,表示两者完全没有差异 | |||
| 例如 (1, ) 和 (2, ) 关系为 None,表示完全不同 | |||
| 例如 (1, 2, 3) 和 (1, ) 关系为 2,表示前者比后者长 2 位 | |||
| 但 例如 (1, 2, 3) 和 (2, ) 关系为 None,因为它们从前往后的key 不一样 | |||
| 例如 (1, 2, 3) 和 (1, 3) 关系为 None,因为它们从前往后的key 不一样 | |||
| 例如 (1, ) 和 (1, 2, 3) 关系为 -2,表示后者比前者长 2 位 | |||
| 但 例如 (2, ) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样 | |||
| 例如 (1, 3) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样 | |||
| :param t1: | |||
| :param t2: | |||
| :return: None 没有关系; 0 两者完全一样; >0 t1比t2长,<0 t2比t1长 | |||
| """ | |||
| if t1 == t2: | |||
| return 0 | |||
| for _t1, _t2 in zip(t1, t2): # 会按照最短的计算 | |||
| if _t1 != _t2: | |||
| return None | |||
| return len(t1) - len(t2) | |||
| @@ -3,7 +3,7 @@ from functools import reduce | |||
| from typing import Sequence, Mapping, Dict | |||
| class MappingPackerUnPacker: | |||
| class MappingPackerUnpacker: | |||
| @staticmethod | |||
| def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict: | |||
| """ | |||
| @@ -53,8 +53,9 @@ class NestedMappingPackerUnpacker: | |||
| @staticmethod | |||
| def pack_batch(batch): | |||
| if len(batch) == 0: | |||
| return [] | |||
| dicts = [] | |||
| for key, value in batch.items(): | |||
| if not isinstance(key, tuple): | |||
| key = [key] | |||
| @@ -65,30 +66,38 @@ class NestedMappingPackerUnpacker: | |||
| return reduce(_merge_dict, dicts) | |||
| class | |||
| class SequencePackerUnpacker: | |||
| @staticmethod | |||
| def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict: | |||
| """ | |||
| 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||
| :param batch: | |||
| :param ignore_fields: 需要忽略的field | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| for sample in batch: | |||
| for i, content in enumerate(sample): | |||
| field_name = f'_{i}' | |||
| if field_name in ignore_fields: | |||
| continue | |||
| dict_batch[field_name].append(content) | |||
| return dict_batch | |||
| def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: | |||
| """ | |||
| 将 nested 的 dict 中的内容展开到一个 flat dict 中 | |||
| @staticmethod | |||
| def pack_batch(batch): | |||
| return list(batch.values()) | |||
| :param batch: | |||
| :param ignore_fields: 需要忽略的 field 。 | |||
| :param stop_deep_fields: 不需要继续往下衍射的 | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| for sample in batch: | |||
| for key, value in sample.items(): | |||
| if key in ignore_fields: | |||
| continue | |||
| if isinstance(value, Mapping) and key not in stop_deep_fields: | |||
| _dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent=(key,)) | |||
| for key, value in _dict_batch.items(): | |||
| dict_batch[key].append(value) | |||
| else: | |||
| dict_batch[key].append(value) | |||
| return dict_batch | |||
| class SinglePackerUnpacker: | |||
| @staticmethod | |||
| def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields): | |||
| return {'_single': batch} | |||
| @staticmethod | |||
| def pack_batch(batch): | |||
| return batch['_single'] | |||
| def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict: | |||
| @@ -136,25 +145,3 @@ def _merge_dict(a, b, path=None): | |||
| else: | |||
| a[key] = b[key] | |||
| return a | |||
| def unpack_batch_sequence(batch:Sequence[Sequence], ignore_fields)->Dict: | |||
| """ | |||
| 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||
| :param batch: | |||
| :param ignore_fields: 需要忽略的field | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| for sample in batch: | |||
| for i, content in enumerate(sample): | |||
| field_name = f'_{i}' | |||
| if field_name in ignore_fields: | |||
| continue | |||
| dict_batch[field_name].append(content) | |||
| return dict_batch | |||
| def pack_batch_sequence(batch:Mapping)->Sequence: | |||
| return list(batch.values()) | |||
| @@ -112,16 +112,19 @@ class TorchTensorPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| device = None | |||
| try: | |||
| if not isinstance(batch_field[0], torch.Tensor): | |||
| batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||
| else: | |||
| device = batch_field[0].device | |||
| except AttributeError: | |||
| raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | |||
| f"it must have tolist() method.") | |||
| shapes = [field.shape for field in batch_field] | |||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
| tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
| tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype, device=device) | |||
| for i, field in enumerate(batch_field): | |||
| slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
| tensor[slices] = field | |||
| @@ -221,7 +221,7 @@ class Evaluator: | |||
| @evaluate_batch_loop.setter | |||
| def evaluate_batch_loop(self, loop: Loop): | |||
| if self.evaluate_batch_step_fn is not None: | |||
| logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||
| logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||
| "when the `evaluate_batch_loop` is also customized.") | |||
| self._evaluate_batch_loop = loop | |||
| @@ -7,6 +7,7 @@ from dataclasses import is_dataclass | |||
| import os | |||
| from pathlib import Path | |||
| import io | |||
| import inspect | |||
| __all__ = [ | |||
| 'Trainer', | |||
| @@ -62,8 +63,8 @@ class Trainer(TrainerEventTrigger): | |||
| ): | |||
| r""" | |||
| `Trainer` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 | |||
| 的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需 | |||
| 要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; | |||
| 的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需 | |||
| 要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; | |||
| :param model: 训练所需要的模型,目前支持 pytorch; | |||
| :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle | |||
| @@ -305,7 +306,7 @@ class Trainer(TrainerEventTrigger): | |||
| else: | |||
| if self.driver.is_distributed(): | |||
| if catch_KeyboardInterrupt: | |||
| logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||
| logger.rank_zero_warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device " | |||
| "driver. And we are gonna to set it to False.") | |||
| catch_KeyboardInterrupt = False | |||
| @@ -535,7 +536,7 @@ class Trainer(TrainerEventTrigger): | |||
| _not_called_callback_fns.append(each_callback_fn) | |||
| if check_mode: | |||
| logger.warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||
| logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||
| f"callback_fns: {_not_called_callback_fns}, but it seems that" | |||
| "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||
| # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass | |||
| @@ -56,12 +56,12 @@ class TrainerState: | |||
| 我们保存的state大部分上是 trainer 断点重训 需要重新加载的; | |||
| 专属于 `Trainer` 的状态记载的类; | |||
| n_epochs: 训练过程中总共的 epoch 的数量; | |||
| cur_epoch_idx: 当前正在运行第几个 epoch; | |||
| global_forward_batches: 当前模型总共 forward 了多少个 step; | |||
| batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||
| num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||
| total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||
| :param n_epochs: 训练过程中总共的 epoch 的数量; | |||
| :param cur_epoch_idx: 当前正在运行第几个 epoch; | |||
| :param global_forward_batches: 当前模型总共 forward 了多少个 step; | |||
| :param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||
| :param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||
| :param total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||
| """ | |||
| n_epochs: Optional[int] = None # 无论如何重新算 | |||
| @@ -139,7 +139,7 @@ class JittorDataLoader: | |||
| def set_ignore(self, *field_names) -> "JittorDataLoader": | |||
| """ | |||
| 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
| Ex:: | |||
| Example:: | |||
| collator.set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| @@ -143,7 +143,7 @@ class PaddleDataLoader(DataLoader): | |||
| def set_ignore(self, *field_names) -> Collator: | |||
| """ | |||
| 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
| Ex:: | |||
| Example:: | |||
| collator.set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| @@ -152,7 +152,7 @@ class TorchDataLoader(DataLoader): | |||
| def set_ignore(self, *field_names) -> Collator: | |||
| """ | |||
| 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
| Ex:: | |||
| Example:: | |||
| collator.set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| @@ -9,22 +9,18 @@ __all__ = [ | |||
| import _pickle as pickle | |||
| from copy import deepcopy | |||
| from typing import Optional, List, Callable, Union, Dict, Any, Mapping | |||
| from functools import partial | |||
| from types import LambdaType | |||
| import sys | |||
| import time | |||
| import numpy as np | |||
| from threading import Thread | |||
| try: | |||
| import multiprocessing as mp | |||
| except: | |||
| pass | |||
| from .field import FieldArray | |||
| from .instance import Instance | |||
| from fastNLP.core.utils.utils import pretty_table_printer | |||
| from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||
| from fastNLP.core.collators import Collator | |||
| from fastNLP.core.utils.rich_progress import f_rich_progress | |||
| from fastNLP.core.log import logger | |||
| from ..log import logger | |||
| class ApplyResultException(Exception): | |||
| @@ -35,14 +31,13 @@ class ApplyResultException(Exception): | |||
| def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, | |||
| pipe=None, desc: str = None) -> list: | |||
| desc: str = None) -> list: | |||
| """ | |||
| 对数据集进行处理封装函数,以便多进程使用 | |||
| :param ds: 数据集 | |||
| :param _apply_field: 需要处理数据集的field_name | |||
| :param func: 用户自定义的func | |||
| :param pipe: 管道 | |||
| :param desc: 进度条的描述字符 | |||
| :param show_progress_bar: 是否展示子进程进度条 | |||
| :return: | |||
| @@ -60,8 +55,6 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||
| results.append(func(ins[_apply_field])) | |||
| else: | |||
| results.append(func(ins)) | |||
| if pipe is not None: | |||
| pipe.send([idx + 1]) | |||
| if show_progress_bar: | |||
| f_rich_progress.update(pg_main, advance=1) | |||
| @@ -75,31 +68,36 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||
| return results | |||
| def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None: | |||
| def _multi_proc(ds, _apply_field, func, counter, queue): | |||
| """ | |||
| 多进程下显示主进程的进度条 | |||
| 对数据集进行处理封装函数,以便多进程使用 | |||
| :param parent: 进程管道 | |||
| :param total_len: 数据集总长度 | |||
| :param desc: 进度条描述符 | |||
| :param show_progress_bar: 是否展示进度条 | |||
| :param ds: 数据集 | |||
| :param _apply_field: 需要处理数据集的field_name | |||
| :param func: 用户自定义的func | |||
| :param counter: 计数器 | |||
| :param queue: 多进程时,将结果输入到这个 queue 中 | |||
| :return: | |||
| """ | |||
| desc = desc if desc else "Main" | |||
| main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar) | |||
| # pb_main = tqdm(total=total_len, desc=desc, position=0) | |||
| nums = 0 | |||
| while True: | |||
| msg = parent.recv()[0] | |||
| if msg is not None: | |||
| f_rich_progress.update(main_pro, advance=1) | |||
| nums += 1 | |||
| if nums == total_len: | |||
| break | |||
| f_rich_progress.destroy_task(main_pro) | |||
| # pb_main.close() | |||
| idx = -1 | |||
| import contextlib | |||
| with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁 | |||
| logger.set_stdout(stdout='raw') | |||
| results = [] | |||
| try: | |||
| for idx, ins in enumerate(ds): | |||
| if _apply_field is not None: | |||
| res = func(ins[_apply_field]) | |||
| else: | |||
| res = func(ins) | |||
| results.append(res) | |||
| with counter.get_lock(): | |||
| counter.value += 1 | |||
| except BaseException as e: | |||
| if idx != -1: | |||
| logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||
| raise e | |||
| queue.put(pickle.dumps(results)) | |||
| class DataSet: | |||
| @@ -114,7 +112,7 @@ class DataSet: | |||
| 每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
| """ | |||
| self.field_arrays = {} | |||
| self._collator = Collator(backend="numpy") | |||
| self._collator = Collator() | |||
| if data is not None: | |||
| if isinstance(data, Dict): | |||
| length_set = set() | |||
| @@ -127,7 +125,6 @@ class DataSet: | |||
| for ins in data: | |||
| assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||
| self.append(ins) | |||
| else: | |||
| raise ValueError("data only be dict or list type.") | |||
| @@ -263,7 +260,7 @@ class DataSet: | |||
| try: | |||
| self.field_arrays[name].append(field) | |||
| except Exception as e: | |||
| print(f"Cannot append to field:{name}.") | |||
| logger.error(f"Cannot append to field:{name}.") | |||
| raise e | |||
| def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | |||
| @@ -399,17 +396,17 @@ class DataSet: | |||
| raise KeyError("DataSet has no field named {}.".format(field_name)) | |||
| return self | |||
| def apply_field(self, func: Union[Callable], field_name: str = None, | |||
| def apply_field(self, func: Callable, field_name: str = None, | |||
| new_field_name: str = None, num_proc: int = 0, | |||
| progress_desc: str = None, show_progress_bar: bool = True): | |||
| r""" | |||
| 将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 | |||
| :param num_proc: 进程的数量 | |||
| :param field_name: 传入 func 的是哪个 field。 | |||
| :param func: input是 instance 中名为 `field_name` 的 field 的内容。 | |||
| :param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 | |||
| 盖之前的 field。如果为 None 则不创建新的 field。 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param progress_desc: progress_desc 的值,默认为 Main | |||
| :param show_progress_bar: 是否展示进度条,默认展示进度条 | |||
| """ | |||
| @@ -435,13 +432,13 @@ class DataSet: | |||
| func 可以返回一个或多个 field 上的结果。 | |||
| .. note:: | |||
| ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply`` 区别的介绍。 | |||
| :param num_proc: 进程的数量 | |||
| :param field_name: 传入func的是哪个field。 | |||
| :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示进度条,默认展示 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | |||
| :return Dict[str:Field]: 返回一个字典 | |||
| @@ -469,9 +466,7 @@ class DataSet: | |||
| except Exception as e: | |||
| if idx != -1: | |||
| if isinstance(e, ApplyResultException): | |||
| print(e.msg) | |||
| print("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
| logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
| raise e | |||
| if modify_fields is True: | |||
| @@ -484,24 +479,25 @@ class DataSet: | |||
| show_progress_bar: bool = True, _apply_field: str = None, | |||
| progress_desc: str = 'Main') -> list: | |||
| """ | |||
| :param num_proc: 进程的数量 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | |||
| :param _apply_field: 需要传进去func的数据集的field_name | |||
| :param show_progress_bar: 是否展示progress进度条,默认为展示 | |||
| :param progress_desc: 进度条的描述字符,默认为'Main | |||
| """ | |||
| if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>": | |||
| raise ("Lambda function does not support multiple processes, please set `num_proc=0`.") | |||
| if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'): | |||
| raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`") | |||
| if num_proc == 0: | |||
| if num_proc < 2: | |||
| results = _apply_single(ds=self, _apply_field=_apply_field, func=func, | |||
| desc=progress_desc, show_progress_bar=show_progress_bar) | |||
| else: | |||
| # TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 | |||
| results = [] | |||
| if num_proc > len(self): | |||
| num_proc = len(self) | |||
| print( | |||
| f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}." | |||
| ) | |||
| import multiprocessing as mp | |||
| ctx = mp.get_context('fork') | |||
| num_proc = min(num_proc, len(self)) | |||
| # 划分数据集 | |||
| shard_len = len(self) // num_proc | |||
| num_left_sample = len(self) % num_proc | |||
| @@ -511,24 +507,32 @@ class DataSet: | |||
| end = shard_len + int(_i<num_left_sample) + start | |||
| shard_data.append(self[start:end]) | |||
| start = end | |||
| # 配置管道,线程以实现 main progress 能够实时更新。 | |||
| parent, child = mp.Pipe() | |||
| main_thread = Thread(target=_progress_bar, args=(parent, len(self), progress_desc, | |||
| show_progress_bar)) | |||
| partial_single_map = partial(_apply_single, _apply_field=_apply_field, func=func, | |||
| pipe=child, show_progress_bar=False) | |||
| # 开启进程池,线程 | |||
| main_thread.start() | |||
| pool = mp.Pool(processes=num_proc) | |||
| pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
| for proc_id, ds in enumerate(shard_data)] | |||
| pool.close() | |||
| pool.join() | |||
| main_thread.join() | |||
| for async_result in pool_outs: | |||
| data = async_result.get() | |||
| results.extend(data) | |||
| # 配置共享参数,线程以实现 main progress 能够实时更新。 | |||
| counter = ctx.Value('i', 0, lock=True) | |||
| pool = [] | |||
| queues = [] | |||
| results = [] | |||
| for i in range(num_proc): | |||
| queue = ctx.SimpleQueue() | |||
| proc = ctx.Process(target=_multi_proc, args=(shard_data[i], _apply_field, func, counter, queue)) | |||
| proc.start() | |||
| pool.append(proc) | |||
| queues.append(queue) | |||
| total_len = len(self) | |||
| task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar) | |||
| last_count = -1 | |||
| while counter.value < total_len or last_count == -1: | |||
| while counter.value == last_count: | |||
| time.sleep(0.1) | |||
| advance = counter.value - last_count | |||
| last_count = counter.value | |||
| f_rich_progress.update(task_id, advance=advance, refresh=True) | |||
| for idx, proc in enumerate(pool): | |||
| results.extend(pickle.loads(queues[idx].get())) | |||
| proc.join() | |||
| f_rich_progress.destroy_task(task_id) | |||
| return results | |||
| def apply_more(self, func: Callable = None, modify_fields: bool = True, | |||
| @@ -548,12 +552,11 @@ class DataSet: | |||
| :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
| :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param num_proc: 进程的数量 | |||
| :param show_progress_bar: 是否使用tqd显示预处理进度 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | |||
| :return Dict[str:Field]: 返回一个字典 | |||
| """ | |||
| # 返回 dict , 检查是否一直相同 | |||
| assert callable(func), "The func you provide is not callable." | |||
| assert callable(func), "The func is not callable." | |||
| assert len(self) != 0, "Null DataSet cannot use apply()." | |||
| assert num_proc >= 0, "num_proc must >= 0" | |||
| idx = -1 | |||
| @@ -577,9 +580,7 @@ class DataSet: | |||
| except Exception as e: | |||
| if idx != -1: | |||
| if isinstance(e, ApplyResultException): | |||
| print(e.msg) | |||
| print("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
| logger.error("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
| raise e | |||
| if modify_fields is True: | |||
| @@ -595,7 +596,7 @@ class DataSet: | |||
| :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
| 盖之前的field。如果为None则不创建新的field。 | |||
| :param num_proc: 进程的数量。 | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示进度条。 | |||
| :param progress_desc: progress bar 显示的值,默认为空。 | |||
| """ | |||
| @@ -665,8 +666,7 @@ class DataSet: | |||
| np.random.shuffle(all_indices) | |||
| split = int(ratio * len(self)) | |||
| if split == 0: | |||
| error_msg = f'Dev DataSet has {split} instance after split.' | |||
| print(error_msg) | |||
| error_msg = f'Dev DataSet has `{split}` instance after split.' | |||
| raise IndexError(error_msg) | |||
| dev_indices = all_indices[:split] | |||
| train_indices = all_indices[split:] | |||
| @@ -776,35 +776,3 @@ class DataSet: | |||
| if self._collator is None: | |||
| self._collator = Collator() | |||
| return self._collator | |||
| if __name__ == '__main__': | |||
| # from fastNLP import DataSet | |||
| # if __name__=='__main__': | |||
| # data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
| # data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) | |||
| import multiprocess as mp | |||
| # from fastNLP.core.dataset.dataset import _apply_single, _progress_bar | |||
| from functools import partial | |||
| from threading import Thread | |||
| shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), | |||
| DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] | |||
| parent, chid = mp.Pipe() | |||
| partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), | |||
| pipe=chid, show_progress_bar=False) | |||
| thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) | |||
| thread.start() | |||
| pool = mp.Pool(processes=6) | |||
| pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
| for proc_id, ds in enumerate(shard_data)] | |||
| pool.close() | |||
| pool.join() | |||
| thread.join() | |||
| results = [] | |||
| for async_result in pool_outs: | |||
| data = async_result.get() | |||
| results.extend(data) | |||
| print(results) | |||
| @@ -17,6 +17,8 @@ class Instance(Mapping): | |||
| Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | |||
| Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | |||
| instance = Instance() # 请补充完整 | |||
| """ | |||
| def __init__(self, **fields): | |||
| @@ -49,8 +49,8 @@ class Driver(ABC): | |||
| 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
| 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
| 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
| 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
| 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
| :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
| 可以可以加载。 | |||
| @@ -66,13 +66,13 @@ class Driver(ABC): | |||
| def set_deterministic_dataloader(self, dataloader): | |||
| r""" | |||
| 为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的;例如对于 torch 的 dataloader,其 | |||
| 需要将 worker_init_fn 替换; | |||
| 需要将 worker_init_fn 替换; | |||
| """ | |||
| def set_sampler_epoch(self, dataloader, cur_epoch_idx): | |||
| r""" | |||
| 对于分布式的 sampler,例如 torch 的 DistributedSampler,其需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||
| dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。 | |||
| dataloader 中可能真正发挥作用的是 batch_sampler 也可能是 sampler。 | |||
| :param dataloader: 需要设置 epoch 的 dataloader 。 | |||
| :param cur_epoch_idx: 当前是第几个 epoch; | |||
| @@ -101,17 +101,17 @@ class Driver(ABC): | |||
| 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||
| 这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||
| `evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||
| `evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||
| `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||
| `evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||
| `evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||
| `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||
| 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||
| 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||
| 函数,然后给出 warning; | |||
| 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||
| 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||
| forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||
| 可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||
| forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||
| 可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||
| :param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||
| :return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||
| @@ -202,7 +202,7 @@ class Driver(ABC): | |||
| def get_model_no_sync_context(self): | |||
| r""" | |||
| 返回一个用于关闭多进程之间 model 中的自动互相同步操作的 context 上下文对象;只有多卡的 driver 需要单独实现该函数, | |||
| 单卡的 driver 不需要; | |||
| 单卡的 driver 不需要; | |||
| :return: 返回一个类似于 DistributedDataParallel(model).no_sync 的 context 上下文对象; | |||
| """ | |||
| @@ -273,7 +273,7 @@ class Driver(ABC): | |||
| def load(self, folder: Union[str, Path], dataloader, only_state_dict: bool =True, should_load_model: bool = True, **kwargs) -> Dict: | |||
| r""" | |||
| 断点重训的加载函数,注意该函数会负责读取数据,并且恢复 optimizers , fp16 的 state_dict 和 模型(根据 should_load_model )和; | |||
| 其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | |||
| 其它在 Driver.save() 函数中执行的保存操作,然后将一个 state 字典返回给 trainer ( 内容为Driver.save() 接受到的 states )。 | |||
| 该函数应该在所有 rank 上执行。 | |||
| @@ -302,7 +302,7 @@ class Driver(ABC): | |||
| def tensor_to_numeric(tensor, reduce: Optional[str]=None): | |||
| r""" | |||
| 将一个 `tensor` 对象(仅处理当前 driver 使用的 tensor 即可)转换为 python 的 `numeric` 对象;如果 tensor 只包含一个 | |||
| 元素则返回 float 或 int 。 | |||
| 元素则返回 float 或 int 。 | |||
| :param tensor: 需要被转换的 `tensor` 对象 | |||
| :param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||
| @@ -323,7 +323,7 @@ class Driver(ABC): | |||
| """ | |||
| 保证用户拿到的模型一定是最原始的模型; | |||
| 注意因为我们把保存模型的主要逻辑和代码移到了 `Driver` 中,因此在 `save_model` 函数中,一定要先调用此函数来保证我们保存的模型一定是 | |||
| 最为原始的模型; | |||
| 最为原始的模型; | |||
| 需要注意用户本身传入的模型就是经过类似 `torch.nn.DataParallel` 或者 `torch.nn.parallel.DistributedDataParallel` 包裹的模型, | |||
| 因此在该函数内需要先判断模型的类别; | |||
| @@ -335,7 +335,7 @@ class Driver(ABC): | |||
| r""" | |||
| 用来将模型转移到指定的 device 上; | |||
| 之所以写成 `staticmethod`,是因为一方面在 `Driver` 中我们要使用 `unwrap_model` 来拿到最原始的模型,另一方面,在 `save_model` | |||
| 中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数; | |||
| 中,我们需要先将模型移到 cpu 后,又再移到 gpu 上,因此不适宜在该函数内部调用 `unwrap_model`,而是将 model 作为该函数的参数; | |||
| """ | |||
| @abstractmethod | |||
| @@ -373,7 +373,7 @@ class Driver(ABC): | |||
| def on_exception(self): | |||
| """ | |||
| 该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | |||
| 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | |||
| 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | |||
| 因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 | |||
| pid 的信息; | |||
| @@ -399,7 +399,7 @@ class Driver(ABC): | |||
| def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||
| """ | |||
| 从 src 端将 obj 对象(可能是 tensor ,可能是 object )broadcast 到其它所有进程。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||
| 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||
| 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||
| :param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||
| :param int src: source 的 global rank 。 | |||
| @@ -415,7 +415,7 @@ class Driver(ABC): | |||
| def all_gather(self, obj, group)->List: | |||
| """ | |||
| 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||
| pickle 进行序列化,接收到之后再反序列化。 | |||
| pickle 进行序列化,接收到之后再反序列化。 | |||
| :param obj: 可以是 float/int/bool/np.ndarray/{}/[]/Tensor等。 | |||
| :param group: | |||
| @@ -171,7 +171,7 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List: | |||
| """ | |||
| 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||
| example: | |||
| example:: | |||
| obj = { | |||
| 'a': [1, 1], | |||
| 'b': [[1, 2], [1, 2]], | |||
| @@ -379,13 +379,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
| self._has_fleetwrapped = True | |||
| def on_exception(self): | |||
| """ | |||
| 该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 | |||
| 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | |||
| 因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 | |||
| pid 的信息; | |||
| """ | |||
| rank_zero_rm(self.gloo_rendezvous_dir) | |||
| super().on_exception() | |||
| @@ -420,17 +413,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
| return self.model_device | |||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
| """ | |||
| 通过调用 `fn` 来实现训练时的前向传播过程; | |||
| 注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||
| 函数; | |||
| :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
| :param fn: 调用该函数进行一次计算。 | |||
| :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||
| 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||
| :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
| """ | |||
| if self._has_fleetwrapped: | |||
| return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, | |||
| wo_auto_param_call=self.wo_auto_param_call) | |||
| @@ -441,27 +423,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
| return fn(batch) | |||
| def get_model_call_fn(self, fn: str) -> Tuple: | |||
| """ | |||
| 该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||
| 该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||
| 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||
| 这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||
| `evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||
| `evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||
| `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||
| 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||
| 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||
| 函数,然后给出 warning; | |||
| 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||
| 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||
| forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||
| 可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||
| :param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||
| :return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||
| """ | |||
| model = self.unwrap_model() | |||
| if self._has_fleetwrapped: | |||
| if hasattr(model, fn): | |||
| @@ -487,24 +448,6 @@ class PaddleFleetDriver(PaddleDriver): | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]], | |||
| reproducible: bool = False): | |||
| r""" | |||
| 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||
| :param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||
| :param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||
| 切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||
| 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
| 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
| 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
| 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
| :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
| 可以可以加载。 | |||
| :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
| 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
| dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
| """ | |||
| # 暂时不支持iterableDataset | |||
| assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
| "FastNLP does not support `IteratorDataset` now." | |||
| @@ -619,43 +562,9 @@ class PaddleFleetDriver(PaddleDriver): | |||
| f"not {type(each_optimizer)}.") | |||
| def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | |||
| """ | |||
| 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | |||
| 传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||
| :param obj: obj,可能是 Tensor 或 嵌套类型的数据 | |||
| :param int src: source 的 global rank 。 | |||
| :param int dst: target 的 global rank,可以是多个目标 rank | |||
| :param group: 所属的 group | |||
| :param kwargs: | |||
| :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 | |||
| 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | |||
| """ | |||
| # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 | |||
| device = get_device_from_visible(self.data_device) | |||
| return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group) | |||
| def all_gather(self, obj, group=None) -> List: | |||
| """ | |||
| 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | |||
| pickle 进行序列化,接收到之后再反序列化。 | |||
| example: | |||
| obj = { | |||
| 'a': [1, 1], | |||
| 'b': [[1, 2], [1, 2]], | |||
| 'c': { | |||
| 'd': [1, 2] | |||
| } | |||
| } | |||
| -> | |||
| [ | |||
| {'a': 1, 'b':[1, 2], 'c':{'d': 1}}, | |||
| {'a': 1, 'b':[1, 2], 'c':{'d': 2}} | |||
| ] | |||
| :param obj: 需要传输的对象,在每个rank上都应该保持相同的结构。 | |||
| :param group: | |||
| :return: | |||
| """ | |||
| return fastnlp_paddle_all_gather(obj, group=group) | |||
| @@ -69,7 +69,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
| if not isinstance(device, List): | |||
| return PaddleSingleDriver(model, device, **kwargs) | |||
| else: | |||
| logger.warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||
| logger.rank_zero_warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||
| "`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" | |||
| "`driver` as `PaddleFleetDriver`.") | |||
| return PaddleFleetDriver(model, device, **kwargs) | |||
| @@ -77,7 +77,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
| if not isinstance(device, List): | |||
| if device == "cpu": | |||
| raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") | |||
| logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||
| logger.rank_zero_warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||
| "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | |||
| "choose `paddle` driver.") | |||
| return PaddleFleetDriver(model, [device], **kwargs) | |||
| @@ -47,7 +47,7 @@ if _NEED_IMPORT_PADDLE: | |||
| class PaddleDriver(Driver): | |||
| r""" | |||
| Paddle框架的Driver,包括实现单卡训练的`PaddleSingleDriver`和分布式训练的`PaddleFleetDriver`。 | |||
| Paddle框架的Driver,包括实现单卡训练的 `PaddleSingleDriver` 和分布式训练的 `PaddleFleetDriver`。 | |||
| """ | |||
| def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | |||
| if not isinstance(model, paddle.nn.Layer): | |||
| @@ -72,7 +72,7 @@ class PaddleDriver(Driver): | |||
| :param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | |||
| """ | |||
| if set_to_none: | |||
| logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||
| logger.rank_zero_warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||
| for optimizer in self.optimizers: | |||
| optimizer.clear_grad() | |||
| @@ -131,8 +131,7 @@ class PaddleDriver(Driver): | |||
| @staticmethod | |||
| def tensor_to_numeric(tensor, reduce=None): | |||
| r""" | |||
| 将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个 | |||
| 元素则返回 float 或 int 。 | |||
| 将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个元素则返回 float 或 int 。 | |||
| :param tensor: 需要被转换的 `tensor` 对象 | |||
| :param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||
| @@ -158,11 +157,6 @@ class PaddleDriver(Driver): | |||
| ) | |||
| def set_model_mode(self, mode: str): | |||
| r""" | |||
| 设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式; | |||
| :param mode: 应为二者之一:["train", "eval"]; | |||
| """ | |||
| assert mode in {"train", "eval"} | |||
| getattr(self.model, mode)() | |||
| @@ -179,7 +173,6 @@ class PaddleDriver(Driver): | |||
| 可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save` | |||
| 的文档: | |||
| https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save | |||
| :return: | |||
| """ | |||
| model = self.unwrap_model() | |||
| if isinstance(filepath, Path): | |||
| @@ -196,12 +189,12 @@ class PaddleDriver(Driver): | |||
| def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||
| r""" | |||
| 加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | |||
| 加载模型的函数;将 filepath 中的模型加载并赋值给当前 model 。 | |||
| :param filepath: 需要被加载的对象的文件位置(需要包括文件名); | |||
| :param only_state_dict: 是否加载state_dict,默认为True。 | |||
| :param kwargs: | |||
| :return: | |||
| :param load_state_dict: 保存的文件是否只是模型的权重,还是完整的模型。即便是保存的完整的模型,此处也只能使用尝试加载filepath | |||
| 模型中的权重到自身模型,而不会直接替代当前 Driver 中的模型。 | |||
| :return: 返回加载指定文件后的结果; | |||
| """ | |||
| model = self.unwrap_model() | |||
| if isinstance(filepath, Path): | |||
| @@ -216,22 +209,6 @@ class PaddleDriver(Driver): | |||
| @rank_zero_call | |||
| def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
| r""" | |||
| 断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | |||
| 需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | |||
| 本身自己的任何状态;而每一个 driver 实例需要在该函数中实现保存模型和 optimizers 的 state_dict 的逻辑;同时妥善存储传入的 | |||
| states 中的内容(主要用于恢复 Trainer ,Callback 等) | |||
| 需要保证该函数只在 global rank 0 上运行 | |||
| :param folder: 保存断点重训的状态的文件名; | |||
| :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | |||
| 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | |||
| 传入的值保持一致。 | |||
| :param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 | |||
| :param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 | |||
| :param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 | |||
| :return: | |||
| """ | |||
| # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||
| # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||
| @@ -256,7 +233,7 @@ class PaddleDriver(Driver): | |||
| if dataloader_args.batch_size is not None: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| @@ -266,7 +243,7 @@ class PaddleDriver(Driver): | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| else: | |||
| raise RuntimeError( | |||
| @@ -329,7 +306,7 @@ class PaddleDriver(Driver): | |||
| self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||
| logger.debug("Load grad_scaler state dict...") | |||
| elif not isinstance(self.grad_scaler, DummyGradScaler): | |||
| logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||
| logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||
| f"the training process may be unstable.") | |||
| # 4. 恢复 sampler 的状态; | |||
| @@ -422,19 +399,10 @@ class PaddleDriver(Driver): | |||
| random.seed(stdlib_seed) | |||
| def set_deterministic_dataloader(self, dataloader): | |||
| r""" | |||
| 为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的; | |||
| 作用是替换 datalaoder 的 `worker_init_fn`。 | |||
| """ | |||
| if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||
| dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) | |||
| def set_sampler_epoch(self, dataloader: "DataLoader", cur_epoch_idx): | |||
| r""" | |||
| 对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||
| :param cur_epoch_idx: 当前是第几个 epoch; | |||
| """ | |||
| if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
| dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||
| @@ -38,7 +38,7 @@ class PaddleSingleDriver(PaddleDriver): | |||
| if isinstance(model, DataParallel): | |||
| raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") | |||
| cuda_visible_devices = os.environ.get(USER_CUDA_VISIBLE_DEVICES, None) | |||
| cuda_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||
| if cuda_visible_devices is None: | |||
| raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set " | |||
| "`FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | |||
| @@ -73,44 +73,12 @@ class PaddleSingleDriver(PaddleDriver): | |||
| self.model.to(device) | |||
| def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
| """ | |||
| 通过调用 `fn` 来实现训练时的前向传播过程; | |||
| 注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 | |||
| 函数; | |||
| :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | |||
| :param fn: 调用该函数进行一次计算。 | |||
| :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||
| 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||
| :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | |||
| """ | |||
| if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
| return auto_param_call(fn, batch, signature_fn=signature_fn) | |||
| else: | |||
| return fn(batch) | |||
| def get_model_call_fn(self, fn: str) -> Tuple: | |||
| """ | |||
| 该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; | |||
| 该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; | |||
| 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; | |||
| 这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 | |||
| `evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 | |||
| `evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 | |||
| `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | |||
| 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | |||
| 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||
| 函数,然后给出 warning; | |||
| 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||
| 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | |||
| forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | |||
| 可能需要额外标记最初传入 driver 的模型是哪种形式的; | |||
| :param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; | |||
| :return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; | |||
| """ | |||
| if hasattr(self.model, fn): | |||
| fn = getattr(self.model, fn) | |||
| if not callable(fn): | |||
| @@ -125,24 +93,6 @@ class PaddleSingleDriver(PaddleDriver): | |||
| def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
| reproducible: bool = False): | |||
| r""" | |||
| 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | |||
| :param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 | |||
| :param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader | |||
| 切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 | |||
| 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
| 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
| 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
| 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
| 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
| :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
| 可以可以加载。 | |||
| :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
| 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
| dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
| """ | |||
| # 暂时不支持iterableDataset | |||
| assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
| @@ -187,7 +137,7 @@ class PaddleSingleDriver(PaddleDriver): | |||
| @property | |||
| def data_device(self): | |||
| """ | |||
| 单卡模式不支持 data_device; | |||
| 返回数据所在的设备。由于单卡模式不支持 data_device,因此返回的是 model_device | |||
| """ | |||
| return self.model_device | |||
| @@ -51,7 +51,7 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||
| seed = int(seed) | |||
| if not (min_seed_value <= seed <= max_seed_value): | |||
| logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for " | |||
| logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for " | |||
| "you.") | |||
| # rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") | |||
| @@ -197,7 +197,7 @@ class TorchDriver(Driver): | |||
| if dataloader_args.batch_size is not None: | |||
| num_consumed_batches = num_consumed_batches * dataloader_args.batch_size | |||
| else: # 有可能 batch_size 为 None,就只有损失精度了 | |||
| logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| num_consumed_batches = sampler_states['num_consumed_samples'] | |||
| sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] | |||
| @@ -207,7 +207,7 @@ class TorchDriver(Driver): | |||
| sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \ | |||
| * num_consumed_batches | |||
| else: | |||
| logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, " | |||
| "it may cause missing some samples when reload.") | |||
| states['sampler_states'] = sampler_states | |||
| @@ -60,7 +60,7 @@ def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) -> | |||
| seed = int(seed) | |||
| if not (min_seed_value <= seed <= max_seed_value): | |||
| logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") | |||
| logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.") | |||
| seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||
| @@ -162,7 +162,7 @@ def _build_fp16_env(dummy=False): | |||
| if not torch.cuda.is_available(): | |||
| raise RuntimeError("No cuda") | |||
| if torch.cuda.get_device_capability(0)[0] < 7: | |||
| logger.warning( | |||
| logger.rank_zero_warning( | |||
| "NOTE: your device does NOT support faster training with fp16, " | |||
| "please switch to FP32 which is likely to be faster" | |||
| ) | |||
| @@ -124,22 +124,26 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| self._warning_msgs.add(msg) | |||
| def rank_zero_warning(self, msg, *args, **kwargs): | |||
| def rank_zero_warning(self, msg, *args, once=False, **kwargs): | |||
| """ | |||
| 只在 rank 0 上 warning 。 | |||
| :param msg: | |||
| :param args: | |||
| :param once: 是否只 warning 一次 | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | |||
| if msg not in self._warning_msgs: | |||
| if self.isEnabledFor(WARNING): | |||
| # kwargs = self._add_rank_info(kwargs) | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| if once: | |||
| if msg in self._warning_msgs: | |||
| return | |||
| self._warning_msgs.add(msg) | |||
| if self.isEnabledFor(WARNING): | |||
| kwargs = self._add_rank_info(kwargs) | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| def warn(self, msg, *args, **kwargs): | |||
| if self.isEnabledFor(WARNING): | |||
| kwargs = self._add_rank_info(kwargs) | |||
| @@ -304,6 +308,7 @@ def _set_stdout_handler(_logger, stdout='raw', level='INFO'): | |||
| break | |||
| if stream_handler is not None: | |||
| _logger.removeHandler(stream_handler) | |||
| del stream_handler | |||
| # Stream Handler | |||
| if stdout == 'raw': | |||
| @@ -9,9 +9,8 @@ def print(*args, sep=' ', end='\n', file=None, flush=False): | |||
| """ | |||
| 用来重定向 print 函数至 logger.info 的函数。 | |||
| Example: | |||
| Example:: | |||
| from fastNLP import print | |||
| print("This is a test") # 等价于调用了 logger.info("This is a test") | |||
| :param args: 需要打印的内容 | |||
| @@ -21,5 +20,5 @@ def print(*args, sep=' ', end='\n', file=None, flush=False): | |||
| :param flush: 该参数无意义。 | |||
| :return: | |||
| """ | |||
| line = sep.join(args) | |||
| line = sep.join(map(str, args)) | |||
| logger.info(line) | |||
| @@ -8,7 +8,7 @@ from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, Unrepeat | |||
| def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
| """ | |||
| 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
| ReproducibleSampler, | |||
| ReproducibleSampler, | |||
| :param sampler: | |||
| :return: | |||
| @@ -299,7 +299,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
| def total_size(self): | |||
| """ | |||
| 这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、 | |||
| 大于或者小于len(dataset) | |||
| 大于或者小于len(dataset) | |||
| :return: | |||
| """ | |||
| @@ -367,7 +367,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||
| """ | |||
| 首先按照 sample 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,sample 只会在这个桶内进行组合,这样 | |||
| 每个 batch 中的 padding 数量会比较少 (因为桶内的数据的长度都接近)。 | |||
| 每个 batch 中的 padding 数量会比较少 (因为桶内的数据的长度都接近)。 | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
| @@ -440,7 +440,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| def total_size(self): | |||
| """ | |||
| 这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、 | |||
| 大于或者小于len(dataset) | |||
| 大于或者小于len(dataset) | |||
| :return: | |||
| """ | |||
| @@ -19,7 +19,7 @@ class ReproducibleSampler: | |||
| 可复现的 Sampler 对象。 | |||
| 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||
| 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | |||
| 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | |||
| """ | |||
| def __init__(self, **kwargs): | |||
| @@ -87,7 +87,7 @@ class RandomSampler(ReproducibleSampler): | |||
| def __iter__(self): | |||
| r""" | |||
| 当前使用num_consumed_samples做法会在交替使用的时候遇到问题; | |||
| Example: | |||
| Example:: | |||
| >>> sampler = RandomSampler() | |||
| >>> iter1 = iter(sampler) | |||
| >>> iter2 = iter(sampler) | |||
| @@ -99,7 +99,7 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
| def __init__(self, dataset, length:Union[str, List], **kwargs): | |||
| """ | |||
| 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | |||
| batch 数量不完全一致。 | |||
| batch 数量不完全一致。 | |||
| :param dataset: 实现了 __len__ 方法的数据容器。 | |||
| :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
| @@ -35,7 +35,7 @@ class NumConsumedSamplesArray: | |||
| def __init__(self, buffer_size=2000, num_consumed_samples=0): | |||
| """ | |||
| 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | |||
| ex: | |||
| Example:: | |||
| array = NumConsumedSamplesArray(buffer_size=3) | |||
| for i in range(10): | |||
| array.push(i) | |||
| @@ -15,6 +15,7 @@ __all__ = [ | |||
| from fastNLP.core.log.logger import logger | |||
| from fastNLP.core.log.highlighter import ColorHighlighter | |||
| from .utils import _get_fun_msg | |||
| class FuncCallVisitor(ast.NodeVisitor): | |||
| @@ -306,7 +307,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec | |||
| if verbose == 1: | |||
| logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | |||
| if check_hash and old_hash_code != new_hash_code: | |||
| logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | |||
| logger.warning(f"The function {_get_fun_msg(func)} is different from its last cache (Save on {save_time}). The " | |||
| f"difference may caused by the sourcecode change.", | |||
| extra={'highlighter': ColorHighlighter('red')}) | |||
| refresh_flag = False | |||
| @@ -17,7 +17,8 @@ from .utils import apply_to_collection | |||
| class TorchTransferableDataType(ABC): | |||
| """ | |||
| A custom type for data that can be moved to a torch device via `.to(...)`. | |||
| Example: | |||
| Example:: | |||
| >>> isinstance(dict, TorchTransferableDataType) | |||
| False | |||
| >>> isinstance(torch.rand(2, 3), TorchTransferableDataType) | |||
| @@ -52,11 +52,11 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
| mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
| r""" | |||
| 该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping | |||
| 参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||
| 参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||
| 1.该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||
| 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||
| 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
| 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
| 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||
| 4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; | |||
| @@ -68,7 +68,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
| :return: 返回 `fn` 运行的结果; | |||
| Examples: | |||
| Examples:: | |||
| >>> # 1 | |||
| >>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); | |||
| >>> batch = {"x": 20, "y": 1} | |||
| @@ -190,7 +190,7 @@ def _get_fun_msg(fn, with_fp=True)->str: | |||
| def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||
| """ | |||
| 检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会 | |||
| 进行报错。 | |||
| 进行报错。 | |||
| :param fn: 需要检测的函数,可以是 method 或者 function 。 | |||
| :param expected_params: 期待应该支持的参数。 | |||
| @@ -200,29 +200,25 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||
| if fn_name is not None: | |||
| assert callable(fn), f"`{fn_name}` should be callable, instead of `{type(fn)}`." | |||
| parameters = list(inspect.signature(fn).parameters.values()) | |||
| if inspect.ismethod(fn): | |||
| if len(parameters)>0 and parameters[0].name == 'self': | |||
| parameters = parameters[1:] # 去掉self | |||
| no_var_param = True # 没有 * 这种参数 | |||
| number_param_need_value = 0 | |||
| for param in parameters: | |||
| if param.kind is param.VAR_POSITIONAL: | |||
| no_var_param = False | |||
| elif param.kind is param.VAR_KEYWORD: | |||
| no_var_param = False | |||
| else: | |||
| if param.default is param.empty: | |||
| number_param_need_value += 1 | |||
| if len(parameters)<len(expected_params) and no_var_param: | |||
| raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, " | |||
| f"but {len(expected_params)} parameters:{expected_params} will be provided.") | |||
| if number_param_need_value>len(expected_params): | |||
| raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only" | |||
| f" {len(expected_params)} parameters:{expected_params} will be provided.") | |||
| try: | |||
| args = [] | |||
| kwargs = {} | |||
| name = '' | |||
| if isinstance(fn, functools.partial) and not hasattr(fn, '__name__'): | |||
| name = 'partial:' | |||
| f = fn.func | |||
| while isinstance(f, functools.partial): | |||
| name += 'partial:' | |||
| f = f.func | |||
| fn.__name__ = name + f.__name__ | |||
| inspect.getcallargs(fn, *args, *expected_params, **kwargs) | |||
| if name: # 如果一开始没有name的,需要给人家删除掉 | |||
| delattr(fn, '__name__') | |||
| except TypeError as e: | |||
| logger.error(f"The function:{_get_fun_msg(fn)} will be provided with parameters:{expected_params}. " | |||
| f"The following exception will happen.") | |||
| raise e | |||
| def check_user_specific_params(user_params: Dict, fn: Callable): | |||
| @@ -239,7 +235,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable): | |||
| fn_arg_names = get_fn_arg_names(fn) | |||
| for arg_name, arg_value in user_params.items(): | |||
| if arg_name not in fn_arg_names: | |||
| logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||
| logger.rank_zero_warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||
| return user_params | |||
| @@ -20,7 +20,7 @@ def is_cur_env_distributed() -> bool: | |||
| """ | |||
| 单卡模式该函数一定返回 False; | |||
| 注意进程 0 在多卡的训练模式下前后的值是不一样的,例如在开启多卡的 driver 之前,在进程 0 上的该函数返回 False;但是在开启后,在进程 0 上 | |||
| 的该函数返回的值是 True; | |||
| 的该函数返回的值是 True; | |||
| 多卡模式下除了进程 0 外的其它进程返回的值一定是 True; | |||
| """ | |||
| return FASTNLP_GLOBAL_RANK in os.environ | |||
| @@ -34,12 +34,14 @@ def rank_zero_call(fn: Callable): | |||
| """ | |||
| 通过该函数包裹的函数,在单卡模式下该方法不影响任何东西,在多卡状态下仅会在 global rank 为 0 的进程下执行。使用方式有两种 | |||
| # 使用方式1 | |||
| 使用方式1:: | |||
| @rank_zero_call | |||
| def save_model(): | |||
| do_something # will only run in global rank 0 | |||
| # 使用方式2 | |||
| 使用方式2:: | |||
| def add(a, b): | |||
| return a+b | |||
| rank_zero_call(add)(1, 2) | |||
| @@ -103,7 +105,7 @@ def all_rank_call_context(): | |||
| def rank_zero_rm(path: Optional[Union[str, Path]]): | |||
| """ | |||
| 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||
| 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||
| 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||
| 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||
| :param path: | |||
| @@ -223,7 +223,7 @@ class DataBundle: | |||
| def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, | |||
| ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :method:`~fastNLP.DataSet.apply_field` 方法 | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 | |||
| :param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
| :param str field_name: 传入func的是哪个field。 | |||
| @@ -231,8 +231,8 @@ class DataBundle: | |||
| 盖之前的field。如果为None则不创建新的field。 | |||
| :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
| 如果为False,则报错 | |||
| :param ignore_miss_dataset: | |||
| :param num_proc: | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | |||
| :param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
| :param show_progress_bar 是否显示tqdm进度条 | |||
| @@ -251,20 +251,20 @@ class DataBundle: | |||
| def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | |||
| ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法 | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||
| .. note:: | |||
| ``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply`` 区别的介绍。 | |||
| :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param str field_name: 传入func的是哪个field。 | |||
| :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
| 如果为False,则报错 | |||
| :param show_progress_bar: 是否显示tqdm进度条 | |||
| :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
| :param num_proc: | |||
| :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
| @@ -283,19 +283,18 @@ class DataBundle: | |||
| return res | |||
| def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
| progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||
| progress_desc: str = '', show_progress_bar: bool = True): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法 | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||
| 对DataBundle中所有的dataset使用apply方法 | |||
| :param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
| :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
| 盖之前的field。如果为None则不创建新的field。 | |||
| :param _apply_field: | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示tqd进度条 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
| :param num_proc | |||
| """ | |||
| _progress_desc = progress_desc | |||
| @@ -303,23 +302,23 @@ class DataBundle: | |||
| if _progress_desc: | |||
| progress_desc = _progress_desc + f' for `{name}`' | |||
| dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | |||
| progress_desc=progress_desc, _apply_field=_apply_field) | |||
| progress_desc=progress_desc) | |||
| return self | |||
| def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
| progress_desc: str = '', show_progress_bar: bool = True): | |||
| r""" | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法 | |||
| 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||
| .. note:: | |||
| ``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
| ``apply`` 区别的介绍。 | |||
| :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
| :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
| :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
| :param show_progress_bar: 是否显示tqd进度条 | |||
| :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
| :param num_proc | |||
| :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
| """ | |||
| @@ -359,7 +358,8 @@ class DataBundle: | |||
| def set_ignore(self, *field_names) -> "DataBundle": | |||
| """ | |||
| 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
| Ex:: | |||
| Example:: | |||
| collator.set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| @@ -1,4 +1,7 @@ | |||
| r"""undocumented""" | |||
| r""" | |||
| .. todo:: | |||
| doc | |||
| """ | |||
| __all__ = [ | |||
| "ExtCNNDMLoader" | |||
| @@ -19,9 +22,9 @@ class ExtCNNDMLoader(JsonLoader): | |||
| .. csv-table:: | |||
| :header: "text", "summary", "label", "publication" | |||
| ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||
| ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||
| ["..."], ["..."], [], "cnndm" | |||
| "['I got new tires from them and... ','...']", "['The new tires...','...']", "[0, 1]", "cnndm" | |||
| "['Don't waste your time. We had two...','...']", "['Time is precious','...']", "[1]", "cnndm" | |||
| "["..."]", "["..."]", "[]", "cnndm" | |||
| """ | |||
| @@ -87,7 +87,7 @@ class CLSBasePipe(Pipe): | |||
| def process_from_file(self, paths) -> DataBundle: | |||
| r""" | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
| :param paths: | |||
| :return: DataBundle | |||
| @@ -164,7 +164,7 @@ class GraphBuilderBase: | |||
| def build_graph_from_file(self, path: str): | |||
| r""" | |||
| 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
| 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
| :param path: | |||
| :return: scipy_sparse_matrix | |||
| @@ -33,7 +33,7 @@ class Pipe: | |||
| def process_from_file(self, paths: str) -> DataBundle: | |||
| r""" | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()` | |||
| 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
| :param str paths: | |||
| :return: DataBundle | |||
| @@ -53,7 +53,7 @@ class ExtCNNDMPipe(Pipe): | |||
| :param data_bundle: | |||
| :return: 处理得到的数据包括 | |||
| .. csv-table:: | |||
| .. csv-table:: | |||
| :header: "text_wd", "words", "seq_len", "target" | |||
| [["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] | |||
| @@ -40,6 +40,7 @@ class MixModule: | |||
| def named_parameters(self, prefix='', recurse: bool=True, backend=None): | |||
| """ | |||
| 返回模型的名字和参数 | |||
| :param prefix: 输出时在参数名前加上的前缀 | |||
| :param recurse: 是否递归地输出参数 | |||
| :param backend: `backend`=`None`时,将所有模型和张量的参数返回; | |||
| @@ -68,6 +69,7 @@ class MixModule: | |||
| def parameters(self, recurse: bool = True, backend: str = None): | |||
| """ | |||
| 返回模型的参数 | |||
| :param recurse: | |||
| :param backend: `backend`=`None`时,将所有模型和张量的参数返回; | |||
| `backend`=`torch`时,返回`torch`的参数; | |||
| @@ -129,7 +131,9 @@ class MixModule: | |||
| def state_dict(self, backend: str = None) -> Dict: | |||
| """ | |||
| 返回模型的state_dict。 | |||
| NOTE: torch的destination参数会在将来删除,因此不提供destination参数 | |||
| .. note:: torch的destination参数会在将来删除,因此不提供destination参数 | |||
| :param backend: `backend`=`None`时,将所有模型和张量的state dict返回; | |||
| `backend`=`torch`时,返回`torch`的state dict; | |||
| `backend`=`paddle`时,返回`paddle`的state dict。 | |||
| @@ -156,6 +156,7 @@ def _torch2jittor(torch_tensor: 'torch.Tensor', no_gradient: bool = None) -> 'ji | |||
| def torch2paddle(torch_in: Any, target_device: str = None, no_gradient: bool = None) -> Any: | |||
| """ | |||
| 递归地将输入中包含的torch张量转换为paddle张量 | |||
| :param torch_in: 要转换的包含torch.Tensor类型的变量 | |||
| :param target_device: 是否将转换后的张量迁移到特定设备上, | |||
| 输入为`None`时,和输入的张量相同, | |||
| @@ -176,6 +177,7 @@ def torch2paddle(torch_in: Any, target_device: str = None, no_gradient: bool = N | |||
| def paddle2torch(paddle_in: Any, target_device: str = None, no_gradient: bool = None) -> Any: | |||
| """ | |||
| 递归地将输入中包含的paddle张量转换为torch张量 | |||
| :param torch_in: 要转换的包含paddle.Tensor类型的变量 | |||
| :param target_device: 是否将转换后的张量迁移到特定设备上, | |||
| 输入为`None`时,和输入的张量相同, | |||
| @@ -196,6 +198,7 @@ def paddle2torch(paddle_in: Any, target_device: str = None, no_gradient: bool = | |||
| def jittor2torch(jittor_in: Any, target_device: str = None, no_gradient: bool = None) -> Any: | |||
| """ | |||
| 递归地将输入中包含的jittor变量转换为torch张量 | |||
| :param jittor_in: 要转换的jittor变量 | |||
| :param target_device: 是否将转换后的张量迁移到特定设备上,输入为`None`时,默认为cuda:0。 | |||
| :param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
| @@ -215,6 +218,7 @@ def jittor2torch(jittor_in: Any, target_device: str = None, no_gradient: bool = | |||
| def torch2jittor(torch_in: Any, no_gradient: bool = None) -> Any: | |||
| """ | |||
| 递归地将输入中包含的torch张量转换为jittor变量 | |||
| :param torch_tensor: 要转换的torch张量 | |||
| :param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
| 为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
| @@ -5,6 +5,7 @@ import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.collators.collator import Collator | |||
| from ...helpers.utils import Capturing | |||
| def _assert_equal(d1, d2): | |||
| @@ -42,7 +43,6 @@ def findListDiff(d1, d2): | |||
| class TestCollator: | |||
| @pytest.mark.torch | |||
| def test_run(self): | |||
| dict_batch = [{ | |||
| @@ -286,8 +286,83 @@ class TestCollator: | |||
| 'c': [1, 1]}} | |||
| findDictDiff(raw_pad_batch, pad_batch) | |||
| def test_raise(self, capsys): | |||
| from fastNLP.core.log import logger | |||
| logger.set_stdout('raw') | |||
| # 对于 nested 的情况 | |||
| collator = Collator(backend='numpy') | |||
| data = [[1, 2], [2, 3]] | |||
| collator.set_pad('_0') | |||
| collator.set_pad('_0') | |||
| print(collator(data)) | |||
| with Capturing() as out: | |||
| collator.set_ignore('_0') | |||
| assert '_0' in out[0] | |||
| data = [{1: {2: 2, 3: 3}}] | |||
| collator = Collator() | |||
| collator.set_pad((1, 2)) | |||
| collator.set_pad((1, 3)) | |||
| with Capturing() as out: | |||
| collator.set_ignore(1) | |||
| assert '(1, 2)' in out[0] and '(1, 3)' in out[0] | |||
| assert len(collator(data))==0 | |||
| collator = Collator() | |||
| collator.set_ignore((1, 2)) | |||
| with pytest.raises(KeyError): | |||
| collator.set_pad(1) | |||
| collator = Collator() | |||
| collator.set_ignore(1) | |||
| with pytest.raises(KeyError): | |||
| collator.set_pad((1, 2)) | |||
| @pytest.mark.torch | |||
| def test_torch_dl(): | |||
| from fastNLP import TorchDataLoader | |||
| from fastNLP import DataSet | |||
| import numpy as np | |||
| import torch | |||
| ds = DataSet({ | |||
| 'x': [1, 2], 'y': [[1,2], [3]], 'z':[np.ones((1, 2)), np.ones((2, 3))], | |||
| 'i': [{'j': [1, 2]}, {'j': [3]}], 'j': ['a', 'b'] | |||
| }) | |||
| dl = TorchDataLoader(ds, batch_size=2) | |||
| batch = next(iter(dl)) | |||
| assert 'x' in batch and 'y' in batch and 'z' in batch and 'i' in batch and 'j' in batch | |||
| assert isinstance(batch['z'], torch.Tensor) | |||
| assert isinstance(batch['j'], list) | |||
| assert isinstance(batch['i']['j'], torch.Tensor) | |||
| dl.set_ignore('x') | |||
| batch = next(iter(dl)) | |||
| assert 'x' not in batch and 'y' in batch and 'z' in batch | |||
| dl.set_pad('y', pad_val=None) | |||
| batch = next(iter(dl)) | |||
| assert 'x' not in batch and 'y' in batch and 'z' in batch | |||
| assert isinstance(batch['y'], list) | |||
| assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad | |||
| dl.set_pad(('i', 'j'), pad_val=None) | |||
| batch = next(iter(dl)) | |||
| assert 'x' not in batch and 'y' in batch and 'z' in batch | |||
| assert isinstance(batch['y'], list) | |||
| assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad | |||
| assert isinstance(batch['i']['j'], list) | |||
| assert len(batch['i']['j'][0])!=len(batch['i']['j'][1]) # 没有 pad | |||
| with pytest.raises(KeyError): | |||
| dl.set_pad('i', pad_val=None) | |||
| def test_compare_tuple(): | |||
| from fastNLP.core.collators.collator import _compare_tuple | |||
| for t1, t2, t in zip([(1,), (1, 2, 3), (1,), (1, 2)], | |||
| [(1, 2, 3), (1,), (2,), (1, 3)], | |||
| [-2, 2, None, None]): | |||
| assert _compare_tuple(t1, t2) == t | |||
| @@ -1,36 +1,36 @@ | |||
| from fastNLP.core.collators.utils import * | |||
| from fastNLP.core.collators.packer_unpacker import * | |||
| def test_unpack_batch_mapping(): | |||
| batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] | |||
| assert unpack_batch_mapping(batch, {})=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||
| assert MappingPackerUnpacker.unpack_batch(batch, {}, {})=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||
| def test_unpack_batch_nested_mapping(): | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] | |||
| assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c','c'): [1, 2]} | |||
| assert NestedMappingPackerUnpacker.unpack_batch(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c','c'): [1, 2]} | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] | |||
| assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2]} | |||
| assert NestedMappingPackerUnpacker.unpack_batch(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2]} | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, | |||
| {'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] | |||
| assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], | |||
| assert NestedMappingPackerUnpacker.unpack_batch(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], | |||
| ('c','c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} | |||
| def test_pack_batch_nested_mapping(): | |||
| batch = {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], | |||
| ('c', 'c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} | |||
| new_batch = pack_batch_nested_mapping(batch) | |||
| new_batch = NestedMappingPackerUnpacker.pack_batch(batch) | |||
| assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], | |||
| 'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} | |||
| def test_unpack_batch_sequence(): | |||
| batch = [[1, 2, 3], [2, 4, 6]] | |||
| new_batch = unpack_batch_sequence(batch, {}) | |||
| new_batch = SequencePackerUnpacker.unpack_batch(batch, {}, {}) | |||
| assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]} | |||
| @@ -0,0 +1,22 @@ | |||
| import pytest | |||
| from fastNLP import Trainer, Event | |||
| def test_on(): | |||
| with pytest.raises(TypeError): | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(): | |||
| pass | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(*args): | |||
| pass | |||
| with pytest.raises(TypeError): | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(*args, s): | |||
| pass | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(*args, s=2): | |||
| pass | |||
| @@ -210,4 +210,43 @@ def test_trainer_validate_every( | |||
| dist.destroy_process_group() | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) | |||
| @magic_argv_env_context | |||
| def test_trainer_on( | |||
| model_and_optimizers: TrainerParameters, | |||
| driver, | |||
| device, | |||
| n_epochs=2, | |||
| ): | |||
| from fastNLP import Event | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend(trainer, outputs): | |||
| pass | |||
| @Trainer.on(Event.on_before_backward()) | |||
| def before_backend_2(*args): | |||
| pass | |||
| trainer = Trainer( | |||
| model=model_and_optimizers.model, | |||
| driver=driver, | |||
| device=device, | |||
| optimizers=model_and_optimizers.optimizers, | |||
| train_dataloader=model_and_optimizers.train_dataloader, | |||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
| input_mapping=model_and_optimizers.input_mapping, | |||
| output_mapping=model_and_optimizers.output_mapping, | |||
| metrics=model_and_optimizers.metrics, | |||
| n_epochs=n_epochs, | |||
| output_from_new_proc="all", | |||
| evaluate_every=-1 | |||
| ) | |||
| trainer.run() | |||
| @@ -124,9 +124,8 @@ class TestCheckNumberOfParameters: | |||
| # 无默认值,多了报错 | |||
| def validate_every(trainer, other): | |||
| pass | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer']) | |||
| assert "2 parameters" in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 有默认值ok | |||
| @@ -137,19 +136,18 @@ class TestCheckNumberOfParameters: | |||
| # 参数多了 | |||
| def validate_every(trainer): | |||
| pass | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(validate_every, expected_params=['trainer', 'other']) | |||
| assert "accepts 1 parameters" in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 使用partial | |||
| def validate_every(trainer, other): | |||
| pass | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer']) | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError): | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other']) | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more']) | |||
| assert 'accepts 2 parameters' in exc_info.value.args[0] | |||
| print(exc_info.value.args[0]) | |||
| # 如果存在 *args 或 *kwargs 不报错多的 | |||
| @@ -159,7 +157,8 @@ class TestCheckNumberOfParameters: | |||
| def validate_every(trainer, **kwargs): | |||
| pass | |||
| _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
| with pytest.raises(TypeError): | |||
| _check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more']) | |||
| # class 的方法删掉self | |||
| class InnerClass: | |||
| @@ -173,10 +172,8 @@ class TestCheckNumberOfParameters: | |||
| pass | |||
| inner = InnerClass() | |||
| with pytest.raises(RuntimeError) as exc_info: | |||
| with pytest.raises(TypeError) as exc_info: | |||
| _check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more']) | |||
| assert 'accepts 1 parameters' in exc_info.value.args[0] | |||
| _check_valid_parameters_number(inner.demo, expected_params=['trainer']) | |||