Browse Source

[to #42322933] Support saving the best checkpoint for inference

1. Support saving the best checkpoint for inference
2. Fix a bug that _max_iters field does not exist in trainer
3. Fix a bug that function in lambda_lr field cannot be saved to file
4. Fix a bug that save_pretrained would not be called by iterating
5. Fix a bug that interval is not passed from BestCkptHook's init
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9972765
master
yuze.zyz yingda.chen 3 years ago
parent
commit
a9c14e4ead
5 changed files with 50 additions and 40 deletions
  1. +24
    -20
      modelscope/trainers/hooks/checkpoint_hook.py
  2. +2
    -2
      modelscope/trainers/hooks/hook.py
  3. +11
    -6
      modelscope/utils/checkpoint.py
  4. +3
    -0
      modelscope/utils/config.py
  5. +10
    -12
      tests/trainers/test_finetune_text_generation.py

+ 24
- 20
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -27,7 +27,7 @@ class CheckpointHook(Hook):
save_last (bool): Whether to save the last checkpoint. Default: True.
"""

PRIORITY = Priority.NORMAL
PRIORITY = Priority.LOW

def __init__(self,
interval=0,
@@ -75,25 +75,27 @@ class CheckpointHook(Hook):
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth')

save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
self._save_pretrained(trainer)
if (self.is_last_epoch(trainer)
and self.by_epoch) or (self.is_last_iter(trainer)
and not self.by_epoch):
self._save_pretrained(trainer)

def _save_pretrained(self, trainer):
if self.is_last_epoch(trainer) and self.by_epoch:
output_dir = os.path.join(self.save_dir,
ModelFile.TRAIN_OUTPUT_DIR)
from modelscope.trainers.parallel.utils import is_parallel

if is_parallel(trainer.model):
model = trainer.model.module
else:
model = trainer.model

if hasattr(model, 'save_pretrained'):
model.save_pretrained(
output_dir,
ModelFile.TORCH_MODEL_BIN_FILE,
save_function=save_checkpoint,
config=trainer.cfg.to_dict())
output_dir = os.path.join(self.save_dir, ModelFile.TRAIN_OUTPUT_DIR)
from modelscope.trainers.parallel.utils import is_parallel

if is_parallel(trainer.model):
model = trainer.model.module
else:
model = trainer.model

if hasattr(model, 'save_pretrained'):
model.save_pretrained(
output_dir,
ModelFile.TORCH_MODEL_BIN_FILE,
save_function=save_checkpoint,
config=trainer.cfg.to_dict(),
with_meta=False)

def after_train_iter(self, trainer):
if self.by_epoch:
@@ -133,7 +135,7 @@ class BestCkptSaverHook(CheckpointHook):
save_dir (str): Output directory to save best checkpoint.
"""

PRIORITY = Priority.NORMAL
PRIORITY = Priority.LOW
rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}

def __init__(self,
@@ -141,9 +143,11 @@ class BestCkptSaverHook(CheckpointHook):
rule='max',
by_epoch=True,
save_optimizer=True,
save_dir=None):
save_dir=None,
interval=0):
assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.'
super().__init__(
interval=interval,
by_epoch=by_epoch,
save_optimizer=save_optimizer,
save_dir=save_dir,


+ 2
- 2
modelscope/trainers/hooks/hook.py View File

@@ -199,14 +199,14 @@ class Hook:
Whether to reach the last epoch
Returns: bool
"""
return trainer.epoch + 1 == trainer._max_epochs
return trainer.epoch + 1 == trainer.max_epochs

def is_last_iter(self, trainer):
"""
Whether to reach the last iteration in the entire training process
Returns: bool
"""
return trainer.iter + 1 == trainer._max_iters
return trainer.iter + 1 == trainer.max_iters

def get_triggered_stages(self):
trigger_stages = set()


+ 11
- 6
modelscope/utils/checkpoint.py View File

@@ -40,7 +40,8 @@ def weights_to_cpu(state_dict):
def save_checkpoint(model: torch.nn.Module,
filename: str,
optimizer: Optional[Optimizer] = None,
meta: Optional[dict] = None) -> None:
meta: Optional[dict] = None,
with_meta: bool = True) -> None:
"""Save checkpoint to file.

The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
@@ -65,10 +66,14 @@ def save_checkpoint(model: torch.nn.Module,
# save class name to the meta
meta.update(CLASSES=model.CLASSES)

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
}
if with_meta:
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
}
else:
checkpoint = weights_to_cpu(model.state_dict())

# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
@@ -141,7 +146,7 @@ def save_pretrained(model,

# Save the ckpt to the save directory
try:
save_function(model, output_ckpt_path)
save_function(model, output_ckpt_path, **kwargs)
except Exception as e:
raise Exception(
f'During saving checkpoints, the error of "{type(e).__name__} '


+ 3
- 0
modelscope/utils/config.py View File

@@ -9,6 +9,7 @@ import sys
import tempfile
import types
from pathlib import Path
from types import FunctionType
from typing import Dict, Union

import addict
@@ -638,6 +639,8 @@ class JSONIteratorEncoder(json.JSONEncoder):
"""

def default(self, obj):
if isinstance(obj, FunctionType):
return None
try:
iterable = iter(obj)
except TypeError:


+ 10
- 12
tests/trainers/test_finetune_text_generation.py View File

@@ -128,15 +128,14 @@ class TestFinetuneTextGeneration(unittest.TestCase):

@unittest.skip
def test_finetune_cnndm(self):
from datasets import load_dataset
dataset_dict = load_dataset('ccdv/cnn_dailymail', '3.0.0')
train_dataset = dataset_dict['train'] \
.rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \
.remove_columns('id')
eval_dataset = dataset_dict['validation'] \
.rename_columns({'article': 'src_txt', 'highlights': 'tgt_txt'}) \
.remove_columns('id')
num_warmup_steps = 2000
from modelscope.msdatasets import MsDataset
dataset_dict = MsDataset.load('dureader_robust_qg')
train_dataset = dataset_dict['train'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
eval_dataset = dataset_dict['validation'].to_hf_dataset() \
.rename_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
num_warmup_steps = 200
os.environ['LOCAL_RANK'] = '0'

def noam_lambda(current_step: int):
current_step += 1
@@ -154,12 +153,11 @@ class TestFinetuneTextGeneration(unittest.TestCase):
return cfg

kwargs = dict(
model=self.model_id,
model='damo/nlp_palm2.0_text-generation_chinese-base',
train_dataset=train_dataset,
eval_dataset=eval_dataset,
work_dir=self.tmp_dir,
cfg_modify_fn=cfg_modify_fn,
model_revision='beta')
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()


Loading…
Cancel
Save