1. Support saving the best checkpoint for inference
2. Fix a bug that _max_iters field does not exist in trainer
3. Fix a bug that function in lambda_lr field cannot be saved to file
4. Fix a bug that save_pretrained would not be called by iterating
5. Fix a bug that interval is not passed from BestCkptHook's init
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9972765
master
| @@ -27,7 +27,7 @@ class CheckpointHook(Hook): | |||||
| save_last (bool): Whether to save the last checkpoint. Default: True. | save_last (bool): Whether to save the last checkpoint. Default: True. | ||||
| """ | """ | ||||
| PRIORITY = Priority.NORMAL | |||||
| PRIORITY = Priority.LOW | |||||
| def __init__(self, | def __init__(self, | ||||
| interval=0, | interval=0, | ||||
| @@ -75,25 +75,27 @@ class CheckpointHook(Hook): | |||||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | ||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | ||||
| self._save_pretrained(trainer) | |||||
| if (self.is_last_epoch(trainer) | |||||
| and self.by_epoch) or (self.is_last_iter(trainer) | |||||
| and not self.by_epoch): | |||||
| self._save_pretrained(trainer) | |||||
| def _save_pretrained(self, trainer): | def _save_pretrained(self, trainer): | ||||
| if self.is_last_epoch(trainer) and self.by_epoch: | |||||
| output_dir = os.path.join(self.save_dir, | |||||
| ModelFile.TRAIN_OUTPUT_DIR) | |||||
| from modelscope.trainers.parallel.utils import is_parallel | |||||
| if is_parallel(trainer.model): | |||||
| model = trainer.model.module | |||||
| else: | |||||
| model = trainer.model | |||||
| if hasattr(model, 'save_pretrained'): | |||||
| model.save_pretrained( | |||||
| output_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||||
| save_function=save_checkpoint, | |||||
| config=trainer.cfg.to_dict()) | |||||
| output_dir = os.path.join(self.save_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||||
| from modelscope.trainers.parallel.utils import is_parallel | |||||
| if is_parallel(trainer.model): | |||||
| model = trainer.model.module | |||||
| else: | |||||
| model = trainer.model | |||||
| if hasattr(model, 'save_pretrained'): | |||||
| model.save_pretrained( | |||||
| output_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||||
| save_function=save_checkpoint, | |||||
| config=trainer.cfg.to_dict(), | |||||
| with_meta=False) | |||||
| def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| @@ -133,7 +135,7 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| save_dir (str): Output directory to save best checkpoint. | save_dir (str): Output directory to save best checkpoint. | ||||
| """ | """ | ||||
| PRIORITY = Priority.NORMAL | |||||
| PRIORITY = Priority.LOW | |||||
| rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} | rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -141,9 +143,11 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| rule='max', | rule='max', | ||||
| by_epoch=True, | by_epoch=True, | ||||
| save_optimizer=True, | save_optimizer=True, | ||||
| save_dir=None): | |||||
| save_dir=None, | |||||
| interval=0): | |||||
| assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' | ||||
| super().__init__( | super().__init__( | ||||
| interval=interval, | |||||
| by_epoch=by_epoch, | by_epoch=by_epoch, | ||||
| save_optimizer=save_optimizer, | save_optimizer=save_optimizer, | ||||
| save_dir=save_dir, | save_dir=save_dir, | ||||
| @@ -199,14 +199,14 @@ class Hook: | |||||
| Whether to reach the last epoch | Whether to reach the last epoch | ||||
| Returns: bool | Returns: bool | ||||
| """ | """ | ||||
| return trainer.epoch + 1 == trainer._max_epochs | |||||
| return trainer.epoch + 1 == trainer.max_epochs | |||||
| def is_last_iter(self, trainer): | def is_last_iter(self, trainer): | ||||
| """ | """ | ||||
| Whether to reach the last iteration in the entire training process | Whether to reach the last iteration in the entire training process | ||||
| Returns: bool | Returns: bool | ||||
| """ | """ | ||||
| return trainer.iter + 1 == trainer._max_iters | |||||
| return trainer.iter + 1 == trainer.max_iters | |||||
| def get_triggered_stages(self): | def get_triggered_stages(self): | ||||
| trigger_stages = set() | trigger_stages = set() | ||||
| @@ -40,7 +40,8 @@ def weights_to_cpu(state_dict): | |||||
| def save_checkpoint(model: torch.nn.Module, | def save_checkpoint(model: torch.nn.Module, | ||||
| filename: str, | filename: str, | ||||
| optimizer: Optional[Optimizer] = None, | optimizer: Optional[Optimizer] = None, | ||||
| meta: Optional[dict] = None) -> None: | |||||
| meta: Optional[dict] = None, | |||||
| with_meta: bool = True) -> None: | |||||
| """Save checkpoint to file. | """Save checkpoint to file. | ||||
| The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and | ||||
| @@ -65,10 +66,14 @@ def save_checkpoint(model: torch.nn.Module, | |||||
| # save class name to the meta | # save class name to the meta | ||||
| meta.update(CLASSES=model.CLASSES) | meta.update(CLASSES=model.CLASSES) | ||||
| checkpoint = { | |||||
| 'meta': meta, | |||||
| 'state_dict': weights_to_cpu(model.state_dict()) | |||||
| } | |||||
| if with_meta: | |||||
| checkpoint = { | |||||
| 'meta': meta, | |||||
| 'state_dict': weights_to_cpu(model.state_dict()) | |||||
| } | |||||
| else: | |||||
| checkpoint = weights_to_cpu(model.state_dict()) | |||||
| # save optimizer state dict in the checkpoint | # save optimizer state dict in the checkpoint | ||||
| if isinstance(optimizer, Optimizer): | if isinstance(optimizer, Optimizer): | ||||
| checkpoint['optimizer'] = optimizer.state_dict() | checkpoint['optimizer'] = optimizer.state_dict() | ||||
| @@ -141,7 +146,7 @@ def save_pretrained(model, | |||||
| # Save the ckpt to the save directory | # Save the ckpt to the save directory | ||||
| try: | try: | ||||
| save_function(model, output_ckpt_path) | |||||
| save_function(model, output_ckpt_path, **kwargs) | |||||
| except Exception as e: | except Exception as e: | ||||
| raise Exception( | raise Exception( | ||||
| f'During saving checkpoints, the error of "{type(e).__name__} ' | f'During saving checkpoints, the error of "{type(e).__name__} ' | ||||
| @@ -9,6 +9,7 @@ import sys | |||||
| import tempfile | import tempfile | ||||
| import types | import types | ||||
| from pathlib import Path | from pathlib import Path | ||||
| from types import FunctionType | |||||
| from typing import Dict, Union | from typing import Dict, Union | ||||
| import addict | import addict | ||||
| @@ -638,6 +639,8 @@ class JSONIteratorEncoder(json.JSONEncoder): | |||||
| """ | """ | ||||
| def default(self, obj): | def default(self, obj): | ||||
| if isinstance(obj, FunctionType): | |||||
| return None | |||||
| try: | try: | ||||
| iterable = iter(obj) | iterable = iter(obj) | ||||
| except TypeError: | except TypeError: | ||||
| @@ -128,15 +128,14 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||||
| @unittest.skip | @unittest.skip | ||||
| def test_finetune_cnndm(self): | def test_finetune_cnndm(self): | ||||
| from datasets import load_dataset | |||||
| dataset_dict = load_dataset('ccdv/cnn_dailymail', '3.0.0') | |||||
| train_dataset = dataset_dict['train'] \ | |||||
| .rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \ | |||||
| .remove_columns('id') | |||||
| eval_dataset = dataset_dict['validation'] \ | |||||
| .rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \ | |||||
| .remove_columns('id') | |||||
| num_warmup_steps = 2000 | |||||
| from modelscope.msdatasets import MsDataset | |||||
| dataset_dict = MsDataset.load('dureader_robust_qg') | |||||
| train_dataset = dataset_dict['train'].to_hf_dataset() \ | |||||
| .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) | |||||
| eval_dataset = dataset_dict['validation'].to_hf_dataset() \ | |||||
| .rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'}) | |||||
| num_warmup_steps = 200 | |||||
| os.environ['LOCAL_RANK'] = '0' | |||||
| def noam_lambda(current_step: int): | def noam_lambda(current_step: int): | ||||
| current_step += 1 | current_step += 1 | ||||
| @@ -154,12 +153,11 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||||
| return cfg | return cfg | ||||
| kwargs = dict( | kwargs = dict( | ||||
| model=self.model_id, | |||||
| model='damo/nlp_palm2.0_text-generation_chinese-base', | |||||
| train_dataset=train_dataset, | train_dataset=train_dataset, | ||||
| eval_dataset=eval_dataset, | eval_dataset=eval_dataset, | ||||
| work_dir=self.tmp_dir, | work_dir=self.tmp_dir, | ||||
| cfg_modify_fn=cfg_modify_fn, | |||||
| model_revision='beta') | |||||
| cfg_modify_fn=cfg_modify_fn) | |||||
| trainer = build_trainer( | trainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | name=Trainers.nlp_base_trainer, default_args=kwargs) | ||||
| trainer.train() | trainer.train() | ||||