| @@ -19,4 +19,11 @@ fi | |||
| # test with install | |||
| python setup.py install | |||
| python tests/run.py | |||
| if [ $# -eq 0 ]; then | |||
| ci_command="python tests/run.py --subprocess" | |||
| else | |||
| ci_command="$@" | |||
| fi | |||
| echo "Running case with command: $ci_command" | |||
| $ci_command | |||
| #python tests/run.py --isolated_cases test_text_to_speech.py test_multi_modal_embedding.py test_ofa_tasks.py test_video_summarization.py | |||
| @@ -7,7 +7,8 @@ gpus='7 6 5 4 3 2 1 0' | |||
| cpu_sets='0-7 8-15 16-23 24-30 31-37 38-44 45-51 52-58' | |||
| cpu_sets_arr=($cpu_sets) | |||
| is_get_file_lock=false | |||
| CI_COMMAND=${CI_COMMAND:-'bash .dev_scripts/ci_container_test.sh'} | |||
| CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh $RUN_CASE_COMMAND} | |||
| echo "ci command: $CI_COMMAND" | |||
| for gpu in $gpus | |||
| do | |||
| exec {lock_fd}>"/tmp/gpu$gpu" || exit 1 | |||
| @@ -1,9 +1,11 @@ | |||
| import math | |||
| import os | |||
| from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, | |||
| Sequence, Union) | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from datasets import Dataset, DatasetDict | |||
| from datasets import load_dataset as hf_load_dataset | |||
| from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | |||
| @@ -40,6 +42,46 @@ def format_list(para) -> List: | |||
| return para | |||
| class MsIterableDataset(torch.utils.data.IterableDataset): | |||
| def __init__(self, dataset: Iterable, preprocessor_list, retained_columns, | |||
| columns): | |||
| super(MsIterableDataset).__init__() | |||
| self.dataset = dataset | |||
| self.preprocessor_list = preprocessor_list | |||
| self.retained_columns = retained_columns | |||
| self.columns = columns | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| def __iter__(self): | |||
| worker_info = torch.utils.data.get_worker_info() | |||
| if worker_info is None: # single-process data loading | |||
| iter_start = 0 | |||
| iter_end = len(self.dataset) | |||
| else: # in a worker process | |||
| per_worker = math.ceil( | |||
| len(self.dataset) / float(worker_info.num_workers)) | |||
| worker_id = worker_info.id | |||
| iter_start = worker_id * per_worker | |||
| iter_end = min(iter_start + per_worker, len(self.dataset)) | |||
| for idx in range(iter_start, iter_end): | |||
| item_dict = self.dataset[idx] | |||
| res = { | |||
| k: np.array(item_dict[k]) | |||
| for k in self.columns if k in self.retained_columns | |||
| } | |||
| for preprocessor in self.preprocessor_list: | |||
| res.update({ | |||
| k: np.array(v) | |||
| for k, v in preprocessor(item_dict).items() | |||
| if k in self.retained_columns | |||
| }) | |||
| yield res | |||
| class MsDataset: | |||
| """ | |||
| ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to | |||
| @@ -318,45 +360,8 @@ class MsDataset: | |||
| continue | |||
| retained_columns.append(k) | |||
| import math | |||
| import torch | |||
| class MsIterableDataset(torch.utils.data.IterableDataset): | |||
| def __init__(self, dataset: Iterable): | |||
| super(MsIterableDataset).__init__() | |||
| self.dataset = dataset | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| def __iter__(self): | |||
| worker_info = torch.utils.data.get_worker_info() | |||
| if worker_info is None: # single-process data loading | |||
| iter_start = 0 | |||
| iter_end = len(self.dataset) | |||
| else: # in a worker process | |||
| per_worker = math.ceil( | |||
| len(self.dataset) / float(worker_info.num_workers)) | |||
| worker_id = worker_info.id | |||
| iter_start = worker_id * per_worker | |||
| iter_end = min(iter_start + per_worker, len(self.dataset)) | |||
| for idx in range(iter_start, iter_end): | |||
| item_dict = self.dataset[idx] | |||
| res = { | |||
| k: np.array(item_dict[k]) | |||
| for k in columns if k in retained_columns | |||
| } | |||
| for preprocessor in preprocessor_list: | |||
| res.update({ | |||
| k: np.array(v) | |||
| for k, v in preprocessor(item_dict).items() | |||
| if k in retained_columns | |||
| }) | |||
| yield res | |||
| return MsIterableDataset(self._hf_ds) | |||
| return MsIterableDataset(self._hf_ds, preprocessor_list, | |||
| retained_columns, columns) | |||
| def to_torch_dataset( | |||
| self, | |||
| @@ -1,8 +1,10 @@ | |||
| import math | |||
| import os | |||
| import shutil | |||
| import sys | |||
| import uuid | |||
| import absl.flags as absl_flags | |||
| import cv2 | |||
| import numpy as np | |||
| import tensorflow as tf | |||
| @@ -12,6 +14,10 @@ from . import utils | |||
| if tf.__version__ >= '2.0': | |||
| tf = tf.compat.v1 | |||
| # skip parse sys.argv in tf, so fix bug: | |||
| # absl.flags._exceptions.UnrecognizedFlagError: | |||
| # Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag | |||
| absl_flags.FLAGS(sys.argv, known_only=True) | |||
| FLAGS = tf.app.flags.FLAGS | |||
| tf.app.flags.DEFINE_string('weight_init_method', 'xavier', | |||
| 'Weight initialization method') | |||
| @@ -312,7 +312,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
| else ConfigDict(type=None, mode=mode) | |||
| return datasets.to_torch_dataset( | |||
| task_data_config=cfg, | |||
| task_name=self.cfg.task, | |||
| task_name=self.cfg.task | |||
| if hasattr(self.cfg, ConfigFields.task) else None, | |||
| preprocessors=preprocessor) | |||
| elif isinstance(datasets, List) and isinstance( | |||
| datasets[0], MsDataset): | |||
| @@ -8,12 +8,6 @@ from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| if is_tf_available(): | |||
| import tensorflow as tf | |||
| if is_torch_available(): | |||
| import torch | |||
| def verify_device(device_name): | |||
| """ Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X. | |||
| @@ -63,6 +57,7 @@ def device_placement(framework, device_name='gpu:0'): | |||
| device_type, device_id = verify_device(device_name) | |||
| if framework == Frameworks.tf: | |||
| import tensorflow as tf | |||
| if device_type == Devices.gpu and not tf.test.is_gpu_available(): | |||
| logger.warning( | |||
| 'tensorflow cuda is not available, using cpu instead.') | |||
| @@ -76,6 +71,7 @@ def device_placement(framework, device_name='gpu:0'): | |||
| yield | |||
| elif framework == Frameworks.torch: | |||
| import torch | |||
| if device_type == Devices.gpu: | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.set_device(f'cuda:{device_id}') | |||
| @@ -86,12 +82,13 @@ def device_placement(framework, device_name='gpu:0'): | |||
| yield | |||
| def create_device(device_name) -> torch.DeviceObjType: | |||
| def create_device(device_name): | |||
| """ create torch device | |||
| Args: | |||
| device_name (str): cpu, gpu, gpu:0, cuda:0 etc. | |||
| """ | |||
| import torch | |||
| device_type, device_id = verify_device(device_name) | |||
| use_cuda = False | |||
| if device_type == Devices.gpu: | |||
| @@ -0,0 +1,6 @@ | |||
| test_text_to_speech.py | |||
| test_multi_modal_embedding.py | |||
| test_ofa_tasks.py | |||
| test_video_summarization.py | |||
| test_dialog_modeling.py | |||
| test_csanmt_translation.py | |||
| @@ -31,11 +31,10 @@ class MultiModalEmbeddingTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| model = Model.from_pretrained( | |||
| self.model_id, revision=self.model_version) | |||
| pipeline_multi_modal_embedding = pipeline( | |||
| task=Tasks.multi_modal_embedding, | |||
| model=model, | |||
| model_revision=self.model_version) | |||
| task=Tasks.multi_modal_embedding, model=model) | |||
| text_embedding = pipeline_multi_modal_embedding( | |||
| self.test_input)[OutputKeys.TEXT_EMBEDDING] | |||
| print('l1-norm: {}'.format( | |||
| @@ -2,11 +2,20 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import argparse | |||
| import datetime | |||
| import multiprocessing | |||
| import os | |||
| import subprocess | |||
| import sys | |||
| import tempfile | |||
| import unittest | |||
| from fnmatch import fnmatch | |||
| from multiprocessing.managers import BaseManager | |||
| from pathlib import Path | |||
| from turtle import shape | |||
| from unittest import TestResult, TextTestResult | |||
| import pandas | |||
| # NOTICE: Tensorflow 1.15 seems not so compatible with pytorch. | |||
| # A segmentation fault may be raise by pytorch cpp library | |||
| # if 'import tensorflow' in front of 'import torch'. | |||
| @@ -19,6 +28,227 @@ from modelscope.utils.test_utils import set_test_level, test_level | |||
| logger = get_logger() | |||
| def test_cases_result_to_df(result_list): | |||
| table_header = [ | |||
| 'Name', 'Result', 'Info', 'Start time', 'Stop time', | |||
| 'Time cost(seconds)' | |||
| ] | |||
| df = pandas.DataFrame( | |||
| result_list, columns=table_header).sort_values( | |||
| by=['Start time'], ascending=True) | |||
| return df | |||
| def statistics_test_result(df): | |||
| total_cases = df.shape[0] | |||
| # yapf: disable | |||
| success_cases = df.loc[df['Result'] == 'Success'].shape[0] | |||
| error_cases = df.loc[df['Result'] == 'Error'].shape[0] | |||
| failures_cases = df.loc[df['Result'] == 'Failures'].shape[0] | |||
| expected_failure_cases = df.loc[df['Result'] == 'ExpectedFailures'].shape[0] | |||
| unexpected_success_cases = df.loc[df['Result'] == 'UnexpectedSuccesses'].shape[0] | |||
| skipped_cases = df.loc[df['Result'] == 'Skipped'].shape[0] | |||
| # yapf: enable | |||
| if failures_cases > 0 or \ | |||
| error_cases > 0 or \ | |||
| unexpected_success_cases > 0: | |||
| result = 'FAILED' | |||
| else: | |||
| result = 'SUCCESS' | |||
| result_msg = '%s (Runs=%s,success=%s,failures=%s,errors=%s,\ | |||
| skipped=%s,expected failures=%s,unexpected successes=%s)' % ( | |||
| result, total_cases, success_cases, failures_cases, error_cases, | |||
| skipped_cases, expected_failure_cases, unexpected_success_cases) | |||
| print(result_msg) | |||
| if result == 'FAILED': | |||
| sys.exit(1) | |||
| def gather_test_suites_in_files(test_dir, case_file_list, list_tests): | |||
| test_suite = unittest.TestSuite() | |||
| for case in case_file_list: | |||
| test_case = unittest.defaultTestLoader.discover( | |||
| start_dir=test_dir, pattern=case) | |||
| test_suite.addTest(test_case) | |||
| if hasattr(test_case, '__iter__'): | |||
| for subcase in test_case: | |||
| if list_tests: | |||
| print(subcase) | |||
| else: | |||
| if list_tests: | |||
| print(test_case) | |||
| return test_suite | |||
| def gather_test_suites_files(test_dir, pattern): | |||
| case_file_list = [] | |||
| for dirpath, dirnames, filenames in os.walk(test_dir): | |||
| for file in filenames: | |||
| if fnmatch(file, pattern): | |||
| case_file_list.append(file) | |||
| return case_file_list | |||
| def collect_test_results(case_results): | |||
| result_list = [ | |||
| ] # each item is Case, Result, Start time, Stop time, Time cost | |||
| for case_result in case_results.successes: | |||
| result_list.append( | |||
| (case_result.test_full_name, 'Success', '', case_result.start_time, | |||
| case_result.stop_time, case_result.time_cost)) | |||
| for case_result in case_results.errors: | |||
| result_list.append( | |||
| (case_result[0].test_full_name, 'Error', case_result[1], | |||
| case_result[0].start_time, case_result[0].stop_time, | |||
| case_result[0].time_cost)) | |||
| for case_result in case_results.skipped: | |||
| result_list.append( | |||
| (case_result[0].test_full_name, 'Skipped', case_result[1], | |||
| case_result[0].start_time, case_result[0].stop_time, | |||
| case_result[0].time_cost)) | |||
| for case_result in case_results.expectedFailures: | |||
| result_list.append( | |||
| (case_result[0].test_full_name, 'ExpectedFailures', case_result[1], | |||
| case_result[0].start_time, case_result[0].stop_time, | |||
| case_result[0].time_cost)) | |||
| for case_result in case_results.failures: | |||
| result_list.append( | |||
| (case_result[0].test_full_name, 'Failures', case_result[1], | |||
| case_result[0].start_time, case_result[0].stop_time, | |||
| case_result[0].time_cost)) | |||
| for case_result in case_results.unexpectedSuccesses: | |||
| result_list.append((case_result.test_full_name, 'UnexpectedSuccesses', | |||
| '', case_result.start_time, case_result.stop_time, | |||
| case_result.time_cost)) | |||
| return result_list | |||
| class TestSuiteRunner: | |||
| def run(self, msg_queue, test_dir, test_suite_file): | |||
| test_suite = unittest.TestSuite() | |||
| test_case = unittest.defaultTestLoader.discover( | |||
| start_dir=test_dir, pattern=test_suite_file) | |||
| test_suite.addTest(test_case) | |||
| runner = TimeCostTextTestRunner() | |||
| test_suite_result = runner.run(test_suite) | |||
| msg_queue.put(collect_test_results(test_suite_result)) | |||
| def run_command_with_popen(cmd): | |||
| with subprocess.Popen( | |||
| cmd, | |||
| stdout=subprocess.PIPE, | |||
| stderr=subprocess.STDOUT, | |||
| bufsize=1, | |||
| encoding='utf8') as sub_process: | |||
| for line in iter(sub_process.stdout.readline, ''): | |||
| sys.stdout.write(line) | |||
| def run_in_subprocess(args): | |||
| # only case args.isolated_cases run in subporcess, all other run in a subprocess | |||
| test_suite_files = gather_test_suites_files( | |||
| os.path.abspath(args.test_dir), args.pattern) | |||
| if args.subprocess: # run all case in subprocess | |||
| isolated_cases = test_suite_files | |||
| else: | |||
| isolated_cases = [] | |||
| with open(args.isolated_cases, 'r') as f: | |||
| for line in f: | |||
| if line.strip() in test_suite_files: | |||
| isolated_cases.append(line.strip()) | |||
| if not args.list_tests: | |||
| with tempfile.TemporaryDirectory() as temp_result_dir: | |||
| for test_suite_file in isolated_cases: # run case in subprocess | |||
| cmd = [ | |||
| 'python', 'tests/run.py', '--pattern', test_suite_file, | |||
| '--result_dir', temp_result_dir | |||
| ] | |||
| run_command_with_popen(cmd) | |||
| result_dfs = [] | |||
| # run remain cases in a process. | |||
| remain_suite_files = [ | |||
| item for item in test_suite_files if item not in isolated_cases | |||
| ] | |||
| test_suite = gather_test_suites_in_files(args.test_dir, | |||
| remain_suite_files, | |||
| args.list_tests) | |||
| if test_suite.countTestCases() > 0: | |||
| runner = TimeCostTextTestRunner() | |||
| result = runner.run(test_suite) | |||
| result = collect_test_results(result) | |||
| df = test_cases_result_to_df(result) | |||
| result_dfs.append(df) | |||
| # collect test results | |||
| result_path = Path(temp_result_dir) | |||
| for result in result_path.iterdir(): | |||
| if Path.is_file(result): | |||
| df = pandas.read_pickle(result) | |||
| result_dfs.append(df) | |||
| result_pd = pandas.concat( | |||
| result_dfs) # merge result of every test suite. | |||
| print_table_result(result_pd) | |||
| print_abnormal_case_info(result_pd) | |||
| statistics_test_result(result_pd) | |||
| def get_object_full_name(obj): | |||
| klass = obj.__class__ | |||
| module = klass.__module__ | |||
| if module == 'builtins': | |||
| return klass.__qualname__ | |||
| return module + '.' + klass.__qualname__ | |||
| class TimeCostTextTestResult(TextTestResult): | |||
| """Record test case time used!""" | |||
| def __init__(self, stream, descriptions, verbosity): | |||
| self.successes = [] | |||
| return super(TimeCostTextTestResult, | |||
| self).__init__(stream, descriptions, verbosity) | |||
| def startTest(self, test): | |||
| test.start_time = datetime.datetime.now() | |||
| test.test_full_name = get_object_full_name( | |||
| test) + '.' + test._testMethodName | |||
| self.stream.writeln('Test case: %s start at: %s' % | |||
| (test.test_full_name, test.start_time)) | |||
| return super(TimeCostTextTestResult, self).startTest(test) | |||
| def stopTest(self, test): | |||
| TextTestResult.stopTest(self, test) | |||
| test.stop_time = datetime.datetime.now() | |||
| test.time_cost = (test.stop_time - test.start_time).total_seconds() | |||
| self.stream.writeln( | |||
| 'Test case: %s stop at: %s, cost time: %s(seconds)' % | |||
| (test.test_full_name, test.stop_time, test.time_cost)) | |||
| super(TimeCostTextTestResult, self).stopTest(test) | |||
| def addSuccess(self, test): | |||
| self.successes.append(test) | |||
| super(TextTestResult, self).addSuccess(test) | |||
| class TimeCostTextTestRunner(unittest.runner.TextTestRunner): | |||
| resultclass = TimeCostTextTestResult | |||
| def run(self, test): | |||
| return super(TimeCostTextTestRunner, self).run(test) | |||
| def _makeResult(self): | |||
| result = super(TimeCostTextTestRunner, self)._makeResult() | |||
| return result | |||
| def gather_test_cases(test_dir, pattern, list_tests): | |||
| case_list = [] | |||
| for dirpath, dirnames, filenames in os.walk(test_dir): | |||
| @@ -42,16 +272,40 @@ def gather_test_cases(test_dir, pattern, list_tests): | |||
| return test_suite | |||
| def print_abnormal_case_info(df): | |||
| df = df.loc[(df['Result'] == 'Error') | (df['Result'] == 'Failures')] | |||
| for _, row in df.iterrows(): | |||
| print('Case %s run result: %s, msg:\n%s' % | |||
| (row['Name'], row['Result'], row['Info'])) | |||
| def print_table_result(df): | |||
| df = df.loc[df['Result'] != 'Skipped'] | |||
| df = df.drop('Info', axis=1) | |||
| formatters = { | |||
| 'Name': '{{:<{}s}}'.format(df['Name'].str.len().max()).format, | |||
| 'Result': '{{:<{}s}}'.format(df['Result'].str.len().max()).format, | |||
| } | |||
| with pandas.option_context('display.max_rows', None, 'display.max_columns', | |||
| None, 'display.width', None): | |||
| print(df.to_string(justify='left', formatters=formatters, index=False)) | |||
| def main(args): | |||
| runner = unittest.TextTestRunner() | |||
| runner = TimeCostTextTestRunner() | |||
| test_suite = gather_test_cases( | |||
| os.path.abspath(args.test_dir), args.pattern, args.list_tests) | |||
| if not args.list_tests: | |||
| result = runner.run(test_suite) | |||
| if len(result.failures) > 0: | |||
| sys.exit(len(result.failures)) | |||
| if len(result.errors) > 0: | |||
| sys.exit(len(result.errors)) | |||
| result = collect_test_results(result) | |||
| df = test_cases_result_to_df(result) | |||
| if args.result_dir is not None: | |||
| file_name = str(int(datetime.datetime.now().timestamp() * 1000)) | |||
| df.to_pickle(os.path.join(args.result_dir, file_name)) | |||
| else: | |||
| print_table_result(df) | |||
| print_abnormal_case_info(df) | |||
| statistics_test_result(df) | |||
| if __name__ == '__main__': | |||
| @@ -66,6 +320,18 @@ if __name__ == '__main__': | |||
| '--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0') | |||
| parser.add_argument( | |||
| '--disable_profile', action='store_true', help='disable profiling') | |||
| parser.add_argument( | |||
| '--isolated_cases', | |||
| default=None, | |||
| help='specified isolated cases config file') | |||
| parser.add_argument( | |||
| '--subprocess', | |||
| action='store_true', | |||
| help='run all test suite in subprocess') | |||
| parser.add_argument( | |||
| '--result_dir', | |||
| default=None, | |||
| help='Save result to directory, internal use only') | |||
| args = parser.parse_args() | |||
| set_test_level(args.level) | |||
| logger.info(f'TEST LEVEL: {test_level()}') | |||
| @@ -73,4 +339,10 @@ if __name__ == '__main__': | |||
| from utils import profiler | |||
| logger.info('enable profile ...') | |||
| profiler.enable() | |||
| main(args) | |||
| if args.isolated_cases is not None or args.subprocess: | |||
| run_in_subprocess(args) | |||
| elif args.isolated_cases is not None and args.subprocess: | |||
| print('isolated_cases and subporcess conflict') | |||
| sys.exit(1) | |||
| else: | |||
| main(args) | |||
| @@ -17,6 +17,41 @@ from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class PairedImageDataset(data.Dataset): | |||
| def __init__(self, root): | |||
| super(PairedImageDataset, self).__init__() | |||
| gt_dir = osp.join(root, 'gt') | |||
| lq_dir = osp.join(root, 'lq') | |||
| self.gt_filelist = os.listdir(gt_dir) | |||
| self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) | |||
| self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] | |||
| self.lq_filelist = os.listdir(lq_dir) | |||
| self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) | |||
| self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] | |||
| def _img_to_tensor(self, img): | |||
| return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( | |||
| torch.float32) / 255. | |||
| def __getitem__(self, index): | |||
| lq = cv2.imread(self.lq_filelist[index]) | |||
| gt = cv2.imread(self.gt_filelist[index]) | |||
| lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) | |||
| gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) | |||
| return \ | |||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
| def __len__(self): | |||
| return len(self.gt_filelist) | |||
| def to_torch_dataset(self, | |||
| columns: Union[str, List[str]] = None, | |||
| preprocessors: Union[Callable, List[Callable]] = None, | |||
| **format_kwargs): | |||
| return self | |||
| class TestImageColorEnhanceTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| @@ -27,47 +62,6 @@ class TestImageColorEnhanceTrainer(unittest.TestCase): | |||
| self.model_id = 'damo/cv_csrnet_image-color-enhance-models' | |||
| class PairedImageDataset(data.Dataset): | |||
| def __init__(self, root): | |||
| super(PairedImageDataset, self).__init__() | |||
| gt_dir = osp.join(root, 'gt') | |||
| lq_dir = osp.join(root, 'lq') | |||
| self.gt_filelist = os.listdir(gt_dir) | |||
| self.gt_filelist = sorted( | |||
| self.gt_filelist, key=lambda x: int(x[:-4])) | |||
| self.gt_filelist = [ | |||
| osp.join(gt_dir, f) for f in self.gt_filelist | |||
| ] | |||
| self.lq_filelist = os.listdir(lq_dir) | |||
| self.lq_filelist = sorted( | |||
| self.lq_filelist, key=lambda x: int(x[:-4])) | |||
| self.lq_filelist = [ | |||
| osp.join(lq_dir, f) for f in self.lq_filelist | |||
| ] | |||
| def _img_to_tensor(self, img): | |||
| return torch.from_numpy(img[:, :, [2, 1, 0]]).permute( | |||
| 2, 0, 1).type(torch.float32) / 255. | |||
| def __getitem__(self, index): | |||
| lq = cv2.imread(self.lq_filelist[index]) | |||
| gt = cv2.imread(self.gt_filelist[index]) | |||
| lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) | |||
| gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) | |||
| return \ | |||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
| def __len__(self): | |||
| return len(self.gt_filelist) | |||
| def to_torch_dataset(self, | |||
| columns: Union[str, List[str]] = None, | |||
| preprocessors: Union[Callable, | |||
| List[Callable]] = None, | |||
| **format_kwargs): | |||
| return self | |||
| self.dataset = PairedImageDataset( | |||
| './data/test/images/image_color_enhance/') | |||
| @@ -19,6 +19,47 @@ from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class PairedImageDataset(data.Dataset): | |||
| def __init__(self, root, size=512): | |||
| super(PairedImageDataset, self).__init__() | |||
| self.size = size | |||
| gt_dir = osp.join(root, 'gt') | |||
| lq_dir = osp.join(root, 'lq') | |||
| self.gt_filelist = os.listdir(gt_dir) | |||
| self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) | |||
| self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] | |||
| self.lq_filelist = os.listdir(lq_dir) | |||
| self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) | |||
| self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] | |||
| def _img_to_tensor(self, img): | |||
| img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( | |||
| torch.float32) / 255. | |||
| return (img - 0.5) / 0.5 | |||
| def __getitem__(self, index): | |||
| lq = cv2.imread(self.lq_filelist[index]) | |||
| gt = cv2.imread(self.gt_filelist[index]) | |||
| lq = cv2.resize( | |||
| lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| gt = cv2.resize( | |||
| gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| return \ | |||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
| def __len__(self): | |||
| return len(self.gt_filelist) | |||
| def to_torch_dataset(self, | |||
| columns: Union[str, List[str]] = None, | |||
| preprocessors: Union[Callable, List[Callable]] = None, | |||
| **format_kwargs): | |||
| # self.preprocessor = preprocessors | |||
| return self | |||
| class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| @@ -29,53 +70,6 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||
| class PairedImageDataset(data.Dataset): | |||
| def __init__(self, root, size=512): | |||
| super(PairedImageDataset, self).__init__() | |||
| self.size = size | |||
| gt_dir = osp.join(root, 'gt') | |||
| lq_dir = osp.join(root, 'lq') | |||
| self.gt_filelist = os.listdir(gt_dir) | |||
| self.gt_filelist = sorted( | |||
| self.gt_filelist, key=lambda x: int(x[:-4])) | |||
| self.gt_filelist = [ | |||
| osp.join(gt_dir, f) for f in self.gt_filelist | |||
| ] | |||
| self.lq_filelist = os.listdir(lq_dir) | |||
| self.lq_filelist = sorted( | |||
| self.lq_filelist, key=lambda x: int(x[:-4])) | |||
| self.lq_filelist = [ | |||
| osp.join(lq_dir, f) for f in self.lq_filelist | |||
| ] | |||
| def _img_to_tensor(self, img): | |||
| img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute( | |||
| 2, 0, 1).type(torch.float32) / 255. | |||
| return (img - 0.5) / 0.5 | |||
| def __getitem__(self, index): | |||
| lq = cv2.imread(self.lq_filelist[index]) | |||
| gt = cv2.imread(self.gt_filelist[index]) | |||
| lq = cv2.resize( | |||
| lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| gt = cv2.resize( | |||
| gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| return \ | |||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
| def __len__(self): | |||
| return len(self.gt_filelist) | |||
| def to_torch_dataset(self, | |||
| columns: Union[str, List[str]] = None, | |||
| preprocessors: Union[Callable, | |||
| List[Callable]] = None, | |||
| **format_kwargs): | |||
| # self.preprocessor = preprocessors | |||
| return self | |||
| self.dataset = PairedImageDataset( | |||
| './data/test/images/face_enhancement/') | |||
| @@ -16,6 +16,7 @@ from modelscope.metainfo import Metrics, Trainers | |||
| from modelscope.metrics.builder import MetricKeys | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.trainers.base import DummyTrainer | |||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||
| from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | |||
| @@ -264,7 +265,7 @@ class TrainerTest(unittest.TestCase): | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 1, | |||
| LogKeys.ITER: 20 | |||
| LogKeys.ITER: 10 | |||
| }, json.loads(lines[2])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| @@ -284,7 +285,7 @@ class TrainerTest(unittest.TestCase): | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 2, | |||
| LogKeys.ITER: 20 | |||
| LogKeys.ITER: 10 | |||
| }, json.loads(lines[5])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| @@ -304,7 +305,7 @@ class TrainerTest(unittest.TestCase): | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 3, | |||
| LogKeys.ITER: 20 | |||
| LogKeys.ITER: 10 | |||
| }, json.loads(lines[8])) | |||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||