Browse Source

[to #42322933] Support saving the best checkpoint for inference

1. Support saving the best checkpoint for inference
2. Fix a bug that _max_iters field does not exist in trainer
3. Fix a bug that function in lambda_lr field cannot be saved to file
4. Fix a bug that save_pretrained would not be called by iterating
5. Fix a bug that interval is not passed from BestCkptHook's init
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9972765
master
yuze.zyz yingda.chen 3 years ago
parent
commit
a9c14e4ead
5 changed files with 50 additions and 40 deletions
  1. +24
    -20
      modelscope/trainers/hooks/checkpoint_hook.py
  2. +2
    -2
      modelscope/trainers/hooks/hook.py
  3. +11
    -6
      modelscope/utils/checkpoint.py
  4. +3
    -0
      modelscope/utils/config.py
  5. +10
    -12
      tests/trainers/test_finetune_text_generation.py

+ 24
- 20
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -27,7 +27,7 @@ class CheckpointHook(Hook):
save_last (bool): Whether to save the last checkpoint. Default: True. save_last (bool): Whether to save the last checkpoint. Default: True.
""" """


PRIORITY = Priority.NORMAL
PRIORITY = Priority.LOW


def __init__(self, def __init__(self,
interval=0, interval=0,
@@ -75,25 +75,27 @@ class CheckpointHook(Hook):
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth')


save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
self._save_pretrained(trainer)
if (self.is_last_epoch(trainer)
and self.by_epoch) or (self.is_last_iter(trainer)
and not self.by_epoch):
self._save_pretrained(trainer)


def _save_pretrained(self, trainer): def _save_pretrained(self, trainer):
if self.is_last_epoch(trainer) and self.by_epoch:
output_dir = os.path.join(self.save_dir,
ModelFile.TRAIN_OUTPUT_DIR)
from modelscope.trainers.parallel.utils import is_parallel

if is_parallel(trainer.model):
model = trainer.model.module
else:
model = trainer.model

if hasattr(model, 'save_pretrained'):
model.save_pretrained(
output_dir,
ModelFile.TORCH_MODEL_BIN_FILE,
save_function=save_checkpoint,
config=trainer.cfg.to_dict())
output_dir = os.path.join(self.save_dir, ModelFile.TRAIN_OUTPUT_DIR)
from modelscope.trainers.parallel.utils import is_parallel

if is_parallel(trainer.model):
model = trainer.model.module
else:
model = trainer.model

if hasattr(model, 'save_pretrained'):
model.save_pretrained(
output_dir,
ModelFile.TORCH_MODEL_BIN_FILE,
save_function=save_checkpoint,
config=trainer.cfg.to_dict(),
with_meta=False)


def after_train_iter(self, trainer): def after_train_iter(self, trainer):
if self.by_epoch: if self.by_epoch:
@@ -133,7 +135,7 @@ class BestCkptSaverHook(CheckpointHook):
save_dir (str): Output directory to save best checkpoint. save_dir (str): Output directory to save best checkpoint.
""" """


PRIORITY = Priority.NORMAL
PRIORITY = Priority.LOW
rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y} rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}


def __init__(self, def __init__(self,
@@ -141,9 +143,11 @@ class BestCkptSaverHook(CheckpointHook):
rule='max', rule='max',
by_epoch=True, by_epoch=True,
save_optimizer=True, save_optimizer=True,
save_dir=None):
save_dir=None,
interval=0):
assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.' assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.'
super().__init__( super().__init__(
interval=interval,
by_epoch=by_epoch, by_epoch=by_epoch,
save_optimizer=save_optimizer, save_optimizer=save_optimizer,
save_dir=save_dir, save_dir=save_dir,


+ 2
- 2
modelscope/trainers/hooks/hook.py View File

@@ -199,14 +199,14 @@ class Hook:
Whether to reach the last epoch Whether to reach the last epoch
Returns: bool Returns: bool
""" """
return trainer.epoch + 1 == trainer._max_epochs
return trainer.epoch + 1 == trainer.max_epochs


def is_last_iter(self, trainer): def is_last_iter(self, trainer):
""" """
Whether to reach the last iteration in the entire training process Whether to reach the last iteration in the entire training process
Returns: bool Returns: bool
""" """
return trainer.iter + 1 == trainer._max_iters
return trainer.iter + 1 == trainer.max_iters


def get_triggered_stages(self): def get_triggered_stages(self):
trigger_stages = set() trigger_stages = set()


+ 11
- 6
modelscope/utils/checkpoint.py View File

@@ -40,7 +40,8 @@ def weights_to_cpu(state_dict):
def save_checkpoint(model: torch.nn.Module, def save_checkpoint(model: torch.nn.Module,
filename: str, filename: str,
optimizer: Optional[Optimizer] = None, optimizer: Optional[Optimizer] = None,
meta: Optional[dict] = None) -> None:
meta: Optional[dict] = None,
with_meta: bool = True) -> None:
"""Save checkpoint to file. """Save checkpoint to file.


The checkpoint will have 3 fields: ``meta``, ``state_dict`` and The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
@@ -65,10 +66,14 @@ def save_checkpoint(model: torch.nn.Module,
# save class name to the meta # save class name to the meta
meta.update(CLASSES=model.CLASSES) meta.update(CLASSES=model.CLASSES)


checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
}
if with_meta:
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
}
else:
checkpoint = weights_to_cpu(model.state_dict())

# save optimizer state dict in the checkpoint # save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer): if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict() checkpoint['optimizer'] = optimizer.state_dict()
@@ -141,7 +146,7 @@ def save_pretrained(model,


# Save the ckpt to the save directory # Save the ckpt to the save directory
try: try:
save_function(model, output_ckpt_path)
save_function(model, output_ckpt_path, **kwargs)
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f'During saving checkpoints, the error of "{type(e).__name__} ' f'During saving checkpoints, the error of "{type(e).__name__} '


+ 3
- 0
modelscope/utils/config.py View File

@@ -9,6 +9,7 @@ import sys
import tempfile import tempfile
import types import types
from pathlib import Path from pathlib import Path
from types import FunctionType
from typing import Dict, Union from typing import Dict, Union


import addict import addict
@@ -638,6 +639,8 @@ class JSONIteratorEncoder(json.JSONEncoder):
""" """


def default(self, obj): def default(self, obj):
if isinstance(obj, FunctionType):
return None
try: try:
iterable = iter(obj) iterable = iter(obj)
except TypeError: except TypeError:


+ 10
- 12
tests/trainers/test_finetune_text_generation.py View File

@@ -128,15 +128,14 @@ class TestFinetuneTextGeneration(unittest.TestCase):


@unittest.skip @unittest.skip
def test_finetune_cnndm(self): def test_finetune_cnndm(self):
from datasets import load_dataset
dataset_dict = load_dataset('ccdv/cnn_dailymail', '3.0.0')
train_dataset = dataset_dict['train'] \
.rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \
.remove_columns('id')
eval_dataset = dataset_dict['validation'] \
.rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \
.remove_columns('id')
num_warmup_steps = 2000
from modelscope.msdatasets import MsDataset
dataset_dict = MsDataset.load('dureader_robust_qg')
train_dataset = dataset_dict['train'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
eval_dataset = dataset_dict['validation'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
num_warmup_steps = 200
os.environ['LOCAL_RANK'] = '0'


def noam_lambda(current_step: int): def noam_lambda(current_step: int):
current_step += 1 current_step += 1
@@ -154,12 +153,11 @@ class TestFinetuneTextGeneration(unittest.TestCase):
return cfg return cfg


kwargs = dict( kwargs = dict(
model=self.model_id,
model='damo/nlp_palm2.0_text-generation_chinese-base',
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
work_dir=self.tmp_dir, work_dir=self.tmp_dir,
cfg_modify_fn=cfg_modify_fn,
model_revision='beta')
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer( trainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs) name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train() trainer.train()


Loading…
Cancel
Save