1. Fix a bug in trainer's progress bar
2. Fix a bug that trainer does not support dataset in config file
3. Add feature: support go on training via checkpoint file
4. Add feature: support fixed filename when saving best checkpoint
5. Fix a bug that no id2label in config file after finetune of nlp models
6. Fix some other bugs
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10138906
master
| @@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys | |||
| @METRICS.register_module( | |||
| group_key=default_group, module_name=Metrics.seq_cls_metric) | |||
| class SequenceClassificationMetric(Metric): | |||
| """The metric computation class for sequence classification classes. | |||
| """The metric computation class for sequence classification tasks. | |||
| This metric class calculates accuracy for the whole input batches. | |||
| This metric class calculates accuracy of the whole input batches. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| @@ -1,14 +1,16 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import random | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from modelscope import __version__ | |||
| from modelscope.metainfo import Hooks | |||
| from modelscope.utils.checkpoint import save_checkpoint | |||
| from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint | |||
| from modelscope.utils.constant import LogKeys, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.torch_utils import is_master | |||
| from modelscope.utils.torch_utils import get_dist_info, is_master | |||
| from .builder import HOOKS | |||
| from .hook import Hook | |||
| from .priority import Priority | |||
| @@ -25,6 +27,7 @@ class CheckpointHook(Hook): | |||
| save_optimizer (bool): Whether to save optimizer state dict. Default: True. | |||
| save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir` | |||
| save_last (bool): Whether to save the last checkpoint. Default: True. | |||
| checkpoint_file (str): The checkpoint file to be loaded. | |||
| """ | |||
| PRIORITY = Priority.LOW | |||
| @@ -34,12 +37,16 @@ class CheckpointHook(Hook): | |||
| by_epoch=True, | |||
| save_optimizer=True, | |||
| save_dir=None, | |||
| save_last=True): | |||
| save_last=True, | |||
| checkpoint_file=None): | |||
| self.interval = interval | |||
| self.by_epoch = by_epoch | |||
| self.save_optimizer = save_optimizer | |||
| self.save_dir = save_dir | |||
| self.checkpoint_file = checkpoint_file | |||
| self.save_last = save_last | |||
| self.rng_state = None | |||
| self.need_load_rng_state = False | |||
| def before_run(self, trainer): | |||
| if not self.save_dir: | |||
| @@ -56,6 +63,34 @@ class CheckpointHook(Hook): | |||
| if is_master(): | |||
| self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | |||
| if self.checkpoint_file is not None and os.path.isfile( | |||
| self.checkpoint_file): | |||
| meta = self.load_checkpoint(self.checkpoint_file, trainer) | |||
| self.rng_state = meta.get('rng_state') | |||
| self.need_load_rng_state = True | |||
| def before_train_epoch(self, trainer): | |||
| if self.need_load_rng_state: | |||
| if self.rng_state is not None: | |||
| random.setstate(self.rng_state['random']) | |||
| np.random.set_state(self.rng_state['numpy']) | |||
| torch.random.set_rng_state(self.rng_state['cpu']) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.random.set_rng_state_all(self.rng_state['cuda']) | |||
| self.need_load_rng_state = False | |||
| else: | |||
| self.logger.warn( | |||
| 'Random state cannot be found in checkpoint file, ' | |||
| 'this may cause a random data order or model initialization.' | |||
| ) | |||
| self.rng_state = { | |||
| 'random': random.getstate(), | |||
| 'numpy': np.random.get_state(), | |||
| 'cpu': torch.random.get_rng_state(), | |||
| 'cuda': torch.cuda.get_rng_state_all(), | |||
| } | |||
| def after_train_epoch(self, trainer): | |||
| if not self.by_epoch: | |||
| return | |||
| @@ -66,6 +101,39 @@ class CheckpointHook(Hook): | |||
| f'Saving checkpoint at {trainer.epoch + 1} epoch') | |||
| self._save_checkpoint(trainer) | |||
| @classmethod | |||
| def load_checkpoint(cls, filename, trainer): | |||
| from modelscope.trainers.parallel.utils import is_parallel | |||
| if is_parallel(trainer.model): | |||
| model = trainer.model.module | |||
| else: | |||
| model = trainer.model | |||
| meta = load_checkpoint(filename, model, trainer.optimizer, | |||
| trainer.lr_scheduler) | |||
| trainer._epoch = meta.get('epoch', trainer._epoch) | |||
| trainer._iter = meta.get('iter', trainer._iter) | |||
| trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) | |||
| for i, hook in enumerate(trainer.hooks): | |||
| # hook: Hook | |||
| key = f'{hook.__class__}-{i}' | |||
| if key in meta: | |||
| hook.load_state_dict(meta[key]) | |||
| else: | |||
| trainer.logger( | |||
| f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' | |||
| ) | |||
| version = meta.get('modelscope') | |||
| if version != __version__: | |||
| trainer.logger( | |||
| f'The modelscope version of loaded checkpoint does not match the runtime version. ' | |||
| f'The saved version: {version}, runtime version: {__version__}' | |||
| ) | |||
| trainer.logger( | |||
| f'Checkpoint {filename} saving time: {meta.get("time")}') | |||
| return meta | |||
| def _save_checkpoint(self, trainer): | |||
| if self.by_epoch: | |||
| cur_save_name = os.path.join( | |||
| @@ -74,7 +142,21 @@ class CheckpointHook(Hook): | |||
| cur_save_name = os.path.join( | |||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| meta = { | |||
| 'epoch': trainer.epoch, | |||
| 'iter': trainer.iter + 1, | |||
| 'inner_iter': trainer.inner_iter + 1, | |||
| 'rng_state': self.rng_state, | |||
| } | |||
| for i, hook in enumerate(trainer.hooks): | |||
| meta[f'{hook.__class__}-{i}'] = hook.state_dict() | |||
| save_checkpoint( | |||
| trainer.model, | |||
| cur_save_name, | |||
| trainer.optimizer, | |||
| trainer.lr_scheduler, | |||
| meta=meta) | |||
| if (self.is_last_epoch(trainer) | |||
| and self.by_epoch) or (self.is_last_iter(trainer) | |||
| and not self.by_epoch): | |||
| @@ -144,6 +226,7 @@ class BestCkptSaverHook(CheckpointHook): | |||
| by_epoch=True, | |||
| save_optimizer=True, | |||
| save_dir=None, | |||
| save_file_name=None, | |||
| interval=0): | |||
| assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | |||
| super().__init__( | |||
| @@ -179,16 +262,44 @@ class BestCkptSaverHook(CheckpointHook): | |||
| return False | |||
| def _save_checkpoint(self, trainer): | |||
| if self.by_epoch: | |||
| cur_save_name = os.path.join( | |||
| self.save_dir, | |||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' | |||
| ) | |||
| else: | |||
| cur_save_name = os.path.join( | |||
| self.save_dir, | |||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||
| ) | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| cur_save_name = self.save_file_name | |||
| if cur_save_name is None: | |||
| if self.by_epoch: | |||
| cur_save_name = os.path.join( | |||
| self.save_dir, | |||
| f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' | |||
| ) | |||
| else: | |||
| cur_save_name = os.path.join( | |||
| self.save_dir, | |||
| f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' | |||
| ) | |||
| meta = { | |||
| 'epoch': trainer.epoch, | |||
| 'iter': trainer.iter + 1, | |||
| 'inner_iter': trainer.inner_iter + 1, | |||
| 'rng_state': self.rng_state, | |||
| } | |||
| for i, hook in enumerate(trainer.hooks): | |||
| meta[f'{hook.__class__}-{i}'] = hook.state_dict() | |||
| if os.path.isfile(cur_save_name): | |||
| os.remove(cur_save_name) | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer, | |||
| trainer.lr_scheduler, meta) | |||
| self._best_ckpt_file = cur_save_name | |||
| self._save_pretrained(trainer) | |||
| def state_dict(self): | |||
| return { | |||
| 'best_metric': self._best_metric, | |||
| } | |||
| def load_state_dict(self, state_dict): | |||
| if state_dict is not None and len(state_dict) > 0: | |||
| self._best_metric = state_dict.get('best_metric') | |||
| else: | |||
| self.logger.warn( | |||
| 'The state_dict is not available, the best metric value will be affected.' | |||
| ) | |||
| @@ -215,3 +215,9 @@ class Hook: | |||
| trigger_stages.add(stage) | |||
| return [stage for stage in Hook.stages if stage in trigger_stages] | |||
| def state_dict(self): | |||
| return {} | |||
| def load_state_dict(self, state_dict): | |||
| pass | |||
| @@ -4,6 +4,7 @@ import logging | |||
| from torch.nn.utils import clip_grad | |||
| from modelscope.metainfo import Hooks | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.trainers.hooks.builder import HOOKS | |||
| from modelscope.trainers.hooks.hook import Hook | |||
| from modelscope.trainers.hooks.priority import Priority | |||
| @@ -27,7 +28,7 @@ class OptimizerHook(Hook): | |||
| def __init__(self, | |||
| cumulative_iters=1, | |||
| grad_clip=None, | |||
| loss_keys='loss') -> None: | |||
| loss_keys=OutputKeys.LOSS) -> None: | |||
| if isinstance(loss_keys, str): | |||
| loss_keys = [loss_keys] | |||
| assert isinstance(loss_keys, (tuple, list)) | |||
| @@ -28,10 +28,10 @@ class BaseWarmup(_LRScheduler): | |||
| return self.base_scheduler.get_lr() | |||
| def state_dict(self): | |||
| self.base_scheduler.state_dict() | |||
| return self.base_scheduler.state_dict() | |||
| def load_state_dict(self, state_dict): | |||
| self.base_scheduler.load_state_dict(state_dict) | |||
| return self.base_scheduler.load_state_dict(state_dict) | |||
| def scale(self): | |||
| """Scale the learning rates. | |||
| @@ -1,6 +1,7 @@ | |||
| import os | |||
| from typing import Callable, Dict, Optional, Tuple, Union | |||
| from typing import Callable, Optional, Tuple, Union | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.utils.data import Dataset | |||
| @@ -11,9 +12,10 @@ from modelscope.metrics.builder import build_metric | |||
| from modelscope.models.base import Model, TorchModel | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | |||
| from modelscope.utils.config import Config, ConfigDict | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys, | |||
| ModelFile, Tasks) | |||
| from modelscope.utils.hub import parse_label_mapping | |||
| from .base import TRAINERS | |||
| from .trainer import EpochBasedTrainer | |||
| @@ -81,19 +83,32 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||
| assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' | |||
| model_dir = os.path.dirname(cfg_file) | |||
| self.label2id = None | |||
| self.id2label = None | |||
| self.num_labels = None | |||
| self.cfg_modify_fn = cfg_modify_fn | |||
| self.cfg = self.rebuild_config(Config.from_file(cfg_file)) | |||
| try: | |||
| labels = self.cfg.dataset.train.labels | |||
| except AttributeError: | |||
| labels = None | |||
| self.label2id = None | |||
| self.num_labels = None | |||
| if labels is not None and len(labels) > 0: | |||
| self.label2id = {label: idx for idx, label in enumerate(labels)} | |||
| self.id2label = {idx: label for idx, label in enumerate(labels)} | |||
| self.num_labels = len(labels) | |||
| label2id = parse_label_mapping(model_dir) | |||
| if label2id is not None: | |||
| self.label2id = label2id | |||
| self.id2label = {id: label for label, id in label2id.items()} | |||
| self.num_labels = len(label2id) | |||
| else: | |||
| try: | |||
| labels = self.cfg.dataset.train.labels | |||
| if labels is not None and len(labels) > 0: | |||
| self.label2id = { | |||
| label: idx | |||
| for idx, label in enumerate(labels) | |||
| } | |||
| self.id2label = { | |||
| idx: label | |||
| for idx, label in enumerate(labels) | |||
| } | |||
| self.num_labels = len(labels) | |||
| except AttributeError: | |||
| pass | |||
| def build_dataset_keys(cfg): | |||
| if cfg is not None: | |||
| @@ -130,7 +145,13 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): | |||
| def rebuild_config(self, cfg: Config): | |||
| if self.cfg_modify_fn is not None: | |||
| return self.cfg_modify_fn(cfg) | |||
| cfg = self.cfg_modify_fn(cfg) | |||
| if not hasattr(cfg.model, 'label2id') and not hasattr( | |||
| cfg.model, 'id2label'): | |||
| if self.id2label is not None: | |||
| cfg.model['id2label'] = self.id2label | |||
| if self.label2id is not None: | |||
| cfg.model['label2id'] = self.label2id | |||
| return cfg | |||
| def build_model(self) -> Union[nn.Module, TorchModel]: | |||
| @@ -203,6 +224,9 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
| """ | |||
| from modelscope.msdatasets.task_datasets import VecoDataset | |||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
| from modelscope.trainers.hooks import CheckpointHook | |||
| CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
| self.model.eval() | |||
| self._mode = ModeKeys.EVAL | |||
| metric_values = {} | |||
| @@ -223,12 +247,10 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
| self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) | |||
| self.data_loader = self.eval_dataloader | |||
| metric_classes = [ | |||
| build_metric(metric, default_args={'trainer': self}) | |||
| for metric in self.metrics | |||
| ] | |||
| self.evaluation_loop(self.eval_dataloader, checkpoint_path, | |||
| metric_classes) | |||
| metric_classes = [build_metric(metric) for metric in self.metrics] | |||
| for m in metric_classes: | |||
| m.trainer = self | |||
| self.evaluation_loop(self.eval_dataloader, metric_classes) | |||
| for m_idx, metric_cls in enumerate(metric_classes): | |||
| if f'eval_dataset[{idx}]' not in metric_values: | |||
| @@ -242,4 +264,8 @@ class VecoTrainer(NlpEpochBasedTrainer): | |||
| else: | |||
| break | |||
| for metric_name in self.metrics: | |||
| metric_values[metric_name] = np.average( | |||
| [m[metric_name] for m in metric_values.values()]) | |||
| return metric_values | |||
| @@ -1,6 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import random | |||
| import time | |||
| from collections.abc import Mapping | |||
| from distutils.version import LooseVersion | |||
| @@ -8,7 +7,6 @@ from functools import partial | |||
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from torch import distributed as dist | |||
| from torch import nn | |||
| @@ -425,8 +423,16 @@ class EpochBasedTrainer(BaseTrainer): | |||
| metrics = [metrics] | |||
| return metrics | |||
| def train(self, *args, **kwargs): | |||
| self.model.train() | |||
| def set_checkpoint_file_to_hook(self, checkpoint_path): | |||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
| from modelscope.trainers.hooks import CheckpointHook | |||
| checkpoint_hooks = list( | |||
| filter(lambda hook: isinstance(hook, CheckpointHook), | |||
| self.hooks)) | |||
| for hook in checkpoint_hooks: | |||
| hook.checkpoint_file = checkpoint_path | |||
| def train(self, checkpoint_path=None, *args, **kwargs): | |||
| self._mode = ModeKeys.TRAIN | |||
| if self.train_dataset is None: | |||
| @@ -442,13 +448,17 @@ class EpochBasedTrainer(BaseTrainer): | |||
| self.register_optimizers_hook() | |||
| self.register_hook_from_cfg(self.cfg.train.hooks) | |||
| self.set_checkpoint_file_to_hook(checkpoint_path) | |||
| self.model.train() | |||
| self.train_loop(self.train_dataloader) | |||
| def evaluate(self, checkpoint_path=None): | |||
| if checkpoint_path is not None and os.path.isfile(checkpoint_path): | |||
| from modelscope.trainers.hooks import CheckpointHook | |||
| CheckpointHook.load_checkpoint(checkpoint_path, self) | |||
| self.model.eval() | |||
| self._mode = ModeKeys.EVAL | |||
| if self.eval_dataset is None: | |||
| self.eval_dataloader = self.get_eval_data_loader() | |||
| else: | |||
| @@ -462,8 +472,9 @@ class EpochBasedTrainer(BaseTrainer): | |||
| metric_classes = [build_metric(metric) for metric in self.metrics] | |||
| for m in metric_classes: | |||
| m.trainer = self | |||
| metric_values = self.evaluation_loop(self.eval_dataloader, | |||
| checkpoint_path, metric_classes) | |||
| metric_classes) | |||
| self._metric_values = metric_values | |||
| return metric_values | |||
| @@ -631,18 +642,13 @@ class EpochBasedTrainer(BaseTrainer): | |||
| if hasattr(data_cfg, 'name'): | |||
| dataset = MsDataset.load( | |||
| dataset_name=data_cfg.name, | |||
| split=data_cfg.split, | |||
| subset_name=data_cfg.subset_name if hasattr( | |||
| data_cfg, 'subset_name') else None, | |||
| hub=data_cfg.hub | |||
| if hasattr(data_cfg, 'hub') else Hubs.modelscope, | |||
| **data_cfg, | |||
| ) | |||
| cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | |||
| torch_dataset = dataset.to_torch_dataset( | |||
| task_data_config=cfg, | |||
| task_name=self.cfg.task, | |||
| preprocessors=self.preprocessor) | |||
| preprocessors=preprocessor) | |||
| else: | |||
| torch_dataset = build_task_dataset(data_cfg, self.cfg.task) | |||
| dataset = self.to_task_dataset(torch_dataset, mode) | |||
| @@ -802,19 +808,22 @@ class EpochBasedTrainer(BaseTrainer): | |||
| """ Training loop used by `EpochBasedTrainer.train()` | |||
| """ | |||
| self.invoke_hook(TrainerStages.before_run) | |||
| self._epoch = 0 | |||
| kwargs = {} | |||
| self.model.train() | |||
| for _ in range(self._epoch, self._max_epochs): | |||
| self.invoke_hook(TrainerStages.before_train_epoch) | |||
| time.sleep(2) # Prevent possible deadlock during epoch transition | |||
| for i, data_batch in enumerate(data_loader): | |||
| if i < self.inner_iter: | |||
| # inner_iter may be read out from the checkpoint file, so skip the trained iters in the epoch. | |||
| continue | |||
| data_batch = to_device(data_batch, self.device) | |||
| self.data_batch = data_batch | |||
| self._inner_iter = i | |||
| self.invoke_hook(TrainerStages.before_train_iter) | |||
| self.train_step(self.model, data_batch, **kwargs) | |||
| self.invoke_hook(TrainerStages.after_train_iter) | |||
| # Value changed after the hooks are invoked, do not move them above the invoke_hook code. | |||
| del self.data_batch | |||
| self._iter += 1 | |||
| self._mode = ModeKeys.TRAIN | |||
| @@ -823,12 +832,14 @@ class EpochBasedTrainer(BaseTrainer): | |||
| break | |||
| self.invoke_hook(TrainerStages.after_train_epoch) | |||
| # Value changed after the hooks are invoked, do not move them above the invoke_hook code. | |||
| self._inner_iter = 0 | |||
| self._epoch += 1 | |||
| time.sleep(1) # wait for some hooks like loggers to finish | |||
| self.invoke_hook(TrainerStages.after_run) | |||
| def evaluation_loop(self, data_loader, checkpoint_path, metric_classes): | |||
| def evaluation_loop(self, data_loader, metric_classes): | |||
| """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | |||
| """ | |||
| @@ -841,7 +852,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| tmpdir=None, | |||
| gpu_collect=False, | |||
| metric_classes=metric_classes, | |||
| data_loader_iters_per_gpu=self.iters_per_epoch) | |||
| data_loader_iters_per_gpu=self._eval_iters_per_epoch) | |||
| else: | |||
| from modelscope.trainers.utils.inference import single_gpu_test | |||
| metric_values = single_gpu_test( | |||
| @@ -849,7 +860,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| data_loader, | |||
| device=self.device, | |||
| metric_classes=metric_classes, | |||
| data_loader_iters=self.iters_per_epoch) | |||
| data_loader_iters=self._eval_iters_per_epoch) | |||
| self._inner_iter = self.iters_per_epoch - 1 # start from index 0 | |||
| @@ -8,14 +8,17 @@ from shutil import copytree, ignore_patterns, rmtree | |||
| from typing import Callable, List, Optional, Union | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from torch.optim import Optimizer | |||
| from torch.optim.lr_scheduler import _LRScheduler | |||
| from modelscope import __version__ | |||
| from modelscope.fileio import File, LocalStorage | |||
| from modelscope.utils.config import JSONIteratorEncoder | |||
| from modelscope.utils.constant import ConfigFields, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger(__name__) | |||
| storage = LocalStorage() | |||
| @@ -40,24 +43,27 @@ def weights_to_cpu(state_dict): | |||
| def save_checkpoint(model: torch.nn.Module, | |||
| filename: str, | |||
| optimizer: Optional[Optimizer] = None, | |||
| lr_scheduler: Optional[_LRScheduler] = None, | |||
| meta: Optional[dict] = None, | |||
| with_meta: bool = True) -> None: | |||
| """Save checkpoint to file. | |||
| The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | |||
| ``optimizer``. By default ``meta`` will contain version and time info. | |||
| ``optimizer``. By default, ``meta`` will contain version and time info. | |||
| Args: | |||
| model (Module): Module whose params are to be saved. | |||
| filename (str): Checkpoint filename. | |||
| optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. | |||
| lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved. | |||
| meta (dict, optional): Metadata to be saved in checkpoint. | |||
| with_meta (bool, optional): | |||
| """ | |||
| if meta is None: | |||
| meta = {} | |||
| elif not isinstance(meta, dict): | |||
| raise TypeError(f'meta must be a dict or None, but got {type(meta)}') | |||
| meta.update(modescope=__version__, time=time.asctime()) | |||
| meta.update(modelscope=__version__, time=time.asctime()) | |||
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
| model = model.module | |||
| @@ -71,22 +77,69 @@ def save_checkpoint(model: torch.nn.Module, | |||
| 'meta': meta, | |||
| 'state_dict': weights_to_cpu(model.state_dict()) | |||
| } | |||
| # save optimizer state dict in the checkpoint | |||
| if isinstance(optimizer, Optimizer): | |||
| checkpoint['optimizer'] = optimizer.state_dict() | |||
| elif isinstance(optimizer, dict): | |||
| checkpoint['optimizer'] = {} | |||
| for name, optim in optimizer.items(): | |||
| checkpoint['optimizer'][name] = optim.state_dict() | |||
| # save lr_scheduler state dict in the checkpoint | |||
| assert isinstance(lr_scheduler, _LRScheduler), \ | |||
| f'lr_scheduler to be saved should be a subclass of _LRScheduler, current is : {lr_scheduler.__class__}' | |||
| checkpoint['lr_scheduler'] = lr_scheduler.state_dict() | |||
| else: | |||
| checkpoint = weights_to_cpu(model.state_dict()) | |||
| # save optimizer state dict in the checkpoint | |||
| if isinstance(optimizer, Optimizer): | |||
| checkpoint['optimizer'] = optimizer.state_dict() | |||
| elif isinstance(optimizer, dict): | |||
| checkpoint['optimizer'] = {} | |||
| for name, optim in optimizer.items(): | |||
| checkpoint['optimizer'][name] = optim.state_dict() | |||
| with io.BytesIO() as f: | |||
| torch.save(checkpoint, f) | |||
| File.write(f.getvalue(), filename) | |||
| def load_checkpoint(filename, | |||
| model, | |||
| optimizer: Optimizer = None, | |||
| lr_scheduler: _LRScheduler = None): | |||
| if not os.path.exists(filename): | |||
| raise ValueError(f'Checkpoint file {filename} does not exist!') | |||
| checkpoint = torch.load(filename, map_location='cpu') | |||
| if optimizer is not None: | |||
| if 'optimizer' in checkpoint: | |||
| if isinstance(optimizer, Optimizer): | |||
| optimizer.load_state_dict(checkpoint['optimizer']) | |||
| elif isinstance(optimizer, dict): | |||
| optimizer_dict = checkpoint['optimizer'] | |||
| for key, optimizer_ins in optimizer.items(): | |||
| if key in optimizer_dict: | |||
| optimizer_ins.load_state_dict(optimizer_dict[key]) | |||
| else: | |||
| logger.warn( | |||
| f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}' | |||
| ) | |||
| else: | |||
| logger.warn( | |||
| f'The state dict of optimizer cannot be found in checkpoint file: {filename}' | |||
| ) | |||
| if lr_scheduler is not None: | |||
| if 'lr_scheduler' in checkpoint: | |||
| lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |||
| else: | |||
| logger.warn( | |||
| f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}' | |||
| ) | |||
| state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ | |||
| 'state_dict'] | |||
| model.load_state_dict(state_dict) | |||
| if 'meta' in checkpoint: | |||
| return checkpoint.get('meta', {}) | |||
| def save_pretrained(model, | |||
| target_folder: Union[str, os.PathLike], | |||
| save_checkpoint_name: str = None, | |||
| @@ -299,19 +299,23 @@ class MsRegressTool(RegressTool): | |||
| file_name, | |||
| level='config', | |||
| compare_fn=None, | |||
| ignore_keys=None): | |||
| ignore_keys=None, | |||
| compare_random=True, | |||
| lazy_stop_callback=None): | |||
| def lazy_stop_callback(): | |||
| if lazy_stop_callback is None: | |||
| from modelscope.trainers.hooks.hook import Hook, Priority | |||
| def lazy_stop_callback(): | |||
| class EarlyStopHook(Hook): | |||
| PRIORITY = Priority.VERY_LOW | |||
| from modelscope.trainers.hooks.hook import Hook, Priority | |||
| def after_iter(self, trainer): | |||
| raise MsRegressTool.EarlyStopError('Test finished.') | |||
| class EarlyStopHook(Hook): | |||
| PRIORITY = Priority.VERY_LOW | |||
| trainer.register_hook(EarlyStopHook()) | |||
| def after_iter(self, trainer): | |||
| raise MsRegressTool.EarlyStopError('Test finished.') | |||
| trainer.register_hook(EarlyStopHook()) | |||
| def _train_loop(trainer, *args, **kwargs): | |||
| with self.monitor_module_train( | |||
| @@ -320,6 +324,7 @@ class MsRegressTool(RegressTool): | |||
| level, | |||
| compare_fn=compare_fn, | |||
| ignore_keys=ignore_keys, | |||
| compare_random=compare_random, | |||
| lazy_stop_callback=lazy_stop_callback): | |||
| try: | |||
| return trainer.train_loop_origin(*args, **kwargs) | |||
| @@ -1,8 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from huggingface/transformers. | |||
| from collections.abc import Mapping | |||
| import numpy as np | |||
| def torch_nested_numpify(tensors): | |||
| @@ -1,3 +0,0 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:2df2a5f3cdfc6dded52d31a8e97d9a9c41a803cb6d46dee709c51872eda37b21 | |||
| size 151830 | |||
| @@ -11,7 +11,8 @@ from modelscope.models.nlp.sequence_classification import \ | |||
| SbertForSequenceClassification | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.trainers import EpochBasedTrainer, build_trainer | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.hub import read_config | |||
| from modelscope.utils.test_utils import test_level | |||
| @@ -119,6 +120,90 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||
| self.assertTrue(Metrics.accuracy in eval_results) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_configured_datasets(self): | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| cfg: Config = read_config(model_id) | |||
| cfg.train.max_epochs = 20 | |||
| cfg.train.work_dir = self.tmp_dir | |||
| cfg.dataset = { | |||
| 'train': { | |||
| 'name': 'afqmc_small', | |||
| 'split': 'train', | |||
| 'namespace': 'userxiaoming' | |||
| }, | |||
| 'val': { | |||
| 'name': 'afqmc_small', | |||
| 'split': 'train', | |||
| 'namespace': 'userxiaoming' | |||
| }, | |||
| } | |||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
| cfg.dump(cfg_file) | |||
| kwargs = dict(model=model_id, cfg_file=cfg_file) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(cfg.train.max_epochs): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| eval_results = trainer.evaluate( | |||
| checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) | |||
| self.assertTrue(Metrics.accuracy in eval_results) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_continue_train(self): | |||
| from modelscope.utils.regress_test_utils import MsRegressTool | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| cfg: Config = read_config(model_id) | |||
| cfg.train.max_epochs = 3 | |||
| cfg.train.work_dir = self.tmp_dir | |||
| cfg_file = os.path.join(self.tmp_dir, 'config.json') | |||
| cfg.dump(cfg_file) | |||
| dataset = MsDataset.load('clue', subset_name='afqmc', split='train') | |||
| dataset = dataset.to_hf_dataset().select(range(128)) | |||
| kwargs = dict( | |||
| model=model_id, | |||
| train_dataset=dataset, | |||
| eval_dataset=dataset, | |||
| cfg_file=cfg_file) | |||
| regress_tool = MsRegressTool(baseline=True) | |||
| trainer: EpochBasedTrainer = build_trainer(default_args=kwargs) | |||
| def lazy_stop_callback(): | |||
| from modelscope.trainers.hooks.hook import Hook, Priority | |||
| class EarlyStopHook(Hook): | |||
| PRIORITY = Priority.VERY_LOW | |||
| def after_iter(self, trainer): | |||
| if trainer.iter == 12: | |||
| raise MsRegressTool.EarlyStopError('Test finished.') | |||
| if 'EarlyStopHook' not in [ | |||
| hook.__class__.__name__ for hook in trainer.hooks | |||
| ]: | |||
| trainer.register_hook(EarlyStopHook()) | |||
| with regress_tool.monitor_ms_train( | |||
| trainer, | |||
| 'trainer_continue_train', | |||
| level='strict', | |||
| lazy_stop_callback=lazy_stop_callback): | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| regress_tool = MsRegressTool(baseline=False) | |||
| with regress_tool.monitor_ms_train( | |||
| trainer, 'trainer_continue_train', level='strict'): | |||
| trainer.train(os.path.join(self.tmp_dir, 'iter_12.pth')) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||