1. Fix a bug in trainer's progress bar
2. Fix a bug that trainer does not support dataset in config file
3. Add feature: support go on training via checkpoint file
4. Add feature: support fixed filename when saving best checkpoint
5. Fix a bug that no id2label in config file after finetune of nlp models
6. Fix some other bugs
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10138906
master
| @@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module( | @METRICS.register_module( | ||||
| group_key=default_group, module_name=Metrics.seq_cls_metric) | group_key=default_group, module_name=Metrics.seq_cls_metric) | ||||
| class SequenceClassificationMetric(Metric): | class SequenceClassificationMetric(Metric): | ||||
| """The metric computation class for sequence classification classes. | |||||
| """The metric computation class for sequence classification tasks. | |||||
| This metric class calculates accuracy for the whole input batches. | |||||
| This metric class calculates accuracy of the whole input batches. | |||||
| """ | """ | ||||
| def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
| @@ -1,14 +1,16 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import random | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.metainfo import Hooks | from modelscope.metainfo import Hooks | ||||
| from modelscope.utils.checkpoint import save_checkpoint | |||||
| from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint | |||||
| from modelscope.utils.constant import LogKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModelFile | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.torch_utils import is_master | |||||
| from modelscope.utils.torch_utils import get_dist_info, is_master | |||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .priority import Priority | from .priority import Priority | ||||
| @@ -25,6 +27,7 @@ class CheckpointHook(Hook): | |||||
| save_optimizer (bool): Whether to save optimizer state dict. Default: True. | save_optimizer (bool): Whether to save optimizer state dict. Default: True. | ||||
| save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir` | save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir` | ||||
| save_last (bool): Whether to save the last checkpoint. Default: True. | save_last (bool): Whether to save the last checkpoint. Default: True. | ||||
| checkpoint_file (str): The checkpoint file to be loaded. | |||||
| """ | """ | ||||
| PRIORITY = Priority.LOW | PRIORITY = Priority.LOW | ||||
| @@ -34,12 +37,16 @@ class CheckpointHook(Hook): | |||||
| by_epoch=True, | by_epoch=True, | ||||
| save_optimizer=True, | save_optimizer=True, | ||||
| save_dir=None, | save_dir=None, | ||||
| save_last=True): | |||||
| save_last=True, | |||||
| checkpoint_file=None): | |||||
| self.interval = interval | self.interval = interval | ||||
| self.by_epoch = by_epoch | self.by_epoch = by_epoch | ||||
| self.save_optimizer = save_optimizer | self.save_optimizer = save_optimizer | ||||
| self.save_dir = save_dir | self.save_dir = save_dir | ||||
| self.checkpoint_file = checkpoint_file | |||||
| self.save_last = save_last | self.save_last = save_last | ||||
| self.rng_state = None | |||||
| self.need_load_rng_state = False | |||||
| def before_run(self, trainer): | def before_run(self, trainer): | ||||
| if not self.save_dir: | if not self.save_dir: | ||||
| @@ -56,6 +63,34 @@ class CheckpointHook(Hook): | |||||
| if is_master(): | if is_master(): | ||||
| self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | ||||
| if self.checkpoint_file is not None and os.path.isfile( | |||||
| self.checkpoint_file): | |||||
| meta = self.load_checkpoint(self.checkpoint_file, trainer) | |||||
| self.rng_state = meta.get('rng_state') | |||||
| self.need_load_rng_state = True | |||||
| def before_train_epoch(self, trainer): | |||||
| if self.need_load_rng_state: | |||||
| if self.rng_state is not None: | |||||
| random.setstate(self.rng_state['random']) | |||||
| np.random.set_state(self.rng_state['numpy']) | |||||
| torch.random.set_rng_state(self.rng_state['cpu']) | |||||
| if torch.cuda.is_available(): | |||||
| torch.cuda.random.set_rng_state_all(self.rng_state['cuda']) | |||||
| self.need_load_rng_state = False | |||||
| else: | |||||
| self.logger.warn( | |||||
| 'Random state cannot be found in checkpoint file, ' | |||||
| 'this may cause a random data order or model initialization.' | |||||
| ) | |||||
| self.rng_state = { | |||||
| 'random': random.getstate(), | |||||
| 'numpy': np.random.get_state(), | |||||
| 'cpu': torch.random.get_rng_state(), | |||||
| 'cuda': torch.cuda.get_rng_state_all(), | |||||
| } | |||||
| def after_train_epoch(self, trainer): | def after_train_epoch(self, trainer): | ||||
| if not self.by_epoch: | if not self.by_epoch: | ||||
| return | return | ||||
| @@ -66,6 +101,39 @@ class CheckpointHook(Hook): | |||||
| f'Saving checkpoint at {trainer.epoch + 1} epoch') | f'Saving checkpoint at {trainer.epoch + 1} epoch') | ||||
| self._save_checkpoint(trainer) | self._save_checkpoint(trainer) | ||||
| @classmethod | |||||
| def load_checkpoint(cls, filename, trainer): | |||||
| from modelscope.trainers.parallel.utils import is_parallel | |||||
| if is_parallel(trainer.model): | |||||
| model = trainer.model.module | |||||
| else: | |||||
| model = trainer.model | |||||
| meta = load_checkpoint(filename, model, trainer.optimizer, | |||||
| trainer.lr_scheduler) | |||||
| trainer._epoch = meta.get('epoch', trainer._epoch) | |||||
| trainer._iter = meta.get('iter', trainer._iter) | |||||
| trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | |||||
| for i, hook in enumerate(trainer.hooks): | |||||
| # hook: Hook | |||||
| key = f'{hook.__class__}-{i}' | |||||
| if key in meta: | |||||
| hook.load_state_dict(meta[key]) | |||||
| else: | |||||
| trainer.logger( | |||||
| f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | |||||
| ) | |||||
| version = meta.get('modelscope') | |||||
| if version != __version__: | |||||
| trainer.logger( | |||||
| f'The modelscope version of loaded checkpoint does not match the runtime version. ' | |||||
| f'The saved version: {version}, runtime version: {__version__}' | |||||
| ) | |||||
| trainer.logger( | |||||
| f'Checkpoint {filename} saving time: {meta.get("time")}') | |||||
| return meta | |||||
| def _save_checkpoint(self, trainer): | def _save_checkpoint(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| cur_save_name = os.path.join( | cur_save_name = os.path.join( | ||||
| @@ -74,7 +142,21 @@ class CheckpointHook(Hook): | |||||
| cur_save_name = os.path.join( | cur_save_name = os.path.join( | ||||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | ||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| meta = { | |||||
| 'epoch': trainer.epoch, | |||||
| 'iter': trainer.iter + 1, | |||||
| 'inner_iter': trainer.inner_iter + 1, | |||||
| 'rng_state': self.rng_state, | |||||
| } | |||||
| for i, hook in enumerate(trainer.hooks): | |||||
| meta[f'{hook.__class__}-{i}'] = hook.state_dict() | |||||
| save_checkpoint( | |||||
| trainer.model, | |||||
| cur_save_name, | |||||
| trainer.optimizer, | |||||
| trainer.lr_scheduler, | |||||
| meta=meta) | |||||
| if (self.is_last_epoch(trainer) | if (self.is_last_epoch(trainer) | ||||
| and self.by_epoch) or (self.is_last_iter(trainer) | and self.by_epoch) or (self.is_last_iter(trainer) | ||||
| and not self.by_epoch): | and not self.by_epoch): | ||||
| @@ -144,6 +226,7 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| by_epoch=True, | by_epoch=True, | ||||
| save_optimizer=True, | save_optimizer=True, | ||||
| save_dir=None, | save_dir=None, | ||||
| save_file_name=None, | |||||
| interval=0): | interval=0): | ||||
| assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | ||||
| super().__init__( | super().__init__( | ||||
| @@ -179,16 +262,44 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| return False | return False | ||||
| def _save_checkpoint(self, trainer): | def _save_checkpoint(self, trainer): | ||||
| if self.by_epoch: | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' | |||||
| ) | |||||
| else: | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||||
| ) | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| cur_save_name = self.save_file_name | |||||
| if cur_save_name is None: | |||||
| if self.by_epoch: | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' | |||||
| ) | |||||
| else: | |||||
| cur_save_name = os.path.join( | |||||
| self.save_dir, | |||||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||||
| ) | |||||
| meta = { | |||||
| 'epoch': trainer.epoch, | |||||
| 'iter': trainer.iter + 1, | |||||
| 'inner_iter': trainer.inner_iter + 1, | |||||
| 'rng_state': self.rng_state, | |||||
| } | |||||
| for i, hook in enumerate(trainer.hooks): | |||||
| meta[f'{hook.__class__}-{i}'] = hook.state_dict() | |||||
| if os.path.isfile(cur_save_name): | |||||
| os.remove(cur_save_name) | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer, | |||||
| trainer.lr_scheduler, meta) | |||||
| self._best_ckpt_file = cur_save_name | self._best_ckpt_file = cur_save_name | ||||
| self._save_pretrained(trainer) | self._save_pretrained(trainer) | ||||
| def state_dict(self): | |||||
| return { | |||||
| 'best_metric': self._best_metric, | |||||
| } | |||||
| def load_state_dict(self, state_dict): | |||||
| if state_dict is not None and len(state_dict) > 0: | |||||
| self._best_metric = state_dict.get('best_metric') | |||||
| else: | |||||
| self.logger.warn( | |||||
| 'The state_dict is not available, the best metric value will be affected.' | |||||
| ) | |||||
| @@ -215,3 +215,9 @@ class Hook: | |||||
| trigger_stages.add(stage) | trigger_stages.add(stage) | ||||
| return [stage for stage in Hook.stages if stage in trigger_stages] | return [stage for stage in Hook.stages if stage in trigger_stages] | ||||
| def state_dict(self): | |||||
| return {} | |||||
| def load_state_dict(self, state_dict): | |||||
| pass | |||||
| @@ -4,6 +4,7 @@ import logging | |||||
| from torch.nn.utils import clip_grad | from torch.nn.utils import clip_grad | ||||
| from modelscope.metainfo import Hooks | from modelscope.metainfo import Hooks | ||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.trainers.hooks.builder import HOOKS | from modelscope.trainers.hooks.builder import HOOKS | ||||
| from modelscope.trainers.hooks.hook import Hook | from modelscope.trainers.hooks.hook import Hook | ||||
| from modelscope.trainers.hooks.priority import Priority | from modelscope.trainers.hooks.priority import Priority | ||||
| @@ -27,7 +28,7 @@ class OptimizerHook(Hook): | |||||
| def __init__(self, | def __init__(self, | ||||
| cumulative_iters=1, | cumulative_iters=1, | ||||
| grad_clip=None, | grad_clip=None, | ||||
| loss_keys='loss') -> None: | |||||
| loss_keys=OutputKeys.LOSS) -> None: | |||||
| if isinstance(loss_keys, str): | if isinstance(loss_keys, str): | ||||
| loss_keys = [loss_keys] | loss_keys = [loss_keys] | ||||
| assert isinstance(loss_keys, (tuple, list)) | assert isinstance(loss_keys, (tuple, list)) | ||||
| @@ -28,10 +28,10 @@ class BaseWarmup(_LRScheduler): | |||||
| return self.base_scheduler.get_lr() | return self.base_scheduler.get_lr() | ||||
| def state_dict(self): | def state_dict(self): | ||||
| self.base_scheduler.state_dict() | |||||
| return self.base_scheduler.state_dict() | |||||
| def load_state_dict(self, state_dict): | def load_state_dict(self, state_dict): | ||||
| self.base_scheduler.load_state_dict(state_dict) | |||||
| return self.base_scheduler.load_state_dict(state_dict) | |||||
| def scale(self): | def scale(self): | ||||
| """Scale the learning rates. | """Scale the learning rates. | ||||
| @@ -1,6 +1,7 @@ | |||||
| import os | import os | ||||
| from typing import Callable, Dict, Optional, Tuple, Union | |||||
| from typing import Callable, Optional, Tuple, Union | |||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| @@ -11,9 +12,10 @@ from modelscope.metrics.builder import build_metric | |||||
| from modelscope.models.base import Model, TorchModel | from modelscope.models.base import Model, TorchModel | ||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | from modelscope.preprocessors import Preprocessor, build_preprocessor | ||||
| from modelscope.utils.config import Config, ConfigDict | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys, | from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys, | ||||
| ModelFile, Tasks) | ModelFile, Tasks) | ||||
| from modelscope.utils.hub import parse_label_mapping | |||||
| from .base import TRAINERS | from .base import TRAINERS | ||||
| from .trainer import EpochBasedTrainer | from .trainer import EpochBasedTrainer | ||||
| @@ -81,19 +83,32 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||||
| assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' | assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' | ||||
| model_dir = os.path.dirname(cfg_file) | model_dir = os.path.dirname(cfg_file) | ||||
| self.label2id = None | |||||
| self.id2label = None | |||||
| self.num_labels = None | |||||
| self.cfg_modify_fn = cfg_modify_fn | self.cfg_modify_fn = cfg_modify_fn | ||||
| self.cfg = self.rebuild_config(Config.from_file(cfg_file)) | self.cfg = self.rebuild_config(Config.from_file(cfg_file)) | ||||
| try: | |||||
| labels = self.cfg.dataset.train.labels | |||||
| except AttributeError: | |||||
| labels = None | |||||
| self.label2id = None | |||||
| self.num_labels = None | |||||
| if labels is not None and len(labels) > 0: | |||||
| self.label2id = {label: idx for idx, label in enumerate(labels)} | |||||
| self.id2label = {idx: label for idx, label in enumerate(labels)} | |||||
| self.num_labels = len(labels) | |||||
| label2id = parse_label_mapping(model_dir) | |||||
| if label2id is not None: | |||||
| self.label2id = label2id | |||||
| self.id2label = {id: label for label, id in label2id.items()} | |||||
| self.num_labels = len(label2id) | |||||
| else: | |||||
| try: | |||||
| labels = self.cfg.dataset.train.labels | |||||
| if labels is not None and len(labels) > 0: | |||||
| self.label2id = { | |||||
| label: idx | |||||
| for idx, label in enumerate(labels) | |||||
| } | |||||
| self.id2label = { | |||||
| idx: label | |||||
| for idx, label in enumerate(labels) | |||||
| } | |||||
| self.num_labels = len(labels) | |||||
| except AttributeError: | |||||
| pass | |||||
| def build_dataset_keys(cfg): | def build_dataset_keys(cfg): | ||||
| if cfg is not None: | if cfg is not None: | ||||
| @@ -130,7 +145,13 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||||
| def rebuild_config(self, cfg: Config): | def rebuild_config(self, cfg: Config): | ||||
| if self.cfg_modify_fn is not None: | if self.cfg_modify_fn is not None: | ||||
| return self.cfg_modify_fn(cfg) | |||||
| cfg = self.cfg_modify_fn(cfg) | |||||
| if not hasattr(cfg.model, 'label2id') and not hasattr( | |||||
| cfg.model, 'id2label'): | |||||
| if self.id2label is not None: | |||||
| cfg.model['id2label'] = self.id2label | |||||
| if self.label2id is not None: | |||||
| cfg.model['label2id'] = self.label2id | |||||
| return cfg | return cfg | ||||
| def build_model(self) -> Union[nn.Module, TorchModel]: | def build_model(self) -> Union[nn.Module, TorchModel]: | ||||
| @@ -203,6 +224,9 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||||
| """ | """ | ||||
| from modelscope.msdatasets.task_datasets import VecoDataset | from modelscope.msdatasets.task_datasets import VecoDataset | ||||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||||
| from modelscope.trainers.hooks import CheckpointHook | |||||
| CheckpointHook.load_checkpoint(checkpoint_path, self) | |||||
| self.model.eval() | self.model.eval() | ||||
| self._mode = ModeKeys.EVAL | self._mode = ModeKeys.EVAL | ||||
| metric_values = {} | metric_values = {} | ||||
| @@ -223,12 +247,10 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||||
| self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) | self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) | ||||
| self.data_loader = self.eval_dataloader | self.data_loader = self.eval_dataloader | ||||
| metric_classes = [ | |||||
| build_metric(metric, default_args={'trainer': self}) | |||||
| for metric in self.metrics | |||||
| ] | |||||
| self.evaluation_loop(self.eval_dataloader, checkpoint_path, | |||||
| metric_classes) | |||||
| metric_classes = [build_metric(metric) for metric in self.metrics] | |||||
| for m in metric_classes: | |||||
| m.trainer = self | |||||
| self.evaluation_loop(self.eval_dataloader, metric_classes) | |||||
| for m_idx, metric_cls in enumerate(metric_classes): | for m_idx, metric_cls in enumerate(metric_classes): | ||||
| if f'eval_dataset[{idx}]' not in metric_values: | if f'eval_dataset[{idx}]' not in metric_values: | ||||
| @@ -242,4 +264,8 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||||
| else: | else: | ||||
| break | break | ||||
| for metric_name in self.metrics: | |||||
| metric_values[metric_name] = np.average( | |||||
| [m[metric_name] for m in metric_values.values()]) | |||||
| return metric_values | return metric_values | ||||
| @@ -1,6 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import random | |||||
| import time | import time | ||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||
| from distutils.version import LooseVersion | from distutils.version import LooseVersion | ||||
| @@ -8,7 +7,6 @@ from functools import partial | |||||
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | ||||
| import json | import json | ||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from torch import distributed as dist | from torch import distributed as dist | ||||
| from torch import nn | from torch import nn | ||||
| @@ -425,8 +423,16 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| metrics = [metrics] | metrics = [metrics] | ||||
| return metrics | return metrics | ||||
| def train(self, *args, **kwargs): | |||||
| self.model.train() | |||||
| def set_checkpoint_file_to_hook(self, checkpoint_path): | |||||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||||
| from modelscope.trainers.hooks import CheckpointHook | |||||
| checkpoint_hooks = list( | |||||
| filter(lambda hook: isinstance(hook, CheckpointHook), | |||||
| self.hooks)) | |||||
| for hook in checkpoint_hooks: | |||||
| hook.checkpoint_file = checkpoint_path | |||||
| def train(self, checkpoint_path=None, *args, **kwargs): | |||||
| self._mode = ModeKeys.TRAIN | self._mode = ModeKeys.TRAIN | ||||
| if self.train_dataset is None: | if self.train_dataset is None: | ||||
| @@ -442,13 +448,17 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.register_optimizers_hook() | self.register_optimizers_hook() | ||||
| self.register_hook_from_cfg(self.cfg.train.hooks) | self.register_hook_from_cfg(self.cfg.train.hooks) | ||||
| self.set_checkpoint_file_to_hook(checkpoint_path) | |||||
| self.model.train() | |||||
| self.train_loop(self.train_dataloader) | self.train_loop(self.train_dataloader) | ||||
| def evaluate(self, checkpoint_path=None): | def evaluate(self, checkpoint_path=None): | ||||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||||
| from modelscope.trainers.hooks import CheckpointHook | |||||
| CheckpointHook.load_checkpoint(checkpoint_path, self) | |||||
| self.model.eval() | self.model.eval() | ||||
| self._mode = ModeKeys.EVAL | self._mode = ModeKeys.EVAL | ||||
| if self.eval_dataset is None: | if self.eval_dataset is None: | ||||
| self.eval_dataloader = self.get_eval_data_loader() | self.eval_dataloader = self.get_eval_data_loader() | ||||
| else: | else: | ||||
| @@ -462,8 +472,9 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| metric_classes = [build_metric(metric) for metric in self.metrics] | metric_classes = [build_metric(metric) for metric in self.metrics] | ||||
| for m in metric_classes: | for m in metric_classes: | ||||
| m.trainer = self | m.trainer = self | ||||
| metric_values = self.evaluation_loop(self.eval_dataloader, | metric_values = self.evaluation_loop(self.eval_dataloader, | ||||
| checkpoint_path, metric_classes) | |||||
| metric_classes) | |||||
| self._metric_values = metric_values | self._metric_values = metric_values | ||||
| return metric_values | return metric_values | ||||
| @@ -631,18 +642,13 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| if hasattr(data_cfg, 'name'): | if hasattr(data_cfg, 'name'): | ||||
| dataset = MsDataset.load( | dataset = MsDataset.load( | ||||
| dataset_name=data_cfg.name, | dataset_name=data_cfg.name, | ||||
| split=data_cfg.split, | |||||
| subset_name=data_cfg.subset_name if hasattr( | |||||
| data_cfg, 'subset_name') else None, | |||||
| hub=data_cfg.hub | |||||
| if hasattr(data_cfg, 'hub') else Hubs.modelscope, | |||||
| **data_cfg, | **data_cfg, | ||||
| ) | ) | ||||
| cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | ||||
| torch_dataset = dataset.to_torch_dataset( | torch_dataset = dataset.to_torch_dataset( | ||||
| task_data_config=cfg, | task_data_config=cfg, | ||||
| task_name=self.cfg.task, | task_name=self.cfg.task, | ||||
| preprocessors=self.preprocessor) | |||||
| preprocessors=preprocessor) | |||||
| else: | else: | ||||
| torch_dataset = build_task_dataset(data_cfg, self.cfg.task) | torch_dataset = build_task_dataset(data_cfg, self.cfg.task) | ||||
| dataset = self.to_task_dataset(torch_dataset, mode) | dataset = self.to_task_dataset(torch_dataset, mode) | ||||
| @@ -802,19 +808,22 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| """ Training loop used by `EpochBasedTrainer.train()` | """ Training loop used by `EpochBasedTrainer.train()` | ||||
| """ | """ | ||||
| self.invoke_hook(TrainerStages.before_run) | self.invoke_hook(TrainerStages.before_run) | ||||
| self._epoch = 0 | |||||
| kwargs = {} | kwargs = {} | ||||
| self.model.train() | self.model.train() | ||||
| for _ in range(self._epoch, self._max_epochs): | for _ in range(self._epoch, self._max_epochs): | ||||
| self.invoke_hook(TrainerStages.before_train_epoch) | self.invoke_hook(TrainerStages.before_train_epoch) | ||||
| time.sleep(2) # Prevent possible deadlock during epoch transition | time.sleep(2) # Prevent possible deadlock during epoch transition | ||||
| for i, data_batch in enumerate(data_loader): | for i, data_batch in enumerate(data_loader): | ||||
| if i < self.inner_iter: | |||||
| # inner_iter may be read out from the checkpoint file, so skip the trained iters in the epoch. | |||||
| continue | |||||
| data_batch = to_device(data_batch, self.device) | data_batch = to_device(data_batch, self.device) | ||||
| self.data_batch = data_batch | self.data_batch = data_batch | ||||
| self._inner_iter = i | self._inner_iter = i | ||||
| self.invoke_hook(TrainerStages.before_train_iter) | self.invoke_hook(TrainerStages.before_train_iter) | ||||
| self.train_step(self.model, data_batch, **kwargs) | self.train_step(self.model, data_batch, **kwargs) | ||||
| self.invoke_hook(TrainerStages.after_train_iter) | self.invoke_hook(TrainerStages.after_train_iter) | ||||
| # Value changed after the hooks are invoked, do not move them above the invoke_hook code. | |||||
| del self.data_batch | del self.data_batch | ||||
| self._iter += 1 | self._iter += 1 | ||||
| self._mode = ModeKeys.TRAIN | self._mode = ModeKeys.TRAIN | ||||
| @@ -823,12 +832,14 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| break | break | ||||
| self.invoke_hook(TrainerStages.after_train_epoch) | self.invoke_hook(TrainerStages.after_train_epoch) | ||||
| # Value changed after the hooks are invoked, do not move them above the invoke_hook code. | |||||
| self._inner_iter = 0 | |||||
| self._epoch += 1 | self._epoch += 1 | ||||
| time.sleep(1) # wait for some hooks like loggers to finish | time.sleep(1) # wait for some hooks like loggers to finish | ||||
| self.invoke_hook(TrainerStages.after_run) | self.invoke_hook(TrainerStages.after_run) | ||||
| def evaluation_loop(self, data_loader, checkpoint_path, metric_classes): | |||||
| def evaluation_loop(self, data_loader, metric_classes): | |||||
| """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | ||||
| """ | """ | ||||
| @@ -841,7 +852,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| tmpdir=None, | tmpdir=None, | ||||
| gpu_collect=False, | gpu_collect=False, | ||||
| metric_classes=metric_classes, | metric_classes=metric_classes, | ||||
| data_loader_iters_per_gpu=self.iters_per_epoch) | |||||
| data_loader_iters_per_gpu=self._eval_iters_per_epoch) | |||||
| else: | else: | ||||
| from modelscope.trainers.utils.inference import single_gpu_test | from modelscope.trainers.utils.inference import single_gpu_test | ||||
| metric_values = single_gpu_test( | metric_values = single_gpu_test( | ||||
| @@ -849,7 +860,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| data_loader, | data_loader, | ||||
| device=self.device, | device=self.device, | ||||
| metric_classes=metric_classes, | metric_classes=metric_classes, | ||||
| data_loader_iters=self.iters_per_epoch) | |||||
| data_loader_iters=self._eval_iters_per_epoch) | |||||
| self._inner_iter = self.iters_per_epoch - 1 # start from index 0 | self._inner_iter = self.iters_per_epoch - 1 # start from index 0 | ||||
| @@ -8,14 +8,17 @@ from shutil import copytree, ignore_patterns, rmtree | |||||
| from typing import Callable, List, Optional, Union | from typing import Callable, List, Optional, Union | ||||
| import json | import json | ||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from torch.optim import Optimizer | from torch.optim import Optimizer | ||||
| from torch.optim.lr_scheduler import _LRScheduler | |||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.fileio import File, LocalStorage | from modelscope.fileio import File, LocalStorage | ||||
| from modelscope.utils.config import JSONIteratorEncoder | from modelscope.utils.config import JSONIteratorEncoder | ||||
| from modelscope.utils.constant import ConfigFields, ModelFile | from modelscope.utils.constant import ConfigFields, ModelFile | ||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger(__name__) | |||||
| storage = LocalStorage() | storage = LocalStorage() | ||||
| @@ -40,24 +43,27 @@ def weights_to_cpu(state_dict): | |||||
| def save_checkpoint(model: torch.nn.Module, | def save_checkpoint(model: torch.nn.Module, | ||||
| filename: str, | filename: str, | ||||
| optimizer: Optional[Optimizer] = None, | optimizer: Optional[Optimizer] = None, | ||||
| lr_scheduler: Optional[_LRScheduler] = None, | |||||
| meta: Optional[dict] = None, | meta: Optional[dict] = None, | ||||
| with_meta: bool = True) -> None: | with_meta: bool = True) -> None: | ||||
| """Save checkpoint to file. | """Save checkpoint to file. | ||||
| The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | ||||
| ``optimizer``. By default ``meta`` will contain version and time info. | |||||
| ``optimizer``. By default, ``meta`` will contain version and time info. | |||||
| Args: | Args: | ||||
| model (Module): Module whose params are to be saved. | model (Module): Module whose params are to be saved. | ||||
| filename (str): Checkpoint filename. | filename (str): Checkpoint filename. | ||||
| optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. | ||||
| lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved. | |||||
| meta (dict, optional): Metadata to be saved in checkpoint. | meta (dict, optional): Metadata to be saved in checkpoint. | ||||
| with_meta (bool, optional): | |||||
| """ | """ | ||||
| if meta is None: | if meta is None: | ||||
| meta = {} | meta = {} | ||||
| elif not isinstance(meta, dict): | elif not isinstance(meta, dict): | ||||
| raise TypeError(f'meta must be a dict or None, but got {type(meta)}') | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') | ||||
| meta.update(modescope=__version__, time=time.asctime()) | |||||
| meta.update(modelscope=__version__, time=time.asctime()) | |||||
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | if isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||||
| model = model.module | model = model.module | ||||
| @@ -71,22 +77,69 @@ def save_checkpoint(model: torch.nn.Module, | |||||
| 'meta': meta, | 'meta': meta, | ||||
| 'state_dict': weights_to_cpu(model.state_dict()) | 'state_dict': weights_to_cpu(model.state_dict()) | ||||
| } | } | ||||
| # save optimizer state dict in the checkpoint | |||||
| if isinstance(optimizer, Optimizer): | |||||
| checkpoint['optimizer'] = optimizer.state_dict() | |||||
| elif isinstance(optimizer, dict): | |||||
| checkpoint['optimizer'] = {} | |||||
| for name, optim in optimizer.items(): | |||||
| checkpoint['optimizer'][name] = optim.state_dict() | |||||
| # save lr_scheduler state dict in the checkpoint | |||||
| assert isinstance(lr_scheduler, _LRScheduler), \ | |||||
| f'lr_scheduler to be saved should be a subclass of _LRScheduler, current is : {lr_scheduler.__class__}' | |||||
| checkpoint['lr_scheduler'] = lr_scheduler.state_dict() | |||||
| else: | else: | ||||
| checkpoint = weights_to_cpu(model.state_dict()) | checkpoint = weights_to_cpu(model.state_dict()) | ||||
| # save optimizer state dict in the checkpoint | |||||
| if isinstance(optimizer, Optimizer): | |||||
| checkpoint['optimizer'] = optimizer.state_dict() | |||||
| elif isinstance(optimizer, dict): | |||||
| checkpoint['optimizer'] = {} | |||||
| for name, optim in optimizer.items(): | |||||
| checkpoint['optimizer'][name] = optim.state_dict() | |||||
| with io.BytesIO() as f: | with io.BytesIO() as f: | ||||
| torch.save(checkpoint, f) | torch.save(checkpoint, f) | ||||
| File.write(f.getvalue(), filename) | File.write(f.getvalue(), filename) | ||||
| def load_checkpoint(filename, | |||||
| model, | |||||
| optimizer: Optimizer = None, | |||||
| lr_scheduler: _LRScheduler = None): | |||||
| if not os.path.exists(filename): | |||||
| raise ValueError(f'Checkpoint file {filename} does not exist!') | |||||
| checkpoint = torch.load(filename, map_location='cpu') | |||||
| if optimizer is not None: | |||||
| if 'optimizer' in checkpoint: | |||||
| if isinstance(optimizer, Optimizer): | |||||
| optimizer.load_state_dict(checkpoint['optimizer']) | |||||
| elif isinstance(optimizer, dict): | |||||
| optimizer_dict = checkpoint['optimizer'] | |||||
| for key, optimizer_ins in optimizer.items(): | |||||
| if key in optimizer_dict: | |||||
| optimizer_ins.load_state_dict(optimizer_dict[key]) | |||||
| else: | |||||
| logger.warn( | |||||
| f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}' | |||||
| ) | |||||
| else: | |||||
| logger.warn( | |||||
| f'The state dict of optimizer cannot be found in checkpoint file: {filename}' | |||||
| ) | |||||
| if lr_scheduler is not None: | |||||
| if 'lr_scheduler' in checkpoint: | |||||
| lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |||||
| else: | |||||
| logger.warn( | |||||
| f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}' | |||||
| ) | |||||
| state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | |||||
| 'state_dict'] | |||||
| model.load_state_dict(state_dict) | |||||
| if 'meta' in checkpoint: | |||||
| return checkpoint.get('meta', {}) | |||||
| def save_pretrained(model, | def save_pretrained(model, | ||||
| target_folder: Union[str, os.PathLike], | target_folder: Union[str, os.PathLike], | ||||
| save_checkpoint_name: str = None, | save_checkpoint_name: str = None, | ||||
| @@ -299,19 +299,23 @@ class MsRegressTool(RegressTool): | |||||
| file_name, | file_name, | ||||
| level='config', | level='config', | ||||
| compare_fn=None, | compare_fn=None, | ||||
| ignore_keys=None): | |||||
| ignore_keys=None, | |||||
| compare_random=True, | |||||
| lazy_stop_callback=None): | |||||
| def lazy_stop_callback(): | |||||
| if lazy_stop_callback is None: | |||||
| from modelscope.trainers.hooks.hook import Hook, Priority | |||||
| def lazy_stop_callback(): | |||||
| class EarlyStopHook(Hook): | |||||
| PRIORITY = Priority.VERY_LOW | |||||
| from modelscope.trainers.hooks.hook import Hook, Priority | |||||
| def after_iter(self, trainer): | |||||
| raise MsRegressTool.EarlyStopError('Test finished.') | |||||
| class EarlyStopHook(Hook): | |||||
| PRIORITY = Priority.VERY_LOW | |||||
| trainer.register_hook(EarlyStopHook()) | |||||
| def after_iter(self, trainer): | |||||
| raise MsRegressTool.EarlyStopError('Test finished.') | |||||
| trainer.register_hook(EarlyStopHook()) | |||||
| def _train_loop(trainer, *args, **kwargs): | def _train_loop(trainer, *args, **kwargs): | ||||
| with self.monitor_module_train( | with self.monitor_module_train( | ||||
| @@ -320,6 +324,7 @@ class MsRegressTool(RegressTool): | |||||
| level, | level, | ||||
| compare_fn=compare_fn, | compare_fn=compare_fn, | ||||
| ignore_keys=ignore_keys, | ignore_keys=ignore_keys, | ||||
| compare_random=compare_random, | |||||
| lazy_stop_callback=lazy_stop_callback): | lazy_stop_callback=lazy_stop_callback): | ||||
| try: | try: | ||||
| return trainer.train_loop_origin(*args, **kwargs) | return trainer.train_loop_origin(*args, **kwargs) | ||||
| @@ -1,8 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| # Part of the implementation is borrowed from huggingface/transformers. | # Part of the implementation is borrowed from huggingface/transformers. | ||||
| from collections.abc import Mapping | |||||
| import numpy as np | |||||
| def torch_nested_numpify(tensors): | def torch_nested_numpify(tensors): | ||||
| @@ -1,3 +0,0 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:2df2a5f3cdfc6dded52d31a8e97d9a9c41a803cb6d46dee709c51872eda37b21 | |||||
| size 151830 | |||||
| @@ -11,7 +11,8 @@ from modelscope.models.nlp.sequence_classification import \ | |||||
| SbertForSequenceClassification | SbertForSequenceClassification | ||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.trainers import EpochBasedTrainer, build_trainer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.hub import read_config | from modelscope.utils.hub import read_config | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| @@ -119,6 +120,90 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | ||||
| self.assertTrue(Metrics.accuracy in eval_results) | self.assertTrue(Metrics.accuracy in eval_results) | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_configured_datasets(self): | |||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||||
| cfg: Config = read_config(model_id) | |||||
| cfg.train.max_epochs = 20 | |||||
| cfg.train.work_dir = self.tmp_dir | |||||
| cfg.dataset = { | |||||
| 'train': { | |||||
| 'name': 'afqmc_small', | |||||
| 'split': 'train', | |||||
| 'namespace': 'userxiaoming' | |||||
| }, | |||||
| 'val': { | |||||
| 'name': 'afqmc_small', | |||||
| 'split': 'train', | |||||
| 'namespace': 'userxiaoming' | |||||
| }, | |||||
| } | |||||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||||
| cfg.dump(cfg_file) | |||||
| kwargs = dict(model=model_id, cfg_file=cfg_file) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(cfg.train.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| eval_results = trainer.evaluate( | |||||
| checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||||
| self.assertTrue(Metrics.accuracy in eval_results) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_trainer_with_continue_train(self): | |||||
| from modelscope.utils.regress_test_utils import MsRegressTool | |||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||||
| cfg: Config = read_config(model_id) | |||||
| cfg.train.max_epochs = 3 | |||||
| cfg.train.work_dir = self.tmp_dir | |||||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||||
| cfg.dump(cfg_file) | |||||
| dataset = MsDataset.load('clue', subset_name='afqmc', split='train') | |||||
| dataset = dataset.to_hf_dataset().select(range(128)) | |||||
| kwargs = dict( | |||||
| model=model_id, | |||||
| train_dataset=dataset, | |||||
| eval_dataset=dataset, | |||||
| cfg_file=cfg_file) | |||||
| regress_tool = MsRegressTool(baseline=True) | |||||
| trainer: EpochBasedTrainer = build_trainer(default_args=kwargs) | |||||
| def lazy_stop_callback(): | |||||
| from modelscope.trainers.hooks.hook import Hook, Priority | |||||
| class EarlyStopHook(Hook): | |||||
| PRIORITY = Priority.VERY_LOW | |||||
| def after_iter(self, trainer): | |||||
| if trainer.iter == 12: | |||||
| raise MsRegressTool.EarlyStopError('Test finished.') | |||||
| if 'EarlyStopHook' not in [ | |||||
| hook.__class__.__name__ for hook in trainer.hooks | |||||
| ]: | |||||
| trainer.register_hook(EarlyStopHook()) | |||||
| with regress_tool.monitor_ms_train( | |||||
| trainer, | |||||
| 'trainer_continue_train', | |||||
| level='strict', | |||||
| lazy_stop_callback=lazy_stop_callback): | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| regress_tool = MsRegressTool(baseline=False) | |||||
| with regress_tool.monitor_ms_train( | |||||
| trainer, 'trainer_continue_train', level='strict'): | |||||
| trainer.train(os.path.join(self.tmp_dir, 'iter_12.pth')) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_trainer_with_model_and_args(self): | def test_trainer_with_model_and_args(self): | ||||
| tmp_dir = tempfile.TemporaryDirectory().name | tmp_dir = tempfile.TemporaryDirectory().name | ||||