diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py index 78f60f9b..f469c218 100755 --- a/modelscope/models/multi_modal/mplug/modeling_mplug.py +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -1867,11 +1867,13 @@ class MPlug(PreTrainedModel): ModelFile.TORCH_MODEL_BIN_FILE) checkpoint = torch.load(checkpoint_path, map_location='cpu') if 'model' in checkpoint: - state_dict = checkpoint['model'] - else: - state_dict = checkpoint['module'] + checkpoint = checkpoint['model'] + checkpoint = { + k.replace('model.', ''): v + for k, v in checkpoint.items() + } - msg = model.load_state_dict(state_dict, strict=False) + msg = model.load_state_dict(checkpoint, strict=False) print('load checkpoint from %s' % checkpoint_path) print(msg) return model diff --git a/modelscope/models/nlp/gpt3/modeling_gpt3.py b/modelscope/models/nlp/gpt3/modeling_gpt3.py index 4e30f697..69e9ba7c 100644 --- a/modelscope/models/nlp/gpt3/modeling_gpt3.py +++ b/modelscope/models/nlp/gpt3/modeling_gpt3.py @@ -339,5 +339,9 @@ class GPT3Model(PreTrainedModel): state_dict_file = os.path.join(pretrained_model_name_or_path, ModelFile.TORCH_MODEL_BIN_FILE) state_dict = torch.load(state_dict_file) + state_dict = { + k.replace('model.language_model', 'language_model'): v + for k, v in state_dict.items() + } model.load_state_dict(state_dict) return model diff --git a/modelscope/models/nlp/palm_v2/modeling_palm.py b/modelscope/models/nlp/palm_v2/modeling_palm.py index ff6fd732..99b00454 100644 --- a/modelscope/models/nlp/palm_v2/modeling_palm.py +++ b/modelscope/models/nlp/palm_v2/modeling_palm.py @@ -592,11 +592,11 @@ class AbsSummarizer(PalmPreTrainedModel): # Model self.generator.dense.weight = self.decoder.embeddings.weight if checkpoint is not None: - for key in list(checkpoint['model'].keys()): - checkpoint['model'][key.replace('module.', - '')] = checkpoint['model'][key] - msg = self.load_state_dict(checkpoint['model'], strict=False) - print(msg) + if 'model' in checkpoint: + checkpoint = checkpoint['model'] + for key in list(checkpoint.keys()): + checkpoint[key.replace('model.palm.', '')] = checkpoint[key] + self.load_state_dict(checkpoint, strict=False) else: for module in self.decoder.modules(): if isinstance(module, (nn.Linear, nn.Embedding)): @@ -734,7 +734,7 @@ class PalmForConditionalGeneration(PalmPreTrainedModel): return addict.Dict(loss=loss) -class Translator(nn.Module): +class Translator(object): """ Uses a model to translate a batch of sentences. """ @@ -1298,8 +1298,8 @@ class Translator(nn.Module): return results - def forward(self, input_ids: torch.Tensor, - attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: + def __call__(self, input_ids: torch.Tensor, + attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: batch = self.Batch( batch_size=input_ids.size()[0], src=input_ids, diff --git a/tests/trainers/test_finetune_mplug.py b/tests/trainers/test_finetune_mplug.py index b46dbf45..72196fba 100644 --- a/tests/trainers/test_finetune_mplug.py +++ b/tests/trainers/test_finetune_mplug.py @@ -41,6 +41,18 @@ class TestFinetuneMPlug(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() + def _cfg_modify_fn(self, cfg): + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': self.max_epochs + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }] + return cfg + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_with_caption(self): kwargs = dict( @@ -48,15 +60,12 @@ class TestFinetuneMPlug(unittest.TestCase): train_dataset=self.train_dataset, eval_dataset=self.test_dataset, max_epochs=self.max_epochs, - work_dir=self.tmp_dir) + work_dir=self.tmp_dir, + cfg_modify_fn=self._cfg_modify_fn) trainer: EpochBasedTrainer = build_trainer( name=Trainers.nlp_base_trainer, default_args=kwargs) trainer.train() - results_files = os.listdir(self.tmp_dir) - self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(self.max_epochs): - self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_caption_with_model_and_args(self): @@ -86,15 +95,12 @@ class TestFinetuneMPlug(unittest.TestCase): train_dataset=self.train_dataset, eval_dataset=self.test_dataset, max_epochs=self.max_epochs, - work_dir=self.tmp_dir) + work_dir=self.tmp_dir, + cfg_modify_fn=self._cfg_modify_fn) trainer: EpochBasedTrainer = build_trainer( name=Trainers.nlp_base_trainer, default_args=kwargs) trainer.train() - results_files = os.listdir(self.tmp_dir) - self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(self.max_epochs): - self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_vqa_with_model_and_args(self): @@ -124,15 +130,12 @@ class TestFinetuneMPlug(unittest.TestCase): train_dataset=self.train_dataset, eval_dataset=self.test_dataset, max_epochs=self.max_epochs, - work_dir=self.tmp_dir) + work_dir=self.tmp_dir, + cfg_modify_fn=self._cfg_modify_fn) trainer: EpochBasedTrainer = build_trainer( name=Trainers.nlp_base_trainer, default_args=kwargs) trainer.train() - results_files = os.listdir(self.tmp_dir) - self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(self.max_epochs): - self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_retrieval_with_model_and_args(self):