From 4cdd0c23eb589d05ad9c53ffda33865b1e3bbb0b Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Mon, 19 Sep 2022 17:05:35 +0800 Subject: [PATCH] [to #42322933] Refactor and fix some bugs 1. Fix a bug in trainer's progress bar 2. Fix a bug that trainer does not support dataset in config file 3. Add feature: support go on training via checkpoint file 4. Add feature: support fixed filename when saving best checkpoint 5. Fix a bug that no id2label in config file after finetune of nlp models 6. Fix some other bugs Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10138906 --- .../metrics/sequence_classification_metric.py | 4 +- modelscope/trainers/hooks/checkpoint_hook.py | 143 ++++++++++++++++-- modelscope/trainers/hooks/hook.py | 6 + modelscope/trainers/hooks/optimizer/base.py | 3 +- .../trainers/lrscheduler/warmup/base.py | 4 +- modelscope/trainers/nlp_trainer.py | 64 +++++--- modelscope/trainers/trainer.py | 43 ++++-- modelscope/utils/checkpoint.py | 75 +++++++-- modelscope/utils/regress_test_utils.py | 21 ++- modelscope/utils/tensor_utils.py | 3 - .../data/test/regression/sbert-base-tnews.bin | 3 - tests/trainers/test_trainer_with_nlp.py | 87 ++++++++++- 12 files changed, 374 insertions(+), 82 deletions(-) delete mode 100644 tests/trainers/data/test/regression/sbert-base-tnews.bin diff --git a/modelscope/metrics/sequence_classification_metric.py b/modelscope/metrics/sequence_classification_metric.py index 83cb39ca..d795d8a2 100644 --- a/modelscope/metrics/sequence_classification_metric.py +++ b/modelscope/metrics/sequence_classification_metric.py @@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys @METRICS.register_module( group_key=default_group, module_name=Metrics.seq_cls_metric) class SequenceClassificationMetric(Metric): - """The metric computation class for sequence classification classes. + """The metric computation class for sequence classification tasks. - This metric class calculates accuracy for the whole input batches. + This metric class calculates accuracy of the whole input batches. """ def __init__(self, *args, **kwargs): diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index fcd8e982..a9b793d4 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -1,14 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import random -import json +import numpy as np +import torch from modelscope import __version__ from modelscope.metainfo import Hooks -from modelscope.utils.checkpoint import save_checkpoint +from modelscope.utils.checkpoint import load_checkpoint, save_checkpoint from modelscope.utils.constant import LogKeys, ModelFile from modelscope.utils.logger import get_logger -from modelscope.utils.torch_utils import is_master +from modelscope.utils.torch_utils import get_dist_info, is_master from .builder import HOOKS from .hook import Hook from .priority import Priority @@ -25,6 +27,7 @@ class CheckpointHook(Hook): save_optimizer (bool): Whether to save optimizer state dict. Default: True. save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir` save_last (bool): Whether to save the last checkpoint. Default: True. + checkpoint_file (str): The checkpoint file to be loaded. """ PRIORITY = Priority.LOW @@ -34,12 +37,16 @@ class CheckpointHook(Hook): by_epoch=True, save_optimizer=True, save_dir=None, - save_last=True): + save_last=True, + checkpoint_file=None): self.interval = interval self.by_epoch = by_epoch self.save_optimizer = save_optimizer self.save_dir = save_dir + self.checkpoint_file = checkpoint_file self.save_last = save_last + self.rng_state = None + self.need_load_rng_state = False def before_run(self, trainer): if not self.save_dir: @@ -56,6 +63,34 @@ class CheckpointHook(Hook): if is_master(): self.logger.info(f'Checkpoints will be saved to {self.save_dir}') + if self.checkpoint_file is not None and os.path.isfile( + self.checkpoint_file): + meta = self.load_checkpoint(self.checkpoint_file, trainer) + self.rng_state = meta.get('rng_state') + self.need_load_rng_state = True + + def before_train_epoch(self, trainer): + if self.need_load_rng_state: + if self.rng_state is not None: + random.setstate(self.rng_state['random']) + np.random.set_state(self.rng_state['numpy']) + torch.random.set_rng_state(self.rng_state['cpu']) + if torch.cuda.is_available(): + torch.cuda.random.set_rng_state_all(self.rng_state['cuda']) + self.need_load_rng_state = False + else: + self.logger.warn( + 'Random state cannot be found in checkpoint file, ' + 'this may cause a random data order or model initialization.' + ) + + self.rng_state = { + 'random': random.getstate(), + 'numpy': np.random.get_state(), + 'cpu': torch.random.get_rng_state(), + 'cuda': torch.cuda.get_rng_state_all(), + } + def after_train_epoch(self, trainer): if not self.by_epoch: return @@ -66,6 +101,39 @@ class CheckpointHook(Hook): f'Saving checkpoint at {trainer.epoch + 1} epoch') self._save_checkpoint(trainer) + @classmethod + def load_checkpoint(cls, filename, trainer): + from modelscope.trainers.parallel.utils import is_parallel + if is_parallel(trainer.model): + model = trainer.model.module + else: + model = trainer.model + meta = load_checkpoint(filename, model, trainer.optimizer, + trainer.lr_scheduler) + trainer._epoch = meta.get('epoch', trainer._epoch) + trainer._iter = meta.get('iter', trainer._iter) + trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) + + for i, hook in enumerate(trainer.hooks): + # hook: Hook + key = f'{hook.__class__}-{i}' + if key in meta: + hook.load_state_dict(meta[key]) + else: + trainer.logger( + f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' + ) + + version = meta.get('modelscope') + if version != __version__: + trainer.logger( + f'The modelscope version of loaded checkpoint does not match the runtime version. ' + f'The saved version: {version}, runtime version: {__version__}' + ) + trainer.logger( + f'Checkpoint {filename} saving time: {meta.get("time")}') + return meta + def _save_checkpoint(self, trainer): if self.by_epoch: cur_save_name = os.path.join( @@ -74,7 +142,21 @@ class CheckpointHook(Hook): cur_save_name = os.path.join( self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') - save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) + meta = { + 'epoch': trainer.epoch, + 'iter': trainer.iter + 1, + 'inner_iter': trainer.inner_iter + 1, + 'rng_state': self.rng_state, + } + for i, hook in enumerate(trainer.hooks): + meta[f'{hook.__class__}-{i}'] = hook.state_dict() + + save_checkpoint( + trainer.model, + cur_save_name, + trainer.optimizer, + trainer.lr_scheduler, + meta=meta) if (self.is_last_epoch(trainer) and self.by_epoch) or (self.is_last_iter(trainer) and not self.by_epoch): @@ -144,6 +226,7 @@ class BestCkptSaverHook(CheckpointHook): by_epoch=True, save_optimizer=True, save_dir=None, + save_file_name=None, interval=0): assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' super().__init__( @@ -179,16 +262,44 @@ class BestCkptSaverHook(CheckpointHook): return False def _save_checkpoint(self, trainer): - if self.by_epoch: - cur_save_name = os.path.join( - self.save_dir, - f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' - ) - else: - cur_save_name = os.path.join( - self.save_dir, - f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' - ) - save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) + cur_save_name = self.save_file_name + if cur_save_name is None: + if self.by_epoch: + cur_save_name = os.path.join( + self.save_dir, + f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}.pth' + ) + else: + cur_save_name = os.path.join( + self.save_dir, + f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}.pth' + ) + + meta = { + 'epoch': trainer.epoch, + 'iter': trainer.iter + 1, + 'inner_iter': trainer.inner_iter + 1, + 'rng_state': self.rng_state, + } + for i, hook in enumerate(trainer.hooks): + meta[f'{hook.__class__}-{i}'] = hook.state_dict() + + if os.path.isfile(cur_save_name): + os.remove(cur_save_name) + save_checkpoint(trainer.model, cur_save_name, trainer.optimizer, + trainer.lr_scheduler, meta) self._best_ckpt_file = cur_save_name self._save_pretrained(trainer) + + def state_dict(self): + return { + 'best_metric': self._best_metric, + } + + def load_state_dict(self, state_dict): + if state_dict is not None and len(state_dict) > 0: + self._best_metric = state_dict.get('best_metric') + else: + self.logger.warn( + 'The state_dict is not available, the best metric value will be affected.' + ) diff --git a/modelscope/trainers/hooks/hook.py b/modelscope/trainers/hooks/hook.py index 1c567f1c..d3805be8 100644 --- a/modelscope/trainers/hooks/hook.py +++ b/modelscope/trainers/hooks/hook.py @@ -215,3 +215,9 @@ class Hook: trigger_stages.add(stage) return [stage for stage in Hook.stages if stage in trigger_stages] + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass diff --git a/modelscope/trainers/hooks/optimizer/base.py b/modelscope/trainers/hooks/optimizer/base.py index dffad6ea..8c61dfdb 100644 --- a/modelscope/trainers/hooks/optimizer/base.py +++ b/modelscope/trainers/hooks/optimizer/base.py @@ -4,6 +4,7 @@ import logging from torch.nn.utils import clip_grad from modelscope.metainfo import Hooks +from modelscope.outputs import OutputKeys from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.hook import Hook from modelscope.trainers.hooks.priority import Priority @@ -27,7 +28,7 @@ class OptimizerHook(Hook): def __init__(self, cumulative_iters=1, grad_clip=None, - loss_keys='loss') -> None: + loss_keys=OutputKeys.LOSS) -> None: if isinstance(loss_keys, str): loss_keys = [loss_keys] assert isinstance(loss_keys, (tuple, list)) diff --git a/modelscope/trainers/lrscheduler/warmup/base.py b/modelscope/trainers/lrscheduler/warmup/base.py index 81497817..4b066281 100644 --- a/modelscope/trainers/lrscheduler/warmup/base.py +++ b/modelscope/trainers/lrscheduler/warmup/base.py @@ -28,10 +28,10 @@ class BaseWarmup(_LRScheduler): return self.base_scheduler.get_lr() def state_dict(self): - self.base_scheduler.state_dict() + return self.base_scheduler.state_dict() def load_state_dict(self, state_dict): - self.base_scheduler.load_state_dict(state_dict) + return self.base_scheduler.load_state_dict(state_dict) def scale(self): """Scale the learning rates. diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py index 3692b486..4a14be31 100644 --- a/modelscope/trainers/nlp_trainer.py +++ b/modelscope/trainers/nlp_trainer.py @@ -1,6 +1,7 @@ import os -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union +import numpy as np import torch from torch import nn from torch.utils.data import Dataset @@ -11,9 +12,10 @@ from modelscope.metrics.builder import build_metric from modelscope.models.base import Model, TorchModel from modelscope.msdatasets import MsDataset from modelscope.preprocessors import Preprocessor, build_preprocessor -from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.config import Config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys, ModelFile, Tasks) +from modelscope.utils.hub import parse_label_mapping from .base import TRAINERS from .trainer import EpochBasedTrainer @@ -81,19 +83,32 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' model_dir = os.path.dirname(cfg_file) + self.label2id = None + self.id2label = None + self.num_labels = None self.cfg_modify_fn = cfg_modify_fn self.cfg = self.rebuild_config(Config.from_file(cfg_file)) - try: - labels = self.cfg.dataset.train.labels - except AttributeError: - labels = None - self.label2id = None - self.num_labels = None - if labels is not None and len(labels) > 0: - self.label2id = {label: idx for idx, label in enumerate(labels)} - self.id2label = {idx: label for idx, label in enumerate(labels)} - self.num_labels = len(labels) + label2id = parse_label_mapping(model_dir) + if label2id is not None: + self.label2id = label2id + self.id2label = {id: label for label, id in label2id.items()} + self.num_labels = len(label2id) + else: + try: + labels = self.cfg.dataset.train.labels + if labels is not None and len(labels) > 0: + self.label2id = { + label: idx + for idx, label in enumerate(labels) + } + self.id2label = { + idx: label + for idx, label in enumerate(labels) + } + self.num_labels = len(labels) + except AttributeError: + pass def build_dataset_keys(cfg): if cfg is not None: @@ -130,7 +145,13 @@ class NlpEpochBasedTrainer(EpochBasedTrainer): def rebuild_config(self, cfg: Config): if self.cfg_modify_fn is not None: - return self.cfg_modify_fn(cfg) + cfg = self.cfg_modify_fn(cfg) + if not hasattr(cfg.model, 'label2id') and not hasattr( + cfg.model, 'id2label'): + if self.id2label is not None: + cfg.model['id2label'] = self.id2label + if self.label2id is not None: + cfg.model['label2id'] = self.label2id return cfg def build_model(self) -> Union[nn.Module, TorchModel]: @@ -203,6 +224,9 @@ class VecoTrainer(NlpEpochBasedTrainer): """ from modelscope.msdatasets.task_datasets import VecoDataset + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + CheckpointHook.load_checkpoint(checkpoint_path, self) self.model.eval() self._mode = ModeKeys.EVAL metric_values = {} @@ -223,12 +247,10 @@ class VecoTrainer(NlpEpochBasedTrainer): self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) self.data_loader = self.eval_dataloader - metric_classes = [ - build_metric(metric, default_args={'trainer': self}) - for metric in self.metrics - ] - self.evaluation_loop(self.eval_dataloader, checkpoint_path, - metric_classes) + metric_classes = [build_metric(metric) for metric in self.metrics] + for m in metric_classes: + m.trainer = self + self.evaluation_loop(self.eval_dataloader, metric_classes) for m_idx, metric_cls in enumerate(metric_classes): if f'eval_dataset[{idx}]' not in metric_values: @@ -242,4 +264,8 @@ class VecoTrainer(NlpEpochBasedTrainer): else: break + for metric_name in self.metrics: + metric_values[metric_name] = np.average( + [m[metric_name] for m in metric_values.values()]) + return metric_values diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index d771d9d6..69645d07 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -1,6 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -import random import time from collections.abc import Mapping from distutils.version import LooseVersion @@ -8,7 +7,6 @@ from functools import partial from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import json -import numpy as np import torch from torch import distributed as dist from torch import nn @@ -425,8 +423,16 @@ class EpochBasedTrainer(BaseTrainer): metrics = [metrics] return metrics - def train(self, *args, **kwargs): - self.model.train() + def set_checkpoint_file_to_hook(self, checkpoint_path): + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + checkpoint_hooks = list( + filter(lambda hook: isinstance(hook, CheckpointHook), + self.hooks)) + for hook in checkpoint_hooks: + hook.checkpoint_file = checkpoint_path + + def train(self, checkpoint_path=None, *args, **kwargs): self._mode = ModeKeys.TRAIN if self.train_dataset is None: @@ -442,13 +448,17 @@ class EpochBasedTrainer(BaseTrainer): self.register_optimizers_hook() self.register_hook_from_cfg(self.cfg.train.hooks) + self.set_checkpoint_file_to_hook(checkpoint_path) + self.model.train() self.train_loop(self.train_dataloader) def evaluate(self, checkpoint_path=None): + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + from modelscope.trainers.hooks import CheckpointHook + CheckpointHook.load_checkpoint(checkpoint_path, self) self.model.eval() self._mode = ModeKeys.EVAL - if self.eval_dataset is None: self.eval_dataloader = self.get_eval_data_loader() else: @@ -462,8 +472,9 @@ class EpochBasedTrainer(BaseTrainer): metric_classes = [build_metric(metric) for metric in self.metrics] for m in metric_classes: m.trainer = self + metric_values = self.evaluation_loop(self.eval_dataloader, - checkpoint_path, metric_classes) + metric_classes) self._metric_values = metric_values return metric_values @@ -631,18 +642,13 @@ class EpochBasedTrainer(BaseTrainer): if hasattr(data_cfg, 'name'): dataset = MsDataset.load( dataset_name=data_cfg.name, - split=data_cfg.split, - subset_name=data_cfg.subset_name if hasattr( - data_cfg, 'subset_name') else None, - hub=data_cfg.hub - if hasattr(data_cfg, 'hub') else Hubs.modelscope, **data_cfg, ) cfg = ConfigDict(type=self.cfg.model.type, mode=mode) torch_dataset = dataset.to_torch_dataset( task_data_config=cfg, task_name=self.cfg.task, - preprocessors=self.preprocessor) + preprocessors=preprocessor) else: torch_dataset = build_task_dataset(data_cfg, self.cfg.task) dataset = self.to_task_dataset(torch_dataset, mode) @@ -802,19 +808,22 @@ class EpochBasedTrainer(BaseTrainer): """ Training loop used by `EpochBasedTrainer.train()` """ self.invoke_hook(TrainerStages.before_run) - self._epoch = 0 kwargs = {} self.model.train() for _ in range(self._epoch, self._max_epochs): self.invoke_hook(TrainerStages.before_train_epoch) time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(data_loader): + if i < self.inner_iter: + # inner_iter may be read out from the checkpoint file, so skip the trained iters in the epoch. + continue data_batch = to_device(data_batch, self.device) self.data_batch = data_batch self._inner_iter = i self.invoke_hook(TrainerStages.before_train_iter) self.train_step(self.model, data_batch, **kwargs) self.invoke_hook(TrainerStages.after_train_iter) + # Value changed after the hooks are invoked, do not move them above the invoke_hook code. del self.data_batch self._iter += 1 self._mode = ModeKeys.TRAIN @@ -823,12 +832,14 @@ class EpochBasedTrainer(BaseTrainer): break self.invoke_hook(TrainerStages.after_train_epoch) + # Value changed after the hooks are invoked, do not move them above the invoke_hook code. + self._inner_iter = 0 self._epoch += 1 time.sleep(1) # wait for some hooks like loggers to finish self.invoke_hook(TrainerStages.after_run) - def evaluation_loop(self, data_loader, checkpoint_path, metric_classes): + def evaluation_loop(self, data_loader, metric_classes): """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. """ @@ -841,7 +852,7 @@ class EpochBasedTrainer(BaseTrainer): tmpdir=None, gpu_collect=False, metric_classes=metric_classes, - data_loader_iters_per_gpu=self.iters_per_epoch) + data_loader_iters_per_gpu=self._eval_iters_per_epoch) else: from modelscope.trainers.utils.inference import single_gpu_test metric_values = single_gpu_test( @@ -849,7 +860,7 @@ class EpochBasedTrainer(BaseTrainer): data_loader, device=self.device, metric_classes=metric_classes, - data_loader_iters=self.iters_per_epoch) + data_loader_iters=self._eval_iters_per_epoch) self._inner_iter = self.iters_per_epoch - 1 # start from index 0 diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index 425d3312..8d8c2b2f 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -8,14 +8,17 @@ from shutil import copytree, ignore_patterns, rmtree from typing import Callable, List, Optional, Union import json -import numpy as np import torch from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from modelscope import __version__ from modelscope.fileio import File, LocalStorage from modelscope.utils.config import JSONIteratorEncoder from modelscope.utils.constant import ConfigFields, ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger(__name__) storage = LocalStorage() @@ -40,24 +43,27 @@ def weights_to_cpu(state_dict): def save_checkpoint(model: torch.nn.Module, filename: str, optimizer: Optional[Optimizer] = None, + lr_scheduler: Optional[_LRScheduler] = None, meta: Optional[dict] = None, with_meta: bool = True) -> None: """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and - ``optimizer``. By default ``meta`` will contain version and time info. + ``optimizer``. By default, ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved. meta (dict, optional): Metadata to be saved in checkpoint. + with_meta (bool, optional): """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') - meta.update(modescope=__version__, time=time.asctime()) + meta.update(modelscope=__version__, time=time.asctime()) if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module @@ -71,22 +77,69 @@ def save_checkpoint(model: torch.nn.Module, 'meta': meta, 'state_dict': weights_to_cpu(model.state_dict()) } + + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + # save lr_scheduler state dict in the checkpoint + assert isinstance(lr_scheduler, _LRScheduler), \ + f'lr_scheduler to be saved should be a subclass of _LRScheduler, current is : {lr_scheduler.__class__}' + checkpoint['lr_scheduler'] = lr_scheduler.state_dict() else: checkpoint = weights_to_cpu(model.state_dict()) - # save optimizer state dict in the checkpoint - if isinstance(optimizer, Optimizer): - checkpoint['optimizer'] = optimizer.state_dict() - elif isinstance(optimizer, dict): - checkpoint['optimizer'] = {} - for name, optim in optimizer.items(): - checkpoint['optimizer'][name] = optim.state_dict() - with io.BytesIO() as f: torch.save(checkpoint, f) File.write(f.getvalue(), filename) +def load_checkpoint(filename, + model, + optimizer: Optimizer = None, + lr_scheduler: _LRScheduler = None): + if not os.path.exists(filename): + raise ValueError(f'Checkpoint file {filename} does not exist!') + checkpoint = torch.load(filename, map_location='cpu') + + if optimizer is not None: + if 'optimizer' in checkpoint: + if isinstance(optimizer, Optimizer): + optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(optimizer, dict): + optimizer_dict = checkpoint['optimizer'] + for key, optimizer_ins in optimizer.items(): + if key in optimizer_dict: + optimizer_ins.load_state_dict(optimizer_dict[key]) + else: + logger.warn( + f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}' + ) + else: + logger.warn( + f'The state dict of optimizer cannot be found in checkpoint file: {filename}' + ) + + if lr_scheduler is not None: + if 'lr_scheduler' in checkpoint: + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + else: + logger.warn( + f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}' + ) + + state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ + 'state_dict'] + model.load_state_dict(state_dict) + + if 'meta' in checkpoint: + return checkpoint.get('meta', {}) + + def save_pretrained(model, target_folder: Union[str, os.PathLike], save_checkpoint_name: str = None, diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py index 82267447..95d2beea 100644 --- a/modelscope/utils/regress_test_utils.py +++ b/modelscope/utils/regress_test_utils.py @@ -299,19 +299,23 @@ class MsRegressTool(RegressTool): file_name, level='config', compare_fn=None, - ignore_keys=None): + ignore_keys=None, + compare_random=True, + lazy_stop_callback=None): - def lazy_stop_callback(): + if lazy_stop_callback is None: - from modelscope.trainers.hooks.hook import Hook, Priority + def lazy_stop_callback(): - class EarlyStopHook(Hook): - PRIORITY = Priority.VERY_LOW + from modelscope.trainers.hooks.hook import Hook, Priority - def after_iter(self, trainer): - raise MsRegressTool.EarlyStopError('Test finished.') + class EarlyStopHook(Hook): + PRIORITY = Priority.VERY_LOW - trainer.register_hook(EarlyStopHook()) + def after_iter(self, trainer): + raise MsRegressTool.EarlyStopError('Test finished.') + + trainer.register_hook(EarlyStopHook()) def _train_loop(trainer, *args, **kwargs): with self.monitor_module_train( @@ -320,6 +324,7 @@ class MsRegressTool(RegressTool): level, compare_fn=compare_fn, ignore_keys=ignore_keys, + compare_random=compare_random, lazy_stop_callback=lazy_stop_callback): try: return trainer.train_loop_origin(*args, **kwargs) diff --git a/modelscope/utils/tensor_utils.py b/modelscope/utils/tensor_utils.py index 7889d944..b438e476 100644 --- a/modelscope/utils/tensor_utils.py +++ b/modelscope/utils/tensor_utils.py @@ -1,8 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/transformers. -from collections.abc import Mapping - -import numpy as np def torch_nested_numpify(tensors): diff --git a/tests/trainers/data/test/regression/sbert-base-tnews.bin b/tests/trainers/data/test/regression/sbert-base-tnews.bin deleted file mode 100644 index 3a06d49c..00000000 --- a/tests/trainers/data/test/regression/sbert-base-tnews.bin +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2df2a5f3cdfc6dded52d31a8e97d9a9c41a803cb6d46dee709c51872eda37b21 -size 151830 diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 2cf1c152..6030ada9 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -11,7 +11,8 @@ from modelscope.models.nlp.sequence_classification import \ SbertForSequenceClassification from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline -from modelscope.trainers import build_trainer +from modelscope.trainers import EpochBasedTrainer, build_trainer +from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.hub import read_config from modelscope.utils.test_utils import test_level @@ -119,6 +120,90 @@ class TestTrainerWithNlp(unittest.TestCase): checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) self.assertTrue(Metrics.accuracy in eval_results) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_configured_datasets(self): + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + cfg: Config = read_config(model_id) + cfg.train.max_epochs = 20 + cfg.train.work_dir = self.tmp_dir + cfg.dataset = { + 'train': { + 'name': 'afqmc_small', + 'split': 'train', + 'namespace': 'userxiaoming' + }, + 'val': { + 'name': 'afqmc_small', + 'split': 'train', + 'namespace': 'userxiaoming' + }, + } + cfg_file = os.path.join(self.tmp_dir, 'config.json') + cfg.dump(cfg_file) + kwargs = dict(model=model_id, cfg_file=cfg_file) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(cfg.train.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + eval_results = trainer.evaluate( + checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) + self.assertTrue(Metrics.accuracy in eval_results) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_continue_train(self): + from modelscope.utils.regress_test_utils import MsRegressTool + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + cfg: Config = read_config(model_id) + cfg.train.max_epochs = 3 + cfg.train.work_dir = self.tmp_dir + cfg_file = os.path.join(self.tmp_dir, 'config.json') + cfg.dump(cfg_file) + dataset = MsDataset.load('clue', subset_name='afqmc', split='train') + dataset = dataset.to_hf_dataset().select(range(128)) + kwargs = dict( + model=model_id, + train_dataset=dataset, + eval_dataset=dataset, + cfg_file=cfg_file) + + regress_tool = MsRegressTool(baseline=True) + trainer: EpochBasedTrainer = build_trainer(default_args=kwargs) + + def lazy_stop_callback(): + from modelscope.trainers.hooks.hook import Hook, Priority + + class EarlyStopHook(Hook): + PRIORITY = Priority.VERY_LOW + + def after_iter(self, trainer): + if trainer.iter == 12: + raise MsRegressTool.EarlyStopError('Test finished.') + + if 'EarlyStopHook' not in [ + hook.__class__.__name__ for hook in trainer.hooks + ]: + trainer.register_hook(EarlyStopHook()) + + with regress_tool.monitor_ms_train( + trainer, + 'trainer_continue_train', + level='strict', + lazy_stop_callback=lazy_stop_callback): + trainer.train() + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + + trainer = build_trainer(default_args=kwargs) + regress_tool = MsRegressTool(baseline=False) + with regress_tool.monitor_ms_train( + trainer, 'trainer_continue_train', level='strict'): + trainer.train(os.path.join(self.tmp_dir, 'iter_12.pth')) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_model_and_args(self): tmp_dir = tempfile.TemporaryDirectory().name