|
|
|
@@ -172,12 +172,17 @@ class CheckpointHook(Hook): |
|
|
|
else: |
|
|
|
model = trainer.model |
|
|
|
|
|
|
|
config = trainer.cfg.to_dict() |
|
|
|
# override pipeline by tasks name after finetune done, |
|
|
|
# avoid case like fill mask pipeline with a text cls task |
|
|
|
config['pipeline'] = {'type': config['task']} |
|
|
|
|
|
|
|
if hasattr(model, 'save_pretrained'): |
|
|
|
model.save_pretrained( |
|
|
|
output_dir, |
|
|
|
ModelFile.TORCH_MODEL_BIN_FILE, |
|
|
|
save_function=save_checkpoint, |
|
|
|
config=trainer.cfg.to_dict(), |
|
|
|
config=config, |
|
|
|
with_meta=False) |
|
|
|
|
|
|
|
def after_train_iter(self, trainer): |
|
|
|
|