| @@ -47,6 +47,7 @@ from fastNLP.core.collators.collator import Collator | |||
| from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
| from fastNLP.core.dataset import DataSet as FDataSet | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | |||
| from ..utils import _match_param | |||
| class _PaddleDataset(Dataset): | |||
| @@ -154,14 +155,17 @@ class PaddleDataLoader(DataLoader): | |||
| else: | |||
| raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
| super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
| return_list=return_list, batch_sampler=batch_sampler, | |||
| batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
| collate_fn=collate_fn, num_workers=num_workers, | |||
| use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
| timeout=timeout, worker_init_fn=worker_init_fn, | |||
| persistent_workers=persistent_workers) | |||
| dl_kwargs = _match_param(PaddleDataLoader.__init__, DataLoader.__init__, DataLoader.__name__) | |||
| if dl_kwargs is None: | |||
| super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
| return_list=return_list, batch_sampler=batch_sampler, | |||
| batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
| collate_fn=collate_fn, num_workers=num_workers, | |||
| use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
| timeout=timeout, worker_init_fn=worker_init_fn, | |||
| persistent_workers=persistent_workers) | |||
| else: | |||
| super().__init__(**dl_kwargs) | |||
| # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | |||
| # if collate_fn is not None: | |||
| # _collate_fn.add_collator(collate_fn) | |||
| @@ -11,6 +11,7 @@ from fastNLP.core.collators import Collator | |||
| from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||
| from ..utils import _match_param | |||
| if _NEED_IMPORT_TORCH: | |||
| from torch.utils.data import DataLoader, Sampler | |||
| @@ -96,12 +97,16 @@ class TorchDataLoader(DataLoader): | |||
| else: | |||
| raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
| super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
| batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
| pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers) | |||
| dl_kwargs = _match_param(TorchDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__) | |||
| if dl_kwargs is None: | |||
| super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
| batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
| pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
| multiprocessing_context=multiprocessing_context, generator=generator, | |||
| prefetch_factor=prefetch_factor, | |||
| persistent_workers=persistent_workers) | |||
| else: | |||
| super().__init__(**dl_kwargs) | |||
| self.cur_batch_indices = None | |||
| @@ -1,4 +1,9 @@ | |||
| from typing import Callable | |||
| import inspect | |||
| import ast | |||
| from ..log import logger | |||
| from ..utils.cache_results import get_func_calls, truncate_start_blanks | |||
| __all__ = [ | |||
| "indice_collate_wrapper" | |||
| ] | |||
| @@ -25,6 +30,72 @@ def indice_collate_wrapper(func:Callable): | |||
| return _indice_collate_wrapper | |||
| def _match_param(fun, call_fn:Callable, fn_name:str=None): | |||
| """ | |||
| 在调用 _match_param 的函数(就是 fun )中会调用 call_fn 这个函数。由于 fun 中支持的函数比 call_fn 更多,例如低版本的 | |||
| :class:`~.fastNLP.TorchDataLoader` 中支持的参数,在torch 1.6 版本的 DataLoader 就不支持,但在高版本的 torch 中是支持的 | |||
| 因此,这里需要根据当前版本的 DataLoader 判定出适合传入 DataLoader 进行初始化的参数,并且在不支持但又被设置的参数上进行 | |||
| warning 。 | |||
| :param fun: 调用函数本身 | |||
| :param call_fn: | |||
| :param fn_name: 方便报错的用的函数 | |||
| :return: | |||
| """ | |||
| try: | |||
| if fn_name is None: | |||
| try: | |||
| fn_name = call_fn.__name__ | |||
| except: | |||
| fn_name = str(call_fn) | |||
| last_frame = inspect.currentframe().f_back | |||
| # 调用 _match_param 的函数名称,获取默认的参数值 | |||
| fun_default_params = {} | |||
| fun_parameters = inspect.signature(fun) | |||
| for name, fun_param in fun_parameters.parameters.items(): | |||
| if fun_param.default is not fun_param.empty: | |||
| fun_default_params[name] = fun_param.default | |||
| # 获取实际传入的参数值 | |||
| param_names, args_name, kwargs_name, values = inspect.getargvalues(last_frame) | |||
| if args_name is not None: | |||
| raise RuntimeError("Function does not support positional arguments, such as: fun(*args).") | |||
| kwargs = values.get(kwargs_name, {}) | |||
| for param in param_names: | |||
| if param not in values: | |||
| value = fun_default_params.get(param) | |||
| else: | |||
| value = values[param] | |||
| kwargs[param] = value | |||
| # 根据需要实际需要调用的 call_fn 的参数进行匹配 | |||
| call_fn_parameters = inspect.signature(call_fn) | |||
| call_fn_kwargs = {} | |||
| has_kwargs = False | |||
| for name, param in call_fn_parameters.parameters.items(): | |||
| if name == 'self': | |||
| continue | |||
| if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY): # 最前面的 args | |||
| call_fn_kwargs[name] = param.default | |||
| if param.kind == param.VAR_KEYWORD: | |||
| has_kwargs = True | |||
| # 组装得到最终的参数 | |||
| call_kwargs = {} | |||
| for name, value in kwargs.items(): | |||
| if name in call_fn_kwargs or has_kwargs: # 如果存在在里面,或者包含了 kwargs 就直接运行 | |||
| call_kwargs[name] = value | |||
| # 如果不在需要调用的函数里面,同时又是非默认值 | |||
| elif name not in call_fn_kwargs and name in fun_default_params and fun_default_params[name]!=value: | |||
| logger.rank_zero_warning(f"Parameter:{name} is not supported for {fn_name}.") | |||
| return call_kwargs | |||
| except BaseException as e: | |||
| logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") | |||
| return None | |||
| if __name__ == '__main__': | |||
| def demo(*args, **kwargs): | |||
| pass | |||
| @@ -0,0 +1,39 @@ | |||
| import pytest | |||
| from fastNLP.core.dataloaders.utils import _match_param | |||
| from fastNLP import logger | |||
| from tests.helpers.utils import recover_logger, Capturing | |||
| def demo(): | |||
| pass | |||
| def test_no_args(): | |||
| def f(*args, a, b, **kwarg): | |||
| c = 100 | |||
| call_kwargs = _match_param(f, demo) | |||
| with pytest.raises(RuntimeError): | |||
| f(a=1, b=2) | |||
| def f(a, *args, b, **kwarg): | |||
| c = 100 | |||
| call_kwargs = _match_param(f, demo) | |||
| with pytest.raises(RuntimeError): | |||
| f(a=1, b=2) | |||
| @recover_logger | |||
| def test_warning(): | |||
| logger.set_stdout('raw') | |||
| def f1(a, b): | |||
| return 1 | |||
| def f2(a, b, c=2): | |||
| kwargs = _match_param(f2, f1) | |||
| return f1(*kwargs) | |||
| with Capturing() as out: | |||
| f2(a=1, b=2, c=3) | |||
| assert 'Parameter:c' in out[0] # 传入了需要 warning | |||
| assert f2(1, 2) == 1 | |||
| @@ -5,6 +5,9 @@ from fastNLP.core.dataset import DataSet | |||
| from fastNLP.io.data_bundle import DataBundle | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.core import Trainer | |||
| from pkg_resources import parse_version | |||
| from tests.helpers.utils import Capturing, recover_logger | |||
| from fastNLP import logger | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| @@ -128,3 +131,33 @@ class TestFdl: | |||
| dl = DataLoader(MyDatset(), collate_fn=collate_batch) | |||
| for batch in dl: | |||
| print(batch) | |||
| @recover_logger | |||
| def test_version_16(self): | |||
| if parse_version(torch.__version__) >= parse_version('1.7'): | |||
| pytest.skip("Torch version larger than 1.7") | |||
| logger.set_stdout() | |||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| with Capturing() as out: | |||
| dl = TorchDataLoader(ds, prefetch_factor=3, shuffle=False) | |||
| for idx, batch in enumerate(dl): | |||
| assert len(batch['x'])==1 | |||
| assert batch['x'][0].tolist() == ds[idx]['x'] | |||
| assert 'Parameter:prefetch_factor' in out[0] | |||
| @recover_logger | |||
| def test_version_111(self): | |||
| if parse_version(torch.__version__) <= parse_version('1.7'): | |||
| pytest.skip("Torch version smaller than 1.7") | |||
| logger.set_stdout() | |||
| ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| with Capturing() as out: | |||
| dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False) | |||
| for idx, batch in enumerate(dl): | |||
| assert len(batch['x'])==1 | |||
| assert batch['x'][0].tolist() == ds[idx]['x'] | |||
| assert 'Parameter:prefetch_factor' not in out[0] | |||