From 4819fd569d6142e4accbf68d555012da16f59a33 Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Tue, 16 Aug 2022 12:05:09 +0800 Subject: [PATCH] [to #43850241] adapt to torch IterableDataset Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9761129 * adapt to torch IterableDataset --- modelscope/trainers/hooks/hook.py | 2 +- .../trainers/hooks/logger/text_logger_hook.py | 2 +- modelscope/trainers/trainer.py | 37 ++++++++++++++++++- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/modelscope/trainers/hooks/hook.py b/modelscope/trainers/hooks/hook.py index 3a58557b..75cc226c 100644 --- a/modelscope/trainers/hooks/hook.py +++ b/modelscope/trainers/hooks/hook.py @@ -192,7 +192,7 @@ class Hook: Whether to reach the end of every epoch Returns: bool """ - return trainer.inner_iter + 1 == len(trainer.data_loader) + return trainer.inner_iter + 1 == trainer.iters_per_epoch def is_last_epoch(self, trainer): """ diff --git a/modelscope/trainers/hooks/logger/text_logger_hook.py b/modelscope/trainers/hooks/logger/text_logger_hook.py index a204284c..6629a0c9 100644 --- a/modelscope/trainers/hooks/logger/text_logger_hook.py +++ b/modelscope/trainers/hooks/logger/text_logger_hook.py @@ -93,7 +93,7 @@ class TextLoggerHook(LoggerHook): lr_str = f'{lr_key}: {log_dict[lr_key]:.3e}' if self.by_epoch: - log_str = f'{epoch_key} [{log_dict[epoch_key]}][{log_dict[iter_key]}/{len(trainer.data_loader)}]\t' + log_str = f'{epoch_key} [{log_dict[epoch_key]}][{log_dict[iter_key]}/{trainer.iters_per_epoch}]\t' else: log_str = f'{iter_key} [{log_dict[iter_key]}/{trainer.max_iters}]\t' log_str += f'{lr_str}, ' diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index b275bba4..e68fc383 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -180,6 +180,16 @@ class EpochBasedTrainer(BaseTrainer): else: self._max_epochs = kwargs['max_epochs'] + self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) + self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) + if self._train_iters_per_epoch is None and hasattr( + self.cfg.train, 'train_iters_per_epoch'): + self._train_iters_per_epoch = self.cfg.train.train_iters_per_epoch + if self._eval_iters_per_epoch is None and hasattr( + self.cfg, 'evaluation') and hasattr(self.cfg.evaluation, + 'val_iters_per_epoch'): + self._eval_iters_per_epoch = self.cfg.evaluation.val_iters_per_epoch + self.use_fp16 = kwargs.get('use_fp16', False) # TODO @wenmeng.zwm add seed init fn @@ -236,7 +246,32 @@ class EpochBasedTrainer(BaseTrainer): @property def max_iters(self): """int: Maximum training iterations.""" - return self._max_epochs * len(self.data_loader) + return self._max_epochs * self.iters_per_epoch + + @property + def iters_per_epoch(self): + """int: Total iterations of one epoch""" + + def _get_data_len(data_loader): + try: + return len(self.data_loader) + except Exception as e: + self.logger.error(e) + raise ValueError( + 'Please implement ``__len__`` method for your dataset, ' + 'or add `train_iters_per_epoch` and `train_iters_per_epoch` ' + 'to your configuration file or kwargs') + + if self.mode == ModeKeys.TRAIN: + if self._train_iters_per_epoch is not None: + return self._train_iters_per_epoch + else: + return _get_data_len(self.data_loader) + elif self.mode == ModeKeys.EVAL: + if self._eval_iters_per_epoch is not None: + return self._eval_iters_per_epoch + else: + return _get_data_len(self.data_loader) def to_task_dataset(self, datasets: Union[Dataset, List[Dataset]],