| @@ -48,6 +48,7 @@ __all__ = [ | |||||
| 'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
| 'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
| 'prepare_torch_dataloader', | 'prepare_torch_dataloader', | ||||
| "prepare_dataloader", | |||||
| # dataset | # dataset | ||||
| 'DataSet', | 'DataSet', | ||||
| @@ -32,7 +32,7 @@ class CheckpointCallback(Callback): | |||||
| model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 | ||||
| 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 | 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 | ||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | ||||
| :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
| @@ -12,7 +12,7 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
| def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | ||||
| """ | """ | ||||
| :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果。 | 果(字典类型),返回一个 float 值作为 monitor 的结果。 | ||||
| :param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
| @@ -34,7 +34,7 @@ class ResultsMonitor: | |||||
| """ | """ | ||||
| 可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 | 可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 | ||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | ||||
| :param larger_better: monitor 是否时越大越好 | :param larger_better: monitor 是否时越大越好 | ||||
| @@ -171,7 +171,7 @@ class HasMonitorCallback(ResultsMonitor, Callback): | |||||
| 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | ||||
| (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | ||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | ||||
| :param larger_better: monitor 是否时越大越好 | :param larger_better: monitor 是否时越大越好 | ||||
| @@ -209,7 +209,7 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): | |||||
| """ | """ | ||||
| 当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 | 当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 | ||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | ||||
| :param larger_better: monitor 是否时越大越好 | :param larger_better: monitor 是否时越大越好 | ||||
| @@ -21,7 +21,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| """ | """ | ||||
| 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 | ||||
| :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | ||||
| 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | ||||
| :param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
| @@ -37,7 +37,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 | ||||
| :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 | ||||
| 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | ||||
| 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 | |||||
| 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最 | |||||
| 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | ||||
| 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | ||||
| :param watch_monitor_larger_better: watch_monitor 是否越大越好。 | :param watch_monitor_larger_better: watch_monitor 是否越大越好。 | ||||
| @@ -46,7 +46,7 @@ class RichCallback(ProgressCallback): | |||||
| :param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | ||||
| 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
| 完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | ||||
| 相关的 monitor 值请返回 None 。 | 相关的 monitor 值请返回 None 。 | ||||
| :param larger_better: 是否是 monitor 的结果越大越好。 | :param larger_better: 是否是 monitor 的结果越大越好。 | ||||
| @@ -141,7 +141,7 @@ class RawTextCallback(ProgressCallback): | |||||
| :param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
| :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
| :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | ||||
| 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
| 完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
| 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | ||||
| 相关的 monitor 值请返回 None 。 | 相关的 monitor 值请返回 None 。 | ||||
| :param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
| @@ -183,7 +183,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||||
| :param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 | :param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 | ||||
| :param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 | :param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 | ||||
| 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数, | |||||
| 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数, | |||||
| 接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请 | 接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请 | ||||
| 返回 None 。 | 返回 None 。 | ||||
| :param larger_better: 该 monitor 是否越大越好。 | :param larger_better: 该 monitor 是否越大越好。 | ||||
| @@ -6,19 +6,20 @@ from typing import List, Union, Dict, Callable, Sequence, Mapping | |||||
| import os | import os | ||||
| import sys | import sys | ||||
| import inspect | import inspect | ||||
| import re | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from .padders.get_padder import get_padder | from .padders.get_padder import get_padder | ||||
| from ...envs import SUPPORT_BACKENDS | |||||
| import re | |||||
| from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \ | from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \ | ||||
| NestedMappingPackerUnpacker | NestedMappingPackerUnpacker | ||||
| sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | ||||
| SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | ||||
| CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | |||||
| # 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行 | |||||
| AUTO_BACKEND_MAPPING = {'jittor': 'numpy'} | |||||
| def _get_backend() -> str: | def _get_backend() -> str: | ||||
| """ | """ | ||||
| @@ -40,7 +41,7 @@ def _get_backend() -> str: | |||||
| catch_backend = [] | catch_backend = [] | ||||
| try: | try: | ||||
| file = module.__file__ | file = module.__file__ | ||||
| for backend in CHECK_BACKEND: | |||||
| for backend in SUPPORT_BACKENDS: | |||||
| if f'{os.sep}site-packages{os.sep}{backend}' in file: | if f'{os.sep}site-packages{os.sep}{backend}' in file: | ||||
| catch_backend = [backend, file] | catch_backend = [backend, file] | ||||
| except: | except: | ||||
| @@ -62,10 +63,10 @@ def _get_backend() -> str: | |||||
| break | break | ||||
| if len(catch_backend): | if len(catch_backend): | ||||
| logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") | logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") | ||||
| return catch_backend[0] | |||||
| return AUTO_BACKEND_MAPPING.get(catch_backend[0], catch_backend[0]) | |||||
| # 方式 (2) | # 方式 (2) | ||||
| for backend in CHECK_BACKEND: | |||||
| for backend in SUPPORT_BACKENDS: | |||||
| if backend in sys.modules: | if backend in sys.modules: | ||||
| logger.debug(f"sys.modules contains backend:{backend}.") | logger.debug(f"sys.modules contains backend:{backend}.") | ||||
| return backend | return backend | ||||
| @@ -30,7 +30,8 @@ if _NEED_IMPORT_PADDLE: | |||||
| } | } | ||||
| from .padder import Padder | from .padder import Padder | ||||
| from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||||
| from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, is_numpy_generic_class, \ | |||||
| get_padded_numpy_array | |||||
| from .exceptions import * | from .exceptions import * | ||||
| @@ -54,7 +55,6 @@ def is_paddle_dtype_str(dtype): | |||||
| return False | return False | ||||
| def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
| if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | ||||
| raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
| @@ -131,7 +131,7 @@ class PaddleTensorPadder(Padder): | |||||
| def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
| try: | try: | ||||
| if not isinstance(batch_field[0], paddle.Tensor): | if not isinstance(batch_field[0], paddle.Tensor): | ||||
| batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||||
| batch_field = [np.array(field.tolist()) for field in batch_field] | |||||
| else: | else: | ||||
| if dtype is None: | if dtype is None: | ||||
| dtype = batch_field[0].dtype | dtype = batch_field[0].dtype | ||||
| @@ -141,46 +141,14 @@ class PaddleTensorPadder(Padder): | |||||
| shapes = [field.shape for field in batch_field] | shapes = [field.shape for field in batch_field] | ||||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
| tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | |||||
| array = np.full(max_shape, fill_value=pad_val) | |||||
| for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
| slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
| tensor[slices] = field | |||||
| array[slices] = field | |||||
| tensor = paddle.to_tensor(array, dtype=dtype) | |||||
| return tensor | return tensor | ||||
| def fill_tensor(batch_field, padded_batch, dtype): | |||||
| """ | |||||
| 将 batch_field 中的值填入到 tensor 中。 | |||||
| :param batch_field: 需要填充进入 array 中的内容 | |||||
| :param padded_batch: 待填充的 tensor | |||||
| :param dtype: 数据的类别 | |||||
| :return: | |||||
| """ | |||||
| if padded_batch.ndim == 2: | |||||
| for i, content_i in enumerate(batch_field): | |||||
| padded_batch[i, :len(content_i)] = paddle.to_tensor(content_i, dtype=dtype) | |||||
| elif padded_batch.ndim == 3: | |||||
| for i, content_i in enumerate(batch_field): | |||||
| for j, content_ii in enumerate(content_i): | |||||
| padded_batch[i, j, :len(content_ii)] = paddle.to_tensor(content_ii, dtype=dtype) | |||||
| elif padded_batch.ndim == 4: | |||||
| try: # 应该是图像,所以直接应该就 ok 了。 | |||||
| padded_batch = np.array(batch_field) | |||||
| except: | |||||
| for i, content_i in enumerate(batch_field): | |||||
| for j, content_ii in enumerate(content_i): | |||||
| for k, content_iii in enumerate(content_ii): | |||||
| padded_batch[i, j, k, :len(content_iii)] = paddle.to_tensor(content_iii, dtype=dtype) | |||||
| elif padded_batch.ndim == 1: | |||||
| padded_batch[:] = paddle.to_tensor(batch_field, dtype=dtype) | |||||
| else: | |||||
| raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
| "report.") | |||||
| return padded_batch | |||||
| def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): | def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): | ||||
| """ | """ | ||||
| 例如: | 例如: | ||||
| @@ -192,7 +160,6 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): | |||||
| :param pad_val: pad 的 value | :param pad_val: pad 的 value | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| shapes = get_shape(batch_field) | |||||
| tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) | |||||
| tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||||
| array = get_padded_numpy_array(batch_field=batch_field, dtype=None, pad_val=pad_val) | |||||
| tensor = paddle.to_tensor(array, dtype=dtype) | |||||
| return tensor | return tensor | ||||
| @@ -51,23 +51,25 @@ class Evaluator: | |||||
| 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; | 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; | ||||
| :param fp16: 是否使用 fp16 。 | :param fp16: 是否使用 fp16 。 | ||||
| :param verbose: 是否打印 evaluate 的结果。 | :param verbose: 是否打印 evaluate 的结果。 | ||||
| :param \**kwargs: | |||||
| See below | |||||
| :kwargs: | :kwargs: | ||||
| * *model_use_eval_mode* (``bool``) -- | |||||
| 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的 | |||||
| dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 | |||||
| 该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 | |||||
| * *use_dist_sampler* -- | |||||
| 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 | |||||
| 分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 | |||||
| * *output_from_new_proc* -- | |||||
| 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
| ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
| log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
| * *progress_bar* -- | |||||
| evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 | |||||
| 到当前terminal为交互型则使用 rich,否则使用 raw。 | |||||
| * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: | |||||
| * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | |||||
| {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | |||||
| * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
| * *model_use_eval_mode* (``bool``) -- | |||||
| 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的 | |||||
| dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 | |||||
| 该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 | |||||
| * *use_dist_sampler* -- | |||||
| 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 | |||||
| 分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 | |||||
| * *output_from_new_proc* -- | |||||
| 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
| ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
| log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
| * *progress_bar* -- | |||||
| evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 | |||||
| 到当前terminal为交互型则使用 rich,否则使用 raw。 | |||||
| """ | """ | ||||
| self.model = model | self.model = model | ||||
| @@ -159,6 +161,7 @@ class Evaluator: | |||||
| self.reset() | self.reset() | ||||
| self.driver.barrier() | self.driver.barrier() | ||||
| except BaseException as e: | except BaseException as e: | ||||
| self.driver.on_exception() | |||||
| raise e | raise e | ||||
| finally: | finally: | ||||
| self.finally_progress_bar() | self.finally_progress_bar() | ||||
| @@ -67,20 +67,28 @@ class Trainer(TrainerEventTrigger): | |||||
| 要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; | 要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; | ||||
| :param model: 训练所需要的模型,目前支持 pytorch; | :param model: 训练所需要的模型,目前支持 pytorch; | ||||
| :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle 等 | |||||
| 国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 | |||||
| :param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch",],之后我们会加入 jittor、paddle 等 | |||||
| 国产框架的训练模式;其中 "torch" 表示使用 ``TorchSingleDriver`` 或者 ``TorchDDPDriver``,具体使用哪一种取决于参数 ``device`` | |||||
| 的设置; | |||||
| :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | ||||
| :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | ||||
| :param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你 | |||||
| 可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 | |||||
| 可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 | |||||
| 自己构造 DDP 的多进程场景); | |||||
| device 的可选输入如下所示: | |||||
| 1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中; | |||||
| 2. torch.device:将模型装载到torch.device上; | |||||
| 3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`; | |||||
| 4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; | |||||
| 5. None: 为None则不对模型进行任何处理; | |||||
| :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None, | |||||
| 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间 | |||||
| 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据 | |||||
| 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); | |||||
| device 的可选输入如下所示: | |||||
| * *str*: 例如 'cpu', 'cuda', 'cuda:0', 'cuda:1' 等; | |||||
| * *torch.device*: 将模型装载到 ``torch.device`` 上; | |||||
| * *int*: 将使用 ``device_id`` 为该值的 ``gpu`` 进行训练;如果值为 -1,那么默认使用全部的显卡,此时使用的 driver 实例是 `TorchDDPDriver`; | |||||
| * *list(int)*: 如果多于 1 个device,应当通过该种方式进行设定;注意此时我们一定会使用 ``TorchDDPDriver``,不管您传入的列表的长度是 1 还是其它值; | |||||
| * *None*: 为None则不对模型进行任何处理; | |||||
| .. node:: | |||||
| 如果希望使用 ``TorchDDPDriver`` | |||||
| :param n_epochs: 训练总共的 epoch 的数量,默认为 20; | :param n_epochs: 训练总共的 epoch 的数量,默认为 20; | ||||
| :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | ||||
| @@ -117,19 +125,20 @@ class Trainer(TrainerEventTrigger): | |||||
| :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
| :param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
| :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | ||||
| 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
| 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
| 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | ||||
| 如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。 | 如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。 | ||||
| :param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
| :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
| :param kwargs: 一些其它的可能需要的参数,见下方的说明 | :param kwargs: 一些其它的可能需要的参数,见下方的说明 | ||||
| :kwargs: | :kwargs: | ||||
| * *torch_non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
| * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: | |||||
| * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | |||||
| {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | |||||
| * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
| * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
| * *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; | * *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; | ||||
| 注意如果 model_device 为 None,那么 data_device 不会起作用; | |||||
| * *torch_ddp_kwargs* -- 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入 | |||||
| {'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。 | |||||
| * *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
| 注意如果 model_device 为 None,那么 data_device 不会起作用; | |||||
| * *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | * *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | ||||
| 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | ||||
| * *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | * *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | ||||
| @@ -143,6 +152,8 @@ class Trainer(TrainerEventTrigger): | |||||
| * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 | * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 | ||||
| * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 | * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 | ||||
| * *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 | * *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 | ||||
| """ | """ | ||||
| self.model = model | self.model = model | ||||
| self.marker = marker | self.marker = marker | ||||
| @@ -205,8 +216,8 @@ class Trainer(TrainerEventTrigger): | |||||
| callbacks=callbacks, | callbacks=callbacks, | ||||
| metrics=metrics, | metrics=metrics, | ||||
| evaluate_every=evaluate_every, | evaluate_every=evaluate_every, | ||||
| input_mapping=evaluate_input_mapping, | |||||
| output_mapping=evaluate_output_mapping, | |||||
| input_mapping=train_input_mapping, | |||||
| output_mapping=train_output_mapping, | |||||
| model_wo_auto_param_call=model_wo_auto_param_call, | model_wo_auto_param_call=model_wo_auto_param_call, | ||||
| accumulation_steps=accumulation_steps, | accumulation_steps=accumulation_steps, | ||||
| fp16=fp16, | fp16=fp16, | ||||
| @@ -263,8 +274,8 @@ class Trainer(TrainerEventTrigger): | |||||
| progress_bar = progress_bar.name | progress_bar = progress_bar.name | ||||
| self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, | ||||
| driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, | driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, | ||||
| evaluate_fn=evaluate_fn, input_mapping=input_mapping, | |||||
| output_mapping=output_mapping, fp16=fp16, verbose=0, | |||||
| evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||||
| output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||||
| use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), | use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), | ||||
| progress_bar=progress_bar) | progress_bar=progress_bar) | ||||
| @@ -279,7 +290,8 @@ class Trainer(TrainerEventTrigger): | |||||
| self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | ||||
| reproducible=self.callback_manager._need_reproducible_sampler) | reproducible=self.callback_manager._need_reproducible_sampler) | ||||
| self.set_grad_to_none = kwargs.get("set_grad_to_none", True) | |||||
| _torch_kwargs = kwargs.get("torch_kwargs", {}) | |||||
| self.set_grad_to_none = _torch_kwargs.get("set_grad_to_none", True) | |||||
| self.evaluate_batch_step_fn = evaluate_batch_step_fn | self.evaluate_batch_step_fn = evaluate_batch_step_fn | ||||
| self.kwargs = kwargs | self.kwargs = kwargs | ||||
| @@ -360,6 +372,14 @@ class Trainer(TrainerEventTrigger): | |||||
| self.on_exception(e) | self.on_exception(e) | ||||
| if not catch_KeyboardInterrupt: | if not catch_KeyboardInterrupt: | ||||
| raise e | raise e | ||||
| except RuntimeError as e: | |||||
| if 'torch' in self.driver_name.lower(): # 如果是 torch ,需要检测一下 find_unused_parameters | |||||
| if 'find_unused_parameters' in e.args[0]: | |||||
| logger.error("You may need to pass `torch_ddp_kwargs={'find_unused_parameters': True}` in the " | |||||
| "Trainer initialization to avoid this error.") | |||||
| self.driver.on_exception() | |||||
| self.on_exception(e) | |||||
| raise e | |||||
| except BaseException as e: | except BaseException as e: | ||||
| self.driver.on_exception() | self.driver.on_exception() | ||||
| self.on_exception(e) | self.on_exception(e) | ||||
| @@ -5,10 +5,13 @@ __all__ = [ | |||||
| 'JittorDataLoader', | 'JittorDataLoader', | ||||
| 'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
| 'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
| 'prepare_torch_dataloader' | |||||
| 'prepare_torch_dataloader', | |||||
| "prepare_dataloader" | |||||
| ] | ] | ||||
| from .mix_dataloader import MixDataLoader | from .mix_dataloader import MixDataLoader | ||||
| from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | ||||
| from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | ||||
| from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | ||||
| from .prepare_dataloader import prepare_dataloader | |||||
| @@ -4,6 +4,7 @@ __all__ = [ | |||||
| ] | ] | ||||
| from typing import Callable, Optional, List, Union | from typing import Callable, Optional, List, Union | ||||
| from copy import deepcopy | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
| @@ -70,10 +71,12 @@ class JittorDataLoader: | |||||
| if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
| if collate_fn == "auto": | if collate_fn == "auto": | ||||
| if isinstance(self.dataset.dataset, FDataSet): | if isinstance(self.dataset.dataset, FDataSet): | ||||
| self.collate_fn = self.dataset.dataset.collator | |||||
| self.collate_fn.set_backend(backend="jittor") | |||||
| self.collate_fn = deepcopy(self.dataset.dataset.collator) | |||||
| # jittor 比较特殊,只需要保证返回 numpy.array, 其Dataloader会转为jt.var | |||||
| self.collate_fn.set_backend(backend="numpy") | |||||
| else: | else: | ||||
| self.collate_fn = Collator(backend="jittor") | |||||
| # jittor 比较特殊,只需要保证返回 numpy.array, 其Dataloader会转为jt.var | |||||
| self.collate_fn = Collator(backend="numpy") | |||||
| else: | else: | ||||
| raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
| elif isinstance(collate_fn, Callable): | elif isinstance(collate_fn, Callable): | ||||
| @@ -4,6 +4,7 @@ __all__ = [ | |||||
| ] | ] | ||||
| from typing import Callable, List, Optional, Union, Dict, Sequence | from typing import Callable, List, Optional, Union, Dict, Sequence | ||||
| from copy import deepcopy | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
| @@ -68,7 +69,7 @@ class PaddleDataLoader(DataLoader): | |||||
| if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
| if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
| if isinstance(dataset.dataset, FDataSet): | if isinstance(dataset.dataset, FDataSet): | ||||
| collate_fn = dataset.dataset.collator | |||||
| collate_fn = deepcopy(dataset.dataset.collator) | |||||
| collate_fn.set_backend(backend="paddle") | collate_fn.set_backend(backend="paddle") | ||||
| else: | else: | ||||
| collate_fn = Collator(backend="paddle") | collate_fn = Collator(backend="paddle") | ||||
| @@ -0,0 +1,114 @@ | |||||
| __all__ = [ | |||||
| 'prepare_dataloader' | |||||
| ] | |||||
| from typing import Union, Callable | |||||
| import os | |||||
| import sys | |||||
| from ..samplers import RandomBatchSampler, RandomSampler | |||||
| from .torch_dataloader import prepare_torch_dataloader | |||||
| from .paddle_dataloader import prepare_paddle_dataloader | |||||
| from .jittor_dataloader import prepare_jittor_dataloader | |||||
| from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS, _module_available | |||||
| from ..log import logger | |||||
| def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | |||||
| collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | |||||
| seed: int = 0, backend: str = 'auto'): | |||||
| """ | |||||
| 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 | |||||
| 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 | |||||
| :func:`~fastNLP.prepare_torch_dataloader` 或 :func:`~fastNLP.prepare_paddle_dataloader` 等。 | |||||
| :param dataset: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。 | |||||
| * 为单个数据集对象时 | |||||
| 返回一个 DataLoader 。 | |||||
| * 为数据集对象序列时 | |||||
| 返回一个序列的 DataLoader 。 | |||||
| * 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 | |||||
| 返回一个字典 。 | |||||
| :param batch_size: 批次大小。 | |||||
| :param shuffle: 是否打乱数据集。 | |||||
| :param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 | |||||
| :param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: | |||||
| * 为 ``auto`` 时 | |||||
| 使用 :class:`~fastNLP.Collator` 进行 padding 和 转tensor 。 | |||||
| * 为 ``Callable`` 时 | |||||
| 应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 | |||||
| * 为 ``None`` 时 | |||||
| 使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | |||||
| :param num_workers: 使用多少进程进行数据的 fetch 。 | |||||
| :param seed: 使用的随机数种子。 | |||||
| :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 | |||||
| * 为 ``auto`` 时 | |||||
| 首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 ``sys.modules`` 中已经 import 的 | |||||
| ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 ``backend`` ,则按照下述的方式处理。 | |||||
| * 为 ``torch`` 时 | |||||
| 使用 :func:`~fastNLP.prepare_torch_dataloader` 。 | |||||
| * 为 ``paddle`` 时 | |||||
| 使用 :func:`~fastNLP.prepare_paddle_dataloader` 。 | |||||
| * 为 ``jittor`` 时 | |||||
| 使用 :func:`~fastNLP.prepare_jittor_dataloader` 。 | |||||
| :return | |||||
| """ | |||||
| if backend == 'auto': | |||||
| backend = _get_backend() | |||||
| if backend == 'torch': | |||||
| batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, | |||||
| drop_last=drop_last, seed=seed) | |||||
| return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, | |||||
| num_workers=num_workers, shuffle=False, sampler=None) | |||||
| elif backend == 'paddle': | |||||
| batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, | |||||
| drop_last=drop_last, seed=seed) | |||||
| return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, | |||||
| num_workers=num_workers) | |||||
| elif backend == 'jittor': | |||||
| sampler = RandomSampler(dataset=dataset, shuffle=shuffle, seed=seed) | |||||
| prepare_jittor_dataloader(ds_or_db=dataset, sampler=sampler, collate_fn=collate_fn, | |||||
| num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, | |||||
| drop_last=drop_last) | |||||
| else: | |||||
| raise ValueError(f"Currently we do not support backend:{backend}.") | |||||
| def _check_module(module): | |||||
| """ | |||||
| 检查该 module 是否含有 某个 backend 的特征 | |||||
| :param module: module 对象 | |||||
| :return: | |||||
| """ | |||||
| try: | |||||
| file = module.__file__ | |||||
| for backend in SUPPORT_BACKENDS: | |||||
| if f'{os.sep}site-packages{os.sep}{backend}' in file: | |||||
| return backend | |||||
| except: | |||||
| pass | |||||
| return None | |||||
| def _get_backend(): | |||||
| if os.environ.get(FASTNLP_BACKEND, None) != None: | |||||
| backend = os.environ.get(FASTNLP_BACKEND) | |||||
| logger.debug(f"Get Dataloader backend:{backend} from os.environ") | |||||
| else: | |||||
| available_backends = set() | |||||
| for module in sys.modules.values(): | |||||
| _backend = _check_module(module) | |||||
| if _backend: | |||||
| available_backends.add(_backend) | |||||
| if len(available_backends) == 1: | |||||
| backend = available_backends.pop() | |||||
| logger.debug(f"Get Dataloader backend:{backend} from sys.modules.") | |||||
| else: | |||||
| raise RuntimeError("Fail to detect dataloader backend automatically, please set it manually.") | |||||
| return backend | |||||
| @@ -4,7 +4,7 @@ __all__ = [ | |||||
| ] | ] | ||||
| from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | ||||
| import inspect | |||||
| from copy import deepcopy | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
| @@ -84,7 +84,7 @@ class TorchDataLoader(DataLoader): | |||||
| if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
| if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
| if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | ||||
| collate_fn = dataset.dataset.collator | |||||
| collate_fn = deepcopy(dataset.dataset.collator) | |||||
| collate_fn.set_backend(backend="torch") | collate_fn.set_backend(backend="torch") | ||||
| else: | else: | ||||
| collate_fn = Collator(backend="torch") | collate_fn = Collator(backend="torch") | ||||
| @@ -250,26 +250,15 @@ def prepare_torch_dataloader(ds_or_db, | |||||
| elif isinstance(ds_or_db, Sequence): | elif isinstance(ds_or_db, Sequence): | ||||
| dl_bundle = [] | dl_bundle = [] | ||||
| for idx, ds in enumerate(ds_or_db): | for idx, ds in enumerate(ds_or_db): | ||||
| if idx == 0: | |||||
| dl_bundle.append( | |||||
| TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
| shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
| drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| dl_bundle.append( | |||||
| TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
| shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
| drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
| ) | |||||
| ) | |||||
| dl_bundle.append( | |||||
| TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
| shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
| num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
| drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||||
| prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
| ) | |||||
| ) | |||||
| return dl_bundle | return dl_bundle | ||||
| elif isinstance(ds_or_db, Mapping): | elif isinstance(ds_or_db, Mapping): | ||||
| @@ -17,7 +17,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
| if isinstance(driver, Driver): | if isinstance(driver, Driver): | ||||
| return driver | return driver | ||||
| if driver in {"torch", "torch_ddp", "fairscale"}: | |||||
| if driver in {"torch", "fairscale"}: | |||||
| from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
| return initialize_torch_driver(driver, device, model, **kwargs) | return initialize_torch_driver(driver, device, model, **kwargs) | ||||
| elif driver in {"jittor"}: | elif driver in {"jittor"}: | ||||
| @@ -27,5 +27,5 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||||
| from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | ||||
| return initialize_paddle_driver(driver, device, model, **kwargs) | return initialize_paddle_driver(driver, device, model, **kwargs) | ||||
| else: | else: | ||||
| raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " | |||||
| raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', " | |||||
| "'jittor', 'paddle', 'fleet'].") | "'jittor', 'paddle', 'fleet'].") | ||||
| @@ -285,7 +285,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) | self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) | ||||
| self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) | self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) | ||||
| reset_seed() | reset_seed() | ||||
| logger.info(f"\nworld size, global rank: {self.world_size}, {self.global_rank}\n") | |||||
| logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
| if not parallel_helper._is_parallel_ctx_initialized(): | if not parallel_helper._is_parallel_ctx_initialized(): | ||||
| fleet.init(self.role_maker, self.is_collective, self.strategy) | fleet.init(self.role_maker, self.is_collective, self.strategy) | ||||
| @@ -220,7 +220,7 @@ class TorchDDPDriver(TorchDriver): | |||||
| self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | ||||
| self.global_rank = 0 | self.global_rank = 0 | ||||
| self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {}) | |||||
| self._ddp_kwargs = self._torch_kwargs.get("ddp_kwargs", {}) | |||||
| check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) | check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) | ||||
| if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | ||||
| logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set " | logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set " | ||||
| @@ -251,7 +251,7 @@ class TorchDDPDriver(TorchDriver): | |||||
| self.world_size = int(os.environ.get("WORLD_SIZE")) | self.world_size = int(os.environ.get("WORLD_SIZE")) | ||||
| self.global_rank = int(os.environ.get("RANK")) | self.global_rank = int(os.environ.get("RANK")) | ||||
| reset_seed() | reset_seed() | ||||
| logger.info(f"World size:{self.world_size}, Global rank:{self.global_rank}") | |||||
| logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
| if not dist.is_initialized(): | if not dist.is_initialized(): | ||||
| dist.init_process_group( | dist.init_process_group( | ||||
| @@ -32,7 +32,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| "`os.environ['LOCAL_RANK']`.") | "`os.environ['LOCAL_RANK']`.") | ||||
| return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
| if driver not in {"torch", "torch_ddp", "fairscale"}: | |||||
| if driver not in {"torch", "fairscale"}: | |||||
| raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale'].") | raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale'].") | ||||
| _could_use_device_num = torch.cuda.device_count() | _could_use_device_num = torch.cuda.device_count() | ||||
| @@ -61,22 +61,9 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| elif device is not None and not isinstance(device, torch.device): | elif device is not None and not isinstance(device, torch.device): | ||||
| raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
| if driver == "torch": | |||||
| if driver == "torch": # single, ddp, 直接启动。 | |||||
| if not isinstance(device, List): | if not isinstance(device, List): | ||||
| return TorchSingleDriver(model, device, **kwargs) | return TorchSingleDriver(model, device, **kwargs) | ||||
| else: | |||||
| logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||||
| "`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | |||||
| "`driver` as `TorchDDPDriver`.") | |||||
| return TorchDDPDriver(model, device, **kwargs) | |||||
| elif driver == "torch_ddp": | |||||
| if device is not None and not isinstance(device, List): | |||||
| if device.type == 'cpu': | |||||
| raise ValueError("You are using `torch_ddp` driver, but your chosen `device` is 'cpu'.") | |||||
| logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " | |||||
| "still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " | |||||
| "choose `torch` driver.") | |||||
| return TorchDDPDriver(model, [device], **kwargs) | |||||
| else: | else: | ||||
| return TorchDDPDriver(model, device, **kwargs) | return TorchDDPDriver(model, device, **kwargs) | ||||
| elif driver == "fairscale": | elif driver == "fairscale": | ||||
| @@ -49,8 +49,9 @@ class TorchDriver(Driver): | |||||
| self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | ||||
| self.grad_scaler = _grad_scaler() | self.grad_scaler = _grad_scaler() | ||||
| self._torch_kwargs = kwargs.get("torch_kwargs", {}) | |||||
| # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | ||||
| self.non_blocking = kwargs.get("torch_non_blocking", True) | |||||
| self.non_blocking = self._torch_kwargs.get("torch_non_blocking", True) | |||||
| # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
| self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
| @@ -22,6 +22,8 @@ import numpy as np | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from ...envs import SUPPORT_BACKENDS | |||||
| __all__ = [ | __all__ = [ | ||||
| 'get_fn_arg_names', | 'get_fn_arg_names', | ||||
| @@ -13,11 +13,13 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, | |||||
| from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | ||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| @dataclass | @dataclass | ||||
| class NormalClassificationTrainTorchConfig: | class NormalClassificationTrainTorchConfig: | ||||
| num_labels: int = 2 | num_labels: int = 2 | ||||
| @@ -101,7 +103,8 @@ def model_and_optimizers(request): | |||||
| # 测试一下普通的情况; | # 测试一下普通的情况; | ||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), | |||||
| ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
| @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
| @@ -173,6 +176,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
| if dist.is_initialized(): | if dist.is_initialized(): | ||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -182,7 +186,6 @@ def test_trainer_validate_every( | |||||
| device, | device, | ||||
| n_epochs=6, | n_epochs=6, | ||||
| ): | ): | ||||
| def validate_every(trainer): | def validate_every(trainer): | ||||
| if trainer.global_forward_batches % 10 == 0: | if trainer.global_forward_batches % 10 == 0: | ||||
| print("\nfastNLP test validate every.\n") | print("\nfastNLP test validate every.\n") | ||||
| @@ -234,7 +237,7 @@ def test_trainer_on( | |||||
| device=device, | device=device, | ||||
| optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
| train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
| evaluate_dataloaders={"dl":model_and_optimizers.evaluate_dataloaders}, | |||||
| evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, | |||||
| input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
| output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
| metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
| @@ -243,10 +246,94 @@ def test_trainer_on( | |||||
| evaluate_every=-1 | evaluate_every=-1 | ||||
| ) | ) | ||||
| trainer.run() | |||||
| trainer.run() | |||||
| @pytest.mark.torch | |||||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 0)]) # ("torch", [0, 1]),("torch", 1) | |||||
| @magic_argv_env_context | |||||
| def test_trainer_specific_params_1( | |||||
| model_and_optimizers: TrainerParameters, | |||||
| driver, | |||||
| device, | |||||
| n_epochs=2, | |||||
| ): | |||||
| """ | |||||
| 测试一些特殊的参数是否能够正确地传递; | |||||
| """ | |||||
| trainer = Trainer( | |||||
| model=model_and_optimizers.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=model_and_optimizers.optimizers, | |||||
| train_dataloader=model_and_optimizers.train_dataloader, | |||||
| evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, | |||||
| input_mapping=model_and_optimizers.input_mapping, | |||||
| output_mapping=model_and_optimizers.output_mapping, | |||||
| metrics=model_and_optimizers.metrics, | |||||
| n_epochs=n_epochs, | |||||
| output_from_new_proc="all", | |||||
| evaluate_every=-1, | |||||
| model_wo_auto_param_call=True, | |||||
| torch_kwargs={ | |||||
| "torch_non_blocking": False, | |||||
| "set_grad_to_none": True | |||||
| } | |||||
| ) | |||||
| assert trainer.set_grad_to_none is True | |||||
| assert trainer.driver.non_blocking is False | |||||
| assert trainer.driver.wo_auto_param_call is True | |||||
| @pytest.mark.torch | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||||
| @magic_argv_env_context | |||||
| def test_trainer_specific_params_2( | |||||
| model_and_optimizers: TrainerParameters, | |||||
| driver, | |||||
| device, | |||||
| n_epochs=2, | |||||
| ): | |||||
| """ | |||||
| 测试一些特殊的参数是否能够正确地传递; | |||||
| """ | |||||
| trainer = Trainer( | |||||
| model=model_and_optimizers.model, | |||||
| driver=driver, | |||||
| device=device, | |||||
| optimizers=model_and_optimizers.optimizers, | |||||
| train_dataloader=model_and_optimizers.train_dataloader, | |||||
| evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders}, | |||||
| input_mapping=model_and_optimizers.input_mapping, | |||||
| output_mapping=model_and_optimizers.output_mapping, | |||||
| metrics=model_and_optimizers.metrics, | |||||
| n_epochs=n_epochs, | |||||
| output_from_new_proc="all", | |||||
| evaluate_every=-1, | |||||
| model_wo_auto_param_call=True, | |||||
| torch_kwargs={ | |||||
| "ddp_kwargs": { | |||||
| "broadcast_buffers": True, | |||||
| "find_unused_parameters": True | |||||
| }, | |||||
| "torch_non_blocking": False, | |||||
| "set_grad_to_none": True | |||||
| } | |||||
| ) | |||||
| assert trainer.set_grad_to_none is True | |||||
| assert trainer.driver.non_blocking is False | |||||
| assert trainer.driver.wo_auto_param_call is True | |||||
| assert trainer.driver.output_from_new_proc == "all" | |||||
| _ddp_kwargs = trainer.driver._ddp_kwargs | |||||
| assert _ddp_kwargs.get("broadcast_buffers") is True | |||||
| assert _ddp_kwargs.get("find_unused_parameters") is True | |||||
| @@ -0,0 +1,13 @@ | |||||
| import pytest | |||||
| from fastNLP import prepare_dataloader | |||||
| from fastNLP import DataSet | |||||
| @pytest.mark.torch | |||||
| def test_torch(): | |||||
| import torch | |||||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
| dl = prepare_dataloader(ds, batch_size=2, shuffle=True) | |||||
| for batch in dl: | |||||
| assert isinstance(batch['x'], torch.Tensor) | |||||