| @@ -0,0 +1,14 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| DEFAULT_CONFIG = { | |||||
| 'train': { | |||||
| 'hooks': [{ | |||||
| 'type': 'CheckpointHook', | |||||
| 'interval': 1 | |||||
| }, { | |||||
| 'type': 'TextLoggerHook', | |||||
| 'interval': 10 | |||||
| }, { | |||||
| 'type': 'IterTimerHook' | |||||
| }] | |||||
| } | |||||
| } | |||||
| @@ -6,11 +6,12 @@ from .hook import Hook | |||||
| from .iter_timer_hook import IterTimerHook | from .iter_timer_hook import IterTimerHook | ||||
| from .logger.text_logger_hook import TextLoggerHook | from .logger.text_logger_hook import TextLoggerHook | ||||
| from .lr_scheduler_hook import LrSchedulerHook | from .lr_scheduler_hook import LrSchedulerHook | ||||
| from .optimizer_hook import OptimizerHook | |||||
| from .optimizer_hook import (ApexAMPOptimizerHook, OptimizerHook, | |||||
| TorchAMPOptimizerHook) | |||||
| from .priority import Priority | from .priority import Priority | ||||
| __all__ = [ | __all__ = [ | ||||
| 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | ||||
| 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | ||||
| 'IterTimerHook' | |||||
| 'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook' | |||||
| ] | ] | ||||
| @@ -3,6 +3,7 @@ import os | |||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.utils.checkpoint import save_checkpoint | from modelscope.utils.checkpoint import save_checkpoint | ||||
| from modelscope.utils.constant import LogKeys | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.torch_utils import get_dist_info | from modelscope.utils.torch_utils import get_dist_info | ||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| @@ -58,11 +59,11 @@ class CheckpointHook(Hook): | |||||
| def _save_checkpoint(self, trainer): | def _save_checkpoint(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| cur_save_name = os.path.join(self.save_dir, | |||||
| f'epoch_{trainer.epoch + 1}.pth') | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, f'{LogKeys.EPOCH}_{trainer.epoch + 1}.pth') | |||||
| else: | else: | ||||
| cur_save_name = os.path.join(self.save_dir, | |||||
| f'iter_{trainer.epoch + 1}.pth') | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||||
| rank, _ = get_dist_info() | rank, _ = get_dist_info() | ||||
| if rank == 0: | if rank == 0: | ||||
| @@ -1,4 +1,10 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| from modelscope.utils.checkpoint import save_checkpoint | |||||
| from modelscope.utils.constant import LogKeys | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.torch_utils import get_dist_info | |||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .priority import Priority | from .priority import Priority | ||||
| @@ -12,17 +18,56 @@ class EvaluationHook(Hook): | |||||
| by_epoch (bool): Evaluate by epoch or by iteration. | by_epoch (bool): Evaluate by epoch or by iteration. | ||||
| start_idx (int | None, optional): The epoch/iterations validation begins. | start_idx (int | None, optional): The epoch/iterations validation begins. | ||||
| Default: None, validate every interval epochs/iterations from scratch. | Default: None, validate every interval epochs/iterations from scratch. | ||||
| save_best_ckpt (bool): Whether save the best checkpoint during evaluation. | |||||
| monitor_key (str): Monitor key to compare rule for best score, only valid when `save_best_ckpt` is true. | |||||
| rule (str): Comparison rule for best score, only valid when `save_best_ckpt` is true. | |||||
| Support "max" and "min". If rule is "max", the checkpoint at the maximum `monitor_key` | |||||
| will be saved, If rule is "min", the checkpoint at the minimum `monitor_key` will be saved. | |||||
| out_dir (str): Output directory to save best checkpoint. | |||||
| """ | """ | ||||
| PRIORITY = Priority.NORMAL | PRIORITY = Priority.NORMAL | ||||
| rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} | |||||
| def __init__(self, interval=1, by_epoch=True, start_idx=None): | |||||
| def __init__(self, | |||||
| interval=1, | |||||
| by_epoch=True, | |||||
| start_idx=None, | |||||
| save_best_ckpt=False, | |||||
| monitor_key=None, | |||||
| rule='max', | |||||
| out_dir=None): | |||||
| assert interval > 0, 'interval must be a positive number' | assert interval > 0, 'interval must be a positive number' | ||||
| if save_best_ckpt: | |||||
| assert monitor_key is not None, 'Must provide `monitor_key` when `save_best_ckpt` is True.' | |||||
| assert rule in ['max', | |||||
| 'min'], 'Only support "max" or "min" rule now.' | |||||
| self.interval = interval | self.interval = interval | ||||
| self.start_idx = start_idx | self.start_idx = start_idx | ||||
| self.by_epoch = by_epoch | self.by_epoch = by_epoch | ||||
| self.save_best_ckpt = save_best_ckpt | |||||
| self.monitor_key = monitor_key | |||||
| self.rule = rule | |||||
| self.out_dir = out_dir | |||||
| self._best_metric = None | |||||
| self._best_ckpt_file = None | |||||
| def before_run(self, trainer): | |||||
| if not self.out_dir: | |||||
| self.out_dir = trainer.work_dir | |||||
| if not os.path.exists(self.out_dir): | |||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| os.makedirs(self.out_dir) | |||||
| if self.save_best_ckpt: | |||||
| if not hasattr(trainer, 'logger'): | |||||
| self.logger = get_logger(__name__) | |||||
| else: | |||||
| self.logger = trainer.logger | |||||
| self.logger.info( | |||||
| f'Best checkpoint will be saved to {self.out_dir}') | |||||
| def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
| """Called after every training iter to evaluate the results.""" | """Called after every training iter to evaluate the results.""" | ||||
| @@ -42,6 +87,46 @@ class EvaluationHook(Hook): | |||||
| trainer.log_buffer.ready = True | trainer.log_buffer.ready = True | ||||
| if self.save_best_ckpt and self._is_best_metric(eval_res): | |||||
| # remove the previous best model and save the latest best model | |||||
| if self._best_ckpt_file is not None and os.path.exists( | |||||
| self._best_ckpt_file): | |||||
| os.remove(self._best_ckpt_file) | |||||
| self._save_checkpoint(trainer) | |||||
| def _is_best_metric(self, eval_res): | |||||
| if self.monitor_key not in eval_res: | |||||
| raise ValueError( | |||||
| f'Not find monitor_key: {self.monitor_key} in {eval_res}') | |||||
| if self._best_metric is None: | |||||
| self._best_metric = eval_res[self.monitor_key] | |||||
| return True | |||||
| else: | |||||
| compare_fn = self.rule_map[self.rule] | |||||
| if compare_fn(eval_res[self.monitor_key], self._best_metric): | |||||
| self._best_metric = eval_res[self.monitor_key] | |||||
| return True | |||||
| return False | |||||
| def _save_checkpoint(self, trainer): | |||||
| if self.by_epoch: | |||||
| cur_save_name = os.path.join( | |||||
| self.out_dir, | |||||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.monitor_key}{self._best_metric}.pth' | |||||
| ) | |||||
| else: | |||||
| cur_save_name = os.path.join( | |||||
| self.out_dir, | |||||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.monitor_key}{self._best_metric}.pth' | |||||
| ) | |||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| self._best_ckpt_file = cur_save_name | |||||
| def _should_evaluate(self, trainer): | def _should_evaluate(self, trainer): | ||||
| """Judge whether to perform evaluation. | """Judge whether to perform evaluation. | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) OpenMMLab. All rights reserved. | # Copyright (c) OpenMMLab. All rights reserved. | ||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from modelscope.utils.constant import TrainerStages | |||||
| from modelscope.utils.import_utils import is_method_overridden | from modelscope.utils.import_utils import is_method_overridden | ||||
| from .priority import Priority | from .priority import Priority | ||||
| @@ -9,11 +10,12 @@ class Hook: | |||||
| The Hook base class of any modelscope trainer. You can build your own hook inherited from this class. | The Hook base class of any modelscope trainer. You can build your own hook inherited from this class. | ||||
| """ | """ | ||||
| # TODO @jiangnana.jnn use constant variable for stages | |||||
| stages = ('before_run', 'before_train_epoch', 'before_train_iter', | |||||
| 'after_train_iter', 'after_train_epoch', 'before_val_epoch', | |||||
| 'before_val_iter', 'after_val_iter', 'after_val_epoch', | |||||
| 'after_run') | |||||
| stages = (TrainerStages.before_run, TrainerStages.before_train_epoch, | |||||
| TrainerStages.before_train_iter, TrainerStages.after_train_iter, | |||||
| TrainerStages.after_train_epoch, TrainerStages.before_val_epoch, | |||||
| TrainerStages.before_val_iter, TrainerStages.after_val_iter, | |||||
| TrainerStages.after_val_epoch, TrainerStages.after_run) | |||||
| PRIORITY = Priority.NORMAL | PRIORITY = Priority.NORMAL | ||||
| def before_run(self, trainer): | def before_run(self, trainer): | ||||
| @@ -171,6 +173,13 @@ class Hook: | |||||
| """ | """ | ||||
| return (trainer.epoch + 1) % n == 0 if n > 0 else False | return (trainer.epoch + 1) % n == 0 if n > 0 else False | ||||
| def every_n_inner_iters(self, runner, n): | |||||
| """ | |||||
| Whether to reach every ``n`` iterations at every epoch | |||||
| Returns: bool | |||||
| """ | |||||
| return (runner.inner_iter + 1) % n == 0 if n > 0 else False | |||||
| def every_n_iters(self, trainer, n): | def every_n_iters(self, trainer, n): | ||||
| """ | """ | ||||
| Whether to reach every ``n`` iterations | Whether to reach every ``n`` iterations | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import time | import time | ||||
| from modelscope.utils.constant import LogKeys | |||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .priority import Priority | from .priority import Priority | ||||
| @@ -15,8 +16,9 @@ class IterTimerHook(Hook): | |||||
| def before_iter(self, trainer): | def before_iter(self, trainer): | ||||
| trainer.log_buffer.update( | trainer.log_buffer.update( | ||||
| {'data_load_time': time.time() - self.start_time}) | |||||
| {LogKeys.DATA_LOAD_TIME: time.time() - self.start_time}) | |||||
| def after_iter(self, trainer): | def after_iter(self, trainer): | ||||
| trainer.log_buffer.update({'time': time.time() - self.start_time}) | |||||
| trainer.log_buffer.update( | |||||
| {LogKeys.ITER_TIME: time.time() - self.start_time}) | |||||
| self.start_time = time.time() | self.start_time = time.time() | ||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from modelscope.trainers.utils.log_buffer import LogBuffer | from modelscope.trainers.utils.log_buffer import LogBuffer | ||||
| from .base import LoggerHook | from .base import LoggerHook | ||||
| from .tensorboard_hook import TensorboardHook | |||||
| from .text_logger_hook import TextLoggerHook | from .text_logger_hook import TextLoggerHook | ||||
| __all__ = ['TextLoggerHook', 'LoggerHook', 'LogBuffer'] | |||||
| __all__ = ['TextLoggerHook', 'LoggerHook', 'LogBuffer', 'TensorboardHook'] | |||||
| @@ -7,6 +7,7 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from modelscope.trainers.hooks.hook import Hook | from modelscope.trainers.hooks.hook import Hook | ||||
| from modelscope.utils.constant import ModeKeys | |||||
| from ..priority import Priority | from ..priority import Priority | ||||
| @@ -60,15 +61,12 @@ class LoggerHook(Hook): | |||||
| return False | return False | ||||
| def get_epoch(self, trainer): | def get_epoch(self, trainer): | ||||
| if trainer.mode == 'train': | |||||
| if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]: | |||||
| epoch = trainer.epoch + 1 | epoch = trainer.epoch + 1 | ||||
| elif trainer.mode == 'val': | |||||
| # normal val mode | |||||
| # trainer.epoch += 1 has been done before val workflow | |||||
| epoch = trainer.epoch | |||||
| else: | else: | ||||
| raise ValueError(f"trainer mode should be 'train' or 'val', " | |||||
| f'but got {trainer.mode}') | |||||
| raise ValueError( | |||||
| f'trainer mode should be {ModeKeys.TRAIN} or {ModeKeys.EVAL}, ' | |||||
| f'but got {trainer.mode}') | |||||
| return epoch | return epoch | ||||
| def get_iter(self, trainer, inner_iter=False): | def get_iter(self, trainer, inner_iter=False): | ||||
| @@ -89,7 +87,7 @@ class LoggerHook(Hook): | |||||
| trainer.log_buffer.clear() # clear logs of last epoch | trainer.log_buffer.clear() # clear logs of last epoch | ||||
| def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
| if self.by_epoch and self.every_n_epochs(trainer, self.interval): | |||||
| if self.by_epoch and self.every_n_inner_iters(trainer, self.interval): | |||||
| trainer.log_buffer.average(self.interval) | trainer.log_buffer.average(self.interval) | ||||
| elif not self.by_epoch and self.every_n_iters(trainer, self.interval): | elif not self.by_epoch and self.every_n_iters(trainer, self.interval): | ||||
| trainer.log_buffer.average(self.interval) | trainer.log_buffer.average(self.interval) | ||||
| @@ -0,0 +1,68 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from modelscope.trainers.hooks.builder import HOOKS | |||||
| from modelscope.utils.constant import LogKeys | |||||
| from modelscope.utils.torch_utils import master_only | |||||
| from .base import LoggerHook | |||||
| @HOOKS.register_module() | |||||
| class TensorboardHook(LoggerHook): | |||||
| """TensorBoard hook for visualization. | |||||
| Args: | |||||
| out_dir: output directory to save tensorboard files | |||||
| interval (int): Logging interval (every k iterations). | |||||
| ignore_last (bool): Ignore the log of last iterations in each epoch | |||||
| if less than `interval`. | |||||
| reset_flag (bool): Whether to clear the output buffer after logging. | |||||
| by_epoch (bool): Whether EpochBasedtrainer is used. | |||||
| skip_keys (list): list of keys which will not add to tensorboard | |||||
| """ | |||||
| def __init__(self, | |||||
| out_dir=None, | |||||
| interval=10, | |||||
| ignore_last=True, | |||||
| reset_flag=False, | |||||
| by_epoch=True, | |||||
| skip_keys=[LogKeys.ITER_TIME, LogKeys.DATA_LOAD_TIME]): | |||||
| super(TensorboardHook, self).__init__( | |||||
| interval=interval, | |||||
| ignore_last=ignore_last, | |||||
| reset_flag=reset_flag, | |||||
| by_epoch=by_epoch) | |||||
| self.out_dir = out_dir | |||||
| self.skip_keys = skip_keys | |||||
| @master_only | |||||
| def before_run(self, trainer): | |||||
| super(TensorboardHook, self).before_run(trainer) | |||||
| try: | |||||
| from torch.utils.tensorboard import SummaryWriter | |||||
| except ImportError as e: | |||||
| raise ImportError( | |||||
| e.msg + ' ' | |||||
| 'Please pip install tensorboard by ``pip install future tensorboard`` ' | |||||
| 'or upgrade version by ``pip install future tensorboard --upgrade``.' | |||||
| ) | |||||
| if self.out_dir is None: | |||||
| self.out_dir = os.path.join(trainer.work_dir, 'tensorboard_output') | |||||
| self.writer = SummaryWriter(self.out_dir) | |||||
| @master_only | |||||
| def log(self, trainer): | |||||
| for key, val in trainer.log_buffer.output.items(): | |||||
| if key in self.skip_keys: | |||||
| continue | |||||
| if isinstance(val, str): | |||||
| self.writer.add_text(key, val, self.get_iter(trainer)) | |||||
| elif self.is_scalar(val): | |||||
| self.writer.add_scalar(key, val, self.get_iter(trainer)) | |||||
| else: | |||||
| pass | |||||
| @master_only | |||||
| def after_run(self, trainer): | |||||
| self.writer.close() | |||||
| @@ -8,6 +8,7 @@ import json | |||||
| import torch | import torch | ||||
| from torch import distributed as dist | from torch import distributed as dist | ||||
| from modelscope.utils.constant import LogKeys, ModeKeys | |||||
| from modelscope.utils.torch_utils import get_dist_info | from modelscope.utils.torch_utils import get_dist_info | ||||
| from ..builder import HOOKS | from ..builder import HOOKS | ||||
| from .base import LoggerHook | from .base import LoggerHook | ||||
| @@ -72,44 +73,53 @@ class TextLoggerHook(LoggerHook): | |||||
| return mem_mb.item() | return mem_mb.item() | ||||
| def _log_info(self, log_dict, trainer): | def _log_info(self, log_dict, trainer): | ||||
| if log_dict['mode'] == 'train': | |||||
| if isinstance(log_dict['lr'], dict): | |||||
| lr_key = LogKeys.LR | |||||
| epoch_key = LogKeys.EPOCH | |||||
| iter_key = LogKeys.ITER | |||||
| mode_key = LogKeys.MODE | |||||
| iter_time_key = LogKeys.ITER_TIME | |||||
| data_load_time_key = LogKeys.DATA_LOAD_TIME | |||||
| eta_key = LogKeys.ETA | |||||
| if log_dict[mode_key] == ModeKeys.TRAIN: | |||||
| if isinstance(log_dict[lr_key], dict): | |||||
| lr_str = [] | lr_str = [] | ||||
| for k, val in log_dict['lr'].items(): | |||||
| lr_str.append(f'lr_{k}: {val:.3e}') | |||||
| for k, val in log_dict[lr_key].items(): | |||||
| lr_str.append(f'{lr_key}_{k}: {val:.3e}') | |||||
| lr_str = ' '.join(lr_str) | lr_str = ' '.join(lr_str) | ||||
| else: | else: | ||||
| lr_str = f'lr: {log_dict["lr"]:.3e}' | |||||
| lr_str = f'{lr_key}: {log_dict[lr_key]:.3e}' | |||||
| if self.by_epoch: | if self.by_epoch: | ||||
| log_str = f'Epoch [{log_dict["epoch"]}][{log_dict["iter"]}/{len(trainer.data_loader)}]\t' | |||||
| log_str = f'{epoch_key} [{log_dict[epoch_key]}][{log_dict[iter_key]}/{len(trainer.data_loader)}]\t' | |||||
| else: | else: | ||||
| log_str = f'Iter [{log_dict["iter"]}/{trainer.max_iters}]\t' | |||||
| log_str = f'{iter_key} [{log_dict[iter_key]}/{trainer.max_iters}]\t' | |||||
| log_str += f'{lr_str}, ' | log_str += f'{lr_str}, ' | ||||
| self._logged_keys.extend(['lr', 'mode', 'iter', 'epoch']) | |||||
| self._logged_keys.extend([lr_key, mode_key, iter_key, epoch_key]) | |||||
| if 'time' in log_dict.keys(): | |||||
| self.time_sec_tot += (log_dict['time'] * self.interval) | |||||
| if iter_time_key in log_dict.keys(): | |||||
| self.time_sec_tot += (log_dict[iter_time_key] * self.interval) | |||||
| time_sec_avg = self.time_sec_tot / ( | time_sec_avg = self.time_sec_tot / ( | ||||
| trainer.iter - self.start_iter + 1) | trainer.iter - self.start_iter + 1) | ||||
| eta_sec = time_sec_avg * (trainer.max_iters - trainer.iter - 1) | eta_sec = time_sec_avg * (trainer.max_iters - trainer.iter - 1) | ||||
| eta_str = str(datetime.timedelta(seconds=int(eta_sec))) | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) | ||||
| log_str += f'eta: {eta_str}, ' | |||||
| log_str += f'time: {log_dict["time"]:.3f}, data_load_time: {log_dict["data_load_time"]:.3f}, ' | |||||
| log_str += f'{eta_key}: {eta_str}, ' | |||||
| log_str += f'{iter_time_key}: {log_dict[iter_time_key]:.3f}, ' | |||||
| log_str += f'{data_load_time_key}: {log_dict[data_load_time_key]:.3f}, ' | |||||
| self._logged_keys.extend([ | self._logged_keys.extend([ | ||||
| 'time', | |||||
| 'data_load_time', | |||||
| iter_time_key, | |||||
| data_load_time_key, | |||||
| ]) | ]) | ||||
| else: | else: | ||||
| # val/test time | # val/test time | ||||
| # here 1000 is the length of the val dataloader | # here 1000 is the length of the val dataloader | ||||
| # by epoch: Epoch[val] [4][1000] | |||||
| # by iter: Iter[val] [1000] | |||||
| # by epoch: epoch[val] [4][1000] | |||||
| # by iter: iter[val] [1000] | |||||
| if self.by_epoch: | if self.by_epoch: | ||||
| log_str = f'Epoch({log_dict["mode"]}) [{log_dict["epoch"]}][{log_dict["iter"]}]\t' | |||||
| log_str = f'{epoch_key}({log_dict[mode_key]}) [{log_dict[epoch_key]}][{log_dict[iter_key]}]\t' | |||||
| else: | else: | ||||
| log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t' | |||||
| self._logged_keys.extend(['mode', 'iter', 'epoch']) | |||||
| log_str = f'{iter_key}({log_dict[mode_key]}) [{log_dict[iter_key]}]\t' | |||||
| self._logged_keys.extend([mode_key, iter_key, epoch_key]) | |||||
| log_items = [] | log_items = [] | ||||
| for name, val in log_dict.items(): | for name, val in log_dict.items(): | ||||
| @@ -150,7 +160,7 @@ class TextLoggerHook(LoggerHook): | |||||
| # statistic memory | # statistic memory | ||||
| if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
| log_dict['memory'] = self._get_max_memory(trainer) | |||||
| log_dict[LogKeys.MEMORY] = self._get_max_memory(trainer) | |||||
| log_dict = dict(log_dict, **trainer.log_buffer.output) | log_dict = dict(log_dict, **trainer.log_buffer.output) | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | ||||
| from modelscope.utils.constant import LogKeys | |||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .priority import Priority | from .priority import Priority | ||||
| @@ -46,7 +47,7 @@ class LrSchedulerHook(Hook): | |||||
| return lr | return lr | ||||
| def before_train_iter(self, trainer): | def before_train_iter(self, trainer): | ||||
| trainer.log_buffer.output['lr'] = self._get_log_lr(trainer) | |||||
| trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||||
| def before_train_epoch(self, trainer): | def before_train_epoch(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| @@ -54,7 +55,7 @@ class LrSchedulerHook(Hook): | |||||
| self.warmup_lr_scheduler.step() | self.warmup_lr_scheduler.step() | ||||
| else: | else: | ||||
| trainer.lr_scheduler.step() | trainer.lr_scheduler.step() | ||||
| trainer.log_buffer.output['lr'] = self._get_log_lr(trainer) | |||||
| trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||||
| def _get_log_lr(self, trainer): | def _get_log_lr(self, trainer): | ||||
| cur_lr = self.get_current_lr(trainer) | cur_lr = self.get_current_lr(trainer) | ||||
| @@ -1,4 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import logging | |||||
| from torch.nn.utils import clip_grad | from torch.nn.utils import clip_grad | ||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| @@ -8,14 +10,28 @@ from .priority import Priority | |||||
| @HOOKS.register_module() | @HOOKS.register_module() | ||||
| class OptimizerHook(Hook): | class OptimizerHook(Hook): | ||||
| """Optimizer hook | |||||
| Args: | |||||
| cumulative_iters (int): interval of gradients accumulation. Default: 1 | |||||
| grad_clip (dict): Default None. Containing keys: | |||||
| max_norm (float or int): max norm of the gradients | |||||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. | |||||
| More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_` | |||||
| loss_keys (str | list): keys list of loss | |||||
| """ | |||||
| PRIORITY = Priority.ABOVE_NORMAL | PRIORITY = Priority.ABOVE_NORMAL | ||||
| def __init__(self, grad_clip=None, loss_keys='loss') -> None: | |||||
| def __init__(self, | |||||
| cumulative_iters=1, | |||||
| grad_clip=None, | |||||
| loss_keys='loss') -> None: | |||||
| if isinstance(loss_keys, str): | if isinstance(loss_keys, str): | ||||
| loss_keys = [loss_keys] | loss_keys = [loss_keys] | ||||
| assert isinstance(loss_keys, (tuple, list)) | assert isinstance(loss_keys, (tuple, list)) | ||||
| self.loss_keys = loss_keys | self.loss_keys = loss_keys | ||||
| self.cumulative_iters = cumulative_iters | |||||
| self.grad_clip = grad_clip | self.grad_clip = grad_clip | ||||
| def clip_grads(self, params, **clip_args): | def clip_grads(self, params, **clip_args): | ||||
| @@ -24,14 +40,163 @@ class OptimizerHook(Hook): | |||||
| if len(params) > 0: | if len(params) > 0: | ||||
| return clip_grad.clip_grad_norm_(params, **clip_args) | return clip_grad.clip_grad_norm_(params, **clip_args) | ||||
| def after_train_iter(self, trainer): | |||||
| def before_run(self, trainer): | |||||
| trainer.optimizer.zero_grad() | trainer.optimizer.zero_grad() | ||||
| def after_train_iter(self, trainer): | |||||
| for k in self.loss_keys: | for k in self.loss_keys: | ||||
| trainer.train_outputs[k] /= self.cumulative_iters | |||||
| trainer.train_outputs[k].backward() | trainer.train_outputs[k].backward() | ||||
| clip_args = self.grad_clip | |||||
| if clip_args is not None: | |||||
| self.clip_grads(trainer.model.parameters(), **clip_args) | |||||
| if self.every_n_iters(trainer, self.cumulative_iters): | |||||
| if self.grad_clip is not None: | |||||
| self.clip_grads(trainer.model.parameters(), **self.grad_clip) | |||||
| trainer.optimizer.step() | |||||
| trainer.optimizer.zero_grad() | |||||
| @HOOKS.register_module() | |||||
| class TorchAMPOptimizerHook(OptimizerHook): | |||||
| """Fp16 optimizer, if torch version is less than 1.6.0, | |||||
| you must install apex (https://www.github.com/nvidia/apex) else use torch.cuda.amp by default | |||||
| Args: | |||||
| cumulative_iters (int): interval of gradients accumulation. Default: 1 | |||||
| grad_clip (dict): Default None. Containing keys: | |||||
| max_norm (float or int): max norm of the gradients | |||||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. | |||||
| More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_` | |||||
| loss_keys (str | list): keys list of loss | |||||
| loss_scale (float | dict): grade scale config. If loss_scale is a float, | |||||
| static loss scaling will be used with the specified scale. | |||||
| It can also be a dict containing arguments of GradScalar. For Pytorch >= 1.6, | |||||
| we use official torch.cuda.amp.GradScaler. | |||||
| please refer to: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler for the parameters. | |||||
| """ | |||||
| def __init__(self, | |||||
| cumulative_iters=1, | |||||
| grad_clip=None, | |||||
| loss_keys='loss', | |||||
| loss_scale={}): | |||||
| super(TorchAMPOptimizerHook, self).__init__( | |||||
| grad_clip=grad_clip, loss_keys=loss_keys) | |||||
| self.cumulative_iters = cumulative_iters | |||||
| self._scale_update_param = None | |||||
| from torch.cuda import amp | |||||
| if isinstance(loss_scale, float): | |||||
| self._scale_update_param = loss_scale | |||||
| self.scaler = amp.GradScaler(init_scale=loss_scale) | |||||
| elif isinstance(loss_scale, dict): | |||||
| self.scaler = amp.GradScaler(**loss_scale) | |||||
| else: | |||||
| raise ValueError( | |||||
| '`loss_scale` type must be in [float, dict], but got {loss_scale}' | |||||
| ) | |||||
| def before_run(self, trainer): | |||||
| logging.info('open fp16') | |||||
| trainer.optimizer.zero_grad() | |||||
| if hasattr(trainer.model, 'module'): | |||||
| self._ori_model_forward = trainer.model.module.forward | |||||
| self._model = trainer.model.module | |||||
| else: | |||||
| self._ori_model_forward = trainer.model.forward | |||||
| self._model = trainer.model | |||||
| self.ori_model_forward = trainer.model.forward | |||||
| def before_train_iter(self, trainer): | |||||
| from torch.cuda import amp | |||||
| setattr(self._model, 'forward', amp.autocast()(self._model.forward)) | |||||
| def after_train_iter(self, trainer): | |||||
| for k in self.loss_keys: | |||||
| trainer.train_outputs[k] /= self.cumulative_iters | |||||
| for k in self.loss_keys: | |||||
| self.scaler.scale(trainer.train_outputs[k]).backward() | |||||
| if self.every_n_iters(trainer, self.cumulative_iters): | |||||
| self.scaler.unscale_(trainer.optimizer) | |||||
| if self.grad_clip is not None: | |||||
| self.clip_grads(trainer.model.parameters(), **self.grad_clip) | |||||
| self.scaler.step(trainer.optimizer) | |||||
| self.scaler.update(self._scale_update_param) | |||||
| trainer.optimizer.zero_grad() | |||||
| setattr(self._model, 'forward', self._ori_model_forward) | |||||
| @HOOKS.register_module() | |||||
| class ApexAMPOptimizerHook(OptimizerHook): | |||||
| """Fp16 optimizer, if torch version is less than 1.6.0, | |||||
| you must install apex (https://www.github.com/nvidia/apex) else use torch.cuda.amp by default | |||||
| Args: | |||||
| cumulative_iters (int): interval of gradients accumulation. Default: 1 | |||||
| grad_clip (dict): Default None. Containing keys: | |||||
| max_norm (float or int): max norm of the gradients | |||||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. | |||||
| More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_` | |||||
| loss_keys (str | list): keys list of loss | |||||
| opt_level (str): "O0" and "O3" are not true mixed precision, | |||||
| but they are useful for establishing accuracy and speed baselines, respectively. | |||||
| "O1" and "O2" are different implementations of mixed precision. | |||||
| Try both, and see what gives the best speedup and accuracy for your model. | |||||
| """ | |||||
| def __init__(self, | |||||
| cumulative_iters=1, | |||||
| grad_clip=None, | |||||
| loss_keys='loss', | |||||
| opt_level='O1'): | |||||
| super(ApexAMPOptimizerHook, self).__init__( | |||||
| grad_clip=grad_clip, loss_keys=loss_keys) | |||||
| self.cumulative_iters = cumulative_iters | |||||
| self.opt_level = opt_level | |||||
| try: | |||||
| from apex import amp | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| 'apex not installed, please install apex from https://www.github.com/nvidia/apex.' | |||||
| ) | |||||
| def before_run(self, trainer): | |||||
| from apex import amp | |||||
| logging.info('open fp16') | |||||
| # TODO: fix it should initialze amp with model not wrapper by DDP or DP | |||||
| if hasattr(trainer.model, 'module'): | |||||
| trainer.model, trainer.optimizer = amp.initialize( | |||||
| trainer.model.module, | |||||
| trainer.optimizer, | |||||
| opt_level=self.opt_level) | |||||
| else: | |||||
| trainer.model, trainer.optimizer = amp.initialize( | |||||
| trainer.model, trainer.optimizer, opt_level=self.opt_level) | |||||
| trainer.optimizer.zero_grad() | |||||
| def after_train_iter(self, trainer): | |||||
| for k in self.loss_keys: | |||||
| trainer.train_outputs[k] /= self.cumulative_iters | |||||
| from apex import amp | |||||
| for k in self.loss_keys: | |||||
| with amp.scale_loss(trainer.train_outputs[k], | |||||
| trainer.optimizer) as scaled_loss: | |||||
| scaled_loss.backward() | |||||
| if self.every_n_iters(trainer, self.cumulative_iters): | |||||
| if self.grad_clip is not None: | |||||
| self.clip_grads(trainer.model.parameters(), **self.grad_clip) | |||||
| trainer.optimizer.step() | |||||
| trainer.optimizer.step() | |||||
| trainer.optimizer.zero_grad() | |||||
| @@ -26,14 +26,16 @@ from modelscope.trainers.hooks.builder import HOOKS | |||||
| from modelscope.trainers.hooks.priority import Priority, get_priority | from modelscope.trainers.hooks.priority import Priority, get_priority | ||||
| from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | ||||
| from modelscope.trainers.optimizer.builder import build_optimizer | from modelscope.trainers.optimizer.builder import build_optimizer | ||||
| from modelscope.utils.config import ConfigDict | |||||
| from modelscope.utils.constant import Hubs, ModelFile, Tasks | |||||
| from modelscope.utils.config import Config, ConfigDict | |||||
| from modelscope.utils.constant import (Hubs, ModeKeys, ModelFile, Tasks, | |||||
| TrainerStages) | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.registry import build_from_cfg | from modelscope.utils.registry import build_from_cfg | ||||
| from modelscope.utils.tensor_utils import torch_default_data_collator | from modelscope.utils.tensor_utils import torch_default_data_collator | ||||
| from modelscope.utils.torch_utils import get_dist_info | from modelscope.utils.torch_utils import get_dist_info | ||||
| from .base import BaseTrainer | from .base import BaseTrainer | ||||
| from .builder import TRAINERS | from .builder import TRAINERS | ||||
| from .default_config import DEFAULT_CONFIG | |||||
| from .hooks.hook import Hook | from .hooks.hook import Hook | ||||
| @@ -97,6 +99,10 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.model = model | self.model = model | ||||
| super().__init__(cfg_file, arg_parse_fn) | super().__init__(cfg_file, arg_parse_fn) | ||||
| # add default config | |||||
| self.cfg.merge_from_dict(self._get_default_config(), force=False) | |||||
| if 'work_dir' in kwargs: | if 'work_dir' in kwargs: | ||||
| self.work_dir = kwargs['work_dir'] | self.work_dir = kwargs['work_dir'] | ||||
| else: | else: | ||||
| @@ -112,14 +118,14 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.device = int( | self.device = int( | ||||
| os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else None | os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else None | ||||
| self.train_dataset = self.to_task_dataset( | self.train_dataset = self.to_task_dataset( | ||||
| train_dataset, mode='train', preprocessor=self.preprocessor) | |||||
| train_dataset, mode=ModeKeys.TRAIN, preprocessor=self.preprocessor) | |||||
| self.eval_dataset = self.to_task_dataset( | self.eval_dataset = self.to_task_dataset( | ||||
| eval_dataset, mode='eval', preprocessor=self.preprocessor) | |||||
| eval_dataset, mode=ModeKeys.EVAL, preprocessor=self.preprocessor) | |||||
| self.data_collator = data_collator if data_collator is not None else torch_default_data_collator | self.data_collator = data_collator if data_collator is not None else torch_default_data_collator | ||||
| self.metrics = self.get_metrics() | self.metrics = self.get_metrics() | ||||
| self.optimizers = optimizers | self.optimizers = optimizers | ||||
| self.logger = get_logger(log_level=self.cfg.get('log_level', 'INFO')) | self.logger = get_logger(log_level=self.cfg.get('log_level', 'INFO')) | ||||
| self._mode = 'train' | |||||
| self._mode = ModeKeys.TRAIN | |||||
| self._hooks: List[Hook] = [] | self._hooks: List[Hook] = [] | ||||
| self._epoch = 0 | self._epoch = 0 | ||||
| self._iter = 0 | self._iter = 0 | ||||
| @@ -132,6 +138,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| else: | else: | ||||
| self._max_epochs = kwargs['max_epochs'] | self._max_epochs = kwargs['max_epochs'] | ||||
| self.use_fp16 = kwargs.get('use_fp16', False) | |||||
| # TODO @wenmeng.zwm add seed init fn | # TODO @wenmeng.zwm add seed init fn | ||||
| self._seed = 0 | self._seed = 0 | ||||
| @@ -245,7 +253,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| def train(self, *args, **kwargs): | def train(self, *args, **kwargs): | ||||
| self.model.train() | self.model.train() | ||||
| self._mode = 'train' | |||||
| self._mode = ModeKeys.TRAIN | |||||
| if self.train_dataset is None: | if self.train_dataset is None: | ||||
| self.train_dataloader = self.get_train_dataloader() | self.train_dataloader = self.get_train_dataloader() | ||||
| @@ -261,7 +269,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| def evaluate(self, checkpoint_path=None): | def evaluate(self, checkpoint_path=None): | ||||
| self.model.eval() | self.model.eval() | ||||
| self._mode = 'val' | |||||
| self._mode = ModeKeys.EVAL | |||||
| if self.eval_dataset is None: | if self.eval_dataset is None: | ||||
| self.eval_dataloader = self.get_eval_data_loader() | self.eval_dataloader = self.get_eval_data_loader() | ||||
| @@ -329,7 +337,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| # EvaluationHook will do evaluate and change mode to val, return to train mode | # EvaluationHook will do evaluate and change mode to val, return to train mode | ||||
| # TODO: find more pretty way to change mode | # TODO: find more pretty way to change mode | ||||
| model.train() | model.train() | ||||
| self._mode = 'train' | |||||
| self._mode = ModeKeys.TRAIN | |||||
| inputs = self.collate_fn(inputs) | inputs = self.collate_fn(inputs) | ||||
| if isinstance(inputs, dict): | if isinstance(inputs, dict): | ||||
| train_outputs = model.forward(**inputs) | train_outputs = model.forward(**inputs) | ||||
| @@ -394,7 +402,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| """ | """ | ||||
| train_data = self.cfg.dataset.train | train_data = self.cfg.dataset.train | ||||
| if self.train_dataset is None: | if self.train_dataset is None: | ||||
| self.train_dataset = self.build_dataset(train_data, mode='train') | |||||
| self.train_dataset = self.build_dataset( | |||||
| train_data, mode=ModeKeys.TRAIN) | |||||
| data_loader = self._build_dataloader_with_dataset( | data_loader = self._build_dataloader_with_dataset( | ||||
| self.train_dataset, **self.cfg.train.get('dataloader', {})) | self.train_dataset, **self.cfg.train.get('dataloader', {})) | ||||
| @@ -409,7 +418,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| """ | """ | ||||
| val_data = self.cfg.dataset.val | val_data = self.cfg.dataset.val | ||||
| if self.eval_dataset is None: | if self.eval_dataset is None: | ||||
| self.eval_dataset = self.build_dataset(val_data, mode='eval') | |||||
| self.eval_dataset = self.build_dataset( | |||||
| val_data, mode=ModeKeys.TRAIN) | |||||
| batch_size = self.cfg.evaluation.batch_size | batch_size = self.cfg.evaluation.batch_size | ||||
| workers = self.cfg.evaluation.workers | workers = self.cfg.evaluation.workers | ||||
| @@ -492,7 +502,10 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| _, _, optim_options, lr_options = self.create_optimizer_and_scheduler() | _, _, optim_options, lr_options = self.create_optimizer_and_scheduler() | ||||
| lr_hook = dict(type='LrSchedulerHook', **lr_options) | lr_hook = dict(type='LrSchedulerHook', **lr_options) | ||||
| optim_hook = dict(type='OptimizerHook', **optim_options) | |||||
| if self.use_fp16: | |||||
| optim_hook = dict(type='TorchAMPOptimizerHook', **optim_options) | |||||
| else: | |||||
| optim_hook = dict(type='OptimizerHook', **optim_options) | |||||
| self.register_hook_from_cfg([lr_hook, optim_hook]) | self.register_hook_from_cfg([lr_hook, optim_hook]) | ||||
| @@ -578,26 +591,26 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| def train_loop(self, data_loader): | def train_loop(self, data_loader): | ||||
| """ Training loop used by `EpochBasedTrainer.train()` | """ Training loop used by `EpochBasedTrainer.train()` | ||||
| """ | """ | ||||
| self.invoke_hook('before_run') | |||||
| self.invoke_hook(TrainerStages.before_run) | |||||
| self._epoch = 0 | self._epoch = 0 | ||||
| kwargs = {} | kwargs = {} | ||||
| for _ in range(self._epoch, self._max_epochs): | for _ in range(self._epoch, self._max_epochs): | ||||
| self.invoke_hook('before_train_epoch') | |||||
| self.invoke_hook(TrainerStages.before_train_epoch) | |||||
| time.sleep(2) # Prevent possible deadlock during epoch transition | time.sleep(2) # Prevent possible deadlock during epoch transition | ||||
| for i, data_batch in enumerate(data_loader): | for i, data_batch in enumerate(data_loader): | ||||
| self.data_batch = data_batch | self.data_batch = data_batch | ||||
| self._inner_iter = i | self._inner_iter = i | ||||
| self.invoke_hook('before_train_iter') | |||||
| self.invoke_hook(TrainerStages.before_train_iter) | |||||
| self.train_step(self.model, data_batch, **kwargs) | self.train_step(self.model, data_batch, **kwargs) | ||||
| self.invoke_hook('after_train_iter') | |||||
| self.invoke_hook(TrainerStages.after_train_iter) | |||||
| del self.data_batch | del self.data_batch | ||||
| self._iter += 1 | self._iter += 1 | ||||
| self.invoke_hook('after_train_epoch') | |||||
| self.invoke_hook(TrainerStages.after_train_epoch) | |||||
| self._epoch += 1 | self._epoch += 1 | ||||
| time.sleep(1) # wait for some hooks like loggers to finish | time.sleep(1) # wait for some hooks like loggers to finish | ||||
| self.invoke_hook('after_run') | |||||
| self.invoke_hook(TrainerStages.after_run) | |||||
| def evaluation_loop(self, data_loader, checkpoint_path, metric_classes): | def evaluation_loop(self, data_loader, checkpoint_path, metric_classes): | ||||
| """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | ||||
| @@ -693,6 +706,9 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| stage_hook_infos.append(info) | stage_hook_infos.append(info) | ||||
| return '\n'.join(stage_hook_infos) | return '\n'.join(stage_hook_infos) | ||||
| def _get_default_config(self): | |||||
| return DEFAULT_CONFIG | |||||
| def worker_init_fn(worker_id, num_workers, rank, seed): | def worker_init_fn(worker_id, num_workers, rank, seed): | ||||
| # The seed of each worker equals to | # The seed of each worker equals to | ||||
| @@ -20,9 +20,9 @@ def single_gpu_test(model, | |||||
| """Test model with a single gpu. | """Test model with a single gpu. | ||||
| Args: | Args: | ||||
| data_collate_fn: An optional data_collate_fn before fed into the model | |||||
| model (nn.Module): Model to be tested. | model (nn.Module): Model to be tested. | ||||
| data_loader (nn.Dataloader): Pytorch data loader. | data_loader (nn.Dataloader): Pytorch data loader. | ||||
| data_collate_fn: An optional data_collate_fn before fed into the model | |||||
| metric_classes(List): List of Metric class that uses to collect metrics | metric_classes(List): List of Metric class that uses to collect metrics | ||||
| Returns: | Returns: | ||||
| @@ -62,10 +62,10 @@ def multi_gpu_test(model, | |||||
| Args: | Args: | ||||
| model (nn.Module): Model to be tested. | model (nn.Module): Model to be tested. | ||||
| data_loader (nn.Dataloader): Pytorch data loader. | data_loader (nn.Dataloader): Pytorch data loader. | ||||
| data_collate_fn: An optional data_collate_fn before fed into the model | |||||
| tmpdir (str): Path of directory to save the temporary results from | tmpdir (str): Path of directory to save the temporary results from | ||||
| different gpus under cpu mode. | different gpus under cpu mode. | ||||
| gpu_collect (bool): Option to use either gpu or cpu to collect results. | gpu_collect (bool): Option to use either gpu or cpu to collect results. | ||||
| data_collate_fn: An optional data_collate_fn before fed into the model | |||||
| metric_classes(List): List of Metric class that uses to collect metrics | metric_classes(List): List of Metric class that uses to collect metrics | ||||
| Returns: | Returns: | ||||
| @@ -1,6 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import ast | |||||
| import copy | import copy | ||||
| import os | import os | ||||
| import os.path as osp | import os.path as osp | ||||
| @@ -9,24 +8,15 @@ import shutil | |||||
| import sys | import sys | ||||
| import tempfile | import tempfile | ||||
| import types | import types | ||||
| import uuid | |||||
| from importlib import import_module | |||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Dict | from typing import Dict | ||||
| import addict | import addict | ||||
| from yapf.yapflib.yapf_api import FormatCode | from yapf.yapflib.yapf_api import FormatCode | ||||
| from modelscope.utils.import_utils import (import_modules, | |||||
| import_modules_from_file, | |||||
| validate_py_syntax) | |||||
| from modelscope.utils.import_utils import import_modules_from_file | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| if platform.system() == 'Windows': | |||||
| import regex as re # type: ignore | |||||
| else: | |||||
| import re # type: ignore | |||||
| logger = get_logger() | logger = get_logger() | ||||
| BASE_KEY = '_base_' | BASE_KEY = '_base_' | ||||
| @@ -380,8 +370,8 @@ class Config: | |||||
| file_format = file.split('.')[-1] | file_format = file.split('.')[-1] | ||||
| return dump(cfg_dict, file=file, file_format=file_format) | return dump(cfg_dict, file=file, file_format=file_format) | ||||
| def merge_from_dict(self, options, allow_list_keys=True): | |||||
| """Merge list into cfg_dict. | |||||
| def merge_from_dict(self, options, allow_list_keys=True, force=True): | |||||
| """Merge dict into cfg_dict. | |||||
| Merge the dict parsed by MultipleKVAction into this cfg. | Merge the dict parsed by MultipleKVAction into this cfg. | ||||
| @@ -392,9 +382,9 @@ class Config: | |||||
| >>> cfg.merge_from_dict(options) | >>> cfg.merge_from_dict(options) | ||||
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | ||||
| >>> assert cfg_dict == dict( | >>> assert cfg_dict == dict( | ||||
| ... model=dict(backbone=dict(depth=50, with_cp=True))) | |||||
| ... model=dict(backbone=dict(type='ResNet', depth=50, with_cp=True))) | |||||
| >>> # Merge list element | |||||
| >>> # Merge list element for replace target index | |||||
| >>> cfg = Config(dict(pipeline=[ | >>> cfg = Config(dict(pipeline=[ | ||||
| ... dict(type='Resize'), dict(type='RandomDistortion')])) | ... dict(type='Resize'), dict(type='RandomDistortion')])) | ||||
| >>> options = dict(pipeline={'0': dict(type='MyResize')}) | >>> options = dict(pipeline={'0': dict(type='MyResize')}) | ||||
| @@ -403,12 +393,38 @@ class Config: | |||||
| >>> assert cfg_dict == dict(pipeline=[ | >>> assert cfg_dict == dict(pipeline=[ | ||||
| ... dict(type='MyResize'), dict(type='RandomDistortion')]) | ... dict(type='MyResize'), dict(type='RandomDistortion')]) | ||||
| >>> # Merge list element for replace args and add to list, only support list of type dict with key ``type``, | |||||
| >>> # if you add new list element, the list does not guarantee the order, | |||||
| >>> # it is only suitable for the case where the order of the list is not concerned. | |||||
| >>> cfg = Config(dict(pipeline=[ | |||||
| ... dict(type='Resize', size=224), dict(type='RandomDistortion')])) | |||||
| >>> options = dict(pipeline=[dict(type='Resize', size=256), dict(type='RandomFlip')]) | |||||
| >>> cfg.merge_from_dict(options, allow_list_keys=True) | |||||
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||||
| >>> assert cfg_dict == dict(pipeline=[ | |||||
| ... dict(type='Resize', size=256), dict(type='RandomDistortion'), dict(type='RandomFlip')]) | |||||
| >>> # force usage | |||||
| >>> options = {'model.backbone.depth': 18, | |||||
| ... 'model.backbone.with_cp':True} | |||||
| >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet', depth=50)))) | |||||
| >>> cfg.merge_from_dict(options, force=False) | |||||
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |||||
| >>> assert cfg_dict == dict( | |||||
| ... model=dict(backbone=dict(type='ResNet', depth=50, with_cp=True))) | |||||
| Args: | Args: | ||||
| options (dict): dict of configs to merge from. | options (dict): dict of configs to merge from. | ||||
| allow_list_keys (bool): If True, int string keys (e.g. '0', '1') | allow_list_keys (bool): If True, int string keys (e.g. '0', '1') | ||||
| are allowed in ``options`` and will replace the element of the | are allowed in ``options`` and will replace the element of the | ||||
| corresponding index in the config if the config is a list. | corresponding index in the config if the config is a list. | ||||
| Or you can directly replace args for list or add new list element, | |||||
| only support list of type dict with key ``type``, | |||||
| but if you add new list element, the list does not guarantee the order, | |||||
| It is only suitable for the case where the order of the list is not concerned. | |||||
| Default: True. | Default: True. | ||||
| force (bool): If True, existing key-value will be replaced by new given. | |||||
| If False, existing key-value will not be updated. | |||||
| """ | """ | ||||
| option_cfg_dict = {} | option_cfg_dict = {} | ||||
| for full_key, v in options.items(): | for full_key, v in options.items(): | ||||
| @@ -424,7 +440,122 @@ class Config: | |||||
| super(Config, self).__setattr__( | super(Config, self).__setattr__( | ||||
| '_cfg_dict', | '_cfg_dict', | ||||
| Config._merge_a_into_b( | Config._merge_a_into_b( | ||||
| option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) | |||||
| option_cfg_dict, | |||||
| cfg_dict, | |||||
| allow_list_keys=allow_list_keys, | |||||
| force=force)) | |||||
| @staticmethod | |||||
| def _merge_a_into_b(a, b, allow_list_keys=False, force=True): | |||||
| """merge dict ``a`` into dict ``b`` (non-inplace). | |||||
| Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid | |||||
| in-place modifications. | |||||
| Args: | |||||
| a (dict): The source dict to be merged into ``b``. | |||||
| b (dict): The origin dict to be fetch keys from ``a``. | |||||
| allow_list_keys (bool): If True, int string keys (e.g. '0', '1') | |||||
| are allowed in source ``a`` and will replace the element of the | |||||
| corresponding index in b if b is a list. Default: False. | |||||
| force (bool): If True, existing key-value will be replaced by new given. | |||||
| If False, existing key-value will not be updated. | |||||
| Returns: | |||||
| dict: The modified dict of ``b`` using ``a``. | |||||
| Examples: | |||||
| # Normally merge a into b. | |||||
| >>> Config._merge_a_into_b( | |||||
| ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) | |||||
| {'obj': {'a': 2}} | |||||
| # Delete b first and merge a into b. | |||||
| >>> Config._merge_a_into_b( | |||||
| ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) | |||||
| {'obj': {'a': 2}} | |||||
| # b is a list | |||||
| >>> Config._merge_a_into_b( | |||||
| ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) | |||||
| [{'a': 2}, {'b': 2}] | |||||
| # value of a and b are both list, only support list of type dict with key ``type``, | |||||
| # You can directly replace args for list or add new list element, | |||||
| # but if you add new list element, the list does not guarantee the order, | |||||
| # it is only suitable for the case where the order of the list is not concerned. | |||||
| >>> Config._merge_a_into_b( | |||||
| ... {'k': [dict(a=2), dict(c=3)]}, {'k': [dict(a=1), dict(b=2)]}, True) | |||||
| {'k': [dict(a=2), dict(b=2), dict(c=3)]} | |||||
| # force is False | |||||
| >>> Config._merge_a_into_b( | |||||
| ... dict(obj=dict(a=2, b=2)), dict(obj=dict(a=1))), True, force=False) | |||||
| {'obj': {'a': 1, b=2}} | |||||
| """ | |||||
| b = b.copy() | |||||
| for k, v in a.items(): | |||||
| if allow_list_keys and k.isdigit() and isinstance(b, list): | |||||
| k = int(k) | |||||
| if len(b) <= k: | |||||
| raise KeyError(f'Index {k} exceeds the length of list {b}') | |||||
| b[k] = Config._merge_a_into_b( | |||||
| v, b[k], allow_list_keys, force=force) | |||||
| elif allow_list_keys and isinstance(v, list) and k in b: | |||||
| if not isinstance(b[k], list): | |||||
| raise ValueError( | |||||
| f'type mismatch {type(v)} and {type(b[k])} between a and b for key {k}' | |||||
| ) | |||||
| _is_dict_with_type = True | |||||
| for list_i in b[k] + v: | |||||
| if not isinstance(list_i, dict) or 'type' not in list_i: | |||||
| if k not in b or force: | |||||
| b[k] = v | |||||
| _is_dict_with_type = False | |||||
| if _is_dict_with_type: | |||||
| res_list = [] | |||||
| added_index_bk, added_index_v = [], [] | |||||
| for i, b_li in enumerate(b[k]): | |||||
| for j, a_lj in enumerate(v): | |||||
| if a_lj['type'] == b_li['type']: | |||||
| res_list.append( | |||||
| Config._merge_a_into_b( | |||||
| a_lj, | |||||
| b_li, | |||||
| allow_list_keys, | |||||
| force=force)) | |||||
| added_index_v.append(j) | |||||
| added_index_bk.append(i) | |||||
| break | |||||
| rest_bk = [ | |||||
| b[k][i] for i in range(len(b[k])) | |||||
| if i not in added_index_bk | |||||
| ] | |||||
| rest_v = [ | |||||
| v[i] for i in range(len(v)) if i not in added_index_v | |||||
| ] | |||||
| rest = rest_bk + rest_v | |||||
| res_list += [ | |||||
| Config._merge_a_into_b( | |||||
| rest[i], {}, allow_list_keys, force=force) | |||||
| for i in range(len(rest)) | |||||
| ] | |||||
| b[k] = res_list | |||||
| elif isinstance(v, | |||||
| dict) and k in b and not v.pop(DELETE_KEY, False): | |||||
| allowed_types = (dict, list) if allow_list_keys else dict | |||||
| if not isinstance(b[k], allowed_types): | |||||
| raise TypeError( | |||||
| f'{k}={v} in child config cannot inherit from base ' | |||||
| f'because {k} is a dict in the child config but is of ' | |||||
| f'type {type(b[k])} in base config. You may set ' | |||||
| f'`{DELETE_KEY}=True` to ignore the base config') | |||||
| b[k] = Config._merge_a_into_b( | |||||
| v, b[k], allow_list_keys, force=force) | |||||
| else: | |||||
| if k not in b or force: | |||||
| b[k] = v | |||||
| return b | |||||
| def to_dict(self) -> Dict: | def to_dict(self) -> Dict: | ||||
| """ Convert Config object to python dict | """ Convert Config object to python dict | ||||
| @@ -163,3 +163,33 @@ PYTORCH = 'pytorch' | |||||
| DEFAULT_MODEL_REVISION = 'master' | DEFAULT_MODEL_REVISION = 'master' | ||||
| DEFAULT_DATASET_REVISION = 'master' | DEFAULT_DATASET_REVISION = 'master' | ||||
| class ModeKeys: | |||||
| TRAIN = 'train' | |||||
| EVAL = 'eval' | |||||
| class LogKeys: | |||||
| ITER = 'iter' | |||||
| ITER_TIME = 'iter_time' | |||||
| EPOCH = 'epoch' | |||||
| LR = 'lr' # learning rate | |||||
| MODE = 'mode' | |||||
| DATA_LOAD_TIME = 'data_load_time' | |||||
| ETA = 'eta' # estimated time of arrival | |||||
| MEMORY = 'memory' | |||||
| LOSS = 'loss' | |||||
| class TrainerStages: | |||||
| before_run = 'before_run' | |||||
| before_train_epoch = 'before_train_epoch' | |||||
| before_train_iter = 'before_train_iter' | |||||
| after_train_iter = 'after_train_iter' | |||||
| after_train_epoch = 'after_train_epoch' | |||||
| before_val_epoch = 'before_val_epoch' | |||||
| before_val_iter = 'before_val_iter' | |||||
| after_val_iter = 'after_val_iter' | |||||
| after_val_epoch = 'after_val_epoch' | |||||
| after_run = 'after_run' | |||||
| @@ -0,0 +1,112 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import glob | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from abc import ABCMeta | |||||
| import json | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.utils.data import Dataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import LogKeys, ModelFile | |||||
| class DummyDataset(Dataset, metaclass=ABCMeta): | |||||
| def __len__(self): | |||||
| return 20 | |||||
| def __getitem__(self, idx): | |||||
| return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) | |||||
| class DummyModel(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.linear = nn.Linear(5, 4) | |||||
| self.bn = nn.BatchNorm1d(4) | |||||
| def forward(self, feat, labels): | |||||
| x = self.linear(feat) | |||||
| x = self.bn(x) | |||||
| loss = torch.sum(x) | |||||
| return dict(logits=x, loss=loss) | |||||
| class TensorboardHookTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| def test_tensorboard_hook(self): | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'optimizer': { | |||||
| 'type': 'SGD', | |||||
| 'lr': 0.01 | |||||
| }, | |||||
| 'lr_scheduler': { | |||||
| 'type': 'StepLR', | |||||
| 'step_size': 2, | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'TensorboardHook', | |||||
| 'interval': 2 | |||||
| }] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=DummyModel(), | |||||
| data_collator=None, | |||||
| train_dataset=DummyDataset(), | |||||
| max_epochs=2) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| tb_out_dir = os.path.join(self.tmp_dir, 'tensorboard_output') | |||||
| events_files = glob.glob( | |||||
| os.path.join(tb_out_dir, 'events.out.tfevents.*')) | |||||
| self.assertEqual(len(events_files), 1) | |||||
| from tensorboard.backend.event_processing.event_accumulator import EventAccumulator | |||||
| ea = EventAccumulator(events_files[0]) | |||||
| ea.Reload() | |||||
| self.assertEqual(len(ea.Scalars(LogKeys.LOSS)), 10) | |||||
| self.assertEqual(len(ea.Scalars(LogKeys.LR)), 10) | |||||
| for i in range(5): | |||||
| self.assertAlmostEqual( | |||||
| ea.Scalars(LogKeys.LR)[i].value, 0.01, delta=0.001) | |||||
| for i in range(5, 10): | |||||
| self.assertAlmostEqual( | |||||
| ea.Scalars(LogKeys.LR)[i].value, 0.001, delta=0.0001) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -11,7 +11,7 @@ from torch import nn | |||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.constant import LogKeys, ModelFile | |||||
| class DummyDataset(Dataset, metaclass=ABCMeta): | class DummyDataset(Dataset, metaclass=ABCMeta): | ||||
| @@ -100,8 +100,8 @@ class CheckpointHookTest(unittest.TestCase): | |||||
| trainer = build_trainer(trainer_name, kwargs) | trainer = build_trainer(trainer_name, kwargs) | ||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn('epoch_1.pth', results_files) | |||||
| self.assertIn('epoch_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -0,0 +1,195 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from abc import ABCMeta | |||||
| import json | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.utils.data import Dataset | |||||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import LogKeys, ModelFile | |||||
| from modelscope.utils.registry import default_group | |||||
| _global_iter = 0 | |||||
| @METRICS.register_module(group_key=default_group, module_name='DummyMetric') | |||||
| class DummyMetric: | |||||
| _fake_acc_by_epoch = {1: 0.1, 2: 0.5, 3: 0.2} | |||||
| def add(*args, **kwargs): | |||||
| pass | |||||
| def evaluate(self): | |||||
| global _global_iter | |||||
| _global_iter += 1 | |||||
| return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} | |||||
| class DummyDataset(Dataset, metaclass=ABCMeta): | |||||
| def __len__(self): | |||||
| return 20 | |||||
| def __getitem__(self, idx): | |||||
| return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, ))) | |||||
| class DummyModel(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.linear = nn.Linear(5, 4) | |||||
| self.bn = nn.BatchNorm1d(4) | |||||
| def forward(self, feat, labels): | |||||
| x = self.linear(feat) | |||||
| x = self.bn(x) | |||||
| loss = torch.sum(x) | |||||
| return dict(logits=x, loss=loss) | |||||
| class EvaluationHookTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| def test_best_ckpt_rule_max(self): | |||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': | |||||
| self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'optimizer': { | |||||
| 'type': 'SGD', | |||||
| 'lr': 0.01, | |||||
| }, | |||||
| 'lr_scheduler': { | |||||
| 'type': 'StepLR', | |||||
| 'step_size': 2, | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1, | |||||
| 'save_best_ckpt': True, | |||||
| 'monitor_key': MetricKeys.ACCURACY | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': ['DummyMetric'] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=DummyModel(), | |||||
| data_collator=None, | |||||
| train_dataset=DummyDataset(), | |||||
| eval_dataset=DummyDataset(), | |||||
| max_epochs=3) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| self.assertIn(f'best_{LogKeys.EPOCH}2_{MetricKeys.ACCURACY}0.5.pth', | |||||
| results_files) | |||||
| def test_best_ckpt_rule_min(self): | |||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': | |||||
| self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'optimizer': { | |||||
| 'type': 'SGD', | |||||
| 'lr': 0.01, | |||||
| }, | |||||
| 'lr_scheduler': { | |||||
| 'type': 'StepLR', | |||||
| 'step_size': 2, | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1, | |||||
| 'save_best_ckpt': True, | |||||
| 'monitor_key': 'accuracy', | |||||
| 'rule': 'min', | |||||
| 'out_dir': os.path.join(self.tmp_dir, 'best_ckpt') | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': ['DummyMetric'] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=DummyModel(), | |||||
| data_collator=None, | |||||
| train_dataset=DummyDataset(), | |||||
| eval_dataset=DummyDataset(), | |||||
| max_epochs=3) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | |||||
| os.listdir(os.path.join(self.tmp_dir, 'best_ckpt'))) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR | |||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | |||||
| class DummyDataset(Dataset, metaclass=ABCMeta): | class DummyDataset(Dataset, metaclass=ABCMeta): | ||||
| @@ -66,7 +66,7 @@ class LrSchedulerHookTest(unittest.TestCase): | |||||
| } | } | ||||
| } | } | ||||
| config_path = os.path.join(self.tmp_dir, 'config.json') | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | with open(config_path, 'w') as f: | ||||
| json.dump(json_cfg, f) | json.dump(json_cfg, f) | ||||
| @@ -86,23 +86,23 @@ class LrSchedulerHookTest(unittest.TestCase): | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | ||||
| trainer.register_optimizers_hook() | trainer.register_optimizers_hook() | ||||
| trainer.invoke_hook('before_run') | |||||
| trainer.invoke_hook(TrainerStages.before_run) | |||||
| log_lrs = [] | log_lrs = [] | ||||
| optim_lrs = [] | optim_lrs = [] | ||||
| for _ in range(trainer._epoch, trainer._max_epochs): | for _ in range(trainer._epoch, trainer._max_epochs): | ||||
| trainer.invoke_hook('before_train_epoch') | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for _, data_batch in enumerate(train_dataloader): | for _, data_batch in enumerate(train_dataloader): | ||||
| trainer.invoke_hook('before_train_iter') | |||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||||
| log_lrs.append(trainer.log_buffer.output['lr']) | |||||
| log_lrs.append(trainer.log_buffer.output[LogKeys.LR]) | |||||
| optim_lrs.append(optimizer.param_groups[0]['lr']) | optim_lrs.append(optimizer.param_groups[0]['lr']) | ||||
| trainer.train_step(trainer.model, data_batch) | trainer.train_step(trainer.model, data_batch) | ||||
| trainer.invoke_hook('after_train_iter') | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| trainer.invoke_hook('after_train_epoch') | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||||
| trainer._epoch += 1 | trainer._epoch += 1 | ||||
| trainer.invoke_hook('after_run') | |||||
| trainer.invoke_hook(TrainerStages.after_run) | |||||
| iters = 5 | iters = 5 | ||||
| target_lrs = [0.01] * iters * 1 + [0.001] * iters * 2 + [0.0001 | target_lrs = [0.01] * iters * 1 + [0.001] * iters * 2 + [0.0001 | ||||
| @@ -157,23 +157,23 @@ class LrSchedulerHookTest(unittest.TestCase): | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | ||||
| trainer.register_optimizers_hook() | trainer.register_optimizers_hook() | ||||
| trainer.invoke_hook('before_run') | |||||
| trainer.invoke_hook(TrainerStages.before_run) | |||||
| log_lrs = [] | log_lrs = [] | ||||
| optim_lrs = [] | optim_lrs = [] | ||||
| for _ in range(trainer._epoch, trainer._max_epochs): | for _ in range(trainer._epoch, trainer._max_epochs): | ||||
| trainer.invoke_hook('before_train_epoch') | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for _, data_batch in enumerate(train_dataloader): | for _, data_batch in enumerate(train_dataloader): | ||||
| trainer.invoke_hook('before_train_iter') | |||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||||
| log_lrs.append(round(trainer.log_buffer.output['lr'], 5)) | |||||
| log_lrs.append(round(trainer.log_buffer.output[LogKeys.LR], 5)) | |||||
| optim_lrs.append( | optim_lrs.append( | ||||
| round(trainer.optimizer.param_groups[0]['lr'], 5)) | round(trainer.optimizer.param_groups[0]['lr'], 5)) | ||||
| trainer.train_step(trainer.model, data_batch) | trainer.train_step(trainer.model, data_batch) | ||||
| trainer.invoke_hook('after_train_iter') | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| trainer.invoke_hook('after_train_epoch') | |||||
| trainer.invoke_hook('after_run') | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||||
| trainer.invoke_hook(TrainerStages.after_run) | |||||
| iters = 5 | iters = 5 | ||||
| target_lrs = [0.004] * iters * 1 + [0.007] * iters * 1 + [ | target_lrs = [0.004] * iters * 1 + [0.007] * iters * 1 + [ | ||||
| @@ -0,0 +1,184 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from abc import ABCMeta | |||||
| import json | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.optim import SGD | |||||
| from torch.optim.lr_scheduler import MultiStepLR | |||||
| from torch.utils.data import Dataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile, TrainerStages | |||||
| class DummyDataset(Dataset, metaclass=ABCMeta): | |||||
| """Base Dataset | |||||
| """ | |||||
| def __len__(self): | |||||
| return 10 | |||||
| def __getitem__(self, idx): | |||||
| return dict(feat=torch.rand((2, 2)), label=torch.randint(0, 2, (1, ))) | |||||
| class DummyModel(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.linear = nn.Linear(2, 2) | |||||
| self.bn = nn.BatchNorm1d(2) | |||||
| def forward(self, feat, labels): | |||||
| x = self.linear(feat) | |||||
| x = self.bn(x) | |||||
| loss = torch.sum(x) | |||||
| return dict(logits=x, loss=loss) | |||||
| class OptimizerHookTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| def test_optimizer_hook(self): | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| } | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| model = DummyModel() | |||||
| optimizer = SGD(model.parameters(), lr=0.01) | |||||
| lr_scheduler = MultiStepLR(optimizer, milestones=[1, 2]) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=model, | |||||
| train_dataset=DummyDataset(), | |||||
| optimizers=(optimizer, lr_scheduler), | |||||
| max_epochs=2) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| train_dataloader = trainer._build_dataloader_with_dataset( | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | |||||
| trainer.register_optimizers_hook() | |||||
| trainer.invoke_hook(TrainerStages.before_run) | |||||
| for _ in range(trainer._epoch, trainer._max_epochs): | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for _, data_batch in enumerate(train_dataloader): | |||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||||
| trainer.train_step(trainer.model, data_batch) | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| self.assertEqual( | |||||
| len(trainer.optimizer.param_groups[0]['params']), 4) | |||||
| for i in range(4): | |||||
| self.assertTrue(trainer.optimizer.param_groups[0]['params'] | |||||
| [i].requires_grad) | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||||
| trainer._epoch += 1 | |||||
| trainer.invoke_hook(TrainerStages.after_run) | |||||
| class TorchAMPOptimizerHookTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| def test_amp_optimizer_hook(self): | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| } | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| model = DummyModel().cuda() | |||||
| optimizer = SGD(model.parameters(), lr=0.01) | |||||
| lr_scheduler = MultiStepLR(optimizer, milestones=[1, 2]) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=model, | |||||
| train_dataset=DummyDataset(), | |||||
| optimizers=(optimizer, lr_scheduler), | |||||
| max_epochs=2, | |||||
| use_fp16=True) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| train_dataloader = trainer._build_dataloader_with_dataset( | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | |||||
| trainer.register_optimizers_hook() | |||||
| trainer.invoke_hook(TrainerStages.before_run) | |||||
| for _ in range(trainer._epoch, trainer._max_epochs): | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for _, data_batch in enumerate(train_dataloader): | |||||
| for k, v in data_batch.items(): | |||||
| data_batch[k] = v.cuda() | |||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||||
| trainer.train_step(trainer.model, data_batch) | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| self.assertEqual(trainer.train_outputs['logits'].dtype, | |||||
| torch.float16) | |||||
| # test if `after_train_iter`, whether the model is reset to fp32 | |||||
| trainer.train_step(trainer.model, data_batch) | |||||
| self.assertEqual(trainer.train_outputs['logits'].dtype, | |||||
| torch.float32) | |||||
| self.assertEqual( | |||||
| len(trainer.optimizer.param_groups[0]['params']), 4) | |||||
| for i in range(4): | |||||
| self.assertTrue(trainer.optimizer.param_groups[0]['params'] | |||||
| [i].requires_grad) | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||||
| trainer._epoch += 1 | |||||
| trainer.invoke_hook(TrainerStages.after_run) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR | |||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | |||||
| class DummyDataset(Dataset, metaclass=ABCMeta): | class DummyDataset(Dataset, metaclass=ABCMeta): | ||||
| @@ -89,39 +89,43 @@ class IterTimerHookTest(unittest.TestCase): | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | ||||
| trainer.register_optimizers_hook() | trainer.register_optimizers_hook() | ||||
| trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | ||||
| trainer.invoke_hook('before_run') | |||||
| trainer.data_loader = train_dataloader | |||||
| trainer.invoke_hook(TrainerStages.before_run) | |||||
| for i in range(trainer._epoch, trainer._max_epochs): | for i in range(trainer._epoch, trainer._max_epochs): | ||||
| trainer.invoke_hook('before_train_epoch') | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for _, data_batch in enumerate(train_dataloader): | for _, data_batch in enumerate(train_dataloader): | ||||
| trainer.invoke_hook('before_train_iter') | |||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||||
| trainer.train_step(trainer.model, data_batch) | trainer.train_step(trainer.model, data_batch) | ||||
| trainer.invoke_hook('after_train_iter') | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| self.assertIn('data_load_time', trainer.log_buffer.val_history) | |||||
| self.assertIn('time', trainer.log_buffer.val_history) | |||||
| self.assertIn('loss', trainer.log_buffer.val_history) | |||||
| self.assertIn(LogKeys.DATA_LOAD_TIME, | |||||
| trainer.log_buffer.val_history) | |||||
| self.assertIn(LogKeys.ITER_TIME, | |||||
| trainer.log_buffer.val_history) | |||||
| self.assertIn(LogKeys.LOSS, trainer.log_buffer.val_history) | |||||
| trainer.invoke_hook('after_train_epoch') | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||||
| target_len = 5 * (i + 1) | |||||
| target_len = 5 | |||||
| self.assertEqual( | self.assertEqual( | ||||
| len(trainer.log_buffer.val_history['data_load_time']), | |||||
| len(trainer.log_buffer.val_history[LogKeys.DATA_LOAD_TIME]), | |||||
| target_len) | target_len) | ||||
| self.assertEqual( | self.assertEqual( | ||||
| len(trainer.log_buffer.val_history['time']), target_len) | |||||
| len(trainer.log_buffer.val_history[LogKeys.ITER_TIME]), | |||||
| target_len) | |||||
| self.assertEqual( | self.assertEqual( | ||||
| len(trainer.log_buffer.val_history['loss']), target_len) | |||||
| len(trainer.log_buffer.val_history[LogKeys.LOSS]), target_len) | |||||
| self.assertEqual( | self.assertEqual( | ||||
| len(trainer.log_buffer.n_history['data_load_time']), | |||||
| len(trainer.log_buffer.n_history[LogKeys.DATA_LOAD_TIME]), | |||||
| target_len) | target_len) | ||||
| self.assertEqual( | self.assertEqual( | ||||
| len(trainer.log_buffer.n_history['time']), target_len) | |||||
| len(trainer.log_buffer.n_history[LogKeys.ITER_TIME]), | |||||
| target_len) | |||||
| self.assertEqual( | self.assertEqual( | ||||
| len(trainer.log_buffer.n_history['loss']), target_len) | |||||
| len(trainer.log_buffer.n_history[LogKeys.LOSS]), target_len) | |||||
| trainer.invoke_hook('after_run') | |||||
| trainer.invoke_hook(TrainerStages.after_run) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -12,17 +12,12 @@ from torch.optim import SGD | |||||
| from torch.optim.lr_scheduler import StepLR | from torch.optim.lr_scheduler import StepLR | ||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from modelscope.metrics.builder import MetricKeys | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| class DummyMetric: | |||||
| def __call__(self, ground_truth, predict_results): | |||||
| return {'accuracy': 0.5} | |||||
| class DummyDataset(Dataset, metaclass=ABCMeta): | class DummyDataset(Dataset, metaclass=ABCMeta): | ||||
| """Base Dataset | """Base Dataset | ||||
| """ | """ | ||||
| @@ -130,9 +125,9 @@ class TrainerTest(unittest.TestCase): | |||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| self.assertIn('epoch_1.pth', results_files) | |||||
| self.assertIn('epoch_2.pth', results_files) | |||||
| self.assertIn('epoch_3.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_train_1(self): | def test_train_1(self): | ||||
| @@ -167,7 +162,7 @@ class TrainerTest(unittest.TestCase): | |||||
| } | } | ||||
| } | } | ||||
| config_path = os.path.join(self.tmp_dir, 'config.json') | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | with open(config_path, 'w') as f: | ||||
| json.dump(json_cfg, f) | json.dump(json_cfg, f) | ||||
| @@ -189,9 +184,133 @@ class TrainerTest(unittest.TestCase): | |||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| self.assertIn('epoch_1.pth', results_files) | |||||
| self.assertIn('epoch_2.pth', results_files) | |||||
| self.assertIn('epoch_3.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_train_with_default_config(self): | |||||
| json_cfg = { | |||||
| 'train': { | |||||
| 'work_dir': self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1 | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': ['seq_cls_metric'] | |||||
| } | |||||
| } | |||||
| class _DummyDataset(DummyDataset): | |||||
| """Base Dataset | |||||
| """ | |||||
| def __len__(self): | |||||
| return 40 | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| model = DummyModel() | |||||
| optimmizer = SGD(model.parameters(), lr=0.01) | |||||
| lr_scheduler = StepLR(optimmizer, 2) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=model, | |||||
| data_collator=None, | |||||
| train_dataset=_DummyDataset(), | |||||
| eval_dataset=DummyDataset(), | |||||
| optimizers=(optimmizer, lr_scheduler), | |||||
| max_epochs=3) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_file = os.path.join(self.tmp_dir, f'{trainer.timestamp}.log.json') | |||||
| with open(json_file, 'r') as f: | |||||
| lines = [i.strip() for i in f.readlines()] | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[0])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[1])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 20 | |||||
| }, json.loads(lines[2])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[3])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[4])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 20 | |||||
| }, json.loads(lines[5])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[6])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[7])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 20 | |||||
| }, json.loads(lines[8])) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| for i in [0, 1, 3, 4, 6, 7]: | |||||
| self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||||
| self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||||
| for i in [2, 5, 8]: | |||||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | |||||
| class DummyTrainerTest(unittest.TestCase): | class DummyTrainerTest(unittest.TestCase): | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import argparse | import argparse | ||||
| import copy | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| @@ -77,6 +78,148 @@ class ConfigTest(unittest.TestCase): | |||||
| self.assertEqual(args.optimizer, 'Adam') | self.assertEqual(args.optimizer, 'Adam') | ||||
| self.assertEqual(args.save_checkpoint_epochs, 20) | self.assertEqual(args.save_checkpoint_epochs, 20) | ||||
| def test_merge_from_dict(self): | |||||
| base_cfg = copy.deepcopy(obj) | |||||
| base_cfg.update({'dict_list': [dict(l1=1), dict(l2=2)]}) | |||||
| cfg = Config(base_cfg) | |||||
| merge_dict = { | |||||
| 'a': 2, | |||||
| 'b.d': 'ee', | |||||
| 'b.c': [3, 3, 3], | |||||
| 'dict_list': { | |||||
| '0': dict(l1=3) | |||||
| }, | |||||
| 'c': 'test' | |||||
| } | |||||
| cfg1 = copy.deepcopy(cfg) | |||||
| cfg1.merge_from_dict(merge_dict) | |||||
| self.assertDictEqual( | |||||
| cfg1._cfg_dict, { | |||||
| 'a': 2, | |||||
| 'b': { | |||||
| 'c': [3, 3, 3], | |||||
| 'd': 'ee' | |||||
| }, | |||||
| 'dict_list': [dict(l1=3), dict(l2=2)], | |||||
| 'c': 'test' | |||||
| }) | |||||
| cfg2 = copy.deepcopy(cfg) | |||||
| cfg2.merge_from_dict(merge_dict, force=False) | |||||
| self.assertDictEqual( | |||||
| cfg2._cfg_dict, { | |||||
| 'a': 1, | |||||
| 'b': { | |||||
| 'c': [1, 2, 3], | |||||
| 'd': 'dd' | |||||
| }, | |||||
| 'dict_list': [dict(l1=1), dict(l2=2)], | |||||
| 'c': 'test' | |||||
| }) | |||||
| def test_merge_from_dict_with_list(self): | |||||
| base_cfg = { | |||||
| 'a': | |||||
| 1, | |||||
| 'b': { | |||||
| 'c': [1, 2, 3], | |||||
| 'd': 'dd' | |||||
| }, | |||||
| 'dict_list': [dict(type='l1', v=1), | |||||
| dict(type='l2', v=2)], | |||||
| 'dict_list2': [ | |||||
| dict( | |||||
| type='l1', | |||||
| v=[dict(type='l1_1', v=1), | |||||
| dict(type='l1_2', v=2)]), | |||||
| dict(type='l2', v=2) | |||||
| ] | |||||
| } | |||||
| cfg = Config(base_cfg) | |||||
| merge_dict_for_list = { | |||||
| 'a': | |||||
| 2, | |||||
| 'b.c': [3, 3, 3], | |||||
| 'b.d': | |||||
| 'ee', | |||||
| 'dict_list': [dict(type='l1', v=8), | |||||
| dict(type='l3', v=8)], | |||||
| 'dict_list2': [ | |||||
| dict( | |||||
| type='l1', | |||||
| v=[ | |||||
| dict(type='l1_1', v=8), | |||||
| dict(type='l1_2', v=2), | |||||
| dict(type='l1_3', v=8), | |||||
| ]), | |||||
| dict(type='l2', v=8) | |||||
| ], | |||||
| 'c': | |||||
| 'test' | |||||
| } | |||||
| cfg1 = copy.deepcopy(cfg) | |||||
| cfg1.merge_from_dict(merge_dict_for_list, force=False) | |||||
| self.assertDictEqual( | |||||
| cfg1._cfg_dict, { | |||||
| 'a': | |||||
| 1, | |||||
| 'b': { | |||||
| 'c': [1, 2, 3], | |||||
| 'd': 'dd' | |||||
| }, | |||||
| 'dict_list': [ | |||||
| dict(type='l1', v=1), | |||||
| dict(type='l2', v=2), | |||||
| dict(type='l3', v=8) | |||||
| ], | |||||
| 'dict_list2': [ | |||||
| dict( | |||||
| type='l1', | |||||
| v=[ | |||||
| dict(type='l1_1', v=1), | |||||
| dict(type='l1_2', v=2), | |||||
| dict(type='l1_3', v=8), | |||||
| ]), | |||||
| dict(type='l2', v=2) | |||||
| ], | |||||
| 'c': | |||||
| 'test' | |||||
| }) | |||||
| cfg2 = copy.deepcopy(cfg) | |||||
| cfg2.merge_from_dict(merge_dict_for_list, force=True) | |||||
| self.assertDictEqual( | |||||
| cfg2._cfg_dict, { | |||||
| 'a': | |||||
| 2, | |||||
| 'b': { | |||||
| 'c': [3, 3, 3], | |||||
| 'd': 'ee' | |||||
| }, | |||||
| 'dict_list': [ | |||||
| dict(type='l1', v=8), | |||||
| dict(type='l2', v=2), | |||||
| dict(type='l3', v=8) | |||||
| ], | |||||
| 'dict_list2': [ | |||||
| dict( | |||||
| type='l1', | |||||
| v=[ | |||||
| dict(type='l1_1', v=8), | |||||
| dict(type='l1_2', v=2), | |||||
| dict(type='l1_3', v=8), | |||||
| ]), | |||||
| dict(type='l2', v=8) | |||||
| ], | |||||
| 'c': | |||||
| 'test' | |||||
| }) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||