From a9c14e4eadd64e30820b689b47f5e2ebc19516f4 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Mon, 5 Sep 2022 11:07:48 +0800 Subject: [PATCH] [to #42322933] Support saving the best checkpoint for inference 1. Support saving the best checkpoint for inference 2. Fix a bug that _max_iters field does not exist in trainer 3. Fix a bug that function in lambda_lr field cannot be saved to file 4. Fix a bug that save_pretrained would not be called by iterating 5. Fix a bug that interval is not passed from BestCkptHook's init Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9972765 --- modelscope/trainers/hooks/checkpoint_hook.py | 44 ++++++++++--------- modelscope/trainers/hooks/hook.py | 4 +- modelscope/utils/checkpoint.py | 17 ++++--- modelscope/utils/config.py | 3 ++ .../trainers/test_finetune_text_generation.py | 22 +++++----- 5 files changed, 50 insertions(+), 40 deletions(-) diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index cf7a0f7a..fcd8e982 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -27,7 +27,7 @@ class CheckpointHook(Hook): save_last (bool): Whether to save the last checkpoint. Default: True. """ - PRIORITY = Priority.NORMAL + PRIORITY = Priority.LOW def __init__(self, interval=0, @@ -75,25 +75,27 @@ class CheckpointHook(Hook): self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) - self._save_pretrained(trainer) + if (self.is_last_epoch(trainer) + and self.by_epoch) or (self.is_last_iter(trainer) + and not self.by_epoch): + self._save_pretrained(trainer) def _save_pretrained(self, trainer): - if self.is_last_epoch(trainer) and self.by_epoch: - output_dir = os.path.join(self.save_dir, - ModelFile.TRAIN_OUTPUT_DIR) - from modelscope.trainers.parallel.utils import is_parallel - - if is_parallel(trainer.model): - model = trainer.model.module - else: - model = trainer.model - - if hasattr(model, 'save_pretrained'): - model.save_pretrained( - output_dir, - ModelFile.TORCH_MODEL_BIN_FILE, - save_function=save_checkpoint, - config=trainer.cfg.to_dict()) + output_dir = os.path.join(self.save_dir, ModelFile.TRAIN_OUTPUT_DIR) + from modelscope.trainers.parallel.utils import is_parallel + + if is_parallel(trainer.model): + model = trainer.model.module + else: + model = trainer.model + + if hasattr(model, 'save_pretrained'): + model.save_pretrained( + output_dir, + ModelFile.TORCH_MODEL_BIN_FILE, + save_function=save_checkpoint, + config=trainer.cfg.to_dict(), + with_meta=False) def after_train_iter(self, trainer): if self.by_epoch: @@ -133,7 +135,7 @@ class BestCkptSaverHook(CheckpointHook): save_dir (str): Output directory to save best checkpoint. """ - PRIORITY = Priority.NORMAL + PRIORITY = Priority.LOW rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} def __init__(self, @@ -141,9 +143,11 @@ class BestCkptSaverHook(CheckpointHook): rule='max', by_epoch=True, save_optimizer=True, - save_dir=None): + save_dir=None, + interval=0): assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' super().__init__( + interval=interval, by_epoch=by_epoch, save_optimizer=save_optimizer, save_dir=save_dir, diff --git a/modelscope/trainers/hooks/hook.py b/modelscope/trainers/hooks/hook.py index 75cc226c..1c567f1c 100644 --- a/modelscope/trainers/hooks/hook.py +++ b/modelscope/trainers/hooks/hook.py @@ -199,14 +199,14 @@ class Hook: Whether to reach the last epoch Returns: bool """ - return trainer.epoch + 1 == trainer._max_epochs + return trainer.epoch + 1 == trainer.max_epochs def is_last_iter(self, trainer): """ Whether to reach the last iteration in the entire training process Returns: bool """ - return trainer.iter + 1 == trainer._max_iters + return trainer.iter + 1 == trainer.max_iters def get_triggered_stages(self): trigger_stages = set() diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index 8b9d027a..425d3312 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -40,7 +40,8 @@ def weights_to_cpu(state_dict): def save_checkpoint(model: torch.nn.Module, filename: str, optimizer: Optional[Optimizer] = None, - meta: Optional[dict] = None) -> None: + meta: Optional[dict] = None, + with_meta: bool = True) -> None: """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and @@ -65,10 +66,14 @@ def save_checkpoint(model: torch.nn.Module, # save class name to the meta meta.update(CLASSES=model.CLASSES) - checkpoint = { - 'meta': meta, - 'state_dict': weights_to_cpu(model.state_dict()) - } + if with_meta: + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(model.state_dict()) + } + else: + checkpoint = weights_to_cpu(model.state_dict()) + # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() @@ -141,7 +146,7 @@ def save_pretrained(model, # Save the ckpt to the save directory try: - save_function(model, output_ckpt_path) + save_function(model, output_ckpt_path, **kwargs) except Exception as e: raise Exception( f'During saving checkpoints, the error of "{type(e).__name__} ' diff --git a/modelscope/utils/config.py b/modelscope/utils/config.py index 42985db6..7d972118 100644 --- a/modelscope/utils/config.py +++ b/modelscope/utils/config.py @@ -9,6 +9,7 @@ import sys import tempfile import types from pathlib import Path +from types import FunctionType from typing import Dict, Union import addict @@ -638,6 +639,8 @@ class JSONIteratorEncoder(json.JSONEncoder): """ def default(self, obj): + if isinstance(obj, FunctionType): + return None try: iterable = iter(obj) except TypeError: diff --git a/tests/trainers/test_finetune_text_generation.py b/tests/trainers/test_finetune_text_generation.py index 8cdfdf01..a561effe 100644 --- a/tests/trainers/test_finetune_text_generation.py +++ b/tests/trainers/test_finetune_text_generation.py @@ -128,15 +128,14 @@ class TestFinetuneTextGeneration(unittest.TestCase): @unittest.skip def test_finetune_cnndm(self): - from datasets import load_dataset - dataset_dict = load_dataset('ccdv/cnn_dailymail', '3.0.0') - train_dataset = dataset_dict['train'] \ - .rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \ - .remove_columns('id') - eval_dataset = dataset_dict['validation'] \ - .rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \ - .remove_columns('id') - num_warmup_steps = 2000 + from modelscope.msdatasets import MsDataset + dataset_dict = MsDataset.load('dureader_robust_qg') + train_dataset = dataset_dict['train'].to_hf_dataset() \ + .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) + eval_dataset = dataset_dict['validation'].to_hf_dataset() \ + .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) + num_warmup_steps = 200 + os.environ['LOCAL_RANK'] = '0' def noam_lambda(current_step: int): current_step += 1 @@ -154,12 +153,11 @@ class TestFinetuneTextGeneration(unittest.TestCase): return cfg kwargs = dict( - model=self.model_id, + model='damo/nlp_palm2.0_text-generation_chinese-base', train_dataset=train_dataset, eval_dataset=eval_dataset, work_dir=self.tmp_dir, - cfg_modify_fn=cfg_modify_fn, - model_revision='beta') + cfg_modify_fn=cfg_modify_fn) trainer = build_trainer( name=Trainers.nlp_base_trainer, default_args=kwargs) trainer.train()