From 1a22fa02228f0884bcb48bdaccc4f90a24c85009 Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Fri, 2 Sep 2022 14:06:08 +0800 Subject: [PATCH] fix trainer unittest Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9970626 * fix trainer unittest --- tests/trainers/test_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py index 17fa97f9..86909f74 100644 --- a/tests/trainers/test_trainer.py +++ b/tests/trainers/test_trainer.py @@ -17,7 +17,7 @@ from modelscope.metrics.builder import MetricKeys from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.trainers.base import DummyTrainer -from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile +from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks from modelscope.utils.test_utils import create_dummy_test_dataset, test_level @@ -67,6 +67,7 @@ class TrainerTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_train_0(self): json_cfg = { + 'task': Tasks.image_classification, 'train': { 'work_dir': self.tmp_dir, @@ -141,6 +142,7 @@ class TrainerTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_train_1(self): json_cfg = { + 'task': Tasks.image_classification, 'train': { 'work_dir': self.tmp_dir, @@ -201,6 +203,7 @@ class TrainerTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_train_with_default_config(self): json_cfg = { + 'task': Tasks.image_classification, 'train': { 'work_dir': self.tmp_dir, 'dataloader': { @@ -319,6 +322,7 @@ class TrainerTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_train_with_iters_per_epoch(self): json_cfg = { + 'task': Tasks.image_classification, 'train': { 'work_dir': self.tmp_dir, 'dataloader': {