Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9791200 * fix trainer about iters_per_epochmaster
| @@ -254,7 +254,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| def _get_data_len(data_loader): | def _get_data_len(data_loader): | ||||
| try: | try: | ||||
| return len(self.data_loader) | |||||
| return len(data_loader) | |||||
| except Exception as e: | except Exception as e: | ||||
| self.logger.error(e) | self.logger.error(e) | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -266,12 +266,12 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| if self._train_iters_per_epoch is not None: | if self._train_iters_per_epoch is not None: | ||||
| return self._train_iters_per_epoch | return self._train_iters_per_epoch | ||||
| else: | else: | ||||
| return _get_data_len(self.data_loader) | |||||
| return _get_data_len(self.train_dataloader) | |||||
| elif self.mode == ModeKeys.EVAL: | elif self.mode == ModeKeys.EVAL: | ||||
| if self._eval_iters_per_epoch is not None: | if self._eval_iters_per_epoch is not None: | ||||
| return self._eval_iters_per_epoch | return self._eval_iters_per_epoch | ||||
| else: | else: | ||||
| return _get_data_len(self.data_loader) | |||||
| return _get_data_len(self.eval_dataloader) | |||||
| def to_task_dataset(self, | def to_task_dataset(self, | ||||
| datasets: Union[Dataset, List[Dataset]], | datasets: Union[Dataset, List[Dataset]], | ||||
| @@ -761,6 +761,9 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| del self.data_batch | del self.data_batch | ||||
| self._iter += 1 | self._iter += 1 | ||||
| if i + 1 >= self.iters_per_epoch: | |||||
| break | |||||
| self.invoke_hook(TrainerStages.after_train_epoch) | self.invoke_hook(TrainerStages.after_train_epoch) | ||||
| self._epoch += 1 | self._epoch += 1 | ||||
| @@ -779,14 +782,18 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| device=self.device, | device=self.device, | ||||
| tmpdir=None, | tmpdir=None, | ||||
| gpu_collect=False, | gpu_collect=False, | ||||
| metric_classes=metric_classes) | |||||
| metric_classes=metric_classes, | |||||
| data_loader_iters_per_gpu=self.iters_per_epoch) | |||||
| else: | else: | ||||
| from modelscope.trainers.utils.inference import single_gpu_test | from modelscope.trainers.utils.inference import single_gpu_test | ||||
| metric_values = single_gpu_test( | metric_values = single_gpu_test( | ||||
| self.model, | self.model, | ||||
| data_loader, | data_loader, | ||||
| device=self.device, | device=self.device, | ||||
| metric_classes=metric_classes) | |||||
| metric_classes=metric_classes, | |||||
| data_loader_iters=self.iters_per_epoch) | |||||
| self._inner_iter = self.iters_per_epoch - 1 # start from index 0 | |||||
| return metric_values | return metric_values | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) OpenMMLab. All rights reserved. | # Copyright (c) OpenMMLab. All rights reserved. | ||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import logging | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import shutil | import shutil | ||||
| @@ -16,22 +17,42 @@ from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, | |||||
| make_tmp_dir) | make_tmp_dir) | ||||
| def single_gpu_test(model, data_loader, device, metric_classes=None): | |||||
| def single_gpu_test(model, | |||||
| data_loader, | |||||
| device, | |||||
| metric_classes=None, | |||||
| data_loader_iters=None): | |||||
| """Test model with a single gpu. | """Test model with a single gpu. | ||||
| Args: | Args: | ||||
| model (nn.Module): Model to be tested. | model (nn.Module): Model to be tested. | ||||
| data_loader (nn.Dataloader): Pytorch data loader. | data_loader (nn.Dataloader): Pytorch data loader. | ||||
| device: (str | torch.device): The target device for the data. | |||||
| metric_classes(List): List of Metric class that uses to collect metrics | |||||
| device (str | torch.device): The target device for the data. | |||||
| metric_classes (List): List of Metric class that uses to collect metrics | |||||
| data_loader_iters (int): Used when dataset has no attribute __len__ or only load part of dataset. | |||||
| Returns: | Returns: | ||||
| list: The prediction results. | list: The prediction results. | ||||
| """ | """ | ||||
| model.eval() | model.eval() | ||||
| dataset = data_loader.dataset | dataset = data_loader.dataset | ||||
| with tqdm(total=len(dataset), desc='test samples') as pbar: | |||||
| for data in data_loader: | |||||
| progress_with_iters = False | |||||
| if data_loader_iters is None: | |||||
| try: | |||||
| data_len = len(dataset) | |||||
| except Exception as e: | |||||
| logging.error(e) | |||||
| raise ValueError( | |||||
| 'Please implement ``__len__`` method for your dataset, or provide ``data_loader_iters``' | |||||
| ) | |||||
| desc = 'Total test samples' | |||||
| else: | |||||
| progress_with_iters = True | |||||
| data_len = data_loader_iters | |||||
| desc = 'Test iterations' | |||||
| with tqdm(total=data_len, desc=desc) as pbar: | |||||
| for i, data in enumerate(data_loader): | |||||
| data = to_device(data, device) | data = to_device(data, device) | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| if isinstance(data, Mapping) and not func_receive_dict_inputs( | if isinstance(data, Mapping) and not func_receive_dict_inputs( | ||||
| @@ -43,13 +64,19 @@ def single_gpu_test(model, data_loader, device, metric_classes=None): | |||||
| for metric_cls in metric_classes: | for metric_cls in metric_classes: | ||||
| metric_cls.add(result, data) | metric_cls.add(result, data) | ||||
| if isinstance(data, dict): | |||||
| batch_size = len(next(iter(data.values()))) | |||||
| if progress_with_iters: | |||||
| batch_size = 1 # iteration count | |||||
| else: | else: | ||||
| batch_size = len(data) | |||||
| if isinstance(data, dict): | |||||
| batch_size = len(next(iter(data.values()))) | |||||
| else: | |||||
| batch_size = len(data) | |||||
| for _ in range(batch_size): | for _ in range(batch_size): | ||||
| pbar.update() | pbar.update() | ||||
| if progress_with_iters and (i + 1) >= data_len: | |||||
| break | |||||
| metric_values = {} | metric_values = {} | ||||
| for metric_cls in metric_classes: | for metric_cls in metric_classes: | ||||
| metric_values.update(metric_cls.evaluate()) | metric_values.update(metric_cls.evaluate()) | ||||
| @@ -62,7 +89,8 @@ def multi_gpu_test(model, | |||||
| device, | device, | ||||
| tmpdir=None, | tmpdir=None, | ||||
| gpu_collect=False, | gpu_collect=False, | ||||
| metric_classes=None): | |||||
| metric_classes=None, | |||||
| data_loader_iters_per_gpu=None): | |||||
| """Test model with multiple gpus. | """Test model with multiple gpus. | ||||
| This method tests model with multiple gpus and collects the results | This method tests model with multiple gpus and collects the results | ||||
| @@ -79,7 +107,7 @@ def multi_gpu_test(model, | |||||
| different gpus under cpu mode. | different gpus under cpu mode. | ||||
| gpu_collect (bool): Option to use either gpu or cpu to collect results. | gpu_collect (bool): Option to use either gpu or cpu to collect results. | ||||
| metric_classes(List): List of Metric class that uses to collect metrics | metric_classes(List): List of Metric class that uses to collect metrics | ||||
| data_loader_iters_per_gpu (int): Used when dataset has no attribute __len__ or only load part of dataset. | |||||
| Returns: | Returns: | ||||
| list: The prediction results. | list: The prediction results. | ||||
| """ | """ | ||||
| @@ -87,14 +115,30 @@ def multi_gpu_test(model, | |||||
| results = [] | results = [] | ||||
| data_list = [] | data_list = [] | ||||
| dataset = data_loader.dataset | dataset = data_loader.dataset | ||||
| rank, world_size = get_dist_info() | |||||
| time.sleep(2) # This line can prevent deadlock problem in some cases. | |||||
| progress_with_iters = False | |||||
| if data_loader_iters_per_gpu is None: | |||||
| try: | |||||
| data_len = len(dataset) | |||||
| total_samples = data_len | |||||
| except Exception as e: | |||||
| logging.error(e) | |||||
| raise ValueError( | |||||
| 'Please implement ``__len__`` method for your dataset, or provide ``data_loader_iters_per_gpu``' | |||||
| ) | |||||
| desc = 'Total test samples with multi gpus' | |||||
| else: | |||||
| total_samples = 0 | |||||
| progress_with_iters = True | |||||
| data_len = data_loader_iters_per_gpu * world_size | |||||
| desc = 'Total test iterations with multi gpus' | |||||
| rank, world_size = get_dist_info() | |||||
| time.sleep(2) # This line can prevent deadlock problem in some cases. | |||||
| count = 0 | count = 0 | ||||
| with tqdm(total=len(dataset), desc='test samples with multi gpus') as pbar: | |||||
| for _, data in enumerate(data_loader): | |||||
| with tqdm(total=data_len, desc=desc) as pbar: | |||||
| for i, data in enumerate(data_loader): | |||||
| data = to_device(data, device) | data = to_device(data, device) | ||||
| data_list.append(data) | data_list.append(data) | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| @@ -110,24 +154,32 @@ def multi_gpu_test(model, | |||||
| batch_size = len(next(iter(data.values()))) | batch_size = len(next(iter(data.values()))) | ||||
| else: | else: | ||||
| batch_size = len(data) | batch_size = len(data) | ||||
| if progress_with_iters: | |||||
| total_samples += batch_size * world_size | |||||
| batch_size = 1 # iteration count | |||||
| batch_size_all = batch_size * world_size | batch_size_all = batch_size * world_size | ||||
| count += batch_size_all | count += batch_size_all | ||||
| if count > len(dataset): | |||||
| batch_size_all = len(dataset) - (count - batch_size_all) | |||||
| if count > data_len: | |||||
| batch_size_all = data_len - (count - batch_size_all) | |||||
| for _ in range(batch_size_all): | for _ in range(batch_size_all): | ||||
| pbar.update() | pbar.update() | ||||
| if progress_with_iters and (i + 1) >= data_len: | |||||
| break | |||||
| # TODO: allgather data list may cost a lot of memory and needs to be redesigned | # TODO: allgather data list may cost a lot of memory and needs to be redesigned | ||||
| # collect results and data from all ranks | # collect results and data from all ranks | ||||
| if gpu_collect: | if gpu_collect: | ||||
| results = collect_results_gpu(results, len(dataset)) | |||||
| data_list = collect_results_gpu(data_list, len(dataset)) | |||||
| results = collect_results_gpu(results, total_samples) | |||||
| data_list = collect_results_gpu(data_list, total_samples) | |||||
| else: | else: | ||||
| if tmpdir is None: | if tmpdir is None: | ||||
| tmpdir = make_tmp_dir() | tmpdir = make_tmp_dir() | ||||
| results = collect_results_cpu(results, len(dataset), | |||||
| results = collect_results_cpu(results, total_samples, | |||||
| os.path.join(tmpdir, 'predict')) | os.path.join(tmpdir, 'predict')) | ||||
| data_list = collect_results_cpu(data_list, len(dataset), | |||||
| data_list = collect_results_cpu(data_list, total_samples, | |||||
| os.path.join(tmpdir, 'groundtruth')) | os.path.join(tmpdir, 'groundtruth')) | ||||
| if is_master(): | if is_master(): | ||||
| @@ -84,6 +84,7 @@ class IterTimerHookTest(unittest.TestCase): | |||||
| trainer.register_optimizers_hook() | trainer.register_optimizers_hook() | ||||
| trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | ||||
| trainer.data_loader = train_dataloader | trainer.data_loader = train_dataloader | ||||
| trainer.train_dataloader = train_dataloader | |||||
| trainer.invoke_hook(TrainerStages.before_run) | trainer.invoke_hook(TrainerStages.before_run) | ||||
| for i in range(trainer._epoch, trainer._max_epochs): | for i in range(trainer._epoch, trainer._max_epochs): | ||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | trainer.invoke_hook(TrainerStages.before_train_epoch) | ||||
| @@ -10,6 +10,7 @@ import torch | |||||
| from torch import nn | from torch import nn | ||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| from torch.optim.lr_scheduler import StepLR | from torch.optim.lr_scheduler import StepLR | ||||
| from torch.utils.data import IterableDataset | |||||
| from modelscope.metainfo import Metrics, Trainers | from modelscope.metainfo import Metrics, Trainers | ||||
| from modelscope.metrics.builder import MetricKeys | from modelscope.metrics.builder import MetricKeys | ||||
| @@ -17,6 +18,16 @@ from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | ||||
| class DummyIterableDataset(IterableDataset): | |||||
| def __iter__(self): | |||||
| feat = np.random.random(size=(5, )).astype(np.float32) | |||||
| labels = np.random.randint(0, 4, (1, )) | |||||
| iterations = [{'feat': feat, 'labels': labels}] * 500 | |||||
| return iter(iterations) | |||||
| dummy_dataset_small = create_dummy_test_dataset( | dummy_dataset_small = create_dummy_test_dataset( | ||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| @@ -303,6 +314,124 @@ class TrainerTest(unittest.TestCase): | |||||
| for i in [2, 5, 8]: | for i in [2, 5, 8]: | ||||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | self.assertIn(MetricKeys.ACCURACY, lines[i]) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_train_with_iters_per_epoch(self): | |||||
| json_cfg = { | |||||
| 'train': { | |||||
| 'work_dir': self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1 | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': [Metrics.seq_cls_metric] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| model = DummyModel() | |||||
| optimmizer = SGD(model.parameters(), lr=0.01) | |||||
| lr_scheduler = StepLR(optimmizer, 2) | |||||
| trainer_name = Trainers.default | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=model, | |||||
| data_collator=None, | |||||
| optimizers=(optimmizer, lr_scheduler), | |||||
| train_dataset=DummyIterableDataset(), | |||||
| eval_dataset=DummyIterableDataset(), | |||||
| train_iters_per_epoch=20, | |||||
| val_iters_per_epoch=10, | |||||
| max_epochs=3, | |||||
| device='cpu') | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json') | |||||
| with open(json_file, 'r') as f: | |||||
| lines = [i.strip() for i in f.readlines()] | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[0])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[1])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 10 | |||||
| }, json.loads(lines[2])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[3])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[4])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 10 | |||||
| }, json.loads(lines[5])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[6])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[7])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 10 | |||||
| }, json.loads(lines[8])) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| for i in [0, 1, 3, 4, 6, 7]: | |||||
| self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||||
| self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||||
| for i in [2, 5, 8]: | |||||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | |||||
| class DummyTrainerTest(unittest.TestCase): | class DummyTrainerTest(unittest.TestCase): | ||||
| @@ -11,6 +11,7 @@ import torch | |||||
| from torch import nn | from torch import nn | ||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| from torch.optim.lr_scheduler import StepLR | from torch.optim.lr_scheduler import StepLR | ||||
| from torch.utils.data import IterableDataset | |||||
| from modelscope.metainfo import Metrics, Trainers | from modelscope.metainfo import Metrics, Trainers | ||||
| from modelscope.metrics.builder import MetricKeys | from modelscope.metrics.builder import MetricKeys | ||||
| @@ -19,6 +20,16 @@ from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||||
| from modelscope.utils.test_utils import (DistributedTestCase, | from modelscope.utils.test_utils import (DistributedTestCase, | ||||
| create_dummy_test_dataset, test_level) | create_dummy_test_dataset, test_level) | ||||
| class DummyIterableDataset(IterableDataset): | |||||
| def __iter__(self): | |||||
| feat = np.random.random(size=(5, )).astype(np.float32) | |||||
| labels = np.random.randint(0, 4, (1, )) | |||||
| iterations = [{'feat': feat, 'labels': labels}] * 500 | |||||
| return iter(iterations) | |||||
| dummy_dataset_small = create_dummy_test_dataset( | dummy_dataset_small = create_dummy_test_dataset( | ||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| @@ -41,7 +52,7 @@ class DummyModel(nn.Module): | |||||
| return dict(logits=x, loss=loss) | return dict(logits=x, loss=loss) | ||||
| def train_func(work_dir, dist=False): | |||||
| def train_func(work_dir, dist=False, iterable_dataset=False, **kwargs): | |||||
| json_cfg = { | json_cfg = { | ||||
| 'train': { | 'train': { | ||||
| 'work_dir': work_dir, | 'work_dir': work_dir, | ||||
| @@ -72,18 +83,25 @@ def train_func(work_dir, dist=False): | |||||
| optimmizer = SGD(model.parameters(), lr=0.01) | optimmizer = SGD(model.parameters(), lr=0.01) | ||||
| lr_scheduler = StepLR(optimmizer, 2) | lr_scheduler = StepLR(optimmizer, 2) | ||||
| trainer_name = Trainers.default | trainer_name = Trainers.default | ||||
| kwargs = dict( | |||||
| if iterable_dataset: | |||||
| train_dataset = DummyIterableDataset() | |||||
| eval_dataset = DummyIterableDataset() | |||||
| else: | |||||
| train_dataset = dummy_dataset_big | |||||
| eval_dataset = dummy_dataset_small | |||||
| _kwargs = dict( | |||||
| cfg_file=config_path, | cfg_file=config_path, | ||||
| model=model, | model=model, | ||||
| data_collator=None, | data_collator=None, | ||||
| train_dataset=dummy_dataset_big, | |||||
| eval_dataset=dummy_dataset_small, | |||||
| train_dataset=train_dataset, | |||||
| eval_dataset=eval_dataset, | |||||
| optimizers=(optimmizer, lr_scheduler), | optimizers=(optimmizer, lr_scheduler), | ||||
| max_epochs=3, | max_epochs=3, | ||||
| device='gpu', | device='gpu', | ||||
| launcher='pytorch' if dist else None) | |||||
| launcher='pytorch' if dist else None, | |||||
| **kwargs) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer = build_trainer(trainer_name, _kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| @@ -253,6 +271,28 @@ class TrainerTestMultiGpus(DistributedTestCase): | |||||
| for i in [1, 3, 5]: | for i in [1, 3, 5]: | ||||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | self.assertIn(MetricKeys.ACCURACY, lines[i]) | ||||
| # TODO: support iters_per_epoch for dist mode | |||||
| @unittest.skipIf(True, 'need to adapt to DistributedSampler') | |||||
| def test_multi_gpus_with_iters_per_epoch(self): | |||||
| self.start( | |||||
| train_func, | |||||
| num_gpus=2, | |||||
| work_dir=self.tmp_dir, | |||||
| dist=True, | |||||
| iterable_dataset=True, | |||||
| train_iters_per_epoch=20, | |||||
| val_iters_per_epoch=10, | |||||
| ) | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| with open(json_files[0], 'r') as f: | |||||
| lines = [i.strip() for i in f.readlines()] | |||||
| print(results_files, lines) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||