| @@ -5,10 +5,12 @@ __all__ = [ | |||||
| 'JittorDataLoader', | 'JittorDataLoader', | ||||
| 'prepare_jittor_dataloader', | 'prepare_jittor_dataloader', | ||||
| 'prepare_paddle_dataloader', | 'prepare_paddle_dataloader', | ||||
| 'prepare_torch_dataloader' | |||||
| 'prepare_torch_dataloader', | |||||
| 'indice_collate_wrapper' | |||||
| ] | ] | ||||
| from .mix_dataloader import MixDataLoader | from .mix_dataloader import MixDataLoader | ||||
| from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | ||||
| from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | ||||
| from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | ||||
| from .utils import indice_collate_wrapper | |||||
| @@ -12,7 +12,7 @@ if _NEED_IMPORT_JITTOR: | |||||
| from jittor.dataset import Dataset | from jittor.dataset import Dataset | ||||
| else: | else: | ||||
| from fastNLP.core.dataset import DataSet as Dataset | from fastNLP.core.dataset import DataSet as Dataset | ||||
| from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | |||||
| from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
| from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
| from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
| @@ -1,3 +1,8 @@ | |||||
| __all__ = [ | |||||
| "indice_collate_wrapper" | |||||
| ] | |||||
| def indice_collate_wrapper(func): | def indice_collate_wrapper(func): | ||||
| """ | """ | ||||
| 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | ||||
| @@ -42,7 +42,6 @@ class TestJittor: | |||||
| jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | ||||
| # jtl.set_pad_val('x', 'y') | # jtl.set_pad_val('x', 'y') | ||||
| # jtl.set_input('x') | # jtl.set_input('x') | ||||
| print(str(jittor.Var([0]))) | |||||
| for batch in jtl: | for batch in jtl: | ||||
| print(batch) | print(batch) | ||||
| print(jtl.get_batch_indices()) | print(jtl.get_batch_indices()) | ||||