1. Support `ReduceLROnPlateau` lr scheduler, and add `PlateauLrSchedulerHook` for it
2. Support custom `optimizer_hook` and `lr_scheduler_hook`
3. Remove function of save best ckpt from `EvaluationHook`, replace with `BestCkptSaverHook`
4. `evaluation_loop` return metric values directly,move metric computation to `single_gpu_test` and `multi_gpu_test`
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9584322
* [to #43627720] support ReduceLROnPlateau and fix lr scheduler
master
| @@ -1,6 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .builder import HOOKS, build_hook | from .builder import HOOKS, build_hook | ||||
| from .checkpoint_hook import CheckpointHook | |||||
| from .checkpoint_hook import BestCkptSaverHook, CheckpointHook | |||||
| from .evaluation_hook import EvaluationHook | from .evaluation_hook import EvaluationHook | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .iter_timer_hook import IterTimerHook | from .iter_timer_hook import IterTimerHook | ||||
| @@ -13,5 +13,6 @@ from .priority import Priority | |||||
| __all__ = [ | __all__ = [ | ||||
| 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | ||||
| 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | ||||
| 'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook' | |||||
| 'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook', | |||||
| 'BestCkptSaverHook' | |||||
| ] | ] | ||||
| @@ -42,6 +42,9 @@ class CheckpointHook(Hook): | |||||
| if not self.save_dir: | if not self.save_dir: | ||||
| self.save_dir = trainer.work_dir | self.save_dir = trainer.work_dir | ||||
| if not os.path.exists(self.save_dir) and is_master(): | |||||
| os.makedirs(self.save_dir) | |||||
| if not hasattr(trainer, 'logger'): | if not hasattr(trainer, 'logger'): | ||||
| self.logger = get_logger(__name__) | self.logger = get_logger(__name__) | ||||
| else: | else: | ||||
| @@ -93,3 +96,72 @@ class CheckpointHook(Hook): | |||||
| and check_last(trainer)): | and check_last(trainer)): | ||||
| return True | return True | ||||
| return False | return False | ||||
| @HOOKS.register_module() | |||||
| class BestCkptSaverHook(CheckpointHook): | |||||
| """Save best checkpoints hook. | |||||
| Args: | |||||
| metric_key (str): Metric key to compare rule for best score. | |||||
| rule (str): Comparison rule for best score. | |||||
| Support "max" and "min". If rule is "max", the checkpoint at the maximum `metric_key` | |||||
| will be saved, If rule is "min", the checkpoint at the minimum `metric_key` will be saved. | |||||
| by_epoch (bool): Save best checkpoints by epoch or by iteration. | |||||
| save_optimizer (bool): Whether to save optimizer state dict. Default: True. | |||||
| save_dir (str): Output directory to save best checkpoint. | |||||
| """ | |||||
| PRIORITY = Priority.NORMAL | |||||
| rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} | |||||
| def __init__(self, | |||||
| metric_key, | |||||
| rule='max', | |||||
| by_epoch=True, | |||||
| save_optimizer=True, | |||||
| save_dir=None): | |||||
| assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | |||||
| super().__init__( | |||||
| by_epoch=by_epoch, | |||||
| save_optimizer=save_optimizer, | |||||
| save_dir=save_dir, | |||||
| ) | |||||
| self.metric_key = metric_key | |||||
| self.rule = rule | |||||
| self._best_metric = None | |||||
| self._best_ckpt_file = None | |||||
| def _should_save(self, trainer): | |||||
| return self._is_best_metric(trainer.metric_values) | |||||
| def _is_best_metric(self, metric_values): | |||||
| if metric_values is None: | |||||
| return False | |||||
| if self.metric_key not in metric_values: | |||||
| raise ValueError( | |||||
| f'Not find metric_key: {self.metric_key} in {metric_values}') | |||||
| if self._best_metric is None: | |||||
| self._best_metric = metric_values[self.metric_key] | |||||
| return True | |||||
| else: | |||||
| compare_fn = self.rule_map[self.rule] | |||||
| if compare_fn(metric_values[self.metric_key], self._best_metric): | |||||
| self._best_metric = metric_values[self.metric_key] | |||||
| return True | |||||
| return False | |||||
| def _save_checkpoint(self, trainer): | |||||
| if self.by_epoch: | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' | |||||
| ) | |||||
| else: | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||||
| ) | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| self._best_ckpt_file = cur_save_name | |||||
| @@ -1,13 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| from modelscope.utils.checkpoint import save_checkpoint | |||||
| from modelscope.utils.constant import LogKeys | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.torch_utils import get_dist_info | |||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .priority import Priority | |||||
| @HOOKS.register_module() | @HOOKS.register_module() | ||||
| @@ -18,56 +11,13 @@ class EvaluationHook(Hook): | |||||
| by_epoch (bool): Evaluate by epoch or by iteration. | by_epoch (bool): Evaluate by epoch or by iteration. | ||||
| start_idx (int | None, optional): The epoch/iterations validation begins. | start_idx (int | None, optional): The epoch/iterations validation begins. | ||||
| Default: None, validate every interval epochs/iterations from scratch. | Default: None, validate every interval epochs/iterations from scratch. | ||||
| save_best_ckpt (bool): Whether save the best checkpoint during evaluation. | |||||
| monitor_key (str): Monitor key to compare rule for best score, only valid when `save_best_ckpt` is true. | |||||
| rule (str): Comparison rule for best score, only valid when `save_best_ckpt` is true. | |||||
| Support "max" and "min". If rule is "max", the checkpoint at the maximum `monitor_key` | |||||
| will be saved, If rule is "min", the checkpoint at the minimum `monitor_key` will be saved. | |||||
| out_dir (str): Output directory to save best checkpoint. | |||||
| """ | """ | ||||
| PRIORITY = Priority.NORMAL | |||||
| rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} | |||||
| def __init__(self, | |||||
| interval=1, | |||||
| by_epoch=True, | |||||
| start_idx=None, | |||||
| save_best_ckpt=False, | |||||
| monitor_key=None, | |||||
| rule='max', | |||||
| out_dir=None): | |||||
| def __init__(self, interval=1, by_epoch=True, start_idx=None): | |||||
| assert interval > 0, 'interval must be a positive number' | assert interval > 0, 'interval must be a positive number' | ||||
| if save_best_ckpt: | |||||
| assert monitor_key is not None, 'Must provide `monitor_key` when `save_best_ckpt` is True.' | |||||
| assert rule in ['max', | |||||
| 'min'], 'Only support "max" or "min" rule now.' | |||||
| self.interval = interval | self.interval = interval | ||||
| self.start_idx = start_idx | self.start_idx = start_idx | ||||
| self.by_epoch = by_epoch | self.by_epoch = by_epoch | ||||
| self.save_best_ckpt = save_best_ckpt | |||||
| self.monitor_key = monitor_key | |||||
| self.rule = rule | |||||
| self.out_dir = out_dir | |||||
| self._best_metric = None | |||||
| self._best_ckpt_file = None | |||||
| def before_run(self, trainer): | |||||
| if not self.out_dir: | |||||
| self.out_dir = trainer.work_dir | |||||
| if not os.path.exists(self.out_dir): | |||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| os.makedirs(self.out_dir) | |||||
| if self.save_best_ckpt: | |||||
| if not hasattr(trainer, 'logger'): | |||||
| self.logger = get_logger(__name__) | |||||
| else: | |||||
| self.logger = trainer.logger | |||||
| self.logger.info( | |||||
| f'Best checkpoint will be saved to {self.out_dir}') | |||||
| def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
| """Called after every training iter to evaluate the results.""" | """Called after every training iter to evaluate the results.""" | ||||
| @@ -87,46 +37,6 @@ class EvaluationHook(Hook): | |||||
| trainer.log_buffer.ready = True | trainer.log_buffer.ready = True | ||||
| if self.save_best_ckpt and self._is_best_metric(eval_res): | |||||
| # remove the previous best model and save the latest best model | |||||
| if self._best_ckpt_file is not None and os.path.exists( | |||||
| self._best_ckpt_file): | |||||
| os.remove(self._best_ckpt_file) | |||||
| self._save_checkpoint(trainer) | |||||
| def _is_best_metric(self, eval_res): | |||||
| if self.monitor_key not in eval_res: | |||||
| raise ValueError( | |||||
| f'Not find monitor_key: {self.monitor_key} in {eval_res}') | |||||
| if self._best_metric is None: | |||||
| self._best_metric = eval_res[self.monitor_key] | |||||
| return True | |||||
| else: | |||||
| compare_fn = self.rule_map[self.rule] | |||||
| if compare_fn(eval_res[self.monitor_key], self._best_metric): | |||||
| self._best_metric = eval_res[self.monitor_key] | |||||
| return True | |||||
| return False | |||||
| def _save_checkpoint(self, trainer): | |||||
| if self.by_epoch: | |||||
| cur_save_name = os.path.join( | |||||
| self.out_dir, | |||||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.monitor_key}{self._best_metric}.pth' | |||||
| ) | |||||
| else: | |||||
| cur_save_name = os.path.join( | |||||
| self.out_dir, | |||||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.monitor_key}{self._best_metric}.pth' | |||||
| ) | |||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| self._best_ckpt_file = cur_save_name | |||||
| def _should_evaluate(self, trainer): | def _should_evaluate(self, trainer): | ||||
| """Judge whether to perform evaluation. | """Judge whether to perform evaluation. | ||||
| @@ -1,6 +1,8 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | ||||
| from modelscope.utils.constant import LogKeys | from modelscope.utils.constant import LogKeys | ||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.torch_utils import is_master | |||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .priority import Priority | from .priority import Priority | ||||
| @@ -50,12 +52,14 @@ class LrSchedulerHook(Hook): | |||||
| trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | ||||
| def before_train_epoch(self, trainer): | def before_train_epoch(self, trainer): | ||||
| trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||||
| def after_train_epoch(self, trainer): | |||||
| if self.by_epoch: | if self.by_epoch: | ||||
| if self.warmup_lr_scheduler is not None: | if self.warmup_lr_scheduler is not None: | ||||
| self.warmup_lr_scheduler.step() | self.warmup_lr_scheduler.step() | ||||
| else: | else: | ||||
| trainer.lr_scheduler.step() | trainer.lr_scheduler.step() | ||||
| trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||||
| def _get_log_lr(self, trainer): | def _get_log_lr(self, trainer): | ||||
| cur_lr = self.get_current_lr(trainer) | cur_lr = self.get_current_lr(trainer) | ||||
| @@ -70,3 +74,44 @@ class LrSchedulerHook(Hook): | |||||
| lr.update({k: lr_[0]}) | lr.update({k: lr_[0]}) | ||||
| return lr | return lr | ||||
| @HOOKS.register_module() | |||||
| class PlateauLrSchedulerHook(LrSchedulerHook): | |||||
| """Lr scheduler hook for `ReduceLROnPlateau`. | |||||
| Args: | |||||
| metric_key (str): Metric key returned from `trainer.metric_values`, | |||||
| get the value of metric key and pass it to `ReduceLROnPlateau.step`. | |||||
| by_epoch (bool): Whether lr changes by epoch | |||||
| warmup (dict): warm up config | |||||
| """ | |||||
| PRIORITY = Priority.LOW # should be after EvaluationHook | |||||
| def __init__(self, metric_key, by_epoch=True, warmup=None) -> None: | |||||
| super().__init__(by_epoch=by_epoch, warmup=warmup) | |||||
| self.metric_key = metric_key | |||||
| def before_run(self, trainer): | |||||
| super().before_run(trainer) | |||||
| if not hasattr(trainer, 'logger'): | |||||
| self.logger = get_logger(__name__) | |||||
| else: | |||||
| self.logger = trainer.logger | |||||
| def after_train_epoch(self, trainer): | |||||
| # adapt to evaluation intervel is greater than 1 | |||||
| if trainer.metric_values is None: | |||||
| if is_master(): | |||||
| self.logger.warning( | |||||
| f'Current epoch {trainer.epoch} has no evaluation metric values, skip lr_scheduler.step() !' | |||||
| ) | |||||
| return | |||||
| metrics = trainer.metric_values[self.metric_key] | |||||
| if self.by_epoch: | |||||
| if self.warmup_lr_scheduler is not None: | |||||
| self.warmup_lr_scheduler.step(metrics=metrics) | |||||
| else: | |||||
| trainer.lr_scheduler.step(metrics=metrics) | |||||
| @@ -40,7 +40,8 @@ def register_torch_lr_scheduler(): | |||||
| members = inspect.getmembers(lr_scheduler) | members = inspect.getmembers(lr_scheduler) | ||||
| for name, obj in members: | for name, obj in members: | ||||
| if inspect.isclass(obj) and issubclass(obj, _LRScheduler): | |||||
| if (inspect.isclass(obj) and issubclass( | |||||
| obj, _LRScheduler)) or name in ['ReduceLROnPlateau']: | |||||
| LR_SCHEDULER.register_module(module_name=name, module_cls=obj) | LR_SCHEDULER.register_module(module_name=name, module_cls=obj) | ||||
| @@ -52,12 +52,12 @@ class BaseWarmup(_LRScheduler): | |||||
| for i, group in enumerate(self.optimizer.param_groups): | for i, group in enumerate(self.optimizer.param_groups): | ||||
| group['lr'] *= scale_value[i] | group['lr'] *= scale_value[i] | ||||
| def step(self, epoch=None): | |||||
| def step(self, *args, **kwargs): | |||||
| """ | """ | ||||
| When ``self.base_scheduler._step_count`` is less than ``self.warmup_iters``, multiply lr by scale | When ``self.base_scheduler._step_count`` is less than ``self.warmup_iters``, multiply lr by scale | ||||
| """ | """ | ||||
| if self.base_scheduler._step_count > self.warmup_iters: | if self.base_scheduler._step_count > self.warmup_iters: | ||||
| return self.base_scheduler.step(epoch=epoch) | |||||
| return self.base_scheduler.step(*args, **kwargs) | |||||
| for group, lr in zip(self.optimizer.param_groups, self.base_lrs): | for group, lr in zip(self.optimizer.param_groups, self.base_lrs): | ||||
| group['lr'] = lr | group['lr'] = lr | ||||
| @@ -66,7 +66,7 @@ class BaseWarmup(_LRScheduler): | |||||
| if self._is_init_step: | if self._is_init_step: | ||||
| self._is_init_step = False | self._is_init_step = False | ||||
| else: | else: | ||||
| self.base_scheduler.step(epoch=epoch) | |||||
| self.base_scheduler.step(*args, **kwargs) | |||||
| self.scale() | self.scale() | ||||
| @@ -7,6 +7,7 @@ from distutils.version import LooseVersion | |||||
| from functools import partial | from functools import partial | ||||
| from typing import Callable, List, Optional, Tuple, Union | from typing import Callable, List, Optional, Tuple, Union | ||||
| import json | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from addict import Dict | from addict import Dict | ||||
| @@ -135,6 +136,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.data_collator = data_collator if data_collator is not None else torch_default_data_collator | self.data_collator = data_collator if data_collator is not None else torch_default_data_collator | ||||
| self.metrics = self.get_metrics() | self.metrics = self.get_metrics() | ||||
| self._metric_values = None | |||||
| self.optimizers = optimizers | self.optimizers = optimizers | ||||
| self.logger = get_logger(log_level=self.cfg.get('log_level', 'INFO')) | self.logger = get_logger(log_level=self.cfg.get('log_level', 'INFO')) | ||||
| self._mode = ModeKeys.TRAIN | self._mode = ModeKeys.TRAIN | ||||
| @@ -322,17 +324,16 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| **self.cfg.evaluation.get('dataloader', {})) | **self.cfg.evaluation.get('dataloader', {})) | ||||
| self.data_loader = self.eval_dataloader | self.data_loader = self.eval_dataloader | ||||
| metric_classes = [build_metric(metric) for metric in self.metrics] | metric_classes = [build_metric(metric) for metric in self.metrics] | ||||
| self.evaluation_loop(self.eval_dataloader, checkpoint_path, | |||||
| metric_classes) | |||||
| rank, world_size = get_dist_info() | |||||
| metric_values = {} | |||||
| if rank == 0: | |||||
| for metric_cls in metric_classes: | |||||
| metric_values.update(metric_cls.evaluate()) | |||||
| if world_size > 1: | |||||
| metric_values = broadcast(metric_values, 0) | |||||
| metric_values = self.evaluation_loop(self.eval_dataloader, | |||||
| checkpoint_path, metric_classes) | |||||
| self._metric_values = metric_values | |||||
| return metric_values | return metric_values | ||||
| @property | |||||
| def metric_values(self): | |||||
| return self._metric_values | |||||
| def build_model(self) -> Union[nn.Module, TorchModel]: | def build_model(self) -> Union[nn.Module, TorchModel]: | ||||
| """ Instantiate a pytorch model and return. | """ Instantiate a pytorch model and return. | ||||
| @@ -530,8 +531,6 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| We provide a default implementation, if you want to customize your own optimizer | We provide a default implementation, if you want to customize your own optimizer | ||||
| and lr scheduler, you can either pass a tuple through trainer init function or | and lr scheduler, you can either pass a tuple through trainer init function or | ||||
| subclass this class and override this method. | subclass this class and override this method. | ||||
| """ | """ | ||||
| optimizer, lr_scheduler = self.optimizers | optimizer, lr_scheduler = self.optimizers | ||||
| if optimizer is None: | if optimizer is None: | ||||
| @@ -563,22 +562,38 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| def register_optimizers_hook(self): | def register_optimizers_hook(self): | ||||
| """ Register optimizer hook and lr scheduler hook. | """ Register optimizer hook and lr scheduler hook. | ||||
| """ | """ | ||||
| optimizer, lr_scheduler = self.optimizers | |||||
| opti_error_msg = 'optimizers should be a tuple of `torch.optim.Optimizer`'\ | |||||
| ' and `torch.optim.lr_scheduler._LRScheduler`' | |||||
| if optimizer is not None: | |||||
| assert isinstance(optimizer, torch.optim.Optimizer), opti_error_msg | |||||
| if lr_scheduler is not None: | |||||
| assert isinstance( | |||||
| lr_scheduler, | |||||
| torch.optim.lr_scheduler._LRScheduler), opti_error_msg | |||||
| _, lr_scheduler, optim_options, lr_options = self.create_optimizer_and_scheduler( | |||||
| ) | |||||
| _, _, optim_options, lr_options = self.create_optimizer_and_scheduler() | |||||
| lr_hook = dict(type='LrSchedulerHook', **lr_options) | |||||
| if self.use_fp16: | |||||
| optim_hook = dict(type='TorchAMPOptimizerHook', **optim_options) | |||||
| else: | |||||
| optim_hook = dict(type='OptimizerHook', **optim_options) | |||||
| optim_hook = self.cfg.train.get('optimizer_hook', None) | |||||
| lr_hook = self.cfg.train.get('lr_scheduler_hook', None) | |||||
| # adapt to `ReduceLROnPlateau` | |||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |||||
| if isinstance(lr_scheduler, ReduceLROnPlateau) and lr_hook is None: | |||||
| plateau_cfg = { | |||||
| 'train': { | |||||
| 'lr_scheduler_hook': { | |||||
| 'type': 'PlateauLrSchedulerHook', | |||||
| 'metric_key': | |||||
| 'Metric Key used for PlateauLrSchedulerHook' | |||||
| } | |||||
| } | |||||
| } | |||||
| plateau_cfg = json.dumps( | |||||
| plateau_cfg, sort_keys=False, indent=4, separators=(',', ':')) | |||||
| raise ValueError( | |||||
| 'Must add `lr_scheduler_hook` to configuration for `ReduceLROnPlateau` lr scheduler as follows:' | |||||
| + '\n' + plateau_cfg) | |||||
| if lr_hook is None: | |||||
| lr_hook = dict(type='LrSchedulerHook', **lr_options) | |||||
| if optim_hook is None: | |||||
| if self.use_fp16: | |||||
| optim_hook = dict( | |||||
| type='TorchAMPOptimizerHook', **optim_options) | |||||
| else: | |||||
| optim_hook = dict(type='OptimizerHook', **optim_options) | |||||
| self.register_hook_from_cfg([lr_hook, optim_hook]) | self.register_hook_from_cfg([lr_hook, optim_hook]) | ||||
| @@ -692,7 +707,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| """ | """ | ||||
| if self._dist: | if self._dist: | ||||
| from modelscope.trainers.utils.inference import multi_gpu_test | from modelscope.trainers.utils.inference import multi_gpu_test | ||||
| multi_gpu_test( | |||||
| metric_values = multi_gpu_test( | |||||
| self.model, | self.model, | ||||
| data_loader, | data_loader, | ||||
| tmpdir=None, | tmpdir=None, | ||||
| @@ -701,12 +716,14 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| metric_classes=metric_classes) | metric_classes=metric_classes) | ||||
| else: | else: | ||||
| from modelscope.trainers.utils.inference import single_gpu_test | from modelscope.trainers.utils.inference import single_gpu_test | ||||
| single_gpu_test( | |||||
| metric_values = single_gpu_test( | |||||
| self.model, | self.model, | ||||
| data_loader, | data_loader, | ||||
| data_collate_fn=self.collate_fn, | data_collate_fn=self.collate_fn, | ||||
| metric_classes=metric_classes) | metric_classes=metric_classes) | ||||
| return metric_values | |||||
| def register_hook(self, hook: Hook) -> None: | def register_hook(self, hook: Hook) -> None: | ||||
| """Register a hook into the hook list. | """Register a hook into the hook list. | ||||
| @@ -10,7 +10,8 @@ import torch | |||||
| from torch import distributed as dist | from torch import distributed as dist | ||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from modelscope.utils.torch_utils import get_dist_info, is_master, make_tmp_dir | |||||
| from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, | |||||
| make_tmp_dir) | |||||
| from modelscope.utils.utils import if_func_receive_dict_inputs | from modelscope.utils.utils import if_func_receive_dict_inputs | ||||
| @@ -51,6 +52,12 @@ def single_gpu_test(model, | |||||
| for _ in range(batch_size): | for _ in range(batch_size): | ||||
| pbar.update() | pbar.update() | ||||
| metric_values = {} | |||||
| for metric_cls in metric_classes: | |||||
| metric_values.update(metric_cls.evaluate()) | |||||
| return metric_values | |||||
| def multi_gpu_test(model, | def multi_gpu_test(model, | ||||
| data_loader, | data_loader, | ||||
| @@ -132,6 +139,15 @@ def multi_gpu_test(model, | |||||
| for metric_cls in metric_classes: | for metric_cls in metric_classes: | ||||
| metric_cls.add(results[i], data_list[i]) | metric_cls.add(results[i], data_list[i]) | ||||
| metric_values = {} | |||||
| if rank == 0: | |||||
| for metric_cls in metric_classes: | |||||
| metric_values.update(metric_cls.evaluate()) | |||||
| if world_size > 1: | |||||
| metric_values = broadcast(metric_values, 0) | |||||
| return metric_values | |||||
| def collect_results_cpu(result_part, size, tmpdir=None): | def collect_results_cpu(result_part, size, tmpdir=None): | ||||
| """Collect results under cpu mode. | """Collect results under cpu mode. | ||||
| @@ -56,7 +56,8 @@ class Registry(object): | |||||
| def _register_module(self, | def _register_module(self, | ||||
| group_key=default_group, | group_key=default_group, | ||||
| module_name=None, | module_name=None, | ||||
| module_cls=None): | |||||
| module_cls=None, | |||||
| force=False): | |||||
| assert isinstance(group_key, | assert isinstance(group_key, | ||||
| str), 'group_key is required and must be str' | str), 'group_key is required and must be str' | ||||
| @@ -69,7 +70,7 @@ class Registry(object): | |||||
| if module_name is None: | if module_name is None: | ||||
| module_name = module_cls.__name__ | module_name = module_cls.__name__ | ||||
| if module_name in self._modules[group_key]: | |||||
| if module_name in self._modules[group_key] and not force: | |||||
| raise KeyError(f'{module_name} is already registered in ' | raise KeyError(f'{module_name} is already registered in ' | ||||
| f'{self._name}[{group_key}]') | f'{self._name}[{group_key}]') | ||||
| self._modules[group_key][module_name] = module_cls | self._modules[group_key][module_name] = module_cls | ||||
| @@ -78,7 +79,8 @@ class Registry(object): | |||||
| def register_module(self, | def register_module(self, | ||||
| group_key: str = default_group, | group_key: str = default_group, | ||||
| module_name: str = None, | module_name: str = None, | ||||
| module_cls: type = None): | |||||
| module_cls: type = None, | |||||
| force=False): | |||||
| """ Register module | """ Register module | ||||
| Example: | Example: | ||||
| @@ -102,6 +104,8 @@ class Registry(object): | |||||
| default group name is 'default' | default group name is 'default' | ||||
| module_name: Module name | module_name: Module name | ||||
| module_cls: Module class object | module_cls: Module class object | ||||
| force (bool, optional): Whether to override an existing class with | |||||
| the same name. Default: False. | |||||
| """ | """ | ||||
| if not (module_name is None or isinstance(module_name, str)): | if not (module_name is None or isinstance(module_name, str)): | ||||
| @@ -111,7 +115,8 @@ class Registry(object): | |||||
| self._register_module( | self._register_module( | ||||
| group_key=group_key, | group_key=group_key, | ||||
| module_name=module_name, | module_name=module_name, | ||||
| module_cls=module_cls) | |||||
| module_cls=module_cls, | |||||
| force=force) | |||||
| return module_cls | return module_cls | ||||
| # if module_cls is None, should return a decorator function | # if module_cls is None, should return a decorator function | ||||
| @@ -119,7 +124,8 @@ class Registry(object): | |||||
| self._register_module( | self._register_module( | ||||
| group_key=group_key, | group_key=group_key, | ||||
| module_name=module_name, | module_name=module_name, | ||||
| module_cls=module_cls) | |||||
| module_cls=module_cls, | |||||
| force=force) | |||||
| return module_cls | return module_cls | ||||
| return _register | return _register | ||||
| @@ -99,7 +99,7 @@ class TensorboardHookTest(unittest.TestCase): | |||||
| ea.Scalars(LogKeys.LR)[i].value, 0.01, delta=0.001) | ea.Scalars(LogKeys.LR)[i].value, 0.01, delta=0.001) | ||||
| for i in range(5, 10): | for i in range(5, 10): | ||||
| self.assertAlmostEqual( | self.assertAlmostEqual( | ||||
| ea.Scalars(LogKeys.LR)[i].value, 0.001, delta=0.0001) | |||||
| ea.Scalars(LogKeys.LR)[i].value, 0.01, delta=0.0001) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -9,10 +9,31 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModelFile | ||||
| from modelscope.utils.registry import default_group | |||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | from modelscope.utils.test_utils import create_dummy_test_dataset | ||||
| def create_dummy_metric(): | |||||
| _global_iter = 0 | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name='DummyMetric', force=True) | |||||
| class DummyMetric: | |||||
| _fake_acc_by_epoch = {1: 0.1, 2: 0.5, 3: 0.2} | |||||
| def add(*args, **kwargs): | |||||
| pass | |||||
| def evaluate(self): | |||||
| global _global_iter | |||||
| _global_iter += 1 | |||||
| return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} | |||||
| dummy_dataset = create_dummy_test_dataset( | dummy_dataset = create_dummy_test_dataset( | ||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| @@ -39,12 +60,16 @@ class CheckpointHookTest(unittest.TestCase): | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | self.tmp_dir = tempfile.TemporaryDirectory().name | ||||
| if not os.path.exists(self.tmp_dir): | if not os.path.exists(self.tmp_dir): | ||||
| os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
| create_dummy_metric() | |||||
| def tearDown(self): | def tearDown(self): | ||||
| super().tearDown() | super().tearDown() | ||||
| shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
| def test_checkpoint_hook(self): | def test_checkpoint_hook(self): | ||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | json_cfg = { | ||||
| 'task': 'image_classification', | 'task': 'image_classification', | ||||
| 'train': { | 'train': { | ||||
| @@ -98,5 +123,80 @@ class CheckpointHookTest(unittest.TestCase): | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | ||||
| class BestCkptSaverHookTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| create_dummy_metric() | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| def test_best_checkpoint_hook(self): | |||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': | |||||
| self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'optimizer': { | |||||
| 'type': 'SGD', | |||||
| 'lr': 0.01 | |||||
| }, | |||||
| 'lr_scheduler': { | |||||
| 'type': 'StepLR', | |||||
| 'step_size': 2 | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'BestCkptSaverHook', | |||||
| 'metric_key': MetricKeys.ACCURACY, | |||||
| 'rule': 'min' | |||||
| }, { | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1, | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': ['DummyMetric'] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=DummyModel(), | |||||
| data_collator=None, | |||||
| train_dataset=dummy_dataset, | |||||
| eval_dataset=dummy_dataset, | |||||
| max_epochs=3) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | |||||
| results_files) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -15,21 +15,18 @@ from modelscope.utils.constant import LogKeys, ModelFile | |||||
| from modelscope.utils.registry import default_group | from modelscope.utils.registry import default_group | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | from modelscope.utils.test_utils import create_dummy_test_dataset | ||||
| _global_iter = 0 | |||||
| def create_dummy_metric(): | |||||
| @METRICS.register_module(group_key=default_group, module_name='DummyMetric') | |||||
| class DummyMetric: | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name='DummyMetric', force=True) | |||||
| class DummyMetric: | |||||
| _fake_acc_by_epoch = {1: 0.1, 2: 0.5, 3: 0.2} | |||||
| def add(*args, **kwargs): | |||||
| pass | |||||
| def add(*args, **kwargs): | |||||
| pass | |||||
| def evaluate(self): | |||||
| global _global_iter | |||||
| _global_iter += 1 | |||||
| return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} | |||||
| def evaluate(self): | |||||
| return {MetricKeys.ACCURACY: 0.5} | |||||
| dummy_dataset = create_dummy_test_dataset( | dummy_dataset = create_dummy_test_dataset( | ||||
| @@ -58,80 +55,17 @@ class EvaluationHookTest(unittest.TestCase): | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | self.tmp_dir = tempfile.TemporaryDirectory().name | ||||
| if not os.path.exists(self.tmp_dir): | if not os.path.exists(self.tmp_dir): | ||||
| os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
| create_dummy_metric() | |||||
| def tearDown(self): | def tearDown(self): | ||||
| super().tearDown() | super().tearDown() | ||||
| shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
| def test_best_ckpt_rule_max(self): | |||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': | |||||
| self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'optimizer': { | |||||
| 'type': 'SGD', | |||||
| 'lr': 0.01, | |||||
| }, | |||||
| 'lr_scheduler': { | |||||
| 'type': 'StepLR', | |||||
| 'step_size': 2, | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1, | |||||
| 'save_best_ckpt': True, | |||||
| 'monitor_key': MetricKeys.ACCURACY | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': ['DummyMetric'] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=DummyModel(), | |||||
| data_collator=None, | |||||
| train_dataset=dummy_dataset, | |||||
| eval_dataset=dummy_dataset, | |||||
| max_epochs=3) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| self.assertIn(f'best_{LogKeys.EPOCH}2_{MetricKeys.ACCURACY}0.5.pth', | |||||
| results_files) | |||||
| def test_best_ckpt_rule_min(self): | |||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| def test_evaluation_hook(self): | |||||
| json_cfg = { | json_cfg = { | ||||
| 'task': 'image_classification', | 'task': 'image_classification', | ||||
| 'train': { | 'train': { | ||||
| 'work_dir': | |||||
| self.tmp_dir, | |||||
| 'work_dir': self.tmp_dir, | |||||
| 'dataloader': { | 'dataloader': { | ||||
| 'batch_size_per_gpu': 2, | 'batch_size_per_gpu': 2, | ||||
| 'workers_per_gpu': 1 | 'workers_per_gpu': 1 | ||||
| @@ -147,10 +81,6 @@ class EvaluationHookTest(unittest.TestCase): | |||||
| 'hooks': [{ | 'hooks': [{ | ||||
| 'type': 'EvaluationHook', | 'type': 'EvaluationHook', | ||||
| 'interval': 1, | 'interval': 1, | ||||
| 'save_best_ckpt': True, | |||||
| 'monitor_key': 'accuracy', | |||||
| 'rule': 'min', | |||||
| 'out_dir': os.path.join(self.tmp_dir, 'best_ckpt') | |||||
| }] | }] | ||||
| }, | }, | ||||
| 'evaluation': { | 'evaluation': { | ||||
| @@ -174,16 +104,11 @@ class EvaluationHookTest(unittest.TestCase): | |||||
| data_collator=None, | data_collator=None, | ||||
| train_dataset=dummy_dataset, | train_dataset=dummy_dataset, | ||||
| eval_dataset=dummy_dataset, | eval_dataset=dummy_dataset, | ||||
| max_epochs=3) | |||||
| max_epochs=1) | |||||
| trainer = build_trainer(trainer_name, kwargs) | trainer = build_trainer(trainer_name, kwargs) | ||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | |||||
| os.listdir(os.path.join(self.tmp_dir, 'best_ckpt'))) | |||||
| self.assertDictEqual(trainer.metric_values, {'accuracy': 0.5}) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -9,16 +9,36 @@ import numpy as np | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| from torch.optim.lr_scheduler import MultiStepLR | |||||
| from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau | |||||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | ||||
| from modelscope.utils.registry import default_group | |||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | from modelscope.utils.test_utils import create_dummy_test_dataset | ||||
| dummy_dataset = create_dummy_test_dataset( | dummy_dataset = create_dummy_test_dataset( | ||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) | ||||
| def create_dummy_metric(): | |||||
| _global_iter = 0 | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name='DummyMetric', force=True) | |||||
| class DummyMetric: | |||||
| _fake_acc_by_epoch = {1: 0.1, 2: 0.1, 3: 0.1, 4: 0.1, 5: 0.3} | |||||
| def add(*args, **kwargs): | |||||
| pass | |||||
| def evaluate(self): | |||||
| global _global_iter | |||||
| _global_iter += 1 | |||||
| return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} | |||||
| class DummyModel(nn.Module): | class DummyModel(nn.Module): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -41,12 +61,16 @@ class LrSchedulerHookTest(unittest.TestCase): | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | self.tmp_dir = tempfile.TemporaryDirectory().name | ||||
| if not os.path.exists(self.tmp_dir): | if not os.path.exists(self.tmp_dir): | ||||
| os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
| create_dummy_metric() | |||||
| def tearDown(self): | def tearDown(self): | ||||
| super().tearDown() | super().tearDown() | ||||
| shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
| def test_lr_scheduler_hook(self): | def test_lr_scheduler_hook(self): | ||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | json_cfg = { | ||||
| 'task': 'image_classification', | 'task': 'image_classification', | ||||
| 'train': { | 'train': { | ||||
| @@ -85,25 +109,26 @@ class LrSchedulerHookTest(unittest.TestCase): | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | trainer.invoke_hook(TrainerStages.before_train_epoch) | ||||
| for _, data_batch in enumerate(train_dataloader): | for _, data_batch in enumerate(train_dataloader): | ||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | trainer.invoke_hook(TrainerStages.before_train_iter) | ||||
| trainer.train_step(trainer.model, data_batch) | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| log_lrs.append(trainer.log_buffer.output[LogKeys.LR]) | log_lrs.append(trainer.log_buffer.output[LogKeys.LR]) | ||||
| optim_lrs.append(optimizer.param_groups[0]['lr']) | optim_lrs.append(optimizer.param_groups[0]['lr']) | ||||
| trainer.train_step(trainer.model, data_batch) | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | trainer.invoke_hook(TrainerStages.after_train_epoch) | ||||
| trainer._epoch += 1 | trainer._epoch += 1 | ||||
| trainer.invoke_hook(TrainerStages.after_run) | trainer.invoke_hook(TrainerStages.after_run) | ||||
| iters = 5 | iters = 5 | ||||
| target_lrs = [0.01] * iters * 1 + [0.001] * iters * 2 + [0.0001 | |||||
| ] * iters * 2 | |||||
| target_lrs = [0.01] * iters * 2 + [0.001] * iters * 2 + [0.0001 | |||||
| ] * iters * 1 | |||||
| self.assertListEqual(log_lrs, target_lrs) | self.assertListEqual(log_lrs, target_lrs) | ||||
| self.assertListEqual(optim_lrs, target_lrs) | self.assertListEqual(optim_lrs, target_lrs) | ||||
| def test_warmup_lr_scheduler_hook(self): | def test_warmup_lr_scheduler_hook(self): | ||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | json_cfg = { | ||||
| 'task': 'image_classification', | 'task': 'image_classification', | ||||
| 'train': { | 'train': { | ||||
| @@ -156,22 +181,118 @@ class LrSchedulerHookTest(unittest.TestCase): | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | trainer.invoke_hook(TrainerStages.before_train_epoch) | ||||
| for _, data_batch in enumerate(train_dataloader): | for _, data_batch in enumerate(train_dataloader): | ||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | trainer.invoke_hook(TrainerStages.before_train_iter) | ||||
| trainer.train_step(trainer.model, data_batch) | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| log_lrs.append(round(trainer.log_buffer.output[LogKeys.LR], 5)) | log_lrs.append(round(trainer.log_buffer.output[LogKeys.LR], 5)) | ||||
| optim_lrs.append( | optim_lrs.append( | ||||
| round(trainer.optimizer.param_groups[0]['lr'], 5)) | round(trainer.optimizer.param_groups[0]['lr'], 5)) | ||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||||
| trainer.invoke_hook(TrainerStages.after_run) | |||||
| iters = 5 | |||||
| target_lrs = [0.001] * iters * 1 + [0.004] * iters * 1 + [ | |||||
| 0.007 | |||||
| ] * iters * 1 + [0.01] * iters * 1 + [0.001] * iters * 2 + [ | |||||
| 0.0001 | |||||
| ] * iters * 1 | |||||
| self.assertListEqual(log_lrs, target_lrs) | |||||
| self.assertListEqual(optim_lrs, target_lrs) | |||||
| class PlateauLrSchedulerHookTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| create_dummy_metric() | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| def test_plateau_lr_scheduler_hook(self): | |||||
| global _global_iter | |||||
| _global_iter = 0 | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'lr_scheduler': { | |||||
| 'type': 'ReduceLROnPlateau', | |||||
| 'mode': 'max', | |||||
| 'factor': 0.1, | |||||
| 'patience': 2, | |||||
| }, | |||||
| 'lr_scheduler_hook': { | |||||
| 'type': 'PlateauLrSchedulerHook', | |||||
| 'metric_key': MetricKeys.ACCURACY | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1 | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': ['DummyMetric'] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| model = DummyModel() | |||||
| optimizer = SGD(model.parameters(), lr=0.01) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=model, | |||||
| train_dataset=dummy_dataset, | |||||
| eval_dataset=dummy_dataset, | |||||
| optimizers=(optimizer, None), | |||||
| max_epochs=5) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| train_dataloader = trainer._build_dataloader_with_dataset( | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | |||||
| trainer.data_loader = train_dataloader | |||||
| trainer.register_optimizers_hook() | |||||
| trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | |||||
| trainer.invoke_hook(TrainerStages.before_run) | |||||
| log_lrs = [] | |||||
| optim_lrs = [] | |||||
| for _ in range(trainer._epoch, trainer._max_epochs): | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for _, data_batch in enumerate(train_dataloader): | |||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||||
| trainer.train_step(trainer.model, data_batch) | trainer.train_step(trainer.model, data_batch) | ||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | trainer.invoke_hook(TrainerStages.after_train_iter) | ||||
| log_lrs.append(trainer.log_buffer.output[LogKeys.LR]) | |||||
| optim_lrs.append(optimizer.param_groups[0]['lr']) | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | trainer.invoke_hook(TrainerStages.after_train_epoch) | ||||
| trainer._epoch += 1 | |||||
| trainer.invoke_hook(TrainerStages.after_run) | trainer.invoke_hook(TrainerStages.after_run) | ||||
| iters = 5 | iters = 5 | ||||
| target_lrs = [0.004] * iters * 1 + [0.007] * iters * 1 + [ | |||||
| 0.01 | |||||
| ] * iters * 1 + [0.001] * iters * 2 + [0.0001] * iters * 2 | |||||
| target_lrs = [0.01] * iters * 4 + [0.001] * iters * 1 | |||||
| self.assertListEqual(log_lrs, target_lrs) | self.assertListEqual(log_lrs, target_lrs) | ||||
| self.assertListEqual(optim_lrs, target_lrs) | self.assertListEqual(optim_lrs, target_lrs) | ||||
| @@ -19,13 +19,6 @@ from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | ||||
| class DummyMetric: | |||||
| def __call__(self, ground_truth, predict_results): | |||||
| return {'accuracy': 0.5} | |||||
| dummy_dataset_small = create_dummy_test_dataset( | dummy_dataset_small = create_dummy_test_dataset( | ||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| @@ -265,14 +258,14 @@ class TrainerTest(unittest.TestCase): | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | LogKeys.MODE: ModeKeys.TRAIN, | ||||
| LogKeys.EPOCH: 2, | LogKeys.EPOCH: 2, | ||||
| LogKeys.ITER: 10, | LogKeys.ITER: 10, | ||||
| LogKeys.LR: 0.001 | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[3])) | }, json.loads(lines[3])) | ||||
| self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
| { | { | ||||
| LogKeys.MODE: ModeKeys.TRAIN, | LogKeys.MODE: ModeKeys.TRAIN, | ||||
| LogKeys.EPOCH: 2, | LogKeys.EPOCH: 2, | ||||
| LogKeys.ITER: 20, | LogKeys.ITER: 20, | ||||
| LogKeys.LR: 0.001 | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[4])) | }, json.loads(lines[4])) | ||||
| self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
| { | { | ||||
| @@ -18,13 +18,6 @@ from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||||
| from modelscope.utils.test_utils import (DistributedTestCase, | from modelscope.utils.test_utils import (DistributedTestCase, | ||||
| create_dummy_test_dataset, test_level) | create_dummy_test_dataset, test_level) | ||||
| class DummyMetric: | |||||
| def __call__(self, ground_truth, predict_results): | |||||
| return {'accuracy': 0.5} | |||||
| dummy_dataset_small = create_dummy_test_dataset( | dummy_dataset_small = create_dummy_test_dataset( | ||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| @@ -141,14 +134,14 @@ class TrainerTestSingleGpu(unittest.TestCase): | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | LogKeys.MODE: ModeKeys.TRAIN, | ||||
| LogKeys.EPOCH: 2, | LogKeys.EPOCH: 2, | ||||
| LogKeys.ITER: 10, | LogKeys.ITER: 10, | ||||
| LogKeys.LR: 0.001 | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[3])) | }, json.loads(lines[3])) | ||||
| self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
| { | { | ||||
| LogKeys.MODE: ModeKeys.TRAIN, | LogKeys.MODE: ModeKeys.TRAIN, | ||||
| LogKeys.EPOCH: 2, | LogKeys.EPOCH: 2, | ||||
| LogKeys.ITER: 20, | LogKeys.ITER: 20, | ||||
| LogKeys.LR: 0.001 | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[4])) | }, json.loads(lines[4])) | ||||
| self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
| { | { | ||||
| @@ -229,7 +222,7 @@ class TrainerTestMultiGpus(DistributedTestCase): | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | LogKeys.MODE: ModeKeys.TRAIN, | ||||
| LogKeys.EPOCH: 2, | LogKeys.EPOCH: 2, | ||||
| LogKeys.ITER: 10, | LogKeys.ITER: 10, | ||||
| LogKeys.LR: 0.001 | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[2])) | }, json.loads(lines[2])) | ||||
| self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
| { | { | ||||