diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 786599bb..46dc5c8b 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -79,7 +79,7 @@ class TestOfaTrainer(unittest.TestCase): with open(config_file, 'w') as writer: json.dump(self.finetune_cfg, writer) - pretrained_model = 'damo/ofa_image-caption_coco_large_en' + pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' args = dict( model=pretrained_model, work_dir=WORKSPACE, @@ -97,8 +97,8 @@ class TestOfaTrainer(unittest.TestCase): trainer.train() self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, - os.path.join(WORKSPACE, 'output')) - shutil.rmtree(WORKSPACE) + os.listdir(os.path.join(WORKSPACE, 'output'))) + # shutil.rmtree(WORKSPACE) if __name__ == '__main__':