diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 732f8ffa..fa1605de 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -459,3 +459,4 @@ class Datasets(object): SegDataset = 'SegDataset' DetDataset = 'DetDataset' DetImagesMixDataset = 'DetImagesMixDataset' + PairedDataset = 'PairedDataset' diff --git a/modelscope/metrics/image_portrait_enhancement_metric.py b/modelscope/metrics/image_portrait_enhancement_metric.py index 5a81e956..7d94aade 100644 --- a/modelscope/metrics/image_portrait_enhancement_metric.py +++ b/modelscope/metrics/image_portrait_enhancement_metric.py @@ -2,6 +2,7 @@ # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py from typing import Dict +import cv2 import numpy as np from modelscope.metainfo import Metrics @@ -37,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric): def add(self, outputs: Dict, inputs: Dict): ground_truths = outputs['target'] eval_results = outputs['pred'] + self.preds.extend(eval_results) self.targets.extend(ground_truths) diff --git a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py index 3650ac7b..26e9e532 100644 --- a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py +++ b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py @@ -35,7 +35,7 @@ class ImagePortraitEnhancement(TorchModel): """ super().__init__(model_dir, *args, **kwargs) - self.size = 512 + self.size = 256 self.style_dim = 512 self.n_mlp = 8 self.mean_path_length = 0 @@ -131,9 +131,9 @@ class ImagePortraitEnhancement(TorchModel): return path_penalty, path_mean.detach(), path_lengths @torch.no_grad() - def _evaluate_postprocess(self, src: Tensor, + def _evaluate_postprocess(self, input: Tensor, target: Tensor) -> Dict[str, list]: - preds, _ = self.generator(src) + preds, _ = self.generator(input) preds = list(torch.split(preds, 1, 0)) targets = list(torch.split(target, 1, 0)) @@ -144,11 +144,11 @@ class ImagePortraitEnhancement(TorchModel): return {'pred': preds, 'target': targets} - def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: + def _train_forward_d(self, input: Tensor, target: Tensor) -> Tensor: self.requires_grad(self.generator, False) self.requires_grad(self.discriminator, True) - preds, _ = self.generator(src) + preds, _ = self.generator(input) fake_pred = self.discriminator(preds) real_pred = self.discriminator(target) @@ -156,27 +156,27 @@ class ImagePortraitEnhancement(TorchModel): return d_loss - def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: - src.requires_grad = True + def _train_forward_d_r1(self, input: Tensor, target: Tensor) -> Tensor: + input.requires_grad = True target.requires_grad = True real_pred = self.discriminator(target) r1_loss = self.d_r1_loss(real_pred, target) return r1_loss - def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: + def _train_forward_g(self, input: Tensor, target: Tensor) -> Tensor: self.requires_grad(self.generator, True) self.requires_grad(self.discriminator, False) - preds, _ = self.generator(src) + preds, _ = self.generator(input) fake_pred = self.discriminator(preds) - g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) + g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, input) return g_loss - def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: - fake_img, latents = self.generator(src, return_latents=True) + def _train_forward_g_path(self, input: Tensor, target: Tensor) -> Tensor: + fake_img, latents = self.generator(input, return_latents=True) path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( fake_img, latents, self.mean_path_length) @@ -184,8 +184,8 @@ class ImagePortraitEnhancement(TorchModel): return path_loss @torch.no_grad() - def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: - return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} + def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: + return {'outputs': (self.generator(input)[0] * 0.5 + 0.5).clamp(0, 1)} def forward(self, input: Dict[str, Tensor]) -> Dict[str, Union[list, Tensor]]: diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index 7c31969a..914c41bf 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -27,6 +27,8 @@ else: 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], 'image_inpainting': ['ImageInpaintingDataset'], 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], + 'image_portrait_enhancement_dataset': + ['ImagePortraitEnhancementDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py new file mode 100644 index 00000000..4df24fae --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_portrait_enhancement_dataset import ImagePortraitEnhancementDataset + +else: + _import_structure = { + 'image_portrait_enhancement_dataset': + ['ImagePortraitEnhancementDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py new file mode 100644 index 00000000..1133d3c2 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/data_utils.py @@ -0,0 +1,32 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import cv2 +import torch + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) diff --git a/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py new file mode 100644 index 00000000..58d40778 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/image_portrait_enhancement/image_portrait_enhancement_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import Datasets, Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from .data_utils import img2tensor + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 + + +@TASK_DATASETS.register_module( + Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset) +class ImagePortraitEnhancementDataset(TorchTaskDataset): + """Paired image dataset for image portrait enhancement. + """ + + def __init__(self, dataset, is_train): + self.dataset = dataset + self.gt_size = 256 + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['hq:FILE'] + img_gt = default_loader(gt_path) + lq_path = item_dict['lq:FILE'] + img_lq = default_loader(lq_path) + + gt_size = self.gt_size + img_gt = cv2.resize(img_gt, (gt_size, gt_size)) + img_lq = cv2.resize(img_lq, (gt_size, gt_size)) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + + return {'input': (img_lq - 0.5) / 0.5, 'target': (img_gt - 0.5) / 0.5} diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py index 049adf7e..5c47a59b 100644 --- a/tests/trainers/test_image_portrait_enhancement_trainer.py +++ b/tests/trainers/test_image_portrait_enhancement_trainer.py @@ -14,52 +14,14 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Trainers from modelscope.models.cv.image_portrait_enhancement import \ ImagePortraitEnhancement +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \ + ImagePortraitEnhancementDataset from modelscope.trainers import build_trainer -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import DownloadMode, ModelFile from modelscope.utils.test_utils import test_level -class PairedImageDataset(data.Dataset): - - def __init__(self, root, size=512): - super(PairedImageDataset, self).__init__() - self.size = size - gt_dir = osp.join(root, 'gt') - lq_dir = osp.join(root, 'lq') - self.gt_filelist = os.listdir(gt_dir) - self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) - self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] - self.lq_filelist = os.listdir(lq_dir) - self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) - self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] - - def _img_to_tensor(self, img): - img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( - torch.float32) / 255. - return (img - 0.5) / 0.5 - - def __getitem__(self, index): - lq = cv2.imread(self.lq_filelist[index]) - gt = cv2.imread(self.gt_filelist[index]) - lq = cv2.resize( - lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) - gt = cv2.resize( - gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) - - return \ - {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} - - def __len__(self): - return len(self.gt_filelist) - - def to_torch_dataset(self, - columns: Union[str, List[str]] = None, - preprocessors: Union[Callable, List[Callable]] = None, - **format_kwargs): - # self.preprocessor = preprocessors - return self - - class TestImagePortraitEnhancementTrainer(unittest.TestCase): def setUp(self): @@ -70,8 +32,23 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): self.model_id = 'damo/cv_gpen_image-portrait-enhancement' - self.dataset = PairedImageDataset( - './data/test/images/face_enhancement/') + dataset_train = MsDataset.load( + 'image-portrait-enhancement-dataset', + namespace='modelscope', + subset_name='default', + split='test', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds + dataset_val = MsDataset.load( + 'image-portrait-enhancement-dataset', + namespace='modelscope', + subset_name='default', + split='test', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds + + self.dataset_train = ImagePortraitEnhancementDataset( + dataset_train, is_train=True) + self.dataset_val = ImagePortraitEnhancementDataset( + dataset_val, is_train=False) def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) @@ -81,8 +58,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): def test_trainer(self): kwargs = dict( model=self.model_id, - train_dataset=self.dataset, - eval_dataset=self.dataset, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, device='gpu', work_dir=self.tmp_dir) @@ -101,8 +78,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), model=model, - train_dataset=self.dataset, - eval_dataset=self.dataset, + train_dataset=self.dataset_train, + eval_dataset=self.dataset_val, device='gpu', max_epochs=2, work_dir=self.tmp_dir)