Browse Source

[to #44340132] fix: ci case run out of gpu memory

master
mulin.lyh 3 years ago
parent
commit
12698b31a0
12 changed files with 433 additions and 150 deletions
  1. +8
    -1
      .dev_scripts/ci_container_test.sh
  2. +2
    -1
      .dev_scripts/dockerci.sh
  3. +44
    -39
      modelscope/msdatasets/ms_dataset.py
  4. +6
    -0
      modelscope/pipelines/cv/ocr_utils/ops.py
  5. +2
    -1
      modelscope/trainers/trainer.py
  6. +4
    -7
      modelscope/utils/device.py
  7. +6
    -0
      tests/isolated_cases.txt
  8. +3
    -4
      tests/pipelines/test_multi_modal_embedding.py
  9. +278
    -6
      tests/run.py
  10. +35
    -41
      tests/trainers/test_image_color_enhance_trainer.py
  11. +41
    -47
      tests/trainers/test_image_portrait_enhancement_trainer.py
  12. +4
    -3
      tests/trainers/test_trainer.py

+ 8
- 1
.dev_scripts/ci_container_test.sh View File

@@ -19,4 +19,11 @@ fi
# test with install
python setup.py install

python tests/run.py
if [ $# -eq 0 ]; then
ci_command="python tests/run.py --subprocess"
else
ci_command="$@"
fi
echo "Running case with command: $ci_command"
$ci_command
#python tests/run.py --isolated_cases test_text_to_speech.py test_multi_modal_embedding.py test_ofa_tasks.py test_video_summarization.py

+ 2
- 1
.dev_scripts/dockerci.sh View File

@@ -7,7 +7,8 @@ gpus='7 6 5 4 3 2 1 0'
cpu_sets='0-7 8-15 16-23 24-30 31-37 38-44 45-51 52-58'
cpu_sets_arr=($cpu_sets)
is_get_file_lock=false
CI_COMMAND=${CI_COMMAND:-'bash .dev_scripts/ci_container_test.sh'}
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh $RUN_CASE_COMMAND}
echo "ci command: $CI_COMMAND"
for gpu in $gpus
do
exec {lock_fd}>"/tmp/gpu$gpu" || exit 1


+ 44
- 39
modelscope/msdatasets/ms_dataset.py View File

@@ -1,9 +1,11 @@
import math
import os
from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
Sequence, Union)

import json
import numpy as np
import torch
from datasets import Dataset, DatasetDict
from datasets import load_dataset as hf_load_dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE
@@ -40,6 +42,46 @@ def format_list(para) -> List:
return para


class MsIterableDataset(torch.utils.data.IterableDataset):

def __init__(self, dataset: Iterable, preprocessor_list, retained_columns,
columns):
super(MsIterableDataset).__init__()
self.dataset = dataset
self.preprocessor_list = preprocessor_list
self.retained_columns = retained_columns
self.columns = columns

def __len__(self):
return len(self.dataset)

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading
iter_start = 0
iter_end = len(self.dataset)
else: # in a worker process
per_worker = math.ceil(
len(self.dataset) / float(worker_info.num_workers))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, len(self.dataset))

for idx in range(iter_start, iter_end):
item_dict = self.dataset[idx]
res = {
k: np.array(item_dict[k])
for k in self.columns if k in self.retained_columns
}
for preprocessor in self.preprocessor_list:
res.update({
k: np.array(v)
for k, v in preprocessor(item_dict).items()
if k in self.retained_columns
})
yield res


