diff --git a/.dev_scripts/ci_container_test.sh b/.dev_scripts/ci_container_test.sh index 98e9f88d..2f18aff7 100644 --- a/.dev_scripts/ci_container_test.sh +++ b/.dev_scripts/ci_container_test.sh @@ -19,4 +19,11 @@ fi # test with install python setup.py install -python tests/run.py +if [ $# -eq 0 ]; then + ci_command="python tests/run.py --subprocess" +else + ci_command="$@" +fi +echo "Running case with command: $ci_command" +$ci_command +#python tests/run.py --isolated_cases test_text_to_speech.py test_multi_modal_embedding.py test_ofa_tasks.py test_video_summarization.py diff --git a/.dev_scripts/dockerci.sh b/.dev_scripts/dockerci.sh index 95dd0e1a..dbb79514 100644 --- a/.dev_scripts/dockerci.sh +++ b/.dev_scripts/dockerci.sh @@ -7,7 +7,8 @@ gpus='7 6 5 4 3 2 1 0' cpu_sets='0-7 8-15 16-23 24-30 31-37 38-44 45-51 52-58' cpu_sets_arr=($cpu_sets) is_get_file_lock=false -CI_COMMAND=${CI_COMMAND:-'bash .dev_scripts/ci_container_test.sh'} +CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh $RUN_CASE_COMMAND} +echo "ci command: $CI_COMMAND" for gpu in $gpus do exec {lock_fd}>"/tmp/gpu$gpu" || exit 1 diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index b5527734..338c6333 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -1,9 +1,11 @@ +import math import os from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union) import json import numpy as np +import torch from datasets import Dataset, DatasetDict from datasets import load_dataset as hf_load_dataset from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE @@ -40,6 +42,46 @@ def format_list(para) -> List: return para +class MsIterableDataset(torch.utils.data.IterableDataset): + + def __init__(self, dataset: Iterable, preprocessor_list, retained_columns, + columns): + super(MsIterableDataset).__init__() + self.dataset = dataset + self.preprocessor_list = preprocessor_list + self.retained_columns = retained_columns + self.columns = columns + + def __len__(self): + return len(self.dataset) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: # single-process data loading + iter_start = 0 + iter_end = len(self.dataset) + else: # in a worker process + per_worker = math.ceil( + len(self.dataset) / float(worker_info.num_workers)) + worker_id = worker_info.id + iter_start = worker_id * per_worker + iter_end = min(iter_start + per_worker, len(self.dataset)) + + for idx in range(iter_start, iter_end): + item_dict = self.dataset[idx] + res = { + k: np.array(item_dict[k]) + for k in self.columns if k in self.retained_columns + } + for preprocessor in self.preprocessor_list: + res.update({ + k: np.array(v) + for k, v in preprocessor(item_dict).items() + if k in self.retained_columns + }) + yield res + + class MsDataset: """ ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to @@ -318,45 +360,8 @@ class MsDataset: continue retained_columns.append(k) - import math - import torch - - class MsIterableDataset(torch.utils.data.IterableDataset): - - def __init__(self, dataset: Iterable): - super(MsIterableDataset).__init__() - self.dataset = dataset - - def __len__(self): - return len(self.dataset) - - def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: # single-process data loading - iter_start = 0 - iter_end = len(self.dataset) - else: # in a worker process - per_worker = math.ceil( - len(self.dataset) / float(worker_info.num_workers)) - worker_id = worker_info.id - iter_start = worker_id * per_worker - iter_end = min(iter_start + per_worker, len(self.dataset)) - - for idx in range(iter_start, iter_end): - item_dict = self.dataset[idx] - res = { - k: np.array(item_dict[k]) - for k in columns if k in retained_columns - } - for preprocessor in preprocessor_list: - res.update({ - k: np.array(v) - for k, v in preprocessor(item_dict).items() - if k in retained_columns - }) - yield res - - return MsIterableDataset(self._hf_ds) + return MsIterableDataset(self._hf_ds, preprocessor_list, + retained_columns, columns) def to_torch_dataset( self, diff --git a/modelscope/pipelines/cv/ocr_utils/ops.py b/modelscope/pipelines/cv/ocr_utils/ops.py index eeab36a0..09807b10 100644 --- a/modelscope/pipelines/cv/ocr_utils/ops.py +++ b/modelscope/pipelines/cv/ocr_utils/ops.py @@ -1,8 +1,10 @@ import math import os import shutil +import sys import uuid +import absl.flags as absl_flags import cv2 import numpy as np import tensorflow as tf @@ -12,6 +14,10 @@ from . import utils if tf.__version__ >= '2.0': tf = tf.compat.v1 +# skip parse sys.argv in tf, so fix bug: +# absl.flags._exceptions.UnrecognizedFlagError: +# Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag +absl_flags.FLAGS(sys.argv, known_only=True) FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('weight_init_method', 'xavier', 'Weight initialization method') diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 290478cb..614b728a 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -312,7 +312,8 @@ class EpochBasedTrainer(BaseTrainer): else ConfigDict(type=None, mode=mode) return datasets.to_torch_dataset( task_data_config=cfg, - task_name=self.cfg.task, + task_name=self.cfg.task + if hasattr(self.cfg, ConfigFields.task) else None, preprocessors=preprocessor) elif isinstance(datasets, List) and isinstance( datasets[0], MsDataset): diff --git a/modelscope/utils/device.py b/modelscope/utils/device.py index aa8fda66..77e23122 100644 --- a/modelscope/utils/device.py +++ b/modelscope/utils/device.py @@ -8,12 +8,6 @@ from modelscope.utils.logger import get_logger logger = get_logger() -if is_tf_available(): - import tensorflow as tf - -if is_torch_available(): - import torch - def verify_device(device_name): """ Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X. @@ -63,6 +57,7 @@ def device_placement(framework, device_name='gpu:0'): device_type, device_id = verify_device(device_name) if framework == Frameworks.tf: + import tensorflow as tf if device_type == Devices.gpu and not tf.test.is_gpu_available(): logger.warning( 'tensorflow cuda is not available, using cpu instead.') @@ -76,6 +71,7 @@ def device_placement(framework, device_name='gpu:0'): yield elif framework == Frameworks.torch: + import torch if device_type == Devices.gpu: if torch.cuda.is_available(): torch.cuda.set_device(f'cuda:{device_id}') @@ -86,12 +82,13 @@ def device_placement(framework, device_name='gpu:0'): yield -def create_device(device_name) -> torch.DeviceObjType: +def create_device(device_name): """ create torch device Args: device_name (str): cpu, gpu, gpu:0, cuda:0 etc. """ + import torch device_type, device_id = verify_device(device_name) use_cuda = False if device_type == Devices.gpu: diff --git a/tests/isolated_cases.txt b/tests/isolated_cases.txt new file mode 100644 index 00000000..be85142a --- /dev/null +++ b/tests/isolated_cases.txt @@ -0,0 +1,6 @@ + test_text_to_speech.py + test_multi_modal_embedding.py + test_ofa_tasks.py + test_video_summarization.py + test_dialog_modeling.py + test_csanmt_translation.py diff --git a/tests/pipelines/test_multi_modal_embedding.py b/tests/pipelines/test_multi_modal_embedding.py index 6152f279..f94e31fa 100644 --- a/tests/pipelines/test_multi_modal_embedding.py +++ b/tests/pipelines/test_multi_modal_embedding.py @@ -31,11 +31,10 @@ class MultiModalEmbeddingTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): - model = Model.from_pretrained(self.model_id) + model = Model.from_pretrained( + self.model_id, revision=self.model_version) pipeline_multi_modal_embedding = pipeline( - task=Tasks.multi_modal_embedding, - model=model, - model_revision=self.model_version) + task=Tasks.multi_modal_embedding, model=model) text_embedding = pipeline_multi_modal_embedding( self.test_input)[OutputKeys.TEXT_EMBEDDING] print('l1-norm: {}'.format( diff --git a/tests/run.py b/tests/run.py index 27af7fe5..1a601eda 100644 --- a/tests/run.py +++ b/tests/run.py @@ -2,11 +2,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import argparse +import datetime +import multiprocessing import os +import subprocess import sys +import tempfile import unittest from fnmatch import fnmatch +from multiprocessing.managers import BaseManager +from pathlib import Path +from turtle import shape +from unittest import TestResult, TextTestResult +import pandas # NOTICE: Tensorflow 1.15 seems not so compatible with pytorch. # A segmentation fault may be raise by pytorch cpp library # if 'import tensorflow' in front of 'import torch'. @@ -19,6 +28,227 @@ from modelscope.utils.test_utils import set_test_level, test_level logger = get_logger() +def test_cases_result_to_df(result_list): + table_header = [ + 'Name', 'Result', 'Info', 'Start time', 'Stop time', + 'Time cost(seconds)' + ] + df = pandas.DataFrame( + result_list, columns=table_header).sort_values( + by=['Start time'], ascending=True) + return df + + +def statistics_test_result(df): + total_cases = df.shape[0] + # yapf: disable + success_cases = df.loc[df['Result'] == 'Success'].shape[0] + error_cases = df.loc[df['Result'] == 'Error'].shape[0] + failures_cases = df.loc[df['Result'] == 'Failures'].shape[0] + expected_failure_cases = df.loc[df['Result'] == 'ExpectedFailures'].shape[0] + unexpected_success_cases = df.loc[df['Result'] == 'UnexpectedSuccesses'].shape[0] + skipped_cases = df.loc[df['Result'] == 'Skipped'].shape[0] + # yapf: enable + + if failures_cases > 0 or \ + error_cases > 0 or \ + unexpected_success_cases > 0: + result = 'FAILED' + else: + result = 'SUCCESS' + result_msg = '%s (Runs=%s,success=%s,failures=%s,errors=%s,\ + skipped=%s,expected failures=%s,unexpected successes=%s)' % ( + result, total_cases, success_cases, failures_cases, error_cases, + skipped_cases, expected_failure_cases, unexpected_success_cases) + + print(result_msg) + if result == 'FAILED': + sys.exit(1) + + +def gather_test_suites_in_files(test_dir, case_file_list, list_tests): + test_suite = unittest.TestSuite() + for case in case_file_list: + test_case = unittest.defaultTestLoader.discover( + start_dir=test_dir, pattern=case) + test_suite.addTest(test_case) + if hasattr(test_case, '__iter__'): + for subcase in test_case: + if list_tests: + print(subcase) + else: + if list_tests: + print(test_case) + return test_suite + + +def gather_test_suites_files(test_dir, pattern): + case_file_list = [] + for dirpath, dirnames, filenames in os.walk(test_dir): + for file in filenames: + if fnmatch(file, pattern): + case_file_list.append(file) + return case_file_list + + +def collect_test_results(case_results): + result_list = [ + ] # each item is Case, Result, Start time, Stop time, Time cost + for case_result in case_results.successes: + result_list.append( + (case_result.test_full_name, 'Success', '', case_result.start_time, + case_result.stop_time, case_result.time_cost)) + for case_result in case_results.errors: + result_list.append( + (case_result[0].test_full_name, 'Error', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.skipped: + result_list.append( + (case_result[0].test_full_name, 'Skipped', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.expectedFailures: + result_list.append( + (case_result[0].test_full_name, 'ExpectedFailures', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.failures: + result_list.append( + (case_result[0].test_full_name, 'Failures', case_result[1], + case_result[0].start_time, case_result[0].stop_time, + case_result[0].time_cost)) + for case_result in case_results.unexpectedSuccesses: + result_list.append((case_result.test_full_name, 'UnexpectedSuccesses', + '', case_result.start_time, case_result.stop_time, + case_result.time_cost)) + return result_list + + +class TestSuiteRunner: + + def run(self, msg_queue, test_dir, test_suite_file): + test_suite = unittest.TestSuite() + test_case = unittest.defaultTestLoader.discover( + start_dir=test_dir, pattern=test_suite_file) + test_suite.addTest(test_case) + runner = TimeCostTextTestRunner() + test_suite_result = runner.run(test_suite) + msg_queue.put(collect_test_results(test_suite_result)) + + +def run_command_with_popen(cmd): + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=1, + encoding='utf8') as sub_process: + for line in iter(sub_process.stdout.readline, ''): + sys.stdout.write(line) + + +def run_in_subprocess(args): + # only case args.isolated_cases run in subporcess, all other run in a subprocess + test_suite_files = gather_test_suites_files( + os.path.abspath(args.test_dir), args.pattern) + + if args.subprocess: # run all case in subprocess + isolated_cases = test_suite_files + else: + isolated_cases = [] + with open(args.isolated_cases, 'r') as f: + for line in f: + if line.strip() in test_suite_files: + isolated_cases.append(line.strip()) + + if not args.list_tests: + with tempfile.TemporaryDirectory() as temp_result_dir: + for test_suite_file in isolated_cases: # run case in subprocess + cmd = [ + 'python', 'tests/run.py', '--pattern', test_suite_file, + '--result_dir', temp_result_dir + ] + run_command_with_popen(cmd) + result_dfs = [] + # run remain cases in a process. + remain_suite_files = [ + item for item in test_suite_files if item not in isolated_cases + ] + test_suite = gather_test_suites_in_files(args.test_dir, + remain_suite_files, + args.list_tests) + if test_suite.countTestCases() > 0: + runner = TimeCostTextTestRunner() + result = runner.run(test_suite) + result = collect_test_results(result) + df = test_cases_result_to_df(result) + result_dfs.append(df) + + # collect test results + result_path = Path(temp_result_dir) + for result in result_path.iterdir(): + if Path.is_file(result): + df = pandas.read_pickle(result) + result_dfs.append(df) + + result_pd = pandas.concat( + result_dfs) # merge result of every test suite. + print_table_result(result_pd) + print_abnormal_case_info(result_pd) + statistics_test_result(result_pd) + + +def get_object_full_name(obj): + klass = obj.__class__ + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ + + +class TimeCostTextTestResult(TextTestResult): + """Record test case time used!""" + + def __init__(self, stream, descriptions, verbosity): + self.successes = [] + return super(TimeCostTextTestResult, + self).__init__(stream, descriptions, verbosity) + + def startTest(self, test): + test.start_time = datetime.datetime.now() + test.test_full_name = get_object_full_name( + test) + '.' + test._testMethodName + self.stream.writeln('Test case: %s start at: %s' % + (test.test_full_name, test.start_time)) + + return super(TimeCostTextTestResult, self).startTest(test) + + def stopTest(self, test): + TextTestResult.stopTest(self, test) + test.stop_time = datetime.datetime.now() + test.time_cost = (test.stop_time - test.start_time).total_seconds() + self.stream.writeln( + 'Test case: %s stop at: %s, cost time: %s(seconds)' % + (test.test_full_name, test.stop_time, test.time_cost)) + super(TimeCostTextTestResult, self).stopTest(test) + + def addSuccess(self, test): + self.successes.append(test) + super(TextTestResult, self).addSuccess(test) + + +class TimeCostTextTestRunner(unittest.runner.TextTestRunner): + resultclass = TimeCostTextTestResult + + def run(self, test): + return super(TimeCostTextTestRunner, self).run(test) + + def _makeResult(self): + result = super(TimeCostTextTestRunner, self)._makeResult() + return result + + def gather_test_cases(test_dir, pattern, list_tests): case_list = [] for dirpath, dirnames, filenames in os.walk(test_dir): @@ -42,16 +272,40 @@ def gather_test_cases(test_dir, pattern, list_tests): return test_suite +def print_abnormal_case_info(df): + df = df.loc[(df['Result'] == 'Error') | (df['Result'] == 'Failures')] + for _, row in df.iterrows(): + print('Case %s run result: %s, msg:\n%s' % + (row['Name'], row['Result'], row['Info'])) + + +def print_table_result(df): + df = df.loc[df['Result'] != 'Skipped'] + df = df.drop('Info', axis=1) + formatters = { + 'Name': '{{:<{}s}}'.format(df['Name'].str.len().max()).format, + 'Result': '{{:<{}s}}'.format(df['Result'].str.len().max()).format, + } + with pandas.option_context('display.max_rows', None, 'display.max_columns', + None, 'display.width', None): + print(df.to_string(justify='left', formatters=formatters, index=False)) + + def main(args): - runner = unittest.TextTestRunner() + runner = TimeCostTextTestRunner() test_suite = gather_test_cases( os.path.abspath(args.test_dir), args.pattern, args.list_tests) if not args.list_tests: result = runner.run(test_suite) - if len(result.failures) > 0: - sys.exit(len(result.failures)) - if len(result.errors) > 0: - sys.exit(len(result.errors)) + result = collect_test_results(result) + df = test_cases_result_to_df(result) + if args.result_dir is not None: + file_name = str(int(datetime.datetime.now().timestamp() * 1000)) + df.to_pickle(os.path.join(args.result_dir, file_name)) + else: + print_table_result(df) + print_abnormal_case_info(df) + statistics_test_result(df) if __name__ == '__main__': @@ -66,6 +320,18 @@ if __name__ == '__main__': '--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0') parser.add_argument( '--disable_profile', action='store_true', help='disable profiling') + parser.add_argument( + '--isolated_cases', + default=None, + help='specified isolated cases config file') + parser.add_argument( + '--subprocess', + action='store_true', + help='run all test suite in subprocess') + parser.add_argument( + '--result_dir', + default=None, + help='Save result to directory, internal use only') args = parser.parse_args() set_test_level(args.level) logger.info(f'TEST LEVEL: {test_level()}') @@ -73,4 +339,10 @@ if __name__ == '__main__': from utils import profiler logger.info('enable profile ...') profiler.enable() - main(args) + if args.isolated_cases is not None or args.subprocess: + run_in_subprocess(args) + elif args.isolated_cases is not None and args.subprocess: + print('isolated_cases and subporcess conflict') + sys.exit(1) + else: + main(args) diff --git a/tests/trainers/test_image_color_enhance_trainer.py b/tests/trainers/test_image_color_enhance_trainer.py index f1dcbe51..34d84cd2 100644 --- a/tests/trainers/test_image_color_enhance_trainer.py +++ b/tests/trainers/test_image_color_enhance_trainer.py @@ -17,6 +17,41 @@ from modelscope.utils.constant import ModelFile from modelscope.utils.test_utils import test_level +class PairedImageDataset(data.Dataset): + + def __init__(self, root): + super(PairedImageDataset, self).__init__() + gt_dir = osp.join(root, 'gt') + lq_dir = osp.join(root, 'lq') + self.gt_filelist = os.listdir(gt_dir) + self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) + self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] + self.lq_filelist = os.listdir(lq_dir) + self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) + self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] + + def _img_to_tensor(self, img): + return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( + torch.float32) / 255. + + def __getitem__(self, index): + lq = cv2.imread(self.lq_filelist[index]) + gt = cv2.imread(self.gt_filelist[index]) + lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) + gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) + return \ + {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} + + def __len__(self): + return len(self.gt_filelist) + + def to_torch_dataset(self, + columns: Union[str, List[str]] = None, + preprocessors: Union[Callable, List[Callable]] = None, + **format_kwargs): + return self + + class TestImageColorEnhanceTrainer(unittest.TestCase): def setUp(self): @@ -27,47 +62,6 @@ class TestImageColorEnhanceTrainer(unittest.TestCase): self.model_id = 'damo/cv_csrnet_image-color-enhance-models' - class PairedImageDataset(data.Dataset): - - def __init__(self, root): - super(PairedImageDataset, self).__init__() - gt_dir = osp.join(root, 'gt') - lq_dir = osp.join(root, 'lq') - self.gt_filelist = os.listdir(gt_dir) - self.gt_filelist = sorted( - self.gt_filelist, key=lambda x: int(x[:-4])) - self.gt_filelist = [ - osp.join(gt_dir, f) for f in self.gt_filelist - ] - self.lq_filelist = os.listdir(lq_dir) - self.lq_filelist = sorted( - self.lq_filelist, key=lambda x: int(x[:-4])) - self.lq_filelist = [ - osp.join(lq_dir, f) for f in self.lq_filelist - ] - - def _img_to_tensor(self, img): - return torch.from_numpy(img[:, :, [2, 1, 0]]).permute( - 2, 0, 1).type(torch.float32) / 255. - - def __getitem__(self, index): - lq = cv2.imread(self.lq_filelist[index]) - gt = cv2.imread(self.gt_filelist[index]) - lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) - gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) - return \ - {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} - - def __len__(self): - return len(self.gt_filelist) - - def to_torch_dataset(self, - columns: Union[str, List[str]] = None, - preprocessors: Union[Callable, - List[Callable]] = None, - **format_kwargs): - return self - self.dataset = PairedImageDataset( './data/test/images/image_color_enhance/') diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py index dc450ff0..049adf7e 100644 --- a/tests/trainers/test_image_portrait_enhancement_trainer.py +++ b/tests/trainers/test_image_portrait_enhancement_trainer.py @@ -19,6 +19,47 @@ from modelscope.utils.constant import ModelFile from modelscope.utils.test_utils import test_level +class PairedImageDataset(data.Dataset): + + def __init__(self, root, size=512): + super(PairedImageDataset, self).__init__() + self.size = size + gt_dir = osp.join(root, 'gt') + lq_dir = osp.join(root, 'lq') + self.gt_filelist = os.listdir(gt_dir) + self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) + self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] + self.lq_filelist = os.listdir(lq_dir) + self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) + self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] + + def _img_to_tensor(self, img): + img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( + torch.float32) / 255. + return (img - 0.5) / 0.5 + + def __getitem__(self, index): + lq = cv2.imread(self.lq_filelist[index]) + gt = cv2.imread(self.gt_filelist[index]) + lq = cv2.resize( + lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) + gt = cv2.resize( + gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) + + return \ + {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} + + def __len__(self): + return len(self.gt_filelist) + + def to_torch_dataset(self, + columns: Union[str, List[str]] = None, + preprocessors: Union[Callable, List[Callable]] = None, + **format_kwargs): + # self.preprocessor = preprocessors + return self + + class TestImagePortraitEnhancementTrainer(unittest.TestCase): def setUp(self): @@ -29,53 +70,6 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): self.model_id = 'damo/cv_gpen_image-portrait-enhancement' - class PairedImageDataset(data.Dataset): - - def __init__(self, root, size=512): - super(PairedImageDataset, self).__init__() - self.size = size - gt_dir = osp.join(root, 'gt') - lq_dir = osp.join(root, 'lq') - self.gt_filelist = os.listdir(gt_dir) - self.gt_filelist = sorted( - self.gt_filelist, key=lambda x: int(x[:-4])) - self.gt_filelist = [ - osp.join(gt_dir, f) for f in self.gt_filelist - ] - self.lq_filelist = os.listdir(lq_dir) - self.lq_filelist = sorted( - self.lq_filelist, key=lambda x: int(x[:-4])) - self.lq_filelist = [ - osp.join(lq_dir, f) for f in self.lq_filelist - ] - - def _img_to_tensor(self, img): - img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute( - 2, 0, 1).type(torch.float32) / 255. - return (img - 0.5) / 0.5 - - def __getitem__(self, index): - lq = cv2.imread(self.lq_filelist[index]) - gt = cv2.imread(self.gt_filelist[index]) - lq = cv2.resize( - lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) - gt = cv2.resize( - gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) - - return \ - {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} - - def __len__(self): - return len(self.gt_filelist) - - def to_torch_dataset(self, - columns: Union[str, List[str]] = None, - preprocessors: Union[Callable, - List[Callable]] = None, - **format_kwargs): - # self.preprocessor = preprocessors - return self - self.dataset = PairedImageDataset( './data/test/images/face_enhancement/') diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py index be29844d..17fa97f9 100644 --- a/tests/trainers/test_trainer.py +++ b/tests/trainers/test_trainer.py @@ -16,6 +16,7 @@ from modelscope.metainfo import Metrics, Trainers from modelscope.metrics.builder import MetricKeys from modelscope.models.base import Model from modelscope.trainers import build_trainer +from modelscope.trainers.base import DummyTrainer from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile from modelscope.utils.test_utils import create_dummy_test_dataset, test_level @@ -264,7 +265,7 @@ class TrainerTest(unittest.TestCase): { LogKeys.MODE: ModeKeys.EVAL, LogKeys.EPOCH: 1, - LogKeys.ITER: 20 + LogKeys.ITER: 10 }, json.loads(lines[2])) self.assertDictContainsSubset( { @@ -284,7 +285,7 @@ class TrainerTest(unittest.TestCase): { LogKeys.MODE: ModeKeys.EVAL, LogKeys.EPOCH: 2, - LogKeys.ITER: 20 + LogKeys.ITER: 10 }, json.loads(lines[5])) self.assertDictContainsSubset( { @@ -304,7 +305,7 @@ class TrainerTest(unittest.TestCase): { LogKeys.MODE: ModeKeys.EVAL, LogKeys.EPOCH: 3, - LogKeys.ITER: 20 + LogKeys.ITER: 10 }, json.loads(lines[8])) self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)