Browse Source

[to #42322933]update msdatasets for image-portrait-enhancement training

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10453584
master
baiguan.yt yingda.chen 3 years ago
parent
commit
533ab3df63
8 changed files with 150 additions and 62 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/metrics/image_portrait_enhancement_metric.py
  3. +14
    -14
      modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py
  4. +2
    -0
      modelscope/msdatasets/task_datasets/__init__.py
  5. +23
    -0
      modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py
  6. +32
    -0
      modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py
  7. +51
    -0
      modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py
  8. +25
    -48
      tests/trainers/test_image_portrait_enhancement_trainer.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -459,3 +459,4 @@ class Datasets(object):
SegDataset = 'SegDataset'
DetDataset = 'DetDataset'
DetImagesMixDataset = 'DetImagesMixDataset'
PairedDataset = 'PairedDataset'

+ 2
- 0
modelscope/metrics/image_portrait_enhancement_metric.py View File

@@ -2,6 +2,7 @@
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py
from typing import Dict

import cv2
import numpy as np

from modelscope.metainfo import Metrics
@@ -37,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric):
def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs['target']
eval_results = outputs['pred']

self.preds.extend(eval_results)
self.targets.extend(ground_truths)



+ 14
- 14
modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py View File

@@ -35,7 +35,7 @@ class ImagePortraitEnhancement(TorchModel):
"""
super().__init__(model_dir, *args, **kwargs)

self.size = 512
self.size = 256
self.style_dim = 512
self.n_mlp = 8
self.mean_path_length = 0
@@ -131,9 +131,9 @@ class ImagePortraitEnhancement(TorchModel):
return path_penalty, path_mean.detach(), path_lengths

@torch.no_grad()
def _evaluate_postprocess(self, src: Tensor,
def _evaluate_postprocess(self, input: Tensor,
target: Tensor) -> Dict[str, list]:
preds, _ = self.generator(src)
preds, _ = self.generator(input)
preds = list(torch.split(preds, 1, 0))
targets = list(torch.split(target, 1, 0))

@@ -144,11 +144,11 @@ class ImagePortraitEnhancement(TorchModel):

return {'pred': preds, 'target': targets}

def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor:
def _train_forward_d(self, input: Tensor, target: Tensor) -> Tensor:
self.requires_grad(self.generator, False)
self.requires_grad(self.discriminator, True)

preds, _ = self.generator(src)
preds, _ = self.generator(input)
fake_pred = self.discriminator(preds)
real_pred = self.discriminator(target)

@@ -156,27 +156,27 @@ class ImagePortraitEnhancement(TorchModel):

return d_loss

def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor:
src.requires_grad = True
def _train_forward_d_r1(self, input: Tensor, target: Tensor) -> Tensor:
input.requires_grad = True
target.requires_grad = True
real_pred = self.discriminator(target)
r1_loss = self.d_r1_loss(real_pred, target)

return r1_loss

def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor:
def _train_forward_g(self, input: Tensor, target: Tensor) -> Tensor:
self.requires_grad(self.generator, True)
self.requires_grad(self.discriminator, False)

preds, _ = self.generator(src)
preds, _ = self.generator(input)
fake_pred = self.discriminator(preds)

g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src)
g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, input)

return g_loss

def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor:
fake_img, latents = self.generator(src, return_latents=True)
def _train_forward_g_path(self, input: Tensor, target: Tensor) -> Tensor:
fake_img, latents = self.generator(input, return_latents=True)

path_loss, self.mean_path_length, path_lengths = self.g_path_regularize(
fake_img, latents, self.mean_path_length)
@@ -184,8 +184,8 @@ class ImagePortraitEnhancement(TorchModel):
return path_loss

@torch.no_grad()
def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]:
return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)}
def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
return {'outputs': (self.generator(input)[0] * 0.5 + 0.5).clamp(0, 1)}

def forward(self, input: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:


+ 2
- 0
modelscope/msdatasets/task_datasets/__init__.py View File

@@ -27,6 +27,8 @@ else:
'movie_scene_segmentation': ['MovieSceneSegmentationDataset'],
'image_inpainting': ['ImageInpaintingDataset'],
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'],
'image_portrait_enhancement_dataset':
['ImagePortraitEnhancementDataset'],
}
import sys



+ 23
- 0
modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .image_portrait_enhancement_dataset import ImagePortraitEnhancementDataset

else:
_import_structure = {
'image_portrait_enhancement_dataset':
['ImagePortraitEnhancementDataset'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 32
- 0
modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py View File

@@ -0,0 +1,32 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------

import cv2
import torch


def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""

def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img

if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)

+ 51
- 0
modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py View File

@@ -0,0 +1,51 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import cv2
import numpy as np

from modelscope.metainfo import Datasets, Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.utils.constant import Tasks
from .data_utils import img2tensor


def default_loader(path):
return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0


@TASK_DATASETS.register_module(
Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset)
class ImagePortraitEnhancementDataset(TorchTaskDataset):
"""Paired image dataset for image portrait enhancement.
"""

def __init__(self, dataset, is_train):
self.dataset = dataset
self.gt_size = 256
self.is_train = is_train

def __len__(self):
return len(self.dataset)

def __getitem__(self, index):

# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
item_dict = self.dataset[index]
gt_path = item_dict['hq:FILE']
img_gt = default_loader(gt_path)
lq_path = item_dict['lq:FILE']
img_lq = default_loader(lq_path)

gt_size = self.gt_size
img_gt = cv2.resize(img_gt, (gt_size, gt_size))
img_lq = cv2.resize(img_lq, (gt_size, gt_size))

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq],
bgr2rgb=True,
float32=True)

return {'input': (img_lq - 0.5) / 0.5, 'target': (img_gt - 0.5) / 0.5}

+ 25
- 48
tests/trainers/test_image_portrait_enhancement_trainer.py View File

@@ -14,52 +14,14 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.cv.image_portrait_enhancement import \
ImagePortraitEnhancement
from modelscope.msdatasets import MsDataset
from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \
ImagePortraitEnhancementDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level


class PairedImageDataset(data.Dataset):

def __init__(self, root, size=512):
super(PairedImageDataset, self).__init__()
self.size = size
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist]

def _img_to_tensor(self, img):
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type(
torch.float32) / 255.
return (img - 0.5) / 0.5

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(
lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(
gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC)

return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs):
# self.preprocessor = preprocessors
return self


class TestImagePortraitEnhancementTrainer(unittest.TestCase):

def setUp(self):
@@ -70,8 +32,23 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase):

self.model_id = 'damo/cv_gpen_image-portrait-enhancement'

self.dataset = PairedImageDataset(
'./data/test/images/face_enhancement/')
dataset_train = MsDataset.load(
'image-portrait-enhancement-dataset',
namespace='modelscope',
subset_name='default',
split='test',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
dataset_val = MsDataset.load(
'image-portrait-enhancement-dataset',
namespace='modelscope',
subset_name='default',
split='test',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds

self.dataset_train = ImagePortraitEnhancementDataset(
dataset_train, is_train=True)
self.dataset_val = ImagePortraitEnhancementDataset(
dataset_val, is_train=False)

def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
@@ -81,8 +58,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase):
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
train_dataset=self.dataset_train,
eval_dataset=self.dataset_val,
device='gpu',
work_dir=self.tmp_dir)

@@ -101,8 +78,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase):
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
train_dataset=self.dataset_train,
eval_dataset=self.dataset_val,
device='gpu',
max_epochs=2,
work_dir=self.tmp_dir)


Loading…
Cancel
Save