Browse Source

[to #44340132] fix: ci case run out of gpu memory

master
mulin.lyh 3 years ago
parent
commit
12698b31a0
12 changed files with 433 additions and 150 deletions
  1. +8
    -1
      .dev_scripts/ci_container_test.sh
  2. +2
    -1
      .dev_scripts/dockerci.sh
  3. +44
    -39
      modelscope/msdatasets/ms_dataset.py
  4. +6
    -0
      modelscope/pipelines/cv/ocr_utils/ops.py
  5. +2
    -1
      modelscope/trainers/trainer.py
  6. +4
    -7
      modelscope/utils/device.py
  7. +6
    -0
      tests/isolated_cases.txt
  8. +3
    -4
      tests/pipelines/test_multi_modal_embedding.py
  9. +278
    -6
      tests/run.py
  10. +35
    -41
      tests/trainers/test_image_color_enhance_trainer.py
  11. +41
    -47
      tests/trainers/test_image_portrait_enhancement_trainer.py
  12. +4
    -3
      tests/trainers/test_trainer.py

+ 8
- 1
.dev_scripts/ci_container_test.sh View File

@@ -19,4 +19,11 @@ fi
# test with install # test with install
python setup.py install python setup.py install


python tests/run.py
if [ $# -eq 0 ]; then
ci_command="python tests/run.py --subprocess"
else
ci_command="$@"
fi
echo "Running case with command: $ci_command"
$ci_command
#python tests/run.py --isolated_cases test_text_to_speech.py test_multi_modal_embedding.py test_ofa_tasks.py test_video_summarization.py

+ 2
- 1
.dev_scripts/dockerci.sh View File

@@ -7,7 +7,8 @@ gpus='7 6 5 4 3 2 1 0'
cpu_sets='0-7 8-15 16-23 24-30 31-37 38-44 45-51 52-58' cpu_sets='0-7 8-15 16-23 24-30 31-37 38-44 45-51 52-58'
cpu_sets_arr=($cpu_sets) cpu_sets_arr=($cpu_sets)
is_get_file_lock=false is_get_file_lock=false
CI_COMMAND=${CI_COMMAND:-'bash .dev_scripts/ci_container_test.sh'}
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh $RUN_CASE_COMMAND}
echo "ci command: $CI_COMMAND"
for gpu in $gpus for gpu in $gpus
do do
exec {lock_fd}>"/tmp/gpu$gpu" || exit 1 exec {lock_fd}>"/tmp/gpu$gpu" || exit 1


+ 44
- 39
modelscope/msdatasets/ms_dataset.py View File

@@ -1,9 +1,11 @@
import math
import os import os
from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
Sequence, Union) Sequence, Union)


import json import json
import numpy as np import numpy as np
import torch
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
from datasets import load_dataset as hf_load_dataset from datasets import load_dataset as hf_load_dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE
@@ -40,6 +42,46 @@ def format_list(para) -> List:
return para return para




class MsIterableDataset(torch.utils.data.IterableDataset):

def __init__(self, dataset: Iterable, preprocessor_list, retained_columns,
columns):
super(MsIterableDataset).__init__()
self.dataset = dataset
self.preprocessor_list = preprocessor_list
self.retained_columns = retained_columns
self.columns = columns

def __len__(self):
return len(self.dataset)

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading
iter_start = 0
iter_end = len(self.dataset)
else: # in a worker process
per_worker = math.ceil(
len(self.dataset) / float(worker_info.num_workers))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, len(self.dataset))

for idx in range(iter_start, iter_end):
item_dict = self.dataset[idx]
res = {
k: np.array(item_dict[k])
for k in self.columns if k in self.retained_columns
}
for preprocessor in self.preprocessor_list:
res.update({
k: np.array(v)
for k, v in preprocessor(item_dict).items()
if k in self.retained_columns
})
yield res


class MsDataset: class MsDataset:
""" """
ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to
@@ -318,45 +360,8 @@ class MsDataset:
continue continue
retained_columns.append(k) retained_columns.append(k)


import math
import torch

class MsIterableDataset(torch.utils.data.IterableDataset):

