1. Support saving the best checkpoint for inference
2. Fix a bug that _max_iters field does not exist in trainer
3. Fix a bug that function in lambda_lr field cannot be saved to file
4. Fix a bug that save_pretrained would not be called by iterating
5. Fix a bug that interval is not passed from BestCkptHook's init
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9972765
master
| @@ -27,7 +27,7 @@ class CheckpointHook(Hook): | |||
| save_last (bool): Whether to save the last checkpoint. Default: True. | |||
| """ | |||
| PRIORITY = Priority.NORMAL | |||
| PRIORITY = Priority.LOW | |||
| def __init__(self, | |||
| interval=0, | |||
| @@ -75,25 +75,27 @@ class CheckpointHook(Hook): | |||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| self._save_pretrained(trainer) | |||
| if (self.is_last_epoch(trainer) | |||
| and self.by_epoch) or (self.is_last_iter(trainer) | |||
| and not self.by_epoch): | |||
| self._save_pretrained(trainer) | |||
| def _save_pretrained(self, trainer): | |||
| if self.is_last_epoch(trainer) and self.by_epoch: | |||
| output_dir = os.path.join(self.save_dir, | |||
| ModelFile.TRAIN_OUTPUT_DIR) | |||
| from modelscope.trainers.parallel.utils import is_parallel | |||
| if is_parallel(trainer.model): | |||
| model = trainer.model.module | |||
| else: | |||
| model = trainer.model | |||
| if hasattr(model, 'save_pretrained'): | |||
| model.save_pretrained( | |||
| output_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||
| save_function=save_checkpoint, | |||
| config=trainer.cfg.to_dict()) | |||
| output_dir = os.path.join(self.save_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
| from modelscope.trainers.parallel.utils import is_parallel | |||
| if is_parallel(trainer.model): | |||
| model = trainer.model.module | |||
| else: | |||
| model = trainer.model | |||
| if hasattr(model, 'save_pretrained'): | |||
| model.save_pretrained( | |||
| output_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||
| save_function=save_checkpoint, | |||
| config=trainer.cfg.to_dict(), | |||
| with_meta=False) | |||
| def after_train_iter(self, trainer): | |||
| if self.by_epoch: | |||
| @@ -133,7 +135,7 @@ class BestCkptSaverHook(CheckpointHook): | |||
| save_dir (str): Output directory to save best checkpoint. | |||
| """ | |||
| PRIORITY = Priority.NORMAL | |||
| PRIORITY = Priority.LOW | |||
| rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} | |||
| def __init__(self, | |||
| @@ -141,9 +143,11 @@ class BestCkptSaverHook(CheckpointHook): | |||
| rule='max', | |||
| by_epoch=True, | |||
| save_optimizer=True, | |||
| save_dir=None): | |||
| save_dir=None, | |||
| interval=0): | |||
| assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | |||
| super().__init__( | |||
| interval=interval, | |||
| by_epoch=by_epoch, | |||
| save_optimizer=save_optimizer, | |||
| save_dir=save_dir, | |||
| @@ -199,14 +199,14 @@ class Hook: | |||
| Whether to reach the last epoch | |||
| Returns: bool | |||
| """ | |||
| return trainer.epoch + 1 == trainer._max_epochs | |||
| return trainer.epoch + 1 == trainer.max_epochs | |||
| def is_last_iter(self, trainer): | |||
| """ | |||
| Whether to reach the last iteration in the entire training process | |||
| Returns: bool | |||
| """ | |||
| return trainer.iter + 1 == trainer._max_iters | |||
| return trainer.iter + 1 == trainer.max_iters | |||
| def get_triggered_stages(self): | |||
| trigger_stages = set() | |||
| @@ -40,7 +40,8 @@ def weights_to_cpu(state_dict): | |||
| def save_checkpoint(model: torch.nn.Module, | |||
| filename: str, | |||
| optimizer: Optional[Optimizer] = None, | |||
| meta: Optional[dict] = None) -> None: | |||
| meta: Optional[dict] = None, | |||
| with_meta: bool = True) -> None: | |||
| """Save checkpoint to file. | |||
| The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | |||
| @@ -65,10 +66,14 @@ def save_checkpoint(model: torch.nn.Module, | |||
| # save class name to the meta | |||
| meta.update(CLASSES=model.CLASSES) | |||
| checkpoint = { | |||
| 'meta': meta, | |||
| 'state_dict': weights_to_cpu(model.state_dict()) | |||
| } | |||
| if with_meta: | |||
| checkpoint = { | |||
| 'meta': meta, | |||
| 'state_dict': weights_to_cpu(model.state_dict()) | |||
| } | |||
| else: | |||
| checkpoint = weights_to_cpu(model.state_dict()) | |||
| # save optimizer state dict in the checkpoint | |||
| if isinstance(optimizer, Optimizer): | |||
| checkpoint['optimizer'] = optimizer.state_dict() | |||
| @@ -141,7 +146,7 @@ def save_pretrained(model, | |||
| # Save the ckpt to the save directory | |||
| try: | |||
| save_function(model, output_ckpt_path) | |||
| save_function(model, output_ckpt_path, **kwargs) | |||
| except Exception as e: | |||
| raise Exception( | |||
| f'During saving checkpoints, the error of "{type(e).__name__} ' | |||
| @@ -9,6 +9,7 @@ import sys | |||
| import tempfile | |||
| import types | |||
| from pathlib import Path | |||
| from types import FunctionType | |||
| from typing import Dict, Union | |||
| import addict | |||
| @@ -638,6 +639,8 @@ class JSONIteratorEncoder(json.JSONEncoder): | |||
| """ | |||
| def default(self, obj): | |||
| if isinstance(obj, FunctionType): | |||
| return None | |||
| try: | |||
| iterable = iter(obj) | |||
| except TypeError: | |||
| @@ -128,15 +128,14 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||
| @unittest.skip | |||
| def test_finetune_cnndm(self): | |||
| from datasets import load_dataset | |||
| dataset_dict = load_dataset('ccdv/cnn_dailymail', '3.0.0') | |||
| train_dataset = dataset_dict['train'] \ | |||
| .rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \ | |||
| .remove_columns('id') | |||
| eval_dataset = dataset_dict['validation'] \ | |||
| .rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \ | |||
| .remove_columns('id') | |||
| num_warmup_steps = 2000 | |||
| from modelscope.msdatasets import MsDataset | |||
| dataset_dict = MsDataset.load('dureader_robust_qg') | |||
| train_dataset = dataset_dict['train'].to_hf_dataset() \ | |||
| .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) | |||
| eval_dataset = dataset_dict['validation'].to_hf_dataset() \ | |||
| .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) | |||
| num_warmup_steps = 200 | |||
| os.environ['LOCAL_RANK'] = '0' | |||
| def noam_lambda(current_step: int): | |||
| current_step += 1 | |||
| @@ -154,12 +153,11 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||
| return cfg | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| model='damo/nlp_palm2.0_text-generation_chinese-base', | |||
| train_dataset=train_dataset, | |||
| eval_dataset=eval_dataset, | |||
| work_dir=self.tmp_dir, | |||
| cfg_modify_fn=cfg_modify_fn, | |||
| model_revision='beta') | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| trainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| trainer.train() | |||