|
|
@@ -41,6 +41,18 @@ class TestFinetuneMPlug(unittest.TestCase): |
|
|
shutil.rmtree(self.tmp_dir) |
|
|
shutil.rmtree(self.tmp_dir) |
|
|
super().tearDown() |
|
|
super().tearDown() |
|
|
|
|
|
|
|
|
|
|
|
def _cfg_modify_fn(self, cfg): |
|
|
|
|
|
cfg.train.hooks = [{ |
|
|
|
|
|
'type': 'CheckpointHook', |
|
|
|
|
|
'interval': self.max_epochs |
|
|
|
|
|
}, { |
|
|
|
|
|
'type': 'TextLoggerHook', |
|
|
|
|
|
'interval': 1 |
|
|
|
|
|
}, { |
|
|
|
|
|
'type': 'IterTimerHook' |
|
|
|
|
|
}] |
|
|
|
|
|
return cfg |
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
def test_trainer_with_caption(self): |
|
|
def test_trainer_with_caption(self): |
|
|
kwargs = dict( |
|
|
kwargs = dict( |
|
|
@@ -48,15 +60,12 @@ class TestFinetuneMPlug(unittest.TestCase): |
|
|
train_dataset=self.train_dataset, |
|
|
train_dataset=self.train_dataset, |
|
|
eval_dataset=self.test_dataset, |
|
|
eval_dataset=self.test_dataset, |
|
|
max_epochs=self.max_epochs, |
|
|
max_epochs=self.max_epochs, |
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
|
work_dir=self.tmp_dir, |
|
|
|
|
|
cfg_modify_fn=self._cfg_modify_fn) |
|
|
|
|
|
|
|
|
trainer: EpochBasedTrainer = build_trainer( |
|
|
trainer: EpochBasedTrainer = build_trainer( |
|
|
name=Trainers.nlp_base_trainer, default_args=kwargs) |
|
|
name=Trainers.nlp_base_trainer, default_args=kwargs) |
|
|
trainer.train() |
|
|
trainer.train() |
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
|
|
for i in range(self.max_epochs): |
|
|
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
def test_trainer_with_caption_with_model_and_args(self): |
|
|
def test_trainer_with_caption_with_model_and_args(self): |
|
|
@@ -86,15 +95,12 @@ class TestFinetuneMPlug(unittest.TestCase): |
|
|
train_dataset=self.train_dataset, |
|
|
train_dataset=self.train_dataset, |
|
|
eval_dataset=self.test_dataset, |
|
|
eval_dataset=self.test_dataset, |
|
|
max_epochs=self.max_epochs, |
|
|
max_epochs=self.max_epochs, |
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
|
work_dir=self.tmp_dir, |
|
|
|
|
|
cfg_modify_fn=self._cfg_modify_fn) |
|
|
|
|
|
|
|
|
trainer: EpochBasedTrainer = build_trainer( |
|
|
trainer: EpochBasedTrainer = build_trainer( |
|
|
name=Trainers.nlp_base_trainer, default_args=kwargs) |
|
|
name=Trainers.nlp_base_trainer, default_args=kwargs) |
|
|
trainer.train() |
|
|
trainer.train() |
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
|
|
for i in range(self.max_epochs): |
|
|
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
def test_trainer_with_vqa_with_model_and_args(self): |
|
|
def test_trainer_with_vqa_with_model_and_args(self): |
|
|
@@ -124,15 +130,12 @@ class TestFinetuneMPlug(unittest.TestCase): |
|
|
train_dataset=self.train_dataset, |
|
|
train_dataset=self.train_dataset, |
|
|
eval_dataset=self.test_dataset, |
|
|
eval_dataset=self.test_dataset, |
|
|
max_epochs=self.max_epochs, |
|
|
max_epochs=self.max_epochs, |
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
|
work_dir=self.tmp_dir, |
|
|
|
|
|
cfg_modify_fn=self._cfg_modify_fn) |
|
|
|
|
|
|
|
|
trainer: EpochBasedTrainer = build_trainer( |
|
|
trainer: EpochBasedTrainer = build_trainer( |
|
|
name=Trainers.nlp_base_trainer, default_args=kwargs) |
|
|
name=Trainers.nlp_base_trainer, default_args=kwargs) |
|
|
trainer.train() |
|
|
trainer.train() |
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
|
|
for i in range(self.max_epochs): |
|
|
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
def test_trainer_with_retrieval_with_model_and_args(self): |
|
|
def test_trainer_with_retrieval_with_model_and_args(self): |
|
|
|