| @@ -79,7 +79,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| with open(config_file, 'w') as writer: | with open(config_file, 'w') as writer: | ||||
| json.dump(self.finetune_cfg, writer) | json.dump(self.finetune_cfg, writer) | ||||
| pretrained_model = 'damo/ofa_image-caption_coco_large_en' | |||||
| pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' | |||||
| args = dict( | args = dict( | ||||
| model=pretrained_model, | model=pretrained_model, | ||||
| work_dir=WORKSPACE, | work_dir=WORKSPACE, | ||||
| @@ -97,8 +97,8 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| trainer.train() | trainer.train() | ||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, | ||||
| os.path.join(WORKSPACE, 'output')) | |||||
| shutil.rmtree(WORKSPACE) | |||||
| os.listdir(os.path.join(WORKSPACE, 'output'))) | |||||
| # shutil.rmtree(WORKSPACE) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||