def __init__(self, dataset: Iterable):
super(MsIterableDataset).__init__()
self.dataset = dataset

def __len__(self):
return len(self.dataset)

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading
iter_start = 0
iter_end = len(self.dataset)
else: # in a worker process
per_worker = math.ceil(
len(self.dataset) / float(worker_info.num_workers))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, len(self.dataset))

for idx in range(iter_start, iter_end):
item_dict = self.dataset[idx]
res = {
k: np.array(item_dict[k])
for k in columns if k in retained_columns
}
for preprocessor in preprocessor_list:
res.update({
k: np.array(v)
for k, v in preprocessor(item_dict).items()
if k in retained_columns
})
yield res

return MsIterableDataset(self._hf_ds)
return MsIterableDataset(self._hf_ds, preprocessor_list,
retained_columns, columns)


def to_torch_dataset( def to_torch_dataset(
self, self,


+ 6
- 0
modelscope/pipelines/cv/ocr_utils/ops.py View File

@@ -1,8 +1,10 @@
import math import math
import os import os
import shutil import shutil
import sys
import uuid import uuid


import absl.flags as absl_flags
import cv2 import cv2
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@@ -12,6 +14,10 @@ from . import utils
if tf.__version__ >= '2.0': if tf.__version__ >= '2.0':
tf = tf.compat.v1 tf = tf.compat.v1


# skip parse sys.argv in tf, so fix bug:
# absl.flags._exceptions.UnrecognizedFlagError:
# Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag
absl_flags.FLAGS(sys.argv, known_only=True)
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('weight_init_method', 'xavier', tf.app.flags.DEFINE_string('weight_init_method', 'xavier',
'Weight initialization method') 'Weight initialization method')


+ 2
- 1
modelscope/trainers/trainer.py View File

@@ -312,7 +312,8 @@ class EpochBasedTrainer(BaseTrainer):
else ConfigDict(type=None, mode=mode) else ConfigDict(type=None, mode=mode)
return datasets.to_torch_dataset( return datasets.to_torch_dataset(
task_data_config=cfg, task_data_config=cfg,
task_name=self.cfg.task,
task_name=self.cfg.task
if hasattr(self.cfg, ConfigFields.task) else None,
preprocessors=preprocessor) preprocessors=preprocessor)
elif isinstance(datasets, List) and isinstance( elif isinstance(datasets, List) and isinstance(
datasets[0], MsDataset): datasets[0], MsDataset):


+ 4
- 7
modelscope/utils/device.py View File

@@ -8,12 +8,6 @@ from modelscope.utils.logger import get_logger


logger = get_logger() logger = get_logger()


if is_tf_available():
import tensorflow as tf

if is_torch_available():
import torch



def verify_device(device_name): def verify_device(device_name):
""" Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X. """ Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X.
@@ -63,6 +57,7 @@ def device_placement(framework, device_name='gpu:0'):
device_type, device_id = verify_device(device_name) device_type, device_id = verify_device(device_name)


if framework == Frameworks.tf: if framework == Frameworks.tf:
import tensorflow as tf
if device_type == Devices.gpu and not tf.test.is_gpu_available(): if device_type == Devices.gpu and not tf.test.is_gpu_available():
logger.warning( logger.warning(
'tensorflow cuda is not available, using cpu instead.') 'tensorflow cuda is not available, using cpu instead.')
@@ -76,6 +71,7 @@ def device_placement(framework, device_name='gpu:0'):
yield yield


elif framework == Frameworks.torch: elif framework == Frameworks.torch:
import torch
if device_type == Devices.gpu: if device_type == Devices.gpu:
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(f'cuda:{device_id}') torch.cuda.set_device(f'cuda:{device_id}')
@@ -86,12 +82,13 @@ def device_placement(framework, device_name='gpu:0'):
yield yield




def create_device(device_name) -> torch.DeviceObjType:
def create_device(device_name):
""" create torch device """ create torch device


Args: Args:
device_name (str): cpu, gpu, gpu:0, cuda:0 etc. device_name (str): cpu, gpu, gpu:0, cuda:0 etc.
""" """
import torch
device_type, device_id = verify_device(device_name) device_type, device_id = verify_device(device_name)
use_cuda = False use_cuda = False
if device_type == Devices.gpu: if device_type == Devices.gpu:


+ 6
- 0
tests/isolated_cases.txt View File

@@ -0,0 +1,6 @@
test_text_to_speech.py
test_multi_modal_embedding.py
test_ofa_tasks.py
test_video_summarization.py
test_dialog_modeling.py
test_csanmt_translation.py

+ 3
- 4
tests/pipelines/test_multi_modal_embedding.py View File

@@ -31,11 +31,10 @@ class MultiModalEmbeddingTest(unittest.TestCase):


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
model = Model.from_pretrained(
self.model_id, revision=self.model_version)
pipeline_multi_modal_embedding = pipeline( pipeline_multi_modal_embedding = pipeline(
task=Tasks.multi_modal_embedding,
model=model,
model_revision=self.model_version)
task=Tasks.multi_modal_embedding, model=model)
text_embedding = pipeline_multi_modal_embedding( text_embedding = pipeline_multi_modal_embedding(
self.test_input)[OutputKeys.TEXT_EMBEDDING] self.test_input)[OutputKeys.TEXT_EMBEDDING]
print('l1-norm: {}'.format( print('l1-norm: {}'.format(


+ 278
- 6
tests/run.py View File

@@ -2,11 +2,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import argparse import argparse
import datetime
import multiprocessing
import os import os
import subprocess
import sys import sys
import tempfile
import unittest import unittest
from fnmatch import fnmatch from fnmatch import fnmatch
from multiprocessing.managers import BaseManager
from pathlib import Path
from turtle import shape
from unittest import TestResult, TextTestResult


import pandas
# NOTICE: Tensorflow 1.15 seems not so compatible with pytorch. # NOTICE: Tensorflow 1.15 seems not so compatible with pytorch.
# A segmentation fault may be raise by pytorch cpp library # A segmentation fault may be raise by pytorch cpp library
# if 'import tensorflow' in front of 'import torch'. # if 'import tensorflow' in front of 'import torch'.
@@ -19,6 +28,227 @@ from modelscope.utils.test_utils import set_test_level, test_level
logger = get_logger() logger = get_logger()




def test_cases_result_to_df(result_list):
table_header = [
'Name', 'Result', 'Info', 'Start time', 'Stop time',
'Time cost(seconds)'
]
df = pandas.DataFrame(
result_list, columns=table_header).sort_values(
by=['Start time'], ascending=True)
return df


def statistics_test_result(df):
total_cases = df.shape[0]
# yapf: disable
success_cases = df.loc[df['Result'] == 'Success'].shape[0]
error_cases = df.loc[df['Result'] == 'Error'].shape[0]
failures_cases = df.loc[df['Result'] == 'Failures'].shape[0]
expected_failure_cases = df.loc[df['Result'] == 'ExpectedFailures'].shape[0]
unexpected_success_cases = df.loc[df['Result'] == 'UnexpectedSuccesses'].shape[0]
skipped_cases = df.loc[df['Result'] == 'Skipped'].shape[0]
# yapf: enable

if failures_cases > 0 or \
error_cases > 0 or \
unexpected_success_cases > 0:
result = 'FAILED'
else:
result = 'SUCCESS'
result_msg = '%s (Runs=%s,success=%s,failures=%s,errors=%s,\
skipped=%s,expected failures=%s,unexpected successes=%s)' % (
result, total_cases, success_cases, failures_cases, error_cases,
skipped_cases, expected_failure_cases, unexpected_success_cases)

print(result_msg)
if result == 'FAILED':
sys.exit(1)


def gather_test_suites_in_files(test_dir, case_file_list, list_tests):
test_suite = unittest.TestSuite()
for case in case_file_list:
test_case = unittest.defaultTestLoader.discover(
start_dir=test_dir, pattern=case)
test_suite.addTest(test_case)
if hasattr(test_case, '__iter__'):
for subcase in test_case:
if list_tests:
print(subcase)
else:
if list_tests:
print(test_case)
return test_suite


def gather_test_suites_files(test_dir, pattern):
case_file_list = []
for dirpath, dirnames, filenames in os.walk(test_dir):
for file in filenames:
if fnmatch(file, pattern):
case_file_list.append(file)
return case_file_list


def collect_test_results(case_results):
result_list = [
] # each item is Case, Result, Start time, Stop time, Time cost
for case_result in case_results.successes:
result_list.append(
(case_result.test_full_name, 'Success', '', case_result.start_time,
case_result.stop_time, case_result.time_cost))
for case_result in case_results.errors:
result_list.append(
(case_result[0].test_full_name, 'Error', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.skipped:
result_list.append(
(case_result[0].test_full_name, 'Skipped', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.expectedFailures:
result_list.append(
(case_result[0].test_full_name, 'ExpectedFailures', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.failures:
result_list.append(
(case_result[0].test_full_name, 'Failures', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.unexpectedSuccesses:
result_list.append((case_result.test_full_name, 'UnexpectedSuccesses',
'', case_result.start_time, case_result.stop_time,
case_result.time_cost))
return result_list


class TestSuiteRunner:

def run(self, msg_queue, test_dir, test_suite_file):
test_suite = unittest.TestSuite()
test_case = unittest.defaultTestLoader.discover(
start_dir=test_dir, pattern=test_suite_file)
test_suite.addTest(test_case)
runner = TimeCostTextTestRunner()
test_suite_result = runner.run(test_suite)
msg_queue.put(collect_test_results(test_suite_result))


def run_command_with_popen(cmd):
with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
encoding='utf8') as sub_process:
for line in iter(sub_process.stdout.readline, ''):
sys.stdout.write(line)


def run_in_subprocess(args):
# only case args.isolated_cases run in subporcess, all other run in a subprocess
test_suite_files = gather_test_suites_files(
os.path.abspath(args.test_dir), args.pattern)

if args.subprocess: # run all case in subprocess
isolated_cases = test_suite_files
else:
isolated_cases = []
with open(args.isolated_cases, 'r') as f:
for line in f:
if line.strip() in test_suite_files:
isolated_cases.append(line.strip())

if not args.list_tests:
with tempfile.TemporaryDirectory() as temp_result_dir:
for test_suite_file in isolated_cases: # run case in subprocess
cmd = [
'python', 'tests/run.py', '--pattern', test_suite_file,
'--result_dir', temp_result_dir
]
run_command_with_popen(cmd)
result_dfs = []
# run remain cases in a process.
remain_suite_files = [
item for item in test_suite_files if item not in isolated_cases
]
test_suite = gather_test_suites_in_files(args.test_dir,
remain_suite_files,
args.list_tests)
if test_suite.countTestCases() > 0:
runner = TimeCostTextTestRunner()
result = runner.run(test_suite)
result = collect_test_results(result)
df = test_cases_result_to_df(result)
result_dfs.append(df)

# collect test results
result_path = Path(temp_result_dir)
for result in result_path.iterdir():
if Path.is_file(result):
df = pandas.read_pickle(result)
result_dfs.append(df)

result_pd = pandas.concat(
result_dfs) # merge result of every test suite.
print_table_result(result_pd)
print_abnormal_case_info(result_pd)
statistics_test_result(result_pd)


def get_object_full_name(obj):
klass = obj.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__


class TimeCostTextTestResult(TextTestResult):
"""Record test case time used!"""

def __init__(self, stream, descriptions, verbosity):
self.successes = []
return super(TimeCostTextTestResult,
self).__init__(stream, descriptions, verbosity)

def startTest(self, test):
test.start_time = datetime.datetime.now()
test.test_full_name = get_object_full_name(
test) + '.' + test._testMethodName
self.stream.writeln('Test case: %s start at: %s' %
(test.test_full_name, test.start_time))

return super(TimeCostTextTestResult, self).startTest(test)

def stopTest(self, test):
TextTestResult.stopTest(self, test)
test.stop_time = datetime.datetime.now()
test.time_cost = (test.stop_time - test.start_time).total_seconds()
self.stream.writeln(
'Test case: %s stop at: %s, cost time: %s(seconds)' %
(test.test_full_name, test.stop_time, test.time_cost))
super(TimeCostTextTestResult, self).stopTest(test)

def addSuccess(self, test):
self.successes.append(test)
super(TextTestResult, self).addSuccess(test)


class TimeCostTextTestRunner(unittest.runner.TextTestRunner):
resultclass = TimeCostTextTestResult

def run(self, test):
return super(TimeCostTextTestRunner, self).run(test)

def _makeResult(self):
result = super(TimeCostTextTestRunner, self)._makeResult()
return result


def gather_test_cases(test_dir, pattern, list_tests): def gather_test_cases(test_dir, pattern, list_tests):
case_list = [] case_list = []
for dirpath, dirnames, filenames in os.walk(test_dir): for dirpath, dirnames, filenames in os.walk(test_dir):
@@ -42,16 +272,40 @@ def gather_test_cases(test_dir, pattern, list_tests):
return test_suite return test_suite




def print_abnormal_case_info(df):
df = df.loc[(df['Result'] == 'Error') | (df['Result'] == 'Failures')]
for _, row in df.iterrows():
print('Case %s run result: %s, msg:\n%s' %
(row['Name'], row['Result'], row['Info']))


def print_table_result(df):
df = df.loc[df['Result'] != 'Skipped']
df = df.drop('Info', axis=1)
formatters = {
'Name': '{{:<{}s}}'.format(df['Name'].str.len().max()).format,
'Result': '{{:<{}s}}'.format(df['Result'].str.len().max()).format,
}
with pandas.option_context('display.max_rows', None, 'display.max_columns',
None, 'display.width', None):
print(df.to_string(justify='left', formatters=formatters, index=False))


def main(args): def main(args):
runner = unittest.TextTestRunner()
runner = TimeCostTextTestRunner()
test_suite = gather_test_cases( test_suite = gather_test_cases(
os.path.abspath(args.test_dir), args.pattern, args.list_tests) os.path.abspath(args.test_dir), args.pattern, args.list_tests)
if not args.list_tests: if not args.list_tests:
result = runner.run(test_suite) result = runner.run(test_suite)
if len(result.failures) > 0:
sys.exit(len(result.failures))
if len(result.errors) > 0:
sys.exit(len(result.errors))
result = collect_test_results(result)
df = test_cases_result_to_df(result)
if args.result_dir is not None:
file_name = str(int(datetime.datetime.now().timestamp() * 1000))
df.to_pickle(os.path.join(args.result_dir, file_name))
else:
print_table_result(df)
print_abnormal_case_info(df)
statistics_test_result(df)




if __name__ == '__main__': if __name__ == '__main__':
@@ -66,6 +320,18 @@ if __name__ == '__main__':
'--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0') '--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0')
parser.add_argument( parser.add_argument(
'--disable_profile', action='store_true', help='disable profiling') '--disable_profile', action='store_true', help='disable profiling')
parser.add_argument(
'--isolated_cases',
default=None,
help='specified isolated cases config file')
parser.add_argument(
'--subprocess',
action='store_true',
help='run all test suite in subprocess')
parser.add_argument(
'--result_dir',
default=None,
help='Save result to directory, internal use only')
args = parser.parse_args() args = parser.parse_args()
set_test_level(args.level) set_test_level(args.level)
logger.info(f'TEST LEVEL: {test_level()}') logger.info(f'TEST LEVEL: {test_level()}')
@@ -73,4 +339,10 @@ if __name__ == '__main__':
from utils import profiler from utils import profiler
logger.info('enable profile ...') logger.info('enable profile ...')
profiler.enable() profiler.enable()
main(args)
if args.isolated_cases is not None or args.subprocess:
run_in_subprocess(args)
elif args.isolated_cases is not None and args.subprocess:
print('isolated_cases and subporcess conflict')
sys.exit(1)
else:
main(args)

+ 35
- 41
tests/trainers/test_image_color_enhance_trainer.py View File

@@ -17,6 +17,41 @@ from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




class PairedImageDataset(data.Dataset):

def __init__(self, root):
super(PairedImageDataset, self).__init__()
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist]

def _img_to_tensor(self, img):
return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type(
torch.float32) / 255.

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC)
return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs):
return self


class TestImageColorEnhanceTrainer(unittest.TestCase): class TestImageColorEnhanceTrainer(unittest.TestCase):


def setUp(self): def setUp(self):
@@ -27,47 +62,6 @@ class TestImageColorEnhanceTrainer(unittest.TestCase):


self.model_id = 'damo/cv_csrnet_image-color-enhance-models' self.model_id = 'damo/cv_csrnet_image-color-enhance-models'


class PairedImageDataset(data.Dataset):

def __init__(self, root):
super(PairedImageDataset, self).__init__()
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(
self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [
osp.join(gt_dir, f) for f in self.gt_filelist
]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(
self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [
osp.join(lq_dir, f) for f in self.lq_filelist
]

def _img_to_tensor(self, img):
return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(
2, 0, 1).type(torch.float32) / 255.

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC)
return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable,
List[Callable]] = None,
**format_kwargs):
return self

self.dataset = PairedImageDataset( self.dataset = PairedImageDataset(
'./data/test/images/image_color_enhance/') './data/test/images/image_color_enhance/')




+ 41
- 47
tests/trainers/test_image_portrait_enhancement_trainer.py View File

@@ -19,6 +19,47 @@ from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level




class PairedImageDataset(data.Dataset):

def __init__(self, root, size=512):
super(PairedImageDataset, self).__init__()
self.size = size
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist]

def _img_to_tensor(self, img):
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type(
torch.float32) / 255.
return (img - 0.5) / 0.5

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(
lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(
gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC)

return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs):
# self.preprocessor = preprocessors
return self


class TestImagePortraitEnhancementTrainer(unittest.TestCase): class TestImagePortraitEnhancementTrainer(unittest.TestCase):


def setUp(self): def setUp(self):
@@ -29,53 +70,6 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase):


self.model_id = 'damo/cv_gpen_image-portrait-enhancement' self.model_id = 'damo/cv_gpen_image-portrait-enhancement'


class PairedImageDataset(data.Dataset):

def __init__(self, root, size=512):
super(PairedImageDataset, self).__init__()
self.size = size
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(
self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [
osp.join(gt_dir, f) for f in self.gt_filelist
]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(
self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [
osp.join(lq_dir, f) for f in self.lq_filelist
]

def _img_to_tensor(self, img):
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(
2, 0, 1).type(torch.float32) / 255.
return (img - 0.5) / 0.5

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(
lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(
gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC)

return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable,
List[Callable]] = None,
**format_kwargs):
# self.preprocessor = preprocessors
return self

self.dataset = PairedImageDataset( self.dataset = PairedImageDataset(
'./data/test/images/face_enhancement/') './data/test/images/face_enhancement/')




+ 4
- 3
tests/trainers/test_trainer.py View File

@@ -16,6 +16,7 @@ from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys from modelscope.metrics.builder import MetricKeys
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.trainers import build_trainer from modelscope.trainers import build_trainer
from modelscope.trainers.base import DummyTrainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile
from modelscope.utils.test_utils import create_dummy_test_dataset, test_level from modelscope.utils.test_utils import create_dummy_test_dataset, test_level


@@ -264,7 +265,7 @@ class TrainerTest(unittest.TestCase):
{ {
LogKeys.MODE: ModeKeys.EVAL, LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 1, LogKeys.EPOCH: 1,
LogKeys.ITER: 20
LogKeys.ITER: 10
}, json.loads(lines[2])) }, json.loads(lines[2]))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@@ -284,7 +285,7 @@ class TrainerTest(unittest.TestCase):
{ {
LogKeys.MODE: ModeKeys.EVAL, LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 2, LogKeys.EPOCH: 2,
LogKeys.ITER: 20
LogKeys.ITER: 10
}, json.loads(lines[5])) }, json.loads(lines[5]))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@@ -304,7 +305,7 @@ class TrainerTest(unittest.TestCase):
{ {
LogKeys.MODE: ModeKeys.EVAL, LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 3, LogKeys.EPOCH: 3,
LogKeys.ITER: 20
LogKeys.ITER: 10
}, json.loads(lines[8])) }, json.loads(lines[8]))
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)


Loading…
Cancel
Save