class MsDataset:
"""
ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to
@@ -318,45 +360,8 @@ class MsDataset:
continue
retained_columns.append(k)

import math
import torch

class MsIterableDataset(torch.utils.data.IterableDataset):

def __init__(self, dataset: Iterable):
super(MsIterableDataset).__init__()
self.dataset = dataset

def __len__(self):
return len(self.dataset)

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading
iter_start = 0
iter_end = len(self.dataset)
else: # in a worker process
per_worker = math.ceil(
len(self.dataset) / float(worker_info.num_workers))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, len(self.dataset))

for idx in range(iter_start, iter_end):
item_dict = self.dataset[idx]
res = {
k: np.array(item_dict[k])
for k in columns if k in retained_columns
}
for preprocessor in preprocessor_list:
res.update({
k: np.array(v)
for k, v in preprocessor(item_dict).items()
if k in retained_columns
})
yield res

return MsIterableDataset(self._hf_ds)
return MsIterableDataset(self._hf_ds, preprocessor_list,
retained_columns, columns)

def to_torch_dataset(
self,


+ 6
- 0
modelscope/pipelines/cv/ocr_utils/ops.py View File

@@ -1,8 +1,10 @@
import math
import os
import shutil
import sys
import uuid

import absl.flags as absl_flags
import cv2
import numpy as np
import tensorflow as tf
@@ -12,6 +14,10 @@ from . import utils
if tf.__version__ >= '2.0':
tf = tf.compat.v1

# skip parse sys.argv in tf, so fix bug:
# absl.flags._exceptions.UnrecognizedFlagError:
# Unknown command line flag 'OCRDetectionPipeline: Unknown command line flag
absl_flags.FLAGS(sys.argv, known_only=True)
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('weight_init_method', 'xavier',
'Weight initialization method')


+ 2
- 1
modelscope/trainers/trainer.py View File

@@ -312,7 +312,8 @@ class EpochBasedTrainer(BaseTrainer):
else ConfigDict(type=None, mode=mode)
return datasets.to_torch_dataset(
task_data_config=cfg,
task_name=self.cfg.task,
task_name=self.cfg.task
if hasattr(self.cfg, ConfigFields.task) else None,
preprocessors=preprocessor)
elif isinstance(datasets, List) and isinstance(
datasets[0], MsDataset):


+ 4
- 7
modelscope/utils/device.py View File

@@ -8,12 +8,6 @@ from modelscope.utils.logger import get_logger

logger = get_logger()

if is_tf_available():
import tensorflow as tf

if is_torch_available():
import torch


def verify_device(device_name):
""" Verify device is valid, device should be either cpu, cuda, gpu, cuda:X or gpu:X.
@@ -63,6 +57,7 @@ def device_placement(framework, device_name='gpu:0'):
device_type, device_id = verify_device(device_name)

if framework == Frameworks.tf:
import tensorflow as tf
if device_type == Devices.gpu and not tf.test.is_gpu_available():
logger.warning(
'tensorflow cuda is not available, using cpu instead.')
@@ -76,6 +71,7 @@ def device_placement(framework, device_name='gpu:0'):
yield

elif framework == Frameworks.torch:
import torch
if device_type == Devices.gpu:
if torch.cuda.is_available():
torch.cuda.set_device(f'cuda:{device_id}')
@@ -86,12 +82,13 @@ def device_placement(framework, device_name='gpu:0'):
yield


def create_device(device_name) -> torch.DeviceObjType:
def create_device(device_name):
""" create torch device

Args:
device_name (str): cpu, gpu, gpu:0, cuda:0 etc.
"""
import torch
device_type, device_id = verify_device(device_name)
use_cuda = False
if device_type == Devices.gpu:


+ 6
- 0
tests/isolated_cases.txt View File

@@ -0,0 +1,6 @@
test_text_to_speech.py
test_multi_modal_embedding.py
test_ofa_tasks.py
test_video_summarization.py
test_dialog_modeling.py
test_csanmt_translation.py

+ 3
- 4
tests/pipelines/test_multi_modal_embedding.py View File

