Browse Source

override pipeline by tasks name after finetune done, avoid case like fill mask pipeline with a text cls task

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10554512
master
zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
9df3f5c41f
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      modelscope/trainers/hooks/checkpoint_hook.py

+ 6
- 1
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -172,12 +172,17 @@ class CheckpointHook(Hook):
else:
model = trainer.model

config = trainer.cfg.to_dict()
# override pipeline by tasks name after finetune done,
# avoid case like fill mask pipeline with a text cls task
config['pipeline'] = {'type': config['task']}

if hasattr(model, 'save_pretrained'):
model.save_pretrained(
output_dir,
ModelFile.TORCH_MODEL_BIN_FILE,
save_function=save_checkpoint,
config=trainer.cfg.to_dict(),
config=config,
with_meta=False)

def after_train_iter(self, trainer):


Loading…
Cancel
Save