| @@ -5,6 +5,7 @@ import unittest | |||||
| import json | import json | ||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import DownloadMode, ModelFile | from modelscope.utils.constant import DownloadMode, ModelFile | ||||
| @@ -95,7 +96,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| split='test[:20]', | split='test[:20]', | ||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | ||||
| cfg_file=config_file) | cfg_file=config_file) | ||||
| trainer = build_trainer(name='ofa', default_args=args) | |||||
| trainer = build_trainer(name=Trainers.ofa, default_args=args) | |||||
| trainer.train() | trainer.train() | ||||
| self.assertIn( | self.assertIn( | ||||