@@ -31,11 +31,10 @@ class MultiModalEmbeddingTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
model = Model.from_pretrained(
self.model_id, revision=self.model_version)
pipeline_multi_modal_embedding = pipeline(
task=Tasks.multi_modal_embedding,
model=model,
model_revision=self.model_version)
task=Tasks.multi_modal_embedding, model=model)
text_embedding = pipeline_multi_modal_embedding(
self.test_input)[OutputKeys.TEXT_EMBEDDING]
print('l1-norm: {}'.format(


+ 278
- 6
tests/run.py View File

@@ -2,11 +2,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import argparse
import datetime
import multiprocessing
import os
import subprocess
import sys
import tempfile
import unittest
from fnmatch import fnmatch
from multiprocessing.managers import BaseManager
from pathlib import Path
from turtle import shape
from unittest import TestResult, TextTestResult

import pandas
# NOTICE: Tensorflow 1.15 seems not so compatible with pytorch.
# A segmentation fault may be raise by pytorch cpp library
# if 'import tensorflow' in front of 'import torch'.
@@ -19,6 +28,227 @@ from modelscope.utils.test_utils import set_test_level, test_level
logger = get_logger()


def test_cases_result_to_df(result_list):
table_header = [
'Name', 'Result', 'Info', 'Start time', 'Stop time',
'Time cost(seconds)'
]
df = pandas.DataFrame(
result_list, columns=table_header).sort_values(
by=['Start time'], ascending=True)
return df


def statistics_test_result(df):
total_cases = df.shape[0]
# yapf: disable
success_cases = df.loc[df['Result'] == 'Success'].shape[0]
error_cases = df.loc[df['Result'] == 'Error'].shape[0]
failures_cases = df.loc[df['Result'] == 'Failures'].shape[0]
expected_failure_cases = df.loc[df['Result'] == 'ExpectedFailures'].shape[0]
unexpected_success_cases = df.loc[df['Result'] == 'UnexpectedSuccesses'].shape[0]
skipped_cases = df.loc[df['Result'] == 'Skipped'].shape[0]
# yapf: enable

if failures_cases > 0 or \
error_cases > 0 or \
unexpected_success_cases > 0:
result = 'FAILED'
else:
result = 'SUCCESS'
result_msg = '%s (Runs=%s,success=%s,failures=%s,errors=%s,\
skipped=%s,expected failures=%s,unexpected successes=%s)' % (
result, total_cases, success_cases, failures_cases, error_cases,
skipped_cases, expected_failure_cases, unexpected_success_cases)

print(result_msg)
if result == 'FAILED':
sys.exit(1)


def gather_test_suites_in_files(test_dir, case_file_list, list_tests):
test_suite = unittest.TestSuite()
for case in case_file_list:
test_case = unittest.defaultTestLoader.discover(
start_dir=test_dir, pattern=case)
test_suite.addTest(test_case)
if hasattr(test_case, '__iter__'):
for subcase in test_case:
if list_tests:
print(subcase)
else:
if list_tests:
print(test_case)
return test_suite


def gather_test_suites_files(test_dir, pattern):
case_file_list = []
for dirpath, dirnames, filenames in os.walk(test_dir):
for file in filenames:
if fnmatch(file, pattern):
case_file_list.append(file)
return case_file_list


def collect_test_results(case_results):
result_list = [
] # each item is Case, Result, Start time, Stop time, Time cost
for case_result in case_results.successes:
result_list.append(
(case_result.test_full_name, 'Success', '', case_result.start_time,
case_result.stop_time, case_result.time_cost))
for case_result in case_results.errors:
result_list.append(
(case_result[0].test_full_name, 'Error', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.skipped:
result_list.append(
(case_result[0].test_full_name, 'Skipped', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.expectedFailures:
result_list.append(
(case_result[0].test_full_name, 'ExpectedFailures', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.failures:
result_list.append(
(case_result[0].test_full_name, 'Failures', case_result[1],
case_result[0].start_time, case_result[0].stop_time,
case_result[0].time_cost))
for case_result in case_results.unexpectedSuccesses:
result_list.append((case_result.test_full_name, 'UnexpectedSuccesses',
'', case_result.start_time, case_result.stop_time,
case_result.time_cost))
return result_list


class TestSuiteRunner:

def run(self, msg_queue, test_dir, test_suite_file):
test_suite = unittest.TestSuite()
test_case = unittest.defaultTestLoader.discover(
start_dir=test_dir, pattern=test_suite_file)
test_suite.addTest(test_case)
runner = TimeCostTextTestRunner()
test_suite_result = runner.run(test_suite)
msg_queue.put(collect_test_results(test_suite_result))


def run_command_with_popen(cmd):
with subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
encoding='utf8') as sub_process:
for line in iter(sub_process.stdout.readline, ''):
sys.stdout.write(line)


def run_in_subprocess(args):
# only case args.isolated_cases run in subporcess, all other run in a subprocess
test_suite_files = gather_test_suites_files(
os.path.abspath(args.test_dir), args.pattern)

if args.subprocess: # run all case in subprocess
isolated_cases = test_suite_files
else:
isolated_cases = []
with open(args.isolated_cases, 'r') as f:
for line in f:
if line.strip() in test_suite_files:
isolated_cases.append(line.strip())

if not args.list_tests:
with tempfile.TemporaryDirectory() as temp_result_dir:
for test_suite_file in isolated_cases: # run case in subprocess
cmd = [
'python', 'tests/run.py', '--pattern', test_suite_file,
'--result_dir', temp_result_dir
]
run_command_with_popen(cmd)
result_dfs = []
# run remain cases in a process.
remain_suite_files = [
item for item in test_suite_files if item not in isolated_cases
]
test_suite = gather_test_suites_in_files(args.test_dir,
remain_suite_files,
args.list_tests)
if test_suite.countTestCases() > 0:
runner = TimeCostTextTestRunner()
result = runner.run(test_suite)
result = collect_test_results(result)
df = test_cases_result_to_df(result)
result_dfs.append(df)

# collect test results
result_path = Path(temp_result_dir)
for result in result_path.iterdir():
if Path.is_file(result):
df = pandas.read_pickle(result)
result_dfs.append(df)

result_pd = pandas.concat(
result_dfs) # merge result of every test suite.
print_table_result(result_pd)
print_abnormal_case_info(result_pd)
statistics_test_result(result_pd)


def get_object_full_name(obj):
klass = obj.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__
return module + '.' + klass.__qualname__


class TimeCostTextTestResult(TextTestResult):
"""Record test case time used!"""

def __init__(self, stream, descriptions, verbosity):
self.successes = []
return super(TimeCostTextTestResult,
self).__init__(stream, descriptions, verbosity)

def startTest(self, test):
test.start_time = datetime.datetime.now()
test.test_full_name = get_object_full_name(
test) + '.' + test._testMethodName
self.stream.writeln('Test case: %s start at: %s' %
(test.test_full_name, test.start_time))

return super(TimeCostTextTestResult, self).startTest(test)

def stopTest(self, test):
TextTestResult.stopTest(self, test)
test.stop_time = datetime.datetime.now()
test.time_cost = (test.stop_time - test.start_time).total_seconds()
self.stream.writeln(
'Test case: %s stop at: %s, cost time: %s(seconds)' %
(test.test_full_name, test.stop_time, test.time_cost))
super(TimeCostTextTestResult, self).stopTest(test)

def addSuccess(self, test):
self.successes.append(test)
super(TextTestResult, self).addSuccess(test)


class TimeCostTextTestRunner(unittest.runner.TextTestRunner):
resultclass = TimeCostTextTestResult

def run(self, test):
return super(TimeCostTextTestRunner, self).run(test)

def _makeResult(self):
result = super(TimeCostTextTestRunner, self)._makeResult()
return result


def gather_test_cases(test_dir, pattern, list_tests):
case_list = []
for dirpath, dirnames, filenames in os.walk(test_dir):
@@ -42,16 +272,40 @@ def gather_test_cases(test_dir, pattern, list_tests):
return test_suite


def print_abnormal_case_info(df):
df = df.loc[(df['Result'] == 'Error') | (df['Result'] == 'Failures')]
for _, row in df.iterrows():
print('Case %s run result: %s, msg:\n%s' %
(row['Name'], row['Result'], row['Info']))


def print_table_result(df):
df = df.loc[df['Result'] != 'Skipped']
df = df.drop('Info', axis=1)
formatters = {
'Name': '{{:<{}s}}'.format(df['Name'].str.len().max()).format,
'Result': '{{:<{}s}}'.format(df['Result'].str.len().max()).format,
}
with pandas.option_context('display.max_rows', None, 'display.max_columns',
None, 'display.width', None):
print(df.to_string(justify='left', formatters=formatters, index=False))


def main(args):
runner = unittest.TextTestRunner()
runner = TimeCostTextTestRunner()
test_suite = gather_test_cases(
os.path.abspath(args.test_dir), args.pattern, args.list_tests)
if not args.list_tests:
result = runner.run(test_suite)
if len(result.failures) > 0:
sys.exit(len(result.failures))
if len(result.errors) > 0:
sys.exit(len(result.errors))
result = collect_test_results(result)
df = test_cases_result_to_df(result)
if args.result_dir is not None:
file_name = str(int(datetime.datetime.now().timestamp() * 1000))
df.to_pickle(os.path.join(args.result_dir, file_name))
else:
print_table_result(df)
print_abnormal_case_info(df)
statistics_test_result(df)


if __name__ == '__main__':
@@ -66,6 +320,18 @@ if __name__ == '__main__':
'--level', default=0, type=int, help='2 -- all, 1 -- p1, 0 -- p0')
parser.add_argument(
'--disable_profile', action='store_true', help='disable profiling')
parser.add_argument(
'--isolated_cases',
default=None,
help='specified isolated cases config file')
parser.add_argument(
'--subprocess',
action='store_true',
help='run all test suite in subprocess')
parser.add_argument(
'--result_dir',
default=None,
help='Save result to directory, internal use only')
args = parser.parse_args()
set_test_level(args.level)
logger.info(f'TEST LEVEL: {test_level()}')
@@ -73,4 +339,10 @@ if __name__ == '__main__':
from utils import profiler
logger.info('enable profile ...')
profiler.enable()
main(args)
if args.isolated_cases is not None or args.subprocess:
run_in_subprocess(args)
elif args.isolated_cases is not None and args.subprocess:
print('isolated_cases and subporcess conflict')
sys.exit(1)
else:
main(args)

+ 35
- 41
tests/trainers/test_image_color_enhance_trainer.py View File

@@ -17,6 +17,41 @@ from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class PairedImageDataset(data.Dataset):

def __init__(self, root):
super(PairedImageDataset, self).__init__()
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist]

def _img_to_tensor(self, img):
return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type(
torch.float32) / 255.

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC)
return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs):
return self


class TestImageColorEnhanceTrainer(unittest.TestCase):

def setUp(self):
@@ -27,47 +62,6 @@ class TestImageColorEnhanceTrainer(unittest.TestCase):

self.model_id = 'damo/cv_csrnet_image-color-enhance-models'

class PairedImageDataset(data.Dataset):

def __init__(self, root):
super(PairedImageDataset, self).__init__()
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(
self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [
osp.join(gt_dir, f) for f in self.gt_filelist
]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(
self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [
osp.join(lq_dir, f) for f in self.lq_filelist
]

def _img_to_tensor(self, img):
return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(
2, 0, 1).type(torch.float32) / 255.

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC)
return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable,
List[Callable]] = None,
**format_kwargs):
return self

self.dataset = PairedImageDataset(
'./data/test/images/image_color_enhance/')



+ 41
- 47
tests/trainers/test_image_portrait_enhancement_trainer.py View File

@@ -19,6 +19,47 @@ from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class PairedImageDataset(data.Dataset):

def __init__(self, root, size=512):
super(PairedImageDataset, self).__init__()
self.size = size
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist]

def _img_to_tensor(self, img):
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type(
torch.float32) / 255.
return (img - 0.5) / 0.5

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(
lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(
gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC)

return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs):
# self.preprocessor = preprocessors
return self


class TestImagePortraitEnhancementTrainer(unittest.TestCase):

def setUp(self):
@@ -29,53 +70,6 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase):

