diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py index 03b13674..b7639024 100644 --- a/tests/trainers/test_trainer.py +++ b/tests/trainers/test_trainer.py @@ -13,7 +13,7 @@ from torch import nn from torch.optim import SGD from torch.optim.lr_scheduler import StepLR -from modelscope.metainfo import Trainers +from modelscope.metainfo import Metrics, Trainers from modelscope.metrics.builder import MetricKeys from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer @@ -102,7 +102,7 @@ class TrainerTest(unittest.TestCase): 'workers_per_gpu': 1, 'shuffle': False }, - 'metrics': ['seq-cls-metric'] + 'metrics': [Metrics.seq_cls_metric] } } config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) @@ -156,7 +156,7 @@ class TrainerTest(unittest.TestCase): 'workers_per_gpu': 1, 'shuffle': False }, - 'metrics': ['seq-cls-metric'] + 'metrics': [Metrics.seq_cls_metric] } } @@ -206,7 +206,7 @@ class TrainerTest(unittest.TestCase): 'workers_per_gpu': 1, 'shuffle': False }, - 'metrics': ['seq-cls-metric'] + 'metrics': [Metrics.seq_cls_metric] } } diff --git a/tests/trainers/test_trainer_gpu.py b/tests/trainers/test_trainer_gpu.py index 6502a68d..30390a68 100644 --- a/tests/trainers/test_trainer_gpu.py +++ b/tests/trainers/test_trainer_gpu.py @@ -12,7 +12,7 @@ from torch import nn from torch.optim import SGD from torch.optim.lr_scheduler import StepLR -from modelscope.metainfo import Trainers +from modelscope.metainfo import Metrics, Trainers from modelscope.metrics.builder import MetricKeys from modelscope.trainers import EpochBasedTrainer, build_trainer from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile @@ -60,7 +60,7 @@ def train_func(work_dir, dist=False): 'workers_per_gpu': 1, 'shuffle': False }, - 'metrics': ['seq_cls_metric'] + 'metrics': [Metrics.seq_cls_metric] } }