Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10453584master
| @@ -459,3 +459,4 @@ class Datasets(object): | |||
| SegDataset = 'SegDataset' | |||
| DetDataset = 'DetDataset' | |||
| DetImagesMixDataset = 'DetImagesMixDataset' | |||
| PairedDataset = 'PairedDataset' | |||
| @@ -2,6 +2,7 @@ | |||
| # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py | |||
| from typing import Dict | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.metainfo import Metrics | |||
| @@ -37,6 +38,7 @@ class ImagePortraitEnhancementMetric(Metric): | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| ground_truths = outputs['target'] | |||
| eval_results = outputs['pred'] | |||
| self.preds.extend(eval_results) | |||
| self.targets.extend(ground_truths) | |||
| @@ -35,7 +35,7 @@ class ImagePortraitEnhancement(TorchModel): | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.size = 512 | |||
| self.size = 256 | |||
| self.style_dim = 512 | |||
| self.n_mlp = 8 | |||
| self.mean_path_length = 0 | |||
| @@ -131,9 +131,9 @@ class ImagePortraitEnhancement(TorchModel): | |||
| return path_penalty, path_mean.detach(), path_lengths | |||
| @torch.no_grad() | |||
| def _evaluate_postprocess(self, src: Tensor, | |||
| def _evaluate_postprocess(self, input: Tensor, | |||
| target: Tensor) -> Dict[str, list]: | |||
| preds, _ = self.generator(src) | |||
| preds, _ = self.generator(input) | |||
| preds = list(torch.split(preds, 1, 0)) | |||
| targets = list(torch.split(target, 1, 0)) | |||
| @@ -144,11 +144,11 @@ class ImagePortraitEnhancement(TorchModel): | |||
| return {'pred': preds, 'target': targets} | |||
| def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: | |||
| def _train_forward_d(self, input: Tensor, target: Tensor) -> Tensor: | |||
| self.requires_grad(self.generator, False) | |||
| self.requires_grad(self.discriminator, True) | |||
| preds, _ = self.generator(src) | |||
| preds, _ = self.generator(input) | |||
| fake_pred = self.discriminator(preds) | |||
| real_pred = self.discriminator(target) | |||
| @@ -156,27 +156,27 @@ class ImagePortraitEnhancement(TorchModel): | |||
| return d_loss | |||
| def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: | |||
| src.requires_grad = True | |||
| def _train_forward_d_r1(self, input: Tensor, target: Tensor) -> Tensor: | |||
| input.requires_grad = True | |||
| target.requires_grad = True | |||
| real_pred = self.discriminator(target) | |||
| r1_loss = self.d_r1_loss(real_pred, target) | |||
| return r1_loss | |||
| def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: | |||
| def _train_forward_g(self, input: Tensor, target: Tensor) -> Tensor: | |||
| self.requires_grad(self.generator, True) | |||
| self.requires_grad(self.discriminator, False) | |||
| preds, _ = self.generator(src) | |||
| preds, _ = self.generator(input) | |||
| fake_pred = self.discriminator(preds) | |||
| g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) | |||
| g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, input) | |||
| return g_loss | |||
| def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: | |||
| fake_img, latents = self.generator(src, return_latents=True) | |||
| def _train_forward_g_path(self, input: Tensor, target: Tensor) -> Tensor: | |||
| fake_img, latents = self.generator(input, return_latents=True) | |||
| path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( | |||
| fake_img, latents, self.mean_path_length) | |||
| @@ -184,8 +184,8 @@ class ImagePortraitEnhancement(TorchModel): | |||
| return path_loss | |||
| @torch.no_grad() | |||
| def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: | |||
| return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} | |||
| def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: | |||
| return {'outputs': (self.generator(input)[0] * 0.5 + 0.5).clamp(0, 1)} | |||
| def forward(self, input: Dict[str, | |||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||
| @@ -27,6 +27,8 @@ else: | |||
| 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | |||
| 'image_inpainting': ['ImageInpaintingDataset'], | |||
| 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | |||
| 'image_portrait_enhancement_dataset': | |||
| ['ImagePortraitEnhancementDataset'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .image_portrait_enhancement_dataset import ImagePortraitEnhancementDataset | |||
| else: | |||
| _import_structure = { | |||
| 'image_portrait_enhancement_dataset': | |||
| ['ImagePortraitEnhancementDataset'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,32 @@ | |||
| # ------------------------------------------------------------------------ | |||
| # Modified from BasicSR (https://github.com/xinntao/BasicSR) | |||
| # Copyright 2018-2020 BasicSR Authors | |||
| # ------------------------------------------------------------------------ | |||
| import cv2 | |||
| import torch | |||
| def img2tensor(imgs, bgr2rgb=True, float32=True): | |||
| """Numpy array to tensor. | |||
| Args: | |||
| imgs (list[ndarray] | ndarray): Input images. | |||
| bgr2rgb (bool): Whether to change bgr to rgb. | |||
| float32 (bool): Whether to change to float32. | |||
| Returns: | |||
| list[tensor] | tensor: Tensor images. If returned results only have | |||
| one element, just return tensor. | |||
| """ | |||
| def _totensor(img, bgr2rgb, float32): | |||
| if img.shape[2] == 3 and bgr2rgb: | |||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||
| img = torch.from_numpy(img.transpose(2, 0, 1)) | |||
| if float32: | |||
| img = img.float() | |||
| return img | |||
| if isinstance(imgs, list): | |||
| return [_totensor(img, bgr2rgb, float32) for img in imgs] | |||
| else: | |||
| return _totensor(imgs, bgr2rgb, float32) | |||
| @@ -0,0 +1,51 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import cv2 | |||
| import numpy as np | |||
| from modelscope.metainfo import Datasets, Models | |||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||
| TorchTaskDataset | |||
| from modelscope.utils.constant import Tasks | |||
| from .data_utils import img2tensor | |||
| def default_loader(path): | |||
| return cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0 | |||
| @TASK_DATASETS.register_module( | |||
| Tasks.image_portrait_enhancement, module_name=Datasets.PairedDataset) | |||
| class ImagePortraitEnhancementDataset(TorchTaskDataset): | |||
| """Paired image dataset for image portrait enhancement. | |||
| """ | |||
| def __init__(self, dataset, is_train): | |||
| self.dataset = dataset | |||
| self.gt_size = 256 | |||
| self.is_train = is_train | |||
| def __len__(self): | |||
| return len(self.dataset) | |||
| def __getitem__(self, index): | |||
| # Load gt and lq images. Dimension order: HWC; channel order: BGR; | |||
| # image range: [0, 1], float32. | |||
| item_dict = self.dataset[index] | |||
| gt_path = item_dict['hq:FILE'] | |||
| img_gt = default_loader(gt_path) | |||
| lq_path = item_dict['lq:FILE'] | |||
| img_lq = default_loader(lq_path) | |||
| gt_size = self.gt_size | |||
| img_gt = cv2.resize(img_gt, (gt_size, gt_size)) | |||
| img_lq = cv2.resize(img_lq, (gt_size, gt_size)) | |||
| # BGR to RGB, HWC to CHW, numpy to tensor | |||
| img_gt, img_lq = img2tensor([img_gt, img_lq], | |||
| bgr2rgb=True, | |||
| float32=True) | |||
| return {'input': (img_lq - 0.5) / 0.5, 'target': (img_gt - 0.5) / 0.5} | |||
| @@ -14,52 +14,14 @@ from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.cv.image_portrait_enhancement import \ | |||
| ImagePortraitEnhancement | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \ | |||
| ImagePortraitEnhancementDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.constant import DownloadMode, ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class PairedImageDataset(data.Dataset): | |||
| def __init__(self, root, size=512): | |||
| super(PairedImageDataset, self).__init__() | |||
| self.size = size | |||
| gt_dir = osp.join(root, 'gt') | |||
| lq_dir = osp.join(root, 'lq') | |||
| self.gt_filelist = os.listdir(gt_dir) | |||
| self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) | |||
| self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] | |||
| self.lq_filelist = os.listdir(lq_dir) | |||
| self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) | |||
| self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] | |||
| def _img_to_tensor(self, img): | |||
| img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( | |||
| torch.float32) / 255. | |||
| return (img - 0.5) / 0.5 | |||
| def __getitem__(self, index): | |||
| lq = cv2.imread(self.lq_filelist[index]) | |||
| gt = cv2.imread(self.gt_filelist[index]) | |||
| lq = cv2.resize( | |||
| lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| gt = cv2.resize( | |||
| gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| return \ | |||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
| def __len__(self): | |||
| return len(self.gt_filelist) | |||
| def to_torch_dataset(self, | |||
| columns: Union[str, List[str]] = None, | |||
| preprocessors: Union[Callable, List[Callable]] = None, | |||
| **format_kwargs): | |||
| # self.preprocessor = preprocessors | |||
| return self | |||
| class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| @@ -70,8 +32,23 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||
| self.dataset = PairedImageDataset( | |||
| './data/test/images/face_enhancement/') | |||
| dataset_train = MsDataset.load( | |||
| 'image-portrait-enhancement-dataset', | |||
| namespace='modelscope', | |||
| subset_name='default', | |||
| split='test', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||
| dataset_val = MsDataset.load( | |||
| 'image-portrait-enhancement-dataset', | |||
| namespace='modelscope', | |||
| subset_name='default', | |||
| split='test', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds | |||
| self.dataset_train = ImagePortraitEnhancementDataset( | |||
| dataset_train, is_train=True) | |||
| self.dataset_val = ImagePortraitEnhancementDataset( | |||
| dataset_val, is_train=False) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
| @@ -81,8 +58,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| train_dataset=self.dataset_train, | |||
| eval_dataset=self.dataset_val, | |||
| device='gpu', | |||
| work_dir=self.tmp_dir) | |||
| @@ -101,8 +78,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| train_dataset=self.dataset_train, | |||
| eval_dataset=self.dataset_val, | |||
| device='gpu', | |||
| max_epochs=2, | |||
| work_dir=self.tmp_dir) | |||