diff --git a/data/test/images/Solvay_conference_1927.png b/data/test/images/Solvay_conference_1927.png new file mode 100755 index 00000000..0c97101d --- /dev/null +++ b/data/test/images/Solvay_conference_1927.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa8ab905e8374a0f94b4bfbfc81da14e762c71eaf64bae85bdd03b07cdf884c2 +size 859206 diff --git a/data/test/images/face_enhancement/gt/000000.jpg b/data/test/images/face_enhancement/gt/000000.jpg new file mode 100644 index 00000000..13c18e3b --- /dev/null +++ b/data/test/images/face_enhancement/gt/000000.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cd14710143ba1a912e3ef574d0bf71c7e40bf9897522cba07ecae2567343064 +size 850603 diff --git a/data/test/images/face_enhancement/gt/000001.jpg b/data/test/images/face_enhancement/gt/000001.jpg new file mode 100644 index 00000000..d0b7afc0 --- /dev/null +++ b/data/test/images/face_enhancement/gt/000001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7f166ecb3a6913dbd05a1eb271399cbaa731d1074ac03184c13ae245ca66819 +size 800380 diff --git a/data/test/images/face_enhancement/lq/000000.png b/data/test/images/face_enhancement/lq/000000.png new file mode 100644 index 00000000..8503d219 --- /dev/null +++ b/data/test/images/face_enhancement/lq/000000.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e95d11661485fc0e6f326398f953459dcb3e65b7f4a6c892611266067cf8fe3a +size 245773 diff --git a/data/test/images/face_enhancement/lq/000001.png b/data/test/images/face_enhancement/lq/000001.png new file mode 100644 index 00000000..9afb2a0e --- /dev/null +++ b/data/test/images/face_enhancement/lq/000001.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03972400b20b3e6f1d056b359d9c9f12952653a67a73b36018504ce9ee9edf9d +size 254261 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index d35d6e86..2a7a9c0a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -16,6 +16,7 @@ class Models(object): nafnet = 'nafnet' csrnet = 'csrnet' cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' + gpen = 'gpen' product_retrieval_embedding = 'product-retrieval-embedding' # nlp models @@ -91,6 +92,7 @@ class Pipelines(object): image2image_translation = 'image-to-image-translation' live_category = 'live-category' video_category = 'video-category' + image_portrait_enhancement = 'gpen-image-portrait-enhancement' image_to_image_generation = 'image-to-image-generation' # nlp tasks @@ -160,6 +162,7 @@ class Preprocessors(object): image_denoie_preprocessor = 'image-denoise-preprocessor' image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' + image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' @@ -207,3 +210,5 @@ class Metrics(object): text_gen_metric = 'text-gen-metric' # metrics for image-color-enhance task image_color_enhance_metric = 'image-color-enhance-metric' + # metrics for image-portrait-enhancement task + image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index 54ee705a..c632a9bd 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from .image_denoise_metric import ImageDenoiseMetric from .image_instance_segmentation_metric import \ ImageInstanceSegmentationCOCOMetric + from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric from .sequence_classification_metric import SequenceClassificationMetric from .text_generation_metric import TextGenerationMetric @@ -21,6 +22,8 @@ else: 'image_denoise_metric': ['ImageDenoiseMetric'], 'image_instance_segmentation_metric': ['ImageInstanceSegmentationCOCOMetric'], + 'image_portrait_enhancement_metric': + ['ImagePortraitEnhancementMetric'], 'sequence_classification_metric': ['SequenceClassificationMetric'], 'text_generation_metric': ['TextGenerationMetric'], } diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 5b9f962e..4df856f2 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -23,7 +23,9 @@ task_default_metrics = { Tasks.sentiment_classification: [Metrics.seq_cls_metric], Tasks.text_generation: [Metrics.text_gen_metric], Tasks.image_denoising: [Metrics.image_denoise_metric], - Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric] + Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], + Tasks.image_portrait_enhancement: + [Metrics.image_portrait_enhancement_metric], } diff --git a/modelscope/metrics/image_portrait_enhancement_metric.py b/modelscope/metrics/image_portrait_enhancement_metric.py new file mode 100644 index 00000000..b8412b9e --- /dev/null +++ b/modelscope/metrics/image_portrait_enhancement_metric.py @@ -0,0 +1,47 @@ +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +def calculate_psnr(img, img2): + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.image_portrait_enhancement_metric) +class ImagePortraitEnhancementMetric(Metric): + """The metric for image-portrait-enhancement task. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs['target'] + eval_results = outputs['pred'] + self.preds.extend(eval_results) + self.targets.extend(ground_truths) + + def evaluate(self): + psnrs = [ + calculate_psnr(pred, target) + for pred, target in zip(self.preds, self.targets) + ] + + return {MetricKeys.PSNR: sum(psnrs) / len(psnrs)} diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 3a8a0e55..beeb0994 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -3,6 +3,6 @@ from . import (action_recognition, animal_recognition, cartoon, cmdssl_video_embedding, face_detection, face_generation, image_classification, image_color_enhance, image_colorization, image_denoise, image_instance_segmentation, - image_to_image_generation, image_to_image_translation, - object_detection, product_retrieval_embedding, super_resolution, - virual_tryon) + image_portrait_enhancement, image_to_image_generation, + image_to_image_translation, object_detection, + product_retrieval_embedding, super_resolution, virual_tryon) diff --git a/modelscope/models/cv/image_portrait_enhancement/__init__.py b/modelscope/models/cv/image_portrait_enhancement/__init__.py new file mode 100644 index 00000000..4014bb15 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_portrait_enhancement import ImagePortraitEnhancement + +else: + _import_structure = { + 'image_portrait_enhancement': ['ImagePortraitEnhancement'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_portrait_enhancement/align_faces.py b/modelscope/models/cv/image_portrait_enhancement/align_faces.py new file mode 100755 index 00000000..776b06d8 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/align_faces.py @@ -0,0 +1,252 @@ +import cv2 +import numpy as np +from skimage import transform as trans + +from modelscope.utils.logger import get_logger + +logger = get_logger() + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], + [65.53179932, 51.50139999], + [48.02519989, + 71.73660278], [33.54930115, 92.3655014], + [62.72990036, 92.20410156]] + +DEFAULT_CROP_SIZE = (96, 112) + + +def _umeyama(src, dst, estimate_scale=True, scale=1.0): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = dst_demean.T @ src_demean / num + + # Eq. (39). + d = np.ones((dim, ), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = U @ V + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = U @ np.diag(d) @ V + d[dim - 1] = s + else: + T[:dim, :dim] = U @ np.diag(d) @ V + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) + else: + scale = scale + + T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) + T[:dim, :dim] *= scale + + return T, scale + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, + inner_padding_factor=0.0, + outer_padding=(0, 0), + default_square=False): + ref_5pts = np.array(REFERENCE_FACIAL_POINTS) + ref_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(ref_crop_size) - ref_crop_size + ref_5pts += size_diff / 2 + ref_crop_size += size_diff + + if (output_size and output_size[0] == ref_crop_size[0] + and output_size[1] == ref_crop_size[1]): + return ref_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + logger.info('No paddings to do: return default reference points') + return ref_5pts + else: + raise FaceWarpException( + 'No paddings to do, output_size must be None or {}'.format( + ref_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 + or outer_padding[1] > 0) and output_size is None): + output_size = ref_crop_size * (1 + inner_padding_factor * 2).astype( + np.int32) + output_size += np.array(outer_padding) + logger.info('deduced from paddings, output_size = ', output_size) + + if not (outer_padding[0] < output_size[0] + and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0]' + 'and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = ref_crop_size * inner_padding_factor * 2 + ref_5pts += size_diff / 2 + ref_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * ref_crop_size[1] != size_bf_outer_pad[ + 1] * ref_crop_size[0]: + raise FaceWarpException( + 'Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / ref_crop_size[0] + ref_5pts = ref_5pts * scale_factor + ref_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = ref_5pts + np.array(outer_padding) + ref_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], + [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def get_params(reference_pts, facial_pts, align_type): + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException( + 'reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException( + 'facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException( + 'facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) + else: + params, scale = _umeyama(src_pts, ref_pts) + tfm = params[:2, :] + + params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) + tfm_inv = params[:2, :] + + return tfm, tfm_inv + + +def warp_and_crop_face(src_img, + facial_pts, + reference_pts=None, + crop_size=(96, 112), + align_type='smilarity'): # smilarity cv2_affine affine + + reference_pts_112 = get_reference_facial_points((112, 112), 0.25, (0, 0), + True) + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = True # False + inner_padding_factor = 0.25 # 0 + outer_padding = (0, 0) + output_size = crop_size + reference_pts = get_reference_facial_points( + output_size, inner_padding_factor, outer_padding, + default_square) + + tfm, tfm_inv = get_params(reference_pts, facial_pts, align_type) + tfm_112, tfm_inv_112 = get_params(reference_pts_112, facial_pts, + align_type) + + if src_img is not None: + face_img = cv2.warpAffine( + src_img, tfm, (crop_size[0], crop_size[1]), flags=3) + face_img_112 = cv2.warpAffine(src_img, tfm_112, (112, 112), flags=3) + + return face_img, face_img_112, tfm_inv + else: + return tfm, tfm_inv diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py new file mode 100755 index 00000000..936bed9a --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py @@ -0,0 +1,57 @@ +import os + +import cv2 +import numpy as np +import torch + +from .model_resnet import FaceQuality, ResNet + + +class FQA(object): + + def __init__(self, backbone_path, quality_path, device='cuda', size=112): + self.BACKBONE = ResNet(num_layers=100, feature_dim=512) + self.QUALITY = FaceQuality(512 * 7 * 7) + self.size = size + self.device = device + + self.load_model(backbone_path, quality_path) + + def load_model(self, backbone_path, quality_path): + checkpoint = torch.load(backbone_path, map_location='cpu') + self.load_state_dict(self.BACKBONE, checkpoint) + + checkpoint = torch.load(quality_path, map_location='cpu') + self.load_state_dict(self.QUALITY, checkpoint) + + self.BACKBONE.to(self.device) + self.QUALITY.to(self.device) + self.BACKBONE.eval() + self.QUALITY.eval() + + def load_state_dict(self, model, state_dict): + all_keys = {k for k in state_dict.keys()} + for k in all_keys: + if k.startswith('module.'): + state_dict[k[7:]] = state_dict.pop(k) + model_dict = model.state_dict() + pretrained_dict = { + k: v + for k, v in state_dict.items() + if k in model_dict and v.size() == model_dict[k].size() + } + + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + + def get_face_quality(self, img): + img = torch.from_numpy(img).permute(2, 0, + 1).unsqueeze(0).flip(1).cuda() + img = (img - 127.5) / 128.0 + + # extract features & predict quality + with torch.no_grad(): + feature, fc = self.BACKBONE(img.to(self.device), True) + s = self.QUALITY(fc)[0] + + return s.cpu().numpy()[0], feature.cpu().numpy()[0] diff --git a/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py new file mode 100644 index 00000000..ea3c4f2a --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/eqface/model_resnet.py @@ -0,0 +1,130 @@ +import torch +from torch import nn + + +class BottleNeck_IR(nn.Module): + + def __init__(self, in_channel, out_channel, stride, dim_match): + super(BottleNeck_IR, self).__init__() + self.res_layer = nn.Sequential( + nn.BatchNorm2d(in_channel), + nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), + nn.BatchNorm2d(out_channel), nn.PReLU(out_channel), + nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), + nn.BatchNorm2d(out_channel)) + if dim_match: + self.shortcut_layer = None + else: + self.shortcut_layer = nn.Sequential( + nn.Conv2d( + in_channel, + out_channel, + kernel_size=(1, 1), + stride=stride, + bias=False), nn.BatchNorm2d(out_channel)) + + def forward(self, x): + shortcut = x + res = self.res_layer(x) + + if self.shortcut_layer is not None: + shortcut = self.shortcut_layer(x) + + return shortcut + res + + +channel_list = [64, 64, 128, 256, 512] + + +def get_layers(num_layers): + if num_layers == 34: + return [3, 4, 6, 3] + if num_layers == 50: + return [3, 4, 14, 3] + elif num_layers == 100: + return [3, 13, 30, 3] + elif num_layers == 152: + return [3, 8, 36, 3] + + +class ResNet(nn.Module): + + def __init__(self, + num_layers=100, + feature_dim=512, + drop_ratio=0.4, + channel_list=channel_list): + super(ResNet, self).__init__() + assert num_layers in [34, 50, 100, 152] + layers = get_layers(num_layers) + block = BottleNeck_IR + + self.input_layer = nn.Sequential( + nn.Conv2d( + 3, channel_list[0], (3, 3), stride=1, padding=1, bias=False), + nn.BatchNorm2d(channel_list[0]), nn.PReLU(channel_list[0])) + self.layer1 = self._make_layer( + block, channel_list[0], channel_list[1], layers[0], stride=2) + self.layer2 = self._make_layer( + block, channel_list[1], channel_list[2], layers[1], stride=2) + self.layer3 = self._make_layer( + block, channel_list[2], channel_list[3], layers[2], stride=2) + self.layer4 = self._make_layer( + block, channel_list[3], channel_list[4], layers[3], stride=2) + + self.output_layer = nn.Sequential( + nn.BatchNorm2d(512), nn.Dropout(drop_ratio), nn.Flatten()) + self.feature_layer = nn.Sequential( + nn.Linear(512 * 7 * 7, feature_dim), nn.BatchNorm1d(feature_dim)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d) or isinstance( + m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, in_channel, out_channel, blocks, stride): + layers = [] + layers.append(block(in_channel, out_channel, stride, False)) + for i in range(1, blocks): + layers.append(block(out_channel, out_channel, 1, True)) + return nn.Sequential(*layers) + + def forward(self, x, fc=False): + x = self.input_layer(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.output_layer(x) + feature = self.feature_layer(x) + if fc: + return feature, x + return feature + + +class FaceQuality(nn.Module): + + def __init__(self, feature_dim): + super(FaceQuality, self).__init__() + self.qualtiy = nn.Sequential( + nn.Linear(feature_dim, 512, bias=False), nn.BatchNorm1d(512), + nn.ReLU(inplace=True), nn.Linear(512, 2, bias=False), + nn.Softmax(dim=1)) + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d) or isinstance( + m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.qualtiy(x) + return x[:, 0:1] diff --git a/modelscope/models/cv/image_portrait_enhancement/gpen.py b/modelscope/models/cv/image_portrait_enhancement/gpen.py new file mode 100755 index 00000000..2e21dbc0 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/gpen.py @@ -0,0 +1,813 @@ +import functools +import itertools +import math +import operator +import random + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F + +from modelscope.models.cv.face_generation.op import (FusedLeakyReLU, + fused_leaky_relu, + upfirdn2d) + + +class PixelNorm(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt( + torch.mean(input**2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor**2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor**2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + bias=True): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + + def __init__(self, + in_dim, + out_dim, + bias=True, + bias_init=0, + lr_mul=1, + activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur( + blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size**2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})') + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view(batch * self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view(batch, self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape(batch * in_channel, + self.out_channel, + self.kernel_size, + self.kernel_size) + out = F.conv_transpose2d( + input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + + def __init__(self, isconcat=True): + super().__init__() + + self.isconcat = isconcat + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, channel, height, width = image.shape + noise = image.new_empty(batch, channel, height, width).normal_() + + if self.isconcat: + return torch.cat((image, self.weight * noise), dim=1) + else: + return image + self.weight * noise + + +class ConstantInput(nn.Module): + + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + isconcat=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection(isconcat) + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + feat_multiplier = 2 if isconcat else 1 + self.activate = FusedLeakyReLU(out_channel * feat_multiplier) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + + def __init__(self, + in_channel, + style_dim, + upsample=True, + blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d( + in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + isconcat=True, + narrow=1, + ): + super().__init__() + + self.size = size + self.n_mlp = n_mlp + self.style_dim = style_dim + self.feat_multiplier = 2 if isconcat else 1 + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, + style_dim, + lr_mul=lr_mlp, + activation='fused_lrelu')) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], + self.channels[4], + 3, + style_dim, + blur_kernel=blur_kernel, + isconcat=isconcat) + self.to_rgb1 = ToRGB( + self.channels[4] * self.feat_multiplier, style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + + in_channel = self.channels[4] + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2**i] + + self.convs.append( + StyledConv( + in_channel * self.feat_multiplier, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + isconcat=isconcat, + )) + + self.convs.append( + StyledConv( + out_channel * self.feat_multiplier, + out_channel, + 3, + style_dim, + blur_kernel=blur_kernel, + isconcat=isconcat)) + + self.to_rgbs.append( + ToRGB(out_channel * self.feat_multiplier, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + ''' + noise = [None] * (2 * (self.log_size - 2) + 1) + ''' + noise = [] + batch = styles[0].shape[0] + for i in range(self.n_mlp + 1): + size = 2**(i + 2) + noise.append( + torch.randn( + batch, + self.channels[size], + size, + size, + device=styles[0].device)) + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append(truncation_latent + + truncation * (style - truncation_latent)) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat( + 1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + )) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, + out_channel, + 1, + downsample=True, + activate=False, + bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class FullGenerator(nn.Module): + + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + isconcat=True, + narrow=1, + ): + super().__init__() + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + self.log_size = int(math.log(size, 2)) + self.generator = Generator( + size, + style_dim, + n_mlp, + channel_multiplier=channel_multiplier, + blur_kernel=blur_kernel, + lr_mlp=lr_mlp, + isconcat=isconcat, + narrow=narrow) + + conv = [ConvLayer(3, channels[size], 1)] + self.ecd0 = nn.Sequential(*conv) + in_channel = channels[size] + + self.names = ['ecd%d' % i for i in range(self.log_size - 1)] + for i in range(self.log_size, 2, -1): + out_channel = channels[2**(i - 1)] + # conv = [ResBlock(in_channel, out_channel, blur_kernel)] + conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] + setattr(self, self.names[self.log_size - i + 1], + nn.Sequential(*conv)) + in_channel = out_channel + self.final_linear = nn.Sequential( + EqualLinear( + channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) + + def forward( + self, + inputs, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + ): + noise = [] + for i in range(self.log_size - 1): + ecd = getattr(self, self.names[i]) + inputs = ecd(inputs) + noise.append(inputs) + inputs = inputs.view(inputs.shape[0], -1) + outs = self.final_linear(inputs) + noise = list( + itertools.chain.from_iterable( + itertools.repeat(x, 2) for x in noise))[::-1] + outs = self.generator([outs], + return_latents, + inject_index, + truncation, + truncation_latent, + input_is_latent, + noise=noise[1:]) + return outs + + +class Discriminator(nn.Module): + + def __init__(self, + size, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + narrow=1): + super().__init__() + + channels = { + 4: int(512 * narrow), + 8: int(512 * narrow), + 16: int(512 * narrow), + 32: int(512 * narrow), + 64: int(256 * channel_multiplier * narrow), + 128: int(128 * channel_multiplier * narrow), + 256: int(64 * channel_multiplier * narrow), + 512: int(32 * channel_multiplier * narrow), + 1024: int(16 * channel_multiplier * narrow), + 2048: int(8 * channel_multiplier * narrow) + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2**(i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear( + channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, + channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + return out diff --git a/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py new file mode 100644 index 00000000..3250d393 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/image_portrait_enhancement.py @@ -0,0 +1,205 @@ +import math +import os.path as osp +from copy import deepcopy +from typing import Any, Dict, List, Union + +import torch +import torch.nn.functional as F +from torch import autograd, nn +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .gpen import Discriminator, FullGenerator +from .losses.losses import IDLoss, L1Loss + +logger = get_logger() + +__all__ = ['ImagePortraitEnhancement'] + + +@MODELS.register_module( + Tasks.image_portrait_enhancement, module_name=Models.gpen) +class ImagePortraitEnhancement(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the face enhancement model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + self.size = 512 + self.style_dim = 512 + self.n_mlp = 8 + self.mean_path_length = 0 + self.accum = 0.5**(32 / (10 * 1000)) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + + self.l1_loss = L1Loss() + self.id_loss = IDLoss(f'{model_dir}/arcface/model_ir_se50.pth', + self._device) + self.generator = FullGenerator( + self.size, self.style_dim, self.n_mlp, + isconcat=True).to(self._device) + self.g_ema = FullGenerator( + self.size, self.style_dim, self.n_mlp, + isconcat=True).to(self._device) + self.discriminator = Discriminator(self.size).to(self._device) + + if self.size == 512: + self.load_pretrained(model_dir) + + def load_pretrained(self, model_dir): + g_path = f'{model_dir}/{ModelFile.TORCH_MODEL_FILE}' + g_dict = torch.load(g_path, map_location=torch.device('cpu')) + self.generator.load_state_dict(g_dict) + self.g_ema.load_state_dict(g_dict) + + d_path = f'{model_dir}/net_d.pt' + d_dict = torch.load(d_path, map_location=torch.device('cpu')) + self.discriminator.load_state_dict(d_dict) + + logger.info('load model done.') + + def accumulate(self): + par1 = dict(self.g_ema.named_parameters()) + par2 = dict(self.generator.named_parameters()) + + for k in par1.keys(): + par1[k].data.mul_(self.accum).add_(1 - self.accum, par2[k].data) + + def requires_grad(self, model, flag=True): + for p in model.parameters(): + p.requires_grad = flag + + def d_logistic_loss(self, real_pred, fake_pred): + real_loss = F.softplus(-real_pred) + fake_loss = F.softplus(fake_pred) + + return real_loss.mean() + fake_loss.mean() + + def d_r1_loss(self, real_pred, real_img): + grad_real, = autograd.grad( + outputs=real_pred.sum(), inputs=real_img, create_graph=True) + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], + -1).sum(1).mean() + + return grad_penalty + + def g_nonsaturating_loss(self, + fake_pred, + fake_img=None, + real_img=None, + input_img=None): + loss = F.softplus(-fake_pred).mean() + loss_l1 = self.l1_loss(fake_img, real_img) + loss_id, __, __ = self.id_loss(fake_img, real_img, input_img) + loss_id = 0 + loss += 1.0 * loss_l1 + 1.0 * loss_id + + return loss + + def g_path_regularize(self, + fake_img, + latents, + mean_path_length, + decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt( + fake_img.shape[2] * fake_img.shape[3]) + grad, = autograd.grad( + outputs=(fake_img * noise).sum(), + inputs=latents, + create_graph=True) + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * ( + path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_mean.detach(), path_lengths + + @torch.no_grad() + def _evaluate_postprocess(self, src: Tensor, + target: Tensor) -> Dict[str, list]: + preds, _ = self.generator(src) + preds = list(torch.split(preds, 1, 0)) + targets = list(torch.split(target, 1, 0)) + + preds = [((pred.data * 0.5 + 0.5) * 255.).squeeze(0).type( + torch.uint8).permute(1, 2, 0).cpu().numpy() for pred in preds] + targets = [((target.data * 0.5 + 0.5) * 255.).squeeze(0).type( + torch.uint8).permute(1, 2, 0).cpu().numpy() for target in targets] + + return {'pred': preds, 'target': targets} + + def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: + self.requires_grad(self.generator, False) + self.requires_grad(self.discriminator, True) + + preds, _ = self.generator(src) + fake_pred = self.discriminator(preds) + real_pred = self.discriminator(target) + + d_loss = self.d_logistic_loss(real_pred, fake_pred) + + return d_loss + + def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: + src.requires_grad = True + target.requires_grad = True + real_pred = self.discriminator(target) + r1_loss = self.d_r1_loss(real_pred, target) + + return r1_loss + + def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: + self.requires_grad(self.generator, True) + self.requires_grad(self.discriminator, False) + + preds, _ = self.generator(src) + fake_pred = self.discriminator(preds) + + g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) + + return g_loss + + def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: + fake_img, latents = self.generator(src, return_latents=True) + + path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( + fake_img, latents, self.mean_path_length) + + return path_loss + + @torch.no_grad() + def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: + return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} + + def forward(self, input: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Union[list, Tensor]]: results + """ + for key, value in input.items(): + input[key] = input[key].to(self._device) + + if 'target' in input: + return self._evaluate_postprocess(**input) + else: + return self._inference_forward(**input) diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py new file mode 100644 index 00000000..35ca202f --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/losses/helpers.py @@ -0,0 +1,129 @@ +from collections import namedtuple + +import torch +from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d, + Module, PReLU, ReLU, Sequential, Sigmoid) + + +class Flatten(Module): + + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride) + ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError( + 'Invalid number of layers: {}. Must be one of [50, 100, 152]'. + format(num_layers)) + return blocks + + +class SEModule(Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d( + channels, + channels // reduction, + kernel_size=1, + padding=0, + bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d( + channels // reduction, + channels, + kernel_size=1, + padding=0, + bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), SEModule(depth, 16)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/losses.py b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py new file mode 100644 index 00000000..8934eee7 --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/losses/losses.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .model_irse import Backbone + + +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError( + f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' + ) + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * F.l1_loss( + pred, target, reduction=self.reduction) + + +class IDLoss(nn.Module): + + def __init__(self, model_path, device='cuda', ckpt_dict=None): + super(IDLoss, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone( + input_size=112, num_layers=50, drop_ratio=0.6, + mode='ir_se').to(device) + if ckpt_dict is None: + self.facenet.load_state_dict( + torch.load(model_path, map_location=torch.device('cpu'))) + else: + self.facenet.load_state_dict(ckpt_dict) + self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + def extract_feats(self, x): + _, _, h, w = x.shape + assert h == w + if h != 256: + x = self.pool(x) + x = x[:, :, 35:-33, 32:-36] # crop roi + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats + + @torch.no_grad() + def forward(self, y_hat, y, x): + n_samples = x.shape[0] + x_feats = self.extract_feats(x) + y_feats = self.extract_feats(y) # Otherwise use the feature from there + y_hat_feats = self.extract_feats(y_hat) + y_feats = y_feats.detach() + loss = 0 + sim_improvement = 0 + id_logs = [] + count = 0 + for i in range(n_samples): + diff_target = y_hat_feats[i].dot(y_feats[i]) + diff_input = y_hat_feats[i].dot(x_feats[i]) + diff_views = y_feats[i].dot(x_feats[i]) + id_logs.append({ + 'diff_target': float(diff_target), + 'diff_input': float(diff_input), + 'diff_views': float(diff_views) + }) + loss += 1 - diff_target + id_diff = float(diff_target) - float(diff_views) + sim_improvement += id_diff + count += 1 + + return loss / count, sim_improvement / count, id_logs diff --git a/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py new file mode 100644 index 00000000..3b87d7fd --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/losses/model_irse.py @@ -0,0 +1,92 @@ +from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, + Module, PReLU, Sequential) + +from .helpers import (Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, + l2_norm) + + +class Backbone(Module): + + def __init__(self, + input_size, + num_layers, + mode='ir', + drop_ratio=0.4, + affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], 'input_size should be 112 or 224' + assert num_layers in [50, 100, + 152], 'num_layers should be 50, 100 or 152' + assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential( + Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential( + BatchNorm2d(512), Dropout(drop_ratio), Flatten(), + Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone( + input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone( + input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone( + input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone( + input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone( + input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone( + input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py new file mode 100755 index 00000000..c294438a --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/detection.py @@ -0,0 +1,217 @@ +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.functional as F + +from .models.retinaface import RetinaFace +from .utils import PriorBox, decode, decode_landm, py_cpu_nms + +cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'pretrain': False, + 'return_layers': { + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + 'in_channel': 256, + 'out_channel': 256 +} + + +class RetinaFaceDetection(object): + + def __init__(self, model_path, device='cuda'): + torch.set_grad_enabled(False) + cudnn.benchmark = True + self.model_path = model_path + self.device = device + self.cfg = cfg_re50 + self.net = RetinaFace(cfg=self.cfg) + self.load_model() + self.net = self.net.to(device) + + self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) + + def check_keys(self, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(self.net.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + assert len( + used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + + def remove_prefix(self, state_dict, prefix): + new_state_dict = dict() + # remove unnecessary 'module.' + for k, v in state_dict.items(): + if k.startswith(prefix): + new_state_dict[k[len(prefix):]] = v + else: + new_state_dict[k] = v + return new_state_dict + + def load_model(self, load_to_cpu=False): + pretrained_dict = torch.load( + self.model_path, map_location=torch.device('cpu')) + if 'state_dict' in pretrained_dict.keys(): + pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], + 'module.') + else: + pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') + self.check_keys(pretrained_dict) + self.net.load_state_dict(pretrained_dict, strict=False) + self.net.eval() + + def detect(self, + img_raw, + resize=1, + confidence_threshold=0.9, + nms_threshold=0.4, + top_k=5000, + keep_top_k=750, + save_image=False): + img = np.float32(img_raw) + + im_height, im_width = img.shape[:2] + ss = 1.0 + # tricky + if max(im_height, im_width) > 1500: + ss = 1000.0 / max(im_height, im_width) + img = cv2.resize(img, (0, 0), fx=ss, fy=ss) + im_height, im_width = img.shape[:2] + + scale = torch.Tensor( + [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) + img -= (104, 117, 123) + img = img.transpose(2, 0, 1) + img = torch.from_numpy(img).unsqueeze(0) + img = img.to(self.device) + scale = scale.to(self.device) + + loc, conf, landms = self.net(img) # forward pass + del img + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.to(self.device) + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale / resize + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm( + landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([ + im_width, im_height, im_width, im_height, im_width, im_height, + im_width, im_height, im_width, im_height + ]) + scale1 = scale1.to(self.device) + landms = landms * scale1 / resize + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + landms = landms.reshape((-1, 5, 2)) + landms = landms.transpose((0, 2, 1)) + landms = landms.reshape( + -1, + 10, + ) + return dets / ss, landms / ss + + def detect_tensor(self, + img, + resize=1, + confidence_threshold=0.9, + nms_threshold=0.4, + top_k=5000, + keep_top_k=750, + save_image=False): + im_height, im_width = img.shape[-2:] + ss = 1000 / max(im_height, im_width) + img = F.interpolate(img, scale_factor=ss) + im_height, im_width = img.shape[-2:] + scale = torch.Tensor([im_width, im_height, im_width, + im_height]).to(self.device) + img -= self.mean + + loc, conf, landms = self.net(img) # forward pass + + priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) + priors = priorbox.forward() + priors = priors.to(self.device) + prior_data = priors.data + boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) + boxes = boxes * scale / resize + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = decode_landm( + landms.data.squeeze(0), prior_data, self.cfg['variance']) + scale1 = torch.Tensor([ + img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2], img.shape[3], img.shape[2], + img.shape[3], img.shape[2] + ]) + scale1 = scale1.to(self.device) + landms = landms * scale1 / resize + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1][:top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype( + np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + dets = dets[:keep_top_k, :] + landms = landms[:keep_top_k, :] + + landms = landms.reshape((-1, 5, 2)) + landms = landms.transpose((0, 2, 1)) + landms = landms.reshape( + -1, + 10, + ) + return dets / ss, landms / ss diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/__init__.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py new file mode 100755 index 00000000..0546e0bb --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/net.py @@ -0,0 +1,148 @@ +import time + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +from torch.autograd import Variable + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn( + in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn( + out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu( + out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1( + in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1( + in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1( + in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate( + output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate( + output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + x = x.view(-1, 256) + x = self.fc(x) + return x diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py new file mode 100755 index 00000000..af1d706d --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/models/retinaface.py @@ -0,0 +1,144 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import torchvision.models._utils as _utils +import torchvision.models.detection.backbone_utils as backbone_utils + +from .net import FPN, SSH, MobileNetV1 + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d( + inchannels, + self.num_anchors * 2, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 4, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d( + inchannels, + num_anchors * 10, + kernel_size=(1, 1), + stride=1, + padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +class RetinaFace(nn.Module): + + def __init__(self, cfg=None): + """ + :param cfg: Network related settings. + """ + super(RetinaFace, self).__init__() + backbone = None + if cfg['name'] == 'Resnet50': + backbone = models.resnet50(pretrained=cfg['pretrain']) + else: + raise Exception('Invalid name') + + self.body = _utils.IntermediateLayerGetter(backbone, + cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head( + fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head( + fpn_num=3, inchannels=cfg['out_channel']) + + def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead + + def forward(self, inputs): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat( + [self.BboxHead[i](feature) for i, feature in enumerate(features)], + dim=1) + classifications = torch.cat( + [self.ClassHead[i](feature) for i, feature in enumerate(features)], + dim=1) + ldm_regressions = torch.cat( + [self.LandmarkHead[i](feat) for i, feat in enumerate(features)], + dim=1) + + output = (bbox_regressions, F.softmax(classifications, + dim=-1), ldm_regressions) + return output diff --git a/modelscope/models/cv/image_portrait_enhancement/retinaface/utils.py b/modelscope/models/cv/image_portrait_enhancement/retinaface/utils.py new file mode 100755 index 00000000..60c9e2dd --- /dev/null +++ b/modelscope/models/cv/image_portrait_enhancement/retinaface/utils.py @@ -0,0 +1,123 @@ +# -------------------------------------------------------- +# Modified from https://github.com/biubug6/Pytorch_Retinaface +# -------------------------------------------------------- + +from itertools import product as product +from math import ceil + +import numpy as np +import torch + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ + ceil(self.image_size[0] / step), + ceil(self.image_size[1] / step) + ] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [ + x * self.steps[k] / self.image_size[1] + for x in [j + 0.5] + ] + dense_cy = [ + y * self.steps[k] / self.image_size[0] + for y in [i + 0.5] + ] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat( + (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] + b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] + c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] + d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] + e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] + landms = torch.cat((a, b, c, d, e), dim=1) + return landms diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 20254416..e4d7e373 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -137,6 +137,7 @@ TASK_OUTPUTS = { Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG], Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], + Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], # image generation task result for a single image # {"output_img": np.array with shape (h, w, 3)} diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 4cf1924b..14a3de1e 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -110,6 +110,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_gan_face-image-generation'), Tasks.image_super_resolution: (Pipelines.image_super_resolution, 'damo/cv_rrdb_image-super-resolution'), + Tasks.image_portrait_enhancement: + (Pipelines.image_portrait_enhancement, + 'damo/cv_gpen_image-portrait-enhancement'), Tasks.product_retrieval_embedding: (Pipelines.product_retrieval_embedding, 'damo/cv_resnet50_product-bag-embedding-models'), diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 38593a65..3c7f6092 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -11,14 +11,15 @@ if TYPE_CHECKING: from .face_detection_pipeline import FaceDetectionPipeline from .face_recognition_pipeline import FaceRecognitionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline - from .image_classification_pipeline import ImageClassificationPipeline from .image_cartoon_pipeline import ImageCartoonPipeline from .image_classification_pipeline import GeneralImageClassificationPipeline - from .image_denoise_pipeline import ImageDenoisePipeline from .image_color_enhance_pipeline import ImageColorEnhancePipeline from .image_colorization_pipeline import ImageColorizationPipeline + from .image_classification_pipeline import ImageClassificationPipeline + from .image_denoise_pipeline import ImageDenoisePipeline from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline from .image_matting_pipeline import ImageMattingPipeline + from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline from .image_style_transfer_pipeline import ImageStyleTransferPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline @@ -46,6 +47,8 @@ else: 'image_instance_segmentation_pipeline': ['ImageInstanceSegmentationPipeline'], 'image_matting_pipeline': ['ImageMattingPipeline'], + 'image_portrait_enhancement_pipeline': + ['ImagePortraitEnhancementPipeline'], 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], 'image_to_image_translation_pipeline': diff --git a/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py new file mode 100644 index 00000000..de012221 --- /dev/null +++ b/modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py @@ -0,0 +1,216 @@ +import math +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +from scipy.ndimage import gaussian_filter +from scipy.spatial.distance import pdist, squareform + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_portrait_enhancement import gpen +from modelscope.models.cv.image_portrait_enhancement.align_faces import ( + get_reference_facial_points, warp_and_crop_face) +from modelscope.models.cv.image_portrait_enhancement.eqface import fqa +from modelscope.models.cv.image_portrait_enhancement.retinaface import \ + detection +from modelscope.models.cv.super_resolution import rrdbnet_arch +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage, load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_portrait_enhancement, + module_name=Pipelines.image_portrait_enhancement) +class ImagePortraitEnhancementPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + self.use_sr = True + + self.size = 512 + self.n_mlp = 8 + self.channel_multiplier = 2 + self.narrow = 1 + self.face_enhancer = gpen.FullGenerator( + self.size, + 512, + self.n_mlp, + self.channel_multiplier, + narrow=self.narrow).to(self.device) + + gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' + self.face_enhancer.load_state_dict( + torch.load(gpen_model_path), strict=True) + + logger.info('load face enhancer model done') + + self.threshold = 0.9 + detector_model_path = f'{model}/face_detection/RetinaFace-R50.pth' + self.face_detector = detection.RetinaFaceDetection( + detector_model_path, self.device) + + logger.info('load face detector model done') + + self.num_feat = 32 + self.num_block = 23 + self.scale = 2 + self.sr_model = rrdbnet_arch.RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=self.num_feat, + num_block=self.num_block, + num_grow_ch=32, + scale=self.scale).to(self.device) + + sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' + self.sr_model.load_state_dict( + torch.load(sr_model_path)['params_ema'], strict=True) + + logger.info('load sr model done') + + self.fqa_thres = 0.1 + self.id_thres = 0.15 + self.alpha = 1.0 + backbone_model_path = f'{model}/face_quality/eqface_backbone.pth' + fqa_model_path = f'{model}/face_quality/eqface_quality.pth' + self.eqface = fqa.FQA(backbone_model_path, fqa_model_path, self.device) + + logger.info('load fqa model done') + + # the mask for pasting restored faces back + self.mask = np.zeros((512, 512, 3), np.float32) + cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, + cv2.LINE_AA) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) + self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) + + def enhance_face(self, img): + img = cv2.resize(img, (self.size, self.size)) + img_t = self.img2tensor(img) + + self.face_enhancer.eval() + with torch.no_grad(): + out, __ = self.face_enhancer(img_t) + del img_t + + out = self.tensor2img(out) + + return out + + def img2tensor(self, img, is_norm=True): + img_t = torch.from_numpy(img).to(self.device) / 255. + if is_norm: + img_t = (img_t - 0.5) / 0.5 + img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB + return img_t + + def tensor2img(self, img_t, pmax=255.0, is_denorm=True, imtype=np.uint8): + if is_denorm: + img_t = img_t * 0.5 + 0.5 + img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax + + return img_np.astype(imtype) + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + + img_sr = None + if self.use_sr: + self.sr_model.eval() + with torch.no_grad(): + img_t = self.img2tensor(img, is_norm=False) + img_out = self.sr_model(img_t) + + img_sr = img_out.squeeze(0).permute(1, 2, 0).flip(2).cpu().clamp_( + 0, 1).numpy() + img_sr = (img_sr * 255.0).round().astype(np.uint8) + + img = cv2.resize(img, img_sr.shape[:2][::-1]) + + result = {'img': img, 'img_sr': img_sr} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + img, img_sr = input['img'], input['img_sr'] + img, img_sr = img.cpu().numpy(), img_sr.cpu().numpy() + facebs, landms = self.face_detector.detect(img) + + height, width = img.shape[:2] + full_mask = np.zeros(img.shape, dtype=np.float32) + full_img = np.zeros(img.shape, dtype=np.uint8) + + for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): + if faceb[4] < self.threshold: + continue + # fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0]) + + facial5points = np.reshape(facial5points, (2, 5)) + + of, of_112, tfm_inv = warp_and_crop_face( + img, facial5points, crop_size=(self.size, self.size)) + + # detect orig face quality + fq_o, fea_o = self.eqface.get_face_quality(of_112) + if fq_o < self.fqa_thres: + continue + + # enhance the face + ef = self.enhance_face(of) + + # detect enhanced face quality + ss = self.size // 256 + ef_112 = cv2.resize(ef[35 * ss:-33 * ss, 32 * ss:-36 * ss], + (112, 112)) # crop roi + fq_e, fea_e = self.eqface.get_face_quality(ef_112) + dist = squareform(pdist([fea_o, fea_e], 'cosine')).mean() + if dist > self.id_thres: + continue + + # blending parameter + fq = max(1., (fq_o - self.fqa_thres)) + fq = (1 - 2 * dist) * (1.0 / (1 + math.exp(-(2 * fq - 1)))) + + # blend face + ef = cv2.addWeighted(ef, fq * self.alpha, of, 1 - fq * self.alpha, + 0.0) + + tmp_mask = self.mask + tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) + tmp_mask = cv2.warpAffine( + tmp_mask, tfm_inv, (width, height), flags=3) + + tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) + + mask = np.clip(tmp_mask - full_mask, 0, 1) + full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] + full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] + + if self.use_sr and img_sr is not None: + out_img = cv2.convertScaleAbs(img_sr * (1 - full_mask) + + full_img * full_mask) + else: + out_img = cv2.convertScaleAbs(img * (1 - full_mask) + + full_img * full_mask) + + return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index f3007f95..775514a2 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -163,6 +163,32 @@ class ImageDenoisePreprocessor(Preprocessor): return data +@PREPROCESSORS.register_module( + Fields.cv, + module_name=Preprocessors.image_portrait_enhancement_preprocessor) +class ImagePortraitEnhancementPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """ + + Args: + model_dir (str): model path + """ + super().__init__(*args, **kwargs) + self.model_dir: str = model_dir + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data Dict[str, Any] + + Returns: + Dict[str, Any]: the preprocessed data + """ + return data + + @PREPROCESSORS.register_module( Fields.cv, module_name=Preprocessors.image_instance_segmentation_preprocessor) diff --git a/modelscope/trainers/__init__.py b/modelscope/trainers/__init__.py index f32a33c6..350bab61 100644 --- a/modelscope/trainers/__init__.py +++ b/modelscope/trainers/__init__.py @@ -1,6 +1,7 @@ from .base import DummyTrainer from .builder import build_trainer -from .cv import ImageInstanceSegmentationTrainer +from .cv import (ImageInstanceSegmentationTrainer, + ImagePortraitEnhancementTrainer) from .multi_modal import CLIPTrainer from .nlp import SequenceClassificationTrainer from .trainer import EpochBasedTrainer diff --git a/modelscope/trainers/cv/__init__.py b/modelscope/trainers/cv/__init__.py index 07b1646d..36d64af7 100644 --- a/modelscope/trainers/cv/__init__.py +++ b/modelscope/trainers/cv/__init__.py @@ -1,2 +1,3 @@ from .image_instance_segmentation_trainer import \ ImageInstanceSegmentationTrainer +from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer diff --git a/modelscope/trainers/cv/image_portrait_enhancement_trainer.py b/modelscope/trainers/cv/image_portrait_enhancement_trainer.py new file mode 100644 index 00000000..67c94213 --- /dev/null +++ b/modelscope/trainers/cv/image_portrait_enhancement_trainer.py @@ -0,0 +1,148 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections.abc import Mapping + +import torch +from torch import distributed as dist + +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.trainers.trainer import EpochBasedTrainer +from modelscope.utils.constant import ModeKeys +from modelscope.utils.logger import get_logger + + +@TRAINERS.register_module(module_name='gpen') +class ImagePortraitEnhancementTrainer(EpochBasedTrainer): + + def train_step(self, model, inputs): + """ Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`TorchModel`): The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + # EvaluationHook will do evaluate and change mode to val, return to train mode + # TODO: find more pretty way to change mode + self.d_reg_every = self.cfg.train.get('d_reg_every', 16) + self.g_reg_every = self.cfg.train.get('g_reg_every', 4) + self.path_regularize = self.cfg.train.get('path_regularize', 2) + self.r1 = self.cfg.train.get('r1', 10) + + train_outputs = dict() + self._mode = ModeKeys.TRAIN + inputs = self.collate_fn(inputs) + # call model forward but not __call__ to skip postprocess + if isinstance(inputs, Mapping): + d_loss = model._train_forward_d(**inputs) + else: + d_loss = model._train_forward_d(inputs) + train_outputs['d_loss'] = d_loss + + model.discriminator.zero_grad() + d_loss.backward() + self.optimizer_d.step() + + if self._iter % self.d_reg_every == 0: + + if isinstance(inputs, Mapping): + r1_loss = model._train_forward_d_r1(**inputs) + else: + r1_loss = model._train_forward_d_r1(inputs) + train_outputs['r1_loss'] = r1_loss + + model.discriminator.zero_grad() + (self.r1 / 2 * r1_loss * self.d_reg_every).backward() + + self.optimizer_d.step() + + if isinstance(inputs, Mapping): + g_loss = model._train_forward_g(**inputs) + else: + g_loss = model._train_forward_g(inputs) + train_outputs['g_loss'] = g_loss + + model.generator.zero_grad() + g_loss.backward() + self.optimizer.step() + + path_loss = 0 + if self._iter % self.g_reg_every == 0: + if isinstance(inputs, Mapping): + path_loss = model._train_forward_g_path(**inputs) + else: + path_loss = model._train_forward_g_path(inputs) + train_outputs['path_loss'] = path_loss + + model.generator.zero_grad() + weighted_path_loss = self.path_regularize * self.g_reg_every * path_loss + + weighted_path_loss.backward() + + self.optimizer.step() + + model.accumulate() + + if not isinstance(train_outputs, dict): + raise TypeError('"model.forward()" must return a dict') + + # add model output info to log + if 'log_vars' not in train_outputs: + default_keys_pattern = ['loss'] + match_keys = set([]) + for key_p in default_keys_pattern: + match_keys.update( + [key for key in train_outputs.keys() if key_p in key]) + + log_vars = {} + for key in match_keys: + value = train_outputs.get(key, None) + if value is not None: + if dist.is_available() and dist.is_initialized(): + value = value.data.clone() + dist.all_reduce(value.div_(dist.get_world_size())) + log_vars.update({key: value.item()}) + self.log_buffer.update(log_vars) + else: + self.log_buffer.update(train_outputs['log_vars']) + + self.train_outputs = train_outputs + + def create_optimizer_and_scheduler(self): + """ Create optimizer and lr scheduler + + We provide a default implementation, if you want to customize your own optimizer + and lr scheduler, you can either pass a tuple through trainer init function or + subclass this class and override this method. + + + """ + optimizer, lr_scheduler = self.optimizers + if optimizer is None: + optimizer_cfg = self.cfg.train.get('optimizer', None) + else: + optimizer_cfg = None + optimizer_d_cfg = self.cfg.train.get('optimizer_d', None) + + optim_options = {} + if optimizer_cfg is not None: + optim_options = optimizer_cfg.pop('options', {}) + optimizer = build_optimizer( + self.model.generator, cfg=optimizer_cfg) + if optimizer_d_cfg is not None: + optimizer_d = build_optimizer( + self.model.discriminator, cfg=optimizer_d_cfg) + + lr_options = {} + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.optimizer_d = optimizer_d + return self.optimizer, self.lr_scheduler, optim_options, lr_options diff --git a/modelscope/trainers/hooks/__init__.py b/modelscope/trainers/hooks/__init__.py index 6a581759..ff55da09 100644 --- a/modelscope/trainers/hooks/__init__.py +++ b/modelscope/trainers/hooks/__init__.py @@ -14,5 +14,5 @@ __all__ = [ 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', 'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook', - 'BestCkptSaverHook' + 'BestCkptSaverHook', 'NoneOptimizerHook', 'NoneLrSchedulerHook' ] diff --git a/modelscope/trainers/hooks/lr_scheduler_hook.py b/modelscope/trainers/hooks/lr_scheduler_hook.py index d84ca3b9..cf3a16e7 100644 --- a/modelscope/trainers/hooks/lr_scheduler_hook.py +++ b/modelscope/trainers/hooks/lr_scheduler_hook.py @@ -115,3 +115,18 @@ class PlateauLrSchedulerHook(LrSchedulerHook): self.warmup_lr_scheduler.step(metrics=metrics) else: trainer.lr_scheduler.step(metrics=metrics) + + +@HOOKS.register_module() +class NoneLrSchedulerHook(LrSchedulerHook): + + PRIORITY = Priority.LOW # should be after EvaluationHook + + def __init__(self, by_epoch=True, warmup=None) -> None: + super().__init__(by_epoch=by_epoch, warmup=warmup) + + def before_run(self, trainer): + return + + def after_train_epoch(self, trainer): + return diff --git a/modelscope/trainers/hooks/optimizer_hook.py b/modelscope/trainers/hooks/optimizer_hook.py index 32d58f40..294a06a6 100644 --- a/modelscope/trainers/hooks/optimizer_hook.py +++ b/modelscope/trainers/hooks/optimizer_hook.py @@ -200,3 +200,19 @@ class ApexAMPOptimizerHook(OptimizerHook): trainer.optimizer.step() trainer.optimizer.zero_grad() + + +@HOOKS.register_module() +class NoneOptimizerHook(OptimizerHook): + + def __init__(self, cumulative_iters=1, grad_clip=None, loss_keys='loss'): + + super(NoneOptimizerHook, self).__init__( + grad_clip=grad_clip, loss_keys=loss_keys) + self.cumulative_iters = cumulative_iters + + def before_run(self, trainer): + return + + def after_train_iter(self, trainer): + return diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 20311fba..ef3c4f4f 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -43,6 +43,7 @@ class CVTasks(object): image_colorization = 'image-colorization' image_color_enhancement = 'image-color-enhancement' image_denoising = 'image-denoising' + image_portrait_enhancement = 'image-portrait-enhancement' # image generation image_to_image_translation = 'image-to-image-translation' diff --git a/tests/pipelines/test_image_portrait_enhancement.py b/tests/pipelines/test_image_portrait_enhancement.py new file mode 100644 index 00000000..64b84db6 --- /dev/null +++ b/tests/pipelines/test_image_portrait_enhancement.py @@ -0,0 +1,43 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ImagePortraitEnhancementTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_gpen_image-portrait-enhancement' + self.test_image = 'data/test/images/Solvay_conference_1927.png' + + def pipeline_inference(self, pipeline: Pipeline, test_image: str): + result = pipeline(test_image) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + else: + raise Exception('Testing failed: invalid output') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub(self): + face_enhancement = pipeline( + Tasks.image_portrait_enhancement, model=self.model_id) + self.pipeline_inference(face_enhancement, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + face_enhancement = pipeline(Tasks.image_portrait_enhancement) + self.pipeline_inference(face_enhancement, self.test_image) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_portrait_enhancement_trainer.py b/tests/trainers/test_image_portrait_enhancement_trainer.py new file mode 100644 index 00000000..3de78347 --- /dev/null +++ b/tests/trainers/test_image_portrait_enhancement_trainer.py @@ -0,0 +1,119 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import shutil +import tempfile +import unittest +from typing import Callable, List, Optional, Tuple, Union + +import cv2 +import torch +from torch.utils import data as data + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.image_portrait_enhancement import \ + ImagePortraitEnhancement +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImagePortraitEnhancementTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_gpen_image-portrait-enhancement' + + class PairedImageDataset(data.Dataset): + + def __init__(self, root, size=512): + super(PairedImageDataset, self).__init__() + self.size = size + gt_dir = osp.join(root, 'gt') + lq_dir = osp.join(root, 'lq') + self.gt_filelist = os.listdir(gt_dir) + self.gt_filelist = sorted( + self.gt_filelist, key=lambda x: int(x[:-4])) + self.gt_filelist = [ + osp.join(gt_dir, f) for f in self.gt_filelist + ] + self.lq_filelist = os.listdir(lq_dir) + self.lq_filelist = sorted( + self.lq_filelist, key=lambda x: int(x[:-4])) + self.lq_filelist = [ + osp.join(lq_dir, f) for f in self.lq_filelist + ] + + def _img_to_tensor(self, img): + img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute( + 2, 0, 1).type(torch.float32) / 255. + return (img - 0.5) / 0.5 + + def __getitem__(self, index): + lq = cv2.imread(self.lq_filelist[index]) + gt = cv2.imread(self.gt_filelist[index]) + lq = cv2.resize( + lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) + gt = cv2.resize( + gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) + + return \ + {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} + + def __len__(self): + return len(self.gt_filelist) + + def to_torch_dataset(self, + columns: Union[str, List[str]] = None, + preprocessors: Union[Callable, + List[Callable]] = None, + **format_kwargs): + # self.preprocessor = preprocessors + return self + + self.dataset = PairedImageDataset( + './data/test/images/face_enhancement/') + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + device='gpu', + work_dir=self.tmp_dir) + + trainer = build_trainer(name='gpen', default_args=kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + cache_path = snapshot_download(self.model_id) + model = ImagePortraitEnhancement.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset, + eval_dataset=self.dataset, + device='gpu', + max_epochs=2, + work_dir=self.tmp_dir) + + trainer = build_trainer(name='gpen', default_args=kwargs) + trainer.train() + + +if __name__ == '__main__': + unittest.main()