diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index 93d28ca7..80fdb8d8 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -57,6 +57,9 @@ class MsDataset: def __getitem__(self, key): return self._hf_ds[key] + def __len__(self): + return len(self._hf_ds) + @classmethod def from_hf_dataset(cls, hf_ds: Union[Dataset, DatasetDict], @@ -223,6 +226,7 @@ class MsDataset: retained_columns.append(k) import torch + import math class MsIterableDataset(torch.utils.data.IterableDataset): @@ -230,8 +234,23 @@ class MsDataset: super(MsIterableDataset).__init__() self.dataset = dataset + def __len__(self): + return len(self.dataset) + def __iter__(self): - for item_dict in self.dataset: + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: # single-process data loading + iter_start = 0 + iter_end = len(self.dataset) + else: # in a worker process + per_worker = math.ceil( + len(self.dataset) / float(worker_info.num_workers)) + worker_id = worker_info.id + iter_start = worker_id * per_worker + iter_end = min(iter_start + per_worker, len(self.dataset)) + + for idx in range(iter_start, iter_end): + item_dict = self.dataset[idx] res = { k: np.array(item_dict[k]) for k in columns if k in retained_columns @@ -273,7 +292,8 @@ class MsDataset: 'The function to_torch_dataset requires pytorch to be installed' ) if preprocessors is not None: - return self.to_torch_dataset_with_processors(preprocessors) + return self.to_torch_dataset_with_processors( + preprocessors, columns=columns) else: self._hf_ds.reset_format() self._hf_ds.set_format( diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 6a178104..6a08ffa7 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -21,7 +21,6 @@ from modelscope.models.base_torch import TorchModel from modelscope.msdatasets.ms_dataset import MsDataset from modelscope.preprocessors import build_preprocessor from modelscope.preprocessors.base import Preprocessor -from modelscope.task_datasets import TorchTaskDataset, build_task_dataset from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.priority import Priority, get_priority from modelscope.trainers.lrscheduler.builder import build_lr_scheduler @@ -49,7 +48,7 @@ class EpochBasedTrainer(BaseTrainer): or a model id. If model is None, build_model method will be called. data_collator (`Callable`, *optional*): The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. - train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): + train_dataset (`MsDataset`, *optional*): The dataset to use for training. Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a @@ -117,10 +116,10 @@ class EpochBasedTrainer(BaseTrainer): # TODO how to fill device option? self.device = int( os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else None - self.train_dataset = self.to_task_dataset( - train_dataset, mode=ModeKeys.TRAIN, preprocessor=self.preprocessor) - self.eval_dataset = self.to_task_dataset( - eval_dataset, mode=ModeKeys.EVAL, preprocessor=self.preprocessor) + self.train_dataset = train_dataset.to_torch_dataset( + preprocessors=self.preprocessor) if train_dataset else None + self.eval_dataset = eval_dataset.to_torch_dataset( + preprocessors=self.preprocessor) if eval_dataset else None self.data_collator = data_collator if data_collator is not None else torch_default_data_collator self.metrics = self.get_metrics() self.optimizers = optimizers @@ -179,38 +178,6 @@ class EpochBasedTrainer(BaseTrainer): """int: Maximum training iterations.""" return self._max_epochs * len(self.data_loader) - def to_task_dataset(self, - datasets: Tuple[Dataset, List[Dataset]], - mode: str, - preprocessor: Optional[Preprocessor] = None): - """Build the task specific dataset processor for this trainer. - - Returns: The task dataset processor for the task. If no result for the very model-type and task, - the default TaskDataset will be returned. - """ - try: - if not datasets: - return datasets - if isinstance(datasets, TorchTaskDataset): - return datasets - task_dataset = build_task_dataset( - ConfigDict({ - **self.cfg.model, - 'mode': mode, - 'preprocessor': preprocessor, - 'datasets': datasets, - }), getattr(self.cfg, 'task', None)) - return task_dataset - except Exception: - if isinstance(datasets, (List, Tuple)) or preprocessor is not None: - return TorchTaskDataset( - datasets, - mode=mode, - preprocessor=preprocessor, - **(self.cfg.model if hasattr(self.cfg, 'model') else {})) - else: - return datasets - def build_preprocessor(self) -> Preprocessor: """Build the preprocessor. @@ -448,9 +415,7 @@ class EpochBasedTrainer(BaseTrainer): ) torch_dataset = dataset.to_torch_dataset( preprocessors=self.preprocessor, ) - dataset = self.to_task_dataset( - torch_dataset, mode, preprocessor=self.preprocessor) - return dataset + return torch_dataset def create_optimizer_and_scheduler(self): """ Create optimizer and lr scheduler diff --git a/modelscope/utils/tensor_utils.py b/modelscope/utils/tensor_utils.py index a80ca6cd..93041425 100644 --- a/modelscope/utils/tensor_utils.py +++ b/modelscope/utils/tensor_utils.py @@ -60,6 +60,8 @@ def torch_default_data_collator(features): ) and v is not None and not isinstance(v, str): if isinstance(v, torch.Tensor): batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, list): + batch[k] = torch.stack([d for f in features for d in f[k]]) else: batch[k] = torch.tensor([f[k] for f in features]) elif isinstance(first, tuple): diff --git a/modelscope/utils/test_utils.py b/modelscope/utils/test_utils.py index 95e63dba..0ca58c4e 100644 --- a/modelscope/utils/test_utils.py +++ b/modelscope/utils/test_utils.py @@ -4,8 +4,12 @@ import os import unittest +import numpy as np +from datasets import Dataset from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE +from modelscope.msdatasets import MsDataset + TEST_LEVEL = 2 TEST_LEVEL_STR = 'TEST_LEVEL' @@ -33,3 +37,8 @@ def require_torch(test_case): def set_test_level(level: int): global TEST_LEVEL TEST_LEVEL = level + + +def create_dummy_test_dataset(feat, label, num): + return MsDataset.from_hf_dataset( + Dataset.from_dict(dict(feat=[feat] * num, label=[label] * num))) diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index 50767fd8..08a05e9c 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -34,20 +34,15 @@ class MsDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ds_basic(self): - ms_ds_full = MsDataset.load('squad', namespace='damotest') - ms_ds_full_hf = hfdata.load_dataset('squad') - ms_ds_train = MsDataset.load( - 'squad', namespace='damotest', split='train') - ms_ds_train_hf = hfdata.load_dataset('squad', split='train') - ms_image_train = MsDataset.from_hf_dataset( - hfdata.load_dataset('beans', split='train')) - self.assertEqual(ms_ds_full['train'][0], ms_ds_full_hf['train'][0]) - self.assertEqual(ms_ds_full['validation'][0], - ms_ds_full_hf['validation'][0]) - self.assertEqual(ms_ds_train[0], ms_ds_train_hf[0]) - print(next(iter(ms_ds_full['train']))) - print(next(iter(ms_ds_train))) - print(next(iter(ms_image_train))) + ms_ds_full = MsDataset.load( + 'xcopa', subset_name='translation-et', namespace='damotest') + ms_ds = MsDataset.load( + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test') + print(next(iter(ms_ds_full['test']))) + print(next(iter(ms_ds))) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @require_torch @@ -56,10 +51,13 @@ class MsDatasetTest(unittest.TestCase): nlp_model = Model.from_pretrained(model_id) preprocessor = SequenceClassificationPreprocessor( nlp_model.model_dir, - first_sequence='context', + first_sequence='premise', second_sequence=None) ms_ds_train = MsDataset.load( - 'squad', namespace='damotest', split='train') + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test') pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor) import torch dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) @@ -74,10 +72,13 @@ class MsDatasetTest(unittest.TestCase): nlp_model = Model.from_pretrained(model_id) preprocessor = SequenceClassificationPreprocessor( nlp_model.model_dir, - first_sequence='context', + first_sequence='premise', second_sequence=None) ms_ds_train = MsDataset.load( - 'squad', namespace='damotest', split='train') + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test') tf_dataset = ms_ds_train.to_tf_dataset( batch_size=5, shuffle=True, @@ -89,10 +90,9 @@ class MsDatasetTest(unittest.TestCase): @require_torch def test_to_torch_dataset_img(self): ms_image_train = MsDataset.load( - 'beans', namespace='damotest', split='train') + 'fixtures_image_utils', namespace='damotest', split='test') pt_dataset = ms_image_train.to_torch_dataset( - preprocessors=ImgPreprocessor( - image_path='image_file_path', label='labels')) + preprocessors=ImgPreprocessor(image_path='file')) import torch dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) print(next(iter(dataloader))) @@ -103,13 +103,13 @@ class MsDatasetTest(unittest.TestCase): import tensorflow as tf tf.compat.v1.enable_eager_execution() ms_image_train = MsDataset.load( - 'beans', namespace='damotest', split='train') + 'fixtures_image_utils', namespace='damotest', split='test') tf_dataset = ms_image_train.to_tf_dataset( batch_size=5, shuffle=True, - preprocessors=ImgPreprocessor(image_path='image_file_path'), + preprocessors=ImgPreprocessor(image_path='file'), drop_remainder=True, - label_cols='labels') + ) print(next(iter(tf_dataset))) diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 538041a5..6f13dce7 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -64,10 +64,13 @@ class ImageMattingTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_modelscope_dataset(self): dataset = MsDataset.load( - 'beans', namespace='damotest', split='train', target='image') + 'fixtures_image_utils', + namespace='damotest', + split='test', + target='file') img_matting = pipeline(Tasks.image_matting, model=self.model_id) result = img_matting(dataset) - for i in range(10): + for i in range(2): cv2.imwrite(f'result_{i}.png', next(result)[OutputKeys.OUTPUT_IMG]) print( f'Output written to dir: {osp.dirname(osp.abspath("result_0.png"))}' diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 1bf9f7ca..cacb09e7 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -51,11 +51,11 @@ class SequenceClassificationTest(unittest.TestCase): task=Tasks.text_classification, model=self.model_id) result = text_classification( MsDataset.load( - 'glue', - subset_name='sst2', - split='train', - target='sentence', - hub=Hubs.huggingface)) + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test', + target='premise')) self.printDataset(result) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') @@ -63,28 +63,11 @@ class SequenceClassificationTest(unittest.TestCase): text_classification = pipeline(task=Tasks.text_classification) result = text_classification( MsDataset.load( - 'glue', - subset_name='sst2', - split='train', - target='sentence', - hub=Hubs.huggingface)) - self.printDataset(result) - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_with_dataset(self): - model = Model.from_pretrained(self.model_id) - preprocessor = SequenceClassificationPreprocessor( - model.model_dir, first_sequence='sentence', second_sequence=None) - text_classification = pipeline( - Tasks.text_classification, model=model, preprocessor=preprocessor) - # loaded from huggingface dataset - dataset = MsDataset.load( - 'glue', - subset_name='sst2', - split='train', - target='sentence', - hub=Hubs.huggingface) - result = text_classification(dataset) + 'xcopa', + subset_name='translation-et', + namespace='damotest', + split='test', + target='premise')) self.printDataset(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -92,11 +75,11 @@ class SequenceClassificationTest(unittest.TestCase): text_classification = pipeline(task=Tasks.text_classification) # loaded from modelscope dataset dataset = MsDataset.load( - 'squad', + 'xcopa', + subset_name='translation-et', namespace='damotest', - split='train', - target='context', - hub=Hubs.modelscope) + split='test', + target='premise') result = text_classification(dataset) self.printDataset(result) diff --git a/tests/trainers/hooks/logger/test_tensorboard_hook.py b/tests/trainers/hooks/logger/test_tensorboard_hook.py index 1d3c0e76..dc4a5e83 100644 --- a/tests/trainers/hooks/logger/test_tensorboard_hook.py +++ b/tests/trainers/hooks/logger/test_tensorboard_hook.py @@ -4,24 +4,18 @@ import os import shutil import tempfile import unittest -from abc import ABCMeta import json +import numpy as np import torch from torch import nn -from torch.utils.data import Dataset from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile +from modelscope.utils.test_utils import create_dummy_test_dataset - -class DummyDataset(Dataset, metaclass=ABCMeta): - - def __len__(self): - return 20 - - def __getitem__(self, idx): - return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) class DummyModel(nn.Module): @@ -84,7 +78,7 @@ class TensorboardHookTest(unittest.TestCase): cfg_file=config_path, model=DummyModel(), data_collator=None, - train_dataset=DummyDataset(), + train_dataset=dummy_dataset, max_epochs=2) trainer = build_trainer(trainer_name, kwargs) diff --git a/tests/trainers/hooks/test_checkpoint_hook.py b/tests/trainers/hooks/test_checkpoint_hook.py index afb68869..8375a001 100644 --- a/tests/trainers/hooks/test_checkpoint_hook.py +++ b/tests/trainers/hooks/test_checkpoint_hook.py @@ -3,24 +3,18 @@ import os import shutil import tempfile import unittest -from abc import ABCMeta import json +import numpy as np import torch from torch import nn -from torch.utils.data import Dataset from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile +from modelscope.utils.test_utils import create_dummy_test_dataset - -class DummyDataset(Dataset, metaclass=ABCMeta): - - def __len__(self): - return 20 - - def __getitem__(self, idx): - return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) class DummyModel(nn.Module): @@ -94,7 +88,7 @@ class CheckpointHookTest(unittest.TestCase): cfg_file=config_path, model=DummyModel(), data_collator=None, - train_dataset=DummyDataset(), + train_dataset=dummy_dataset, max_epochs=2) trainer = build_trainer(trainer_name, kwargs) diff --git a/tests/trainers/hooks/test_evaluation_hook.py b/tests/trainers/hooks/test_evaluation_hook.py index 4d13b2e0..ad225aed 100644 --- a/tests/trainers/hooks/test_evaluation_hook.py +++ b/tests/trainers/hooks/test_evaluation_hook.py @@ -3,17 +3,17 @@ import os import shutil import tempfile import unittest -from abc import ABCMeta import json +import numpy as np import torch from torch import nn -from torch.utils.data import Dataset from modelscope.metrics.builder import METRICS, MetricKeys from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile from modelscope.utils.registry import default_group +from modelscope.utils.test_utils import create_dummy_test_dataset _global_iter = 0 @@ -32,13 +32,8 @@ class DummyMetric: return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} -class DummyDataset(Dataset, metaclass=ABCMeta): - - def __len__(self): - return 20 - - def __getitem__(self, idx): - return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) class DummyModel(nn.Module): @@ -115,8 +110,8 @@ class EvaluationHookTest(unittest.TestCase): cfg_file=config_path, model=DummyModel(), data_collator=None, - train_dataset=DummyDataset(), - eval_dataset=DummyDataset(), + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, max_epochs=3) trainer = build_trainer(trainer_name, kwargs) @@ -177,8 +172,8 @@ class EvaluationHookTest(unittest.TestCase): cfg_file=config_path, model=DummyModel(), data_collator=None, - train_dataset=DummyDataset(), - eval_dataset=DummyDataset(), + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, max_epochs=3) trainer = build_trainer(trainer_name, kwargs) diff --git a/tests/trainers/hooks/test_lr_scheduler_hook.py b/tests/trainers/hooks/test_lr_scheduler_hook.py index 575edfd7..ddf4b3fd 100644 --- a/tests/trainers/hooks/test_lr_scheduler_hook.py +++ b/tests/trainers/hooks/test_lr_scheduler_hook.py @@ -3,28 +3,20 @@ import os import shutil import tempfile import unittest -from abc import ABCMeta import json +import numpy as np import torch from torch import nn from torch.optim import SGD from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data import Dataset from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset - -class DummyDataset(Dataset, metaclass=ABCMeta): - """Base Dataset - """ - - def __len__(self): - return 10 - - def __getitem__(self, idx): - return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) class DummyModel(nn.Module): @@ -77,7 +69,7 @@ class LrSchedulerHookTest(unittest.TestCase): kwargs = dict( cfg_file=config_path, model=model, - train_dataset=DummyDataset(), + train_dataset=dummy_dataset, optimizers=(optimizer, lr_scheduler), max_epochs=5) @@ -148,7 +140,7 @@ class LrSchedulerHookTest(unittest.TestCase): kwargs = dict( cfg_file=config_path, model=model, - train_dataset=DummyDataset(), + train_dataset=dummy_dataset, # optimizers=(optimmizer, lr_scheduler), max_epochs=7) diff --git a/tests/trainers/hooks/test_optimizer_hook.py b/tests/trainers/hooks/test_optimizer_hook.py index 98dbfef5..a1ceb503 100644 --- a/tests/trainers/hooks/test_optimizer_hook.py +++ b/tests/trainers/hooks/test_optimizer_hook.py @@ -3,28 +3,20 @@ import os import shutil import tempfile import unittest -from abc import ABCMeta import json +import numpy as np import torch from torch import nn from torch.optim import SGD from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data import Dataset from modelscope.trainers import build_trainer from modelscope.utils.constant import ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset - -class DummyDataset(Dataset, metaclass=ABCMeta): - """Base Dataset - """ - - def __len__(self): - return 10 - - def __getitem__(self, idx): - return dict(feat=torch.rand((2, 2)), label=torch.randint(0, 2, (1, ))) +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(2, 2)), np.random.randint(0, 2, (1, )), 10) class DummyModel(nn.Module): @@ -76,7 +68,7 @@ class OptimizerHookTest(unittest.TestCase): kwargs = dict( cfg_file=config_path, model=model, - train_dataset=DummyDataset(), + train_dataset=dummy_dataset, optimizers=(optimizer, lr_scheduler), max_epochs=2) @@ -140,7 +132,7 @@ class TorchAMPOptimizerHookTest(unittest.TestCase): kwargs = dict( cfg_file=config_path, model=model, - train_dataset=DummyDataset(), + train_dataset=dummy_dataset, optimizers=(optimizer, lr_scheduler), max_epochs=2, use_fp16=True) diff --git a/tests/trainers/hooks/test_timer_hook.py b/tests/trainers/hooks/test_timer_hook.py index 5fafbfbb..d92b5f89 100644 --- a/tests/trainers/hooks/test_timer_hook.py +++ b/tests/trainers/hooks/test_timer_hook.py @@ -3,28 +3,20 @@ import os import shutil import tempfile import unittest -from abc import ABCMeta import json +import numpy as np import torch from torch import nn from torch.optim import SGD from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data import Dataset from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset - -class DummyDataset(Dataset, metaclass=ABCMeta): - """Base Dataset - """ - - def __len__(self): - return 10 - - def __getitem__(self, idx): - return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) class DummyModel(nn.Module): @@ -80,7 +72,7 @@ class IterTimerHookTest(unittest.TestCase): kwargs = dict( cfg_file=config_path, model=model, - train_dataset=DummyDataset(), + train_dataset=dummy_dataset, optimizers=(optimizer, lr_scheduler), max_epochs=5) diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py index a949c6ec..97cb94b8 100644 --- a/tests/trainers/test_trainer.py +++ b/tests/trainers/test_trainer.py @@ -6,27 +6,31 @@ import unittest from abc import ABCMeta import json +import numpy as np import torch +from datasets import Dataset from torch import nn from torch.optim import SGD from torch.optim.lr_scheduler import StepLR -from torch.utils.data import Dataset from modelscope.metrics.builder import MetricKeys +from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile -from modelscope.utils.test_utils import test_level +from modelscope.utils.test_utils import create_dummy_test_dataset, test_level -class DummyDataset(Dataset, metaclass=ABCMeta): - """Base Dataset - """ +class DummyMetric: - def __len__(self): - return 20 + def __call__(self, ground_truth, predict_results): + return {'accuracy': 0.5} - def __getitem__(self, idx): - return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) + +dummy_dataset_small = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) + +dummy_dataset_big = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) class DummyModel(nn.Module): @@ -116,8 +120,8 @@ class TrainerTest(unittest.TestCase): cfg_file=config_path, model=DummyModel(), data_collator=None, - train_dataset=DummyDataset(), - eval_dataset=DummyDataset(), + train_dataset=dummy_dataset_small, + eval_dataset=dummy_dataset_small, max_epochs=3) trainer = build_trainer(trainer_name, kwargs) @@ -174,8 +178,8 @@ class TrainerTest(unittest.TestCase): cfg_file=config_path, model=model, data_collator=None, - train_dataset=DummyDataset(), - eval_dataset=DummyDataset(), + train_dataset=dummy_dataset_small, + eval_dataset=dummy_dataset_small, optimizers=(optimmizer, lr_scheduler), max_epochs=3) @@ -212,13 +216,6 @@ class TrainerTest(unittest.TestCase): } } - class _DummyDataset(DummyDataset): - """Base Dataset - """ - - def __len__(self): - return 40 - config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) with open(config_path, 'w') as f: json.dump(json_cfg, f) @@ -231,8 +228,8 @@ class TrainerTest(unittest.TestCase): cfg_file=config_path, model=model, data_collator=None, - train_dataset=_DummyDataset(), - eval_dataset=DummyDataset(), + train_dataset=dummy_dataset_big, + eval_dataset=dummy_dataset_small, optimizers=(optimmizer, lr_scheduler), max_epochs=3) diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index a20bf97f..6deaaa5f 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -34,13 +34,7 @@ class TestTrainerWithNlp(unittest.TestCase): 'label': [0, 1, 1] } dataset = Dataset.from_dict(dataset_dict) - - class MsDatasetDummy(MsDataset): - - def __len__(self): - return len(self._hf_ds) - - self.dataset = MsDatasetDummy(dataset) + self.dataset = MsDataset.from_hf_dataset(dataset) def tearDown(self): shutil.rmtree(self.tmp_dir)