self.model_id = 'damo/cv_gpen_image-portrait-enhancement'

class PairedImageDataset(data.Dataset):

def __init__(self, root, size=512):
super(PairedImageDataset, self).__init__()
self.size = size
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(
self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [
osp.join(gt_dir, f) for f in self.gt_filelist
]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(
self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [
osp.join(lq_dir, f) for f in self.lq_filelist
]

def _img_to_tensor(self, img):
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(
2, 0, 1).type(torch.float32) / 255.
return (img - 0.5) / 0.5

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(
lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(
gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC)

return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable,
List[Callable]] = None,
**format_kwargs):
# self.preprocessor = preprocessors
return self

self.dataset = PairedImageDataset(
'./data/test/images/face_enhancement/')



+ 4
- 3
tests/trainers/test_trainer.py View File

@@ -16,6 +16,7 @@ from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.trainers.base import DummyTrainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile
from modelscope.utils.test_utils import create_dummy_test_dataset, test_level

@@ -264,7 +265,7 @@ class TrainerTest(unittest.TestCase):
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 1,
LogKeys.ITER: 20
LogKeys.ITER: 10
}, json.loads(lines[2]))
self.assertDictContainsSubset(
{
@@ -284,7 +285,7 @@ class TrainerTest(unittest.TestCase):
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 2,
LogKeys.ITER: 20
LogKeys.ITER: 10
}, json.loads(lines[5]))
self.assertDictContainsSubset(
{
@@ -304,7 +305,7 @@ class TrainerTest(unittest.TestCase):
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 3,
LogKeys.ITER: 20
LogKeys.ITER: 10
}, json.loads(lines[8]))
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)


Loading…
Cancel
Save