jiangnana.jnn 3 years ago
parent
commit
6f5b864735
2 changed files with 6 additions and 6 deletions
  1. +4
    -4
      tests/trainers/test_trainer.py
  2. +2
    -2
      tests/trainers/test_trainer_gpu.py

+ 4
- 4
tests/trainers/test_trainer.py View File

@@ -13,7 +13,7 @@ from torch import nn
from torch.optim import SGD from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR


from modelscope.metainfo import Trainers
from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys from modelscope.metrics.builder import MetricKeys
from modelscope.msdatasets import MsDataset from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer from modelscope.trainers import build_trainer
@@ -102,7 +102,7 @@ class TrainerTest(unittest.TestCase):
'workers_per_gpu': 1, 'workers_per_gpu': 1,
'shuffle': False 'shuffle': False
}, },
'metrics': ['seq-cls-metric']
'metrics': [Metrics.seq_cls_metric]
} }
} }
config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
@@ -156,7 +156,7 @@ class TrainerTest(unittest.TestCase):
'workers_per_gpu': 1, 'workers_per_gpu': 1,
'shuffle': False 'shuffle': False
}, },
'metrics': ['seq-cls-metric']
'metrics': [Metrics.seq_cls_metric]
} }
} }


@@ -206,7 +206,7 @@ class TrainerTest(unittest.TestCase):
'workers_per_gpu': 1, 'workers_per_gpu': 1,
'shuffle': False 'shuffle': False
}, },
'metrics': ['seq-cls-metric']
'metrics': [Metrics.seq_cls_metric]
} }
} }




+ 2
- 2
tests/trainers/test_trainer_gpu.py View File

@@ -12,7 +12,7 @@ from torch import nn
from torch.optim import SGD from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR


from modelscope.metainfo import Trainers
from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys from modelscope.metrics.builder import MetricKeys
from modelscope.trainers import EpochBasedTrainer, build_trainer from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile
@@ -60,7 +60,7 @@ def train_func(work_dir, dist=False):
'workers_per_gpu': 1, 'workers_per_gpu': 1,
'shuffle': False 'shuffle': False
}, },
'metrics': ['seq_cls_metric']
'metrics': [Metrics.seq_cls_metric]
} }
} }




Loading…
Cancel
Save