|
|
|
@@ -5,11 +5,13 @@ import tempfile |
|
|
|
import unittest |
|
|
|
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
|
from modelscope.metainfo import Metrics |
|
|
|
from modelscope.models.nlp.sbert_for_sequence_classification import \ |
|
|
|
SbertTextClassfier |
|
|
|
from modelscope.msdatasets import MsDataset |
|
|
|
from modelscope.trainers import build_trainer |
|
|
|
from modelscope.utils.constant import ModelFile |
|
|
|
from modelscope.utils.hub import read_config |
|
|
|
from modelscope.utils.test_utils import test_level |
|
|
|
|
|
|
|
|
|
|
|
@@ -73,6 +75,36 @@ class TestTrainerWithNlp(unittest.TestCase): |
|
|
|
for i in range(10): |
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
eval_results = trainer.evaluate( |
|
|
|
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) |
|
|
|
self.assertTrue(Metrics.accuracy in eval_results) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
def test_trainer_with_user_defined_config(self): |
|
|
|
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' |
|
|
|
cfg = read_config(model_id, revision='beta') |
|
|
|
cfg.train.max_epochs = 20 |
|
|
|
cfg.train.work_dir = self.tmp_dir |
|
|
|
cfg_file = os.path.join(self.tmp_dir, 'config.json') |
|
|
|
cfg.dump(cfg_file) |
|
|
|
kwargs = dict( |
|
|
|
model=model_id, |
|
|
|
train_dataset=self.dataset, |
|
|
|
eval_dataset=self.dataset, |
|
|
|
cfg_file=cfg_file, |
|
|
|
model_revision='beta') |
|
|
|
|
|
|
|
trainer = build_trainer(default_args=kwargs) |
|
|
|
trainer.train() |
|
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
for i in range(20): |
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
eval_results = trainer.evaluate( |
|
|
|
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) |
|
|
|
self.assertTrue(Metrics.accuracy in eval_results) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_trainer_with_model_and_args(self): |
|
|
|
tmp_dir = tempfile.TemporaryDirectory().name |
|
|
|
|