Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9590794master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:fa8ab905e8374a0f94b4bfbfc81da14e762c71eaf64bae85bdd03b07cdf884c2 | |||||
| size 859206 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:8cd14710143ba1a912e3ef574d0bf71c7e40bf9897522cba07ecae2567343064 | |||||
| size 850603 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:d7f166ecb3a6913dbd05a1eb271399cbaa731d1074ac03184c13ae245ca66819 | |||||
| size 800380 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:e95d11661485fc0e6f326398f953459dcb3e65b7f4a6c892611266067cf8fe3a | |||||
| size 245773 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:03972400b20b3e6f1d056b359d9c9f12952653a67a73b36018504ce9ee9edf9d | |||||
| size 254261 | |||||
| @@ -16,6 +16,7 @@ class Models(object): | |||||
| nafnet = 'nafnet' | nafnet = 'nafnet' | ||||
| csrnet = 'csrnet' | csrnet = 'csrnet' | ||||
| cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | ||||
| gpen = 'gpen' | |||||
| product_retrieval_embedding = 'product-retrieval-embedding' | product_retrieval_embedding = 'product-retrieval-embedding' | ||||
| # nlp models | # nlp models | ||||
| @@ -91,6 +92,7 @@ class Pipelines(object): | |||||
| image2image_translation = 'image-to-image-translation' | image2image_translation = 'image-to-image-translation' | ||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| video_category = 'video-category' | video_category = 'video-category' | ||||
| image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||||
| image_to_image_generation = 'image-to-image-generation' | image_to_image_generation = 'image-to-image-generation' | ||||
| # nlp tasks | # nlp tasks | ||||
| @@ -160,6 +162,7 @@ class Preprocessors(object): | |||||
| image_denoie_preprocessor = 'image-denoise-preprocessor' | image_denoie_preprocessor = 'image-denoise-preprocessor' | ||||
| image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | ||||
| image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | ||||
| image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | |||||
| # nlp preprocessor | # nlp preprocessor | ||||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | sen_sim_tokenizer = 'sen-sim-tokenizer' | ||||
| @@ -207,3 +210,5 @@ class Metrics(object): | |||||
| text_gen_metric = 'text-gen-metric' | text_gen_metric = 'text-gen-metric' | ||||
| # metrics for image-color-enhance task | # metrics for image-color-enhance task | ||||
| image_color_enhance_metric = 'image-color-enhance-metric' | image_color_enhance_metric = 'image-color-enhance-metric' | ||||
| # metrics for image-portrait-enhancement task | |||||
| image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' | |||||
| @@ -10,6 +10,7 @@ if TYPE_CHECKING: | |||||
| from .image_denoise_metric import ImageDenoiseMetric | from .image_denoise_metric import ImageDenoiseMetric | ||||
| from .image_instance_segmentation_metric import \ | from .image_instance_segmentation_metric import \ | ||||
| ImageInstanceSegmentationCOCOMetric | ImageInstanceSegmentationCOCOMetric | ||||
| from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric | |||||
| from .sequence_classification_metric import SequenceClassificationMetric | from .sequence_classification_metric import SequenceClassificationMetric | ||||
| from .text_generation_metric import TextGenerationMetric | from .text_generation_metric import TextGenerationMetric | ||||
| @@ -21,6 +22,8 @@ else: | |||||
| 'image_denoise_metric': ['ImageDenoiseMetric'], | 'image_denoise_metric': ['ImageDenoiseMetric'], | ||||
| 'image_instance_segmentation_metric': | 'image_instance_segmentation_metric': | ||||
| ['ImageInstanceSegmentationCOCOMetric'], | ['ImageInstanceSegmentationCOCOMetric'], | ||||
| 'image_portrait_enhancement_metric': | |||||
| ['ImagePortraitEnhancementMetric'], | |||||
| 'sequence_classification_metric': ['SequenceClassificationMetric'], | 'sequence_classification_metric': ['SequenceClassificationMetric'], | ||||
| 'text_generation_metric': ['TextGenerationMetric'], | 'text_generation_metric': ['TextGenerationMetric'], | ||||
| } | } | ||||
| @@ -23,7 +23,9 @@ task_default_metrics = { | |||||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
| Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
| Tasks.image_denoising: [Metrics.image_denoise_metric], | Tasks.image_denoising: [Metrics.image_denoise_metric], | ||||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric] | |||||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | |||||
| Tasks.image_portrait_enhancement: | |||||
| [Metrics.image_portrait_enhancement_metric], | |||||
| } | } | ||||
| @@ -0,0 +1,47 @@ | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| def calculate_psnr(img, img2): | |||||
| assert img.shape == img2.shape, ( | |||||
| f'Image shapes are different: {img.shape}, {img2.shape}.') | |||||
| img = img.astype(np.float64) | |||||
| img2 = img2.astype(np.float64) | |||||
| mse = np.mean((img - img2)**2) | |||||
| if mse == 0: | |||||
| return float('inf') | |||||
| return 10. * np.log10(255. * 255. / mse) | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, | |||||
| module_name=Metrics.image_portrait_enhancement_metric) | |||||
| class ImagePortraitEnhancementMetric(Metric): | |||||
| """The metric for image-portrait-enhancement task. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.preds = [] | |||||
| self.targets = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| ground_truths = outputs['target'] | |||||
| eval_results = outputs['pred'] | |||||
| self.preds.extend(eval_results) | |||||
| self.targets.extend(ground_truths) | |||||
| def evaluate(self): | |||||
| psnrs = [ | |||||
| calculate_psnr(pred, target) | |||||
| for pred, target in zip(self.preds, self.targets) | |||||
| ] | |||||
| return {MetricKeys.PSNR: sum(psnrs) / len(psnrs)} | |||||
| @@ -3,6 +3,6 @@ from . import (action_recognition, animal_recognition, cartoon, | |||||
| cmdssl_video_embedding, face_detection, face_generation, | cmdssl_video_embedding, face_detection, face_generation, | ||||
| image_classification, image_color_enhance, image_colorization, | image_classification, image_color_enhance, image_colorization, | ||||
| image_denoise, image_instance_segmentation, | image_denoise, image_instance_segmentation, | ||||
| image_to_image_generation, image_to_image_translation, | |||||
| object_detection, product_retrieval_embedding, super_resolution, | |||||
| virual_tryon) | |||||
| image_portrait_enhancement, image_to_image_generation, | |||||
| image_to_image_translation, object_detection, | |||||
| product_retrieval_embedding, super_resolution, virual_tryon) | |||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .image_portrait_enhancement import ImagePortraitEnhancement | |||||
| else: | |||||
| _import_structure = { | |||||
| 'image_portrait_enhancement': ['ImagePortraitEnhancement'] | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,252 @@ | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from skimage import transform as trans | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| # reference facial points, a list of coordinates (x,y) | |||||
| REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], | |||||
| [65.53179932, 51.50139999], | |||||
| [48.02519989, | |||||
| 71.73660278], [33.54930115, 92.3655014], | |||||
| [62.72990036, 92.20410156]] | |||||
| DEFAULT_CROP_SIZE = (96, 112) | |||||
| def _umeyama(src, dst, estimate_scale=True, scale=1.0): | |||||
| """Estimate N-D similarity transformation with or without scaling. | |||||
| Parameters | |||||
| ---------- | |||||
| src : (M, N) array | |||||
| Source coordinates. | |||||
| dst : (M, N) array | |||||
| Destination coordinates. | |||||
| estimate_scale : bool | |||||
| Whether to estimate scaling factor. | |||||
| Returns | |||||
| ------- | |||||
| T : (N + 1, N + 1) | |||||
| The homogeneous similarity transformation matrix. The matrix contains | |||||
| NaN values only if the problem is not well-conditioned. | |||||
| References | |||||
| ---------- | |||||
| .. [1] "Least-squares estimation of transformation parameters between two | |||||
| point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` | |||||
| """ | |||||
| num = src.shape[0] | |||||
| dim = src.shape[1] | |||||
| # Compute mean of src and dst. | |||||
| src_mean = src.mean(axis=0) | |||||
| dst_mean = dst.mean(axis=0) | |||||
| # Subtract mean from src and dst. | |||||
| src_demean = src - src_mean | |||||
| dst_demean = dst - dst_mean | |||||
| # Eq. (38). | |||||
| A = dst_demean.T @ src_demean / num | |||||
| # Eq. (39). | |||||
| d = np.ones((dim, ), dtype=np.double) | |||||
| if np.linalg.det(A) < 0: | |||||
| d[dim - 1] = -1 | |||||
| T = np.eye(dim + 1, dtype=np.double) | |||||
| U, S, V = np.linalg.svd(A) | |||||
| # Eq. (40) and (43). | |||||
| rank = np.linalg.matrix_rank(A) | |||||
| if rank == 0: | |||||
| return np.nan * T | |||||
| elif rank == dim - 1: | |||||
| if np.linalg.det(U) * np.linalg.det(V) > 0: | |||||
| T[:dim, :dim] = U @ V | |||||
| else: | |||||
| s = d[dim - 1] | |||||
| d[dim - 1] = -1 | |||||
| T[:dim, :dim] = U @ np.diag(d) @ V | |||||
| d[dim - 1] = s | |||||
| else: | |||||
| T[:dim, :dim] = U @ np.diag(d) @ V | |||||
| if estimate_scale: | |||||
| # Eq. (41) and (42). | |||||
| scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) | |||||
| else: | |||||
| scale = scale | |||||
| T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) | |||||
| T[:dim, :dim] *= scale | |||||
| return T, scale | |||||
| class FaceWarpException(Exception): | |||||
| def __str__(self): | |||||
| return 'In File {}:{}'.format(__file__, super.__str__(self)) | |||||
| def get_reference_facial_points(output_size=None, | |||||
| inner_padding_factor=0.0, | |||||
| outer_padding=(0, 0), | |||||
| default_square=False): | |||||
| ref_5pts = np.array(REFERENCE_FACIAL_POINTS) | |||||
| ref_crop_size = np.array(DEFAULT_CROP_SIZE) | |||||
| # 0) make the inner region a square | |||||
| if default_square: | |||||
| size_diff = max(ref_crop_size) - ref_crop_size | |||||
| ref_5pts += size_diff / 2 | |||||
| ref_crop_size += size_diff | |||||
| if (output_size and output_size[0] == ref_crop_size[0] | |||||
| and output_size[1] == ref_crop_size[1]): | |||||
| return ref_5pts | |||||
| if (inner_padding_factor == 0 and outer_padding == (0, 0)): | |||||
| if output_size is None: | |||||
| logger.info('No paddings to do: return default reference points') | |||||
| return ref_5pts | |||||
| else: | |||||
| raise FaceWarpException( | |||||
| 'No paddings to do, output_size must be None or {}'.format( | |||||
| ref_crop_size)) | |||||
| # check output size | |||||
| if not (0 <= inner_padding_factor <= 1.0): | |||||
| raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') | |||||
| if ((inner_padding_factor > 0 or outer_padding[0] > 0 | |||||
| or outer_padding[1] > 0) and output_size is None): | |||||
| output_size = ref_crop_size * (1 + inner_padding_factor * 2).astype( | |||||
| np.int32) | |||||
| output_size += np.array(outer_padding) | |||||
| logger.info('deduced from paddings, output_size = ', output_size) | |||||
| if not (outer_padding[0] < output_size[0] | |||||
| and outer_padding[1] < output_size[1]): | |||||
| raise FaceWarpException('Not (outer_padding[0] < output_size[0]' | |||||
| 'and outer_padding[1] < output_size[1])') | |||||
| # 1) pad the inner region according inner_padding_factor | |||||
| if inner_padding_factor > 0: | |||||
| size_diff = ref_crop_size * inner_padding_factor * 2 | |||||
| ref_5pts += size_diff / 2 | |||||
| ref_crop_size += np.round(size_diff).astype(np.int32) | |||||
| # 2) resize the padded inner region | |||||
| size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 | |||||
| if size_bf_outer_pad[0] * ref_crop_size[1] != size_bf_outer_pad[ | |||||
| 1] * ref_crop_size[0]: | |||||
| raise FaceWarpException( | |||||
| 'Must have (output_size - outer_padding)' | |||||
| '= some_scale * (crop_size * (1.0 + inner_padding_factor)') | |||||
| scale_factor = size_bf_outer_pad[0].astype(np.float32) / ref_crop_size[0] | |||||
| ref_5pts = ref_5pts * scale_factor | |||||
| ref_crop_size = size_bf_outer_pad | |||||
| # 3) add outer_padding to make output_size | |||||
| reference_5point = ref_5pts + np.array(outer_padding) | |||||
| ref_crop_size = output_size | |||||
| return reference_5point | |||||
| def get_affine_transform_matrix(src_pts, dst_pts): | |||||
| tfm = np.float32([[1, 0, 0], [0, 1, 0]]) | |||||
| n_pts = src_pts.shape[0] | |||||
| ones = np.ones((n_pts, 1), src_pts.dtype) | |||||
| src_pts_ = np.hstack([src_pts, ones]) | |||||
| dst_pts_ = np.hstack([dst_pts, ones]) | |||||
| A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) | |||||
| if rank == 3: | |||||
| tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], | |||||
| [A[0, 1], A[1, 1], A[2, 1]]]) | |||||
| elif rank == 2: | |||||
| tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) | |||||
| return tfm | |||||
| def get_params(reference_pts, facial_pts, align_type): | |||||
| ref_pts = np.float32(reference_pts) | |||||
| ref_pts_shp = ref_pts.shape | |||||
| if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: | |||||
| raise FaceWarpException( | |||||
| 'reference_pts.shape must be (K,2) or (2,K) and K>2') | |||||
| if ref_pts_shp[0] == 2: | |||||
| ref_pts = ref_pts.T | |||||
| src_pts = np.float32(facial_pts) | |||||
| src_pts_shp = src_pts.shape | |||||
| if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: | |||||
| raise FaceWarpException( | |||||
| 'facial_pts.shape must be (K,2) or (2,K) and K>2') | |||||
| if src_pts_shp[0] == 2: | |||||
| src_pts = src_pts.T | |||||
| if src_pts.shape != ref_pts.shape: | |||||
| raise FaceWarpException( | |||||
| 'facial_pts and reference_pts must have the same shape') | |||||
| if align_type == 'cv2_affine': | |||||
| tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) | |||||
| tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) | |||||
| elif align_type == 'affine': | |||||
| tfm = get_affine_transform_matrix(src_pts, ref_pts) | |||||
| tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) | |||||
| else: | |||||
| params, scale = _umeyama(src_pts, ref_pts) | |||||
| tfm = params[:2, :] | |||||
| params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) | |||||
| tfm_inv = params[:2, :] | |||||
| return tfm, tfm_inv | |||||
| def warp_and_crop_face(src_img, | |||||
| facial_pts, | |||||
| reference_pts=None, | |||||
| crop_size=(96, 112), | |||||
| align_type='smilarity'): # smilarity cv2_affine affine | |||||
| reference_pts_112 = get_reference_facial_points((112, 112), 0.25, (0, 0), | |||||
| True) | |||||
| if reference_pts is None: | |||||
| if crop_size[0] == 96 and crop_size[1] == 112: | |||||
| reference_pts = REFERENCE_FACIAL_POINTS | |||||
| else: | |||||
| default_square = True # False | |||||
| inner_padding_factor = 0.25 # 0 | |||||
| outer_padding = (0, 0) | |||||
| output_size = crop_size | |||||
| reference_pts = get_reference_facial_points( | |||||
| output_size, inner_padding_factor, outer_padding, | |||||
| default_square) | |||||
| tfm, tfm_inv = get_params(reference_pts, facial_pts, align_type) | |||||
| tfm_112, tfm_inv_112 = get_params(reference_pts_112, facial_pts, | |||||
| align_type) | |||||
| if src_img is not None: | |||||
| face_img = cv2.warpAffine( | |||||
| src_img, tfm, (crop_size[0], crop_size[1]), flags=3) | |||||
| face_img_112 = cv2.warpAffine(src_img, tfm_112, (112, 112), flags=3) | |||||
| return face_img, face_img_112, tfm_inv | |||||
| else: | |||||
| return tfm, tfm_inv | |||||
| @@ -0,0 +1,57 @@ | |||||
| import os | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from .model_resnet import FaceQuality, ResNet | |||||
| class FQA(object): | |||||
| def __init__(self, backbone_path, quality_path, device='cuda', size=112): | |||||
| self.BACKBONE = ResNet(num_layers=100, feature_dim=512) | |||||
| self.QUALITY = FaceQuality(512 * 7 * 7) | |||||
| self.size = size | |||||
| self.device = device | |||||
| self.load_model(backbone_path, quality_path) | |||||
| def load_model(self, backbone_path, quality_path): | |||||
| checkpoint = torch.load(backbone_path, map_location='cpu') | |||||
| self.load_state_dict(self.BACKBONE, checkpoint) | |||||
| checkpoint = torch.load(quality_path, map_location='cpu') | |||||
| self.load_state_dict(self.QUALITY, checkpoint) | |||||
| self.BACKBONE.to(self.device) | |||||
| self.QUALITY.to(self.device) | |||||
| self.BACKBONE.eval() | |||||
| self.QUALITY.eval() | |||||
| def load_state_dict(self, model, state_dict): | |||||
| all_keys = {k for k in state_dict.keys()} | |||||
| for k in all_keys: | |||||
| if k.startswith('module.'): | |||||
| state_dict[k[7:]] = state_dict.pop(k) | |||||
| model_dict = model.state_dict() | |||||
| pretrained_dict = { | |||||
| k: v | |||||
| for k, v in state_dict.items() | |||||
| if k in model_dict and v.size() == model_dict[k].size() | |||||
| } | |||||
| model_dict.update(pretrained_dict) | |||||
| model.load_state_dict(model_dict) | |||||
| def get_face_quality(self, img): | |||||
| img = torch.from_numpy(img).permute(2, 0, | |||||
| 1).unsqueeze(0).flip(1).cuda() | |||||
| img = (img - 127.5) / 128.0 | |||||
| # extract features & predict quality | |||||
| with torch.no_grad(): | |||||
| feature, fc = self.BACKBONE(img.to(self.device), True) | |||||
| s = self.QUALITY(fc)[0] | |||||
| return s.cpu().numpy()[0], feature.cpu().numpy()[0] | |||||
| @@ -0,0 +1,130 @@ | |||||
| import torch | |||||
| from torch import nn | |||||
| class BottleNeck_IR(nn.Module): | |||||
| def __init__(self, in_channel, out_channel, stride, dim_match): | |||||
| super(BottleNeck_IR, self).__init__() | |||||
| self.res_layer = nn.Sequential( | |||||
| nn.BatchNorm2d(in_channel), | |||||
| nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), | |||||
| nn.BatchNorm2d(out_channel), nn.PReLU(out_channel), | |||||
| nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), | |||||
| nn.BatchNorm2d(out_channel)) | |||||
| if dim_match: | |||||
| self.shortcut_layer = None | |||||
| else: | |||||
| self.shortcut_layer = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| in_channel, | |||||
| out_channel, | |||||
| kernel_size=(1, 1), | |||||
| stride=stride, | |||||
| bias=False), nn.BatchNorm2d(out_channel)) | |||||
| def forward(self, x): | |||||
| shortcut = x | |||||
| res = self.res_layer(x) | |||||
| if self.shortcut_layer is not None: | |||||
| shortcut = self.shortcut_layer(x) | |||||
| return shortcut + res | |||||
| channel_list = [64, 64, 128, 256, 512] | |||||
| def get_layers(num_layers): | |||||
| if num_layers == 34: | |||||
| return [3, 4, 6, 3] | |||||
| if num_layers == 50: | |||||
| return [3, 4, 14, 3] | |||||
| elif num_layers == 100: | |||||
| return [3, 13, 30, 3] | |||||
| elif num_layers == 152: | |||||
| return [3, 8, 36, 3] | |||||
| class ResNet(nn.Module): | |||||
| def __init__(self, | |||||
| num_layers=100, | |||||
| feature_dim=512, | |||||
| drop_ratio=0.4, | |||||
| channel_list=channel_list): | |||||
| super(ResNet, self).__init__() | |||||
| assert num_layers in [34, 50, 100, 152] | |||||
| layers = get_layers(num_layers) | |||||
| block = BottleNeck_IR | |||||
| self.input_layer = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| 3, channel_list[0], (3, 3), stride=1, padding=1, bias=False), | |||||
| nn.BatchNorm2d(channel_list[0]), nn.PReLU(channel_list[0])) | |||||
| self.layer1 = self._make_layer( | |||||
| block, channel_list[0], channel_list[1], layers[0], stride=2) | |||||
| self.layer2 = self._make_layer( | |||||
| block, channel_list[1], channel_list[2], layers[1], stride=2) | |||||
| self.layer3 = self._make_layer( | |||||
| block, channel_list[2], channel_list[3], layers[2], stride=2) | |||||
| self.layer4 = self._make_layer( | |||||
| block, channel_list[3], channel_list[4], layers[3], stride=2) | |||||
| self.output_layer = nn.Sequential( | |||||
| nn.BatchNorm2d(512), nn.Dropout(drop_ratio), nn.Flatten()) | |||||
| self.feature_layer = nn.Sequential( | |||||
| nn.Linear(512 * 7 * 7, feature_dim), nn.BatchNorm1d(feature_dim)) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |||||
| nn.init.xavier_uniform_(m.weight) | |||||
| if m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0.0) | |||||
| elif isinstance(m, nn.BatchNorm2d) or isinstance( | |||||
| m, nn.BatchNorm1d): | |||||
| nn.init.constant_(m.weight, 1) | |||||
| nn.init.constant_(m.bias, 0) | |||||
| def _make_layer(self, block, in_channel, out_channel, blocks, stride): | |||||
| layers = [] | |||||
| layers.append(block(in_channel, out_channel, stride, False)) | |||||
| for i in range(1, blocks): | |||||
| layers.append(block(out_channel, out_channel, 1, True)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x, fc=False): | |||||
| x = self.input_layer(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| x = self.layer4(x) | |||||
| x = self.output_layer(x) | |||||
| feature = self.feature_layer(x) | |||||
| if fc: | |||||
| return feature, x | |||||
| return feature | |||||
| class FaceQuality(nn.Module): | |||||
| def __init__(self, feature_dim): | |||||
| super(FaceQuality, self).__init__() | |||||
| self.qualtiy = nn.Sequential( | |||||
| nn.Linear(feature_dim, 512, bias=False), nn.BatchNorm1d(512), | |||||
| nn.ReLU(inplace=True), nn.Linear(512, 2, bias=False), | |||||
| nn.Softmax(dim=1)) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |||||
| nn.init.xavier_uniform_(m.weight) | |||||
| if m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0.0) | |||||
| elif isinstance(m, nn.BatchNorm2d) or isinstance( | |||||
| m, nn.BatchNorm1d): | |||||
| nn.init.constant_(m.weight, 1) | |||||
| nn.init.constant_(m.bias, 0) | |||||
| def forward(self, x): | |||||
| x = self.qualtiy(x) | |||||
| return x[:, 0:1] | |||||
| @@ -0,0 +1,813 @@ | |||||
| import functools | |||||
| import itertools | |||||
| import math | |||||
| import operator | |||||
| import random | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.autograd import Function | |||||
| from torch.nn import functional as F | |||||
| from modelscope.models.cv.face_generation.op import (FusedLeakyReLU, | |||||
| fused_leaky_relu, | |||||
| upfirdn2d) | |||||
| class PixelNorm(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| def forward(self, input): | |||||
| return input * torch.rsqrt( | |||||
| torch.mean(input**2, dim=1, keepdim=True) + 1e-8) | |||||
| def make_kernel(k): | |||||
| k = torch.tensor(k, dtype=torch.float32) | |||||
| if k.ndim == 1: | |||||
| k = k[None, :] * k[:, None] | |||||
| k /= k.sum() | |||||
| return k | |||||
| class Upsample(nn.Module): | |||||
| def __init__(self, kernel, factor=2): | |||||
| super().__init__() | |||||
| self.factor = factor | |||||
| kernel = make_kernel(kernel) * (factor**2) | |||||
| self.register_buffer('kernel', kernel) | |||||
| p = kernel.shape[0] - factor | |||||
| pad0 = (p + 1) // 2 + factor - 1 | |||||
| pad1 = p // 2 | |||||
| self.pad = (pad0, pad1) | |||||
| def forward(self, input): | |||||
| out = upfirdn2d( | |||||
| input, self.kernel, up=self.factor, down=1, pad=self.pad) | |||||
| return out | |||||
| class Downsample(nn.Module): | |||||
| def __init__(self, kernel, factor=2): | |||||
| super().__init__() | |||||
| self.factor = factor | |||||
| kernel = make_kernel(kernel) | |||||
| self.register_buffer('kernel', kernel) | |||||
| p = kernel.shape[0] - factor | |||||
| pad0 = (p + 1) // 2 | |||||
| pad1 = p // 2 | |||||
| self.pad = (pad0, pad1) | |||||
| def forward(self, input): | |||||
| out = upfirdn2d( | |||||
| input, self.kernel, up=1, down=self.factor, pad=self.pad) | |||||
| return out | |||||
| class Blur(nn.Module): | |||||
| def __init__(self, kernel, pad, upsample_factor=1): | |||||
| super().__init__() | |||||
| kernel = make_kernel(kernel) | |||||
| if upsample_factor > 1: | |||||
| kernel = kernel * (upsample_factor**2) | |||||
| self.register_buffer('kernel', kernel) | |||||
| self.pad = pad | |||||
| def forward(self, input): | |||||
| out = upfirdn2d(input, self.kernel, pad=self.pad) | |||||
| return out | |||||
| class EqualConv2d(nn.Module): | |||||
| def __init__(self, | |||||
| in_channel, | |||||
| out_channel, | |||||
| kernel_size, | |||||
| stride=1, | |||||
| padding=0, | |||||
| bias=True): | |||||
| super().__init__() | |||||
| self.weight = nn.Parameter( | |||||
| torch.randn(out_channel, in_channel, kernel_size, kernel_size)) | |||||
| self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |||||
| self.stride = stride | |||||
| self.padding = padding | |||||
| if bias: | |||||
| self.bias = nn.Parameter(torch.zeros(out_channel)) | |||||
| else: | |||||
| self.bias = None | |||||
| def forward(self, input): | |||||
| out = F.conv2d( | |||||
| input, | |||||
| self.weight * self.scale, | |||||
| bias=self.bias, | |||||
| stride=self.stride, | |||||
| padding=self.padding, | |||||
| ) | |||||
| return out | |||||
| def __repr__(self): | |||||
| return ( | |||||
| f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' | |||||
| f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' | |||||
| ) | |||||
| class EqualLinear(nn.Module): | |||||
| def __init__(self, | |||||
| in_dim, | |||||
| out_dim, | |||||
| bias=True, | |||||
| bias_init=0, | |||||
| lr_mul=1, | |||||
| activation=None): | |||||
| super().__init__() | |||||
| self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) | |||||
| if bias: | |||||
| self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) | |||||
| else: | |||||
| self.bias = None | |||||
| self.activation = activation | |||||
| self.scale = (1 / math.sqrt(in_dim)) * lr_mul | |||||
| self.lr_mul = lr_mul | |||||
| def forward(self, input): | |||||
| if self.activation: | |||||
| out = F.linear(input, self.weight * self.scale) | |||||
| out = fused_leaky_relu(out, self.bias * self.lr_mul) | |||||
| else: | |||||
| out = F.linear( | |||||
| input, self.weight * self.scale, bias=self.bias * self.lr_mul) | |||||
| return out | |||||
| def __repr__(self): | |||||
| return ( | |||||
| f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' | |||||
| ) | |||||
| class ScaledLeakyReLU(nn.Module): | |||||
| def __init__(self, negative_slope=0.2): | |||||
| super().__init__() | |||||
| self.negative_slope = negative_slope | |||||
| def forward(self, input): | |||||
| out = F.leaky_relu(input, negative_slope=self.negative_slope) | |||||
| return out * math.sqrt(2) | |||||
| class ModulatedConv2d(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channel, | |||||
| out_channel, | |||||
| kernel_size, | |||||
| style_dim, | |||||
| demodulate=True, | |||||
| upsample=False, | |||||
| downsample=False, | |||||
| blur_kernel=[1, 3, 3, 1], | |||||
| ): | |||||
| super().__init__() | |||||
| self.eps = 1e-8 | |||||
| self.kernel_size = kernel_size | |||||
| self.in_channel = in_channel | |||||
| self.out_channel = out_channel | |||||
| self.upsample = upsample | |||||
| self.downsample = downsample | |||||
| if upsample: | |||||
| factor = 2 | |||||
| p = (len(blur_kernel) - factor) - (kernel_size - 1) | |||||
| pad0 = (p + 1) // 2 + factor - 1 | |||||
| pad1 = p // 2 + 1 | |||||
| self.blur = Blur( | |||||
| blur_kernel, pad=(pad0, pad1), upsample_factor=factor) | |||||
| if downsample: | |||||
| factor = 2 | |||||
| p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||||
| pad0 = (p + 1) // 2 | |||||
| pad1 = p // 2 | |||||
| self.blur = Blur(blur_kernel, pad=(pad0, pad1)) | |||||
| fan_in = in_channel * kernel_size**2 | |||||
| self.scale = 1 / math.sqrt(fan_in) | |||||
| self.padding = kernel_size // 2 | |||||
| self.weight = nn.Parameter( | |||||
| torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) | |||||
| self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) | |||||
| self.demodulate = demodulate | |||||
| def __repr__(self): | |||||
| return ( | |||||
| f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' | |||||
| f'upsample={self.upsample}, downsample={self.downsample})') | |||||
| def forward(self, input, style): | |||||
| batch, in_channel, height, width = input.shape | |||||
| style = self.modulation(style).view(batch, 1, in_channel, 1, 1) | |||||
| weight = self.scale * self.weight * style | |||||
| if self.demodulate: | |||||
| demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) | |||||
| weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) | |||||
| weight = weight.view(batch * self.out_channel, in_channel, | |||||
| self.kernel_size, self.kernel_size) | |||||
| if self.upsample: | |||||
| input = input.view(1, batch * in_channel, height, width) | |||||
| weight = weight.view(batch, self.out_channel, in_channel, | |||||
| self.kernel_size, self.kernel_size) | |||||
| weight = weight.transpose(1, 2).reshape(batch * in_channel, | |||||
| self.out_channel, | |||||
| self.kernel_size, | |||||
| self.kernel_size) | |||||
| out = F.conv_transpose2d( | |||||
| input, weight, padding=0, stride=2, groups=batch) | |||||
| _, _, height, width = out.shape | |||||
| out = out.view(batch, self.out_channel, height, width) | |||||
| out = self.blur(out) | |||||
| elif self.downsample: | |||||
| input = self.blur(input) | |||||
| _, _, height, width = input.shape | |||||
| input = input.view(1, batch * in_channel, height, width) | |||||
| out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) | |||||
| _, _, height, width = out.shape | |||||
| out = out.view(batch, self.out_channel, height, width) | |||||
| else: | |||||
| input = input.view(1, batch * in_channel, height, width) | |||||
| out = F.conv2d(input, weight, padding=self.padding, groups=batch) | |||||
| _, _, height, width = out.shape | |||||
| out = out.view(batch, self.out_channel, height, width) | |||||
| return out | |||||
| class NoiseInjection(nn.Module): | |||||
| def __init__(self, isconcat=True): | |||||
| super().__init__() | |||||
| self.isconcat = isconcat | |||||
| self.weight = nn.Parameter(torch.zeros(1)) | |||||
| def forward(self, image, noise=None): | |||||
| if noise is None: | |||||
| batch, channel, height, width = image.shape | |||||
| noise = image.new_empty(batch, channel, height, width).normal_() | |||||
| if self.isconcat: | |||||
| return torch.cat((image, self.weight * noise), dim=1) | |||||
| else: | |||||
| return image + self.weight * noise | |||||
| class ConstantInput(nn.Module): | |||||
| def __init__(self, channel, size=4): | |||||
| super().__init__() | |||||
| self.input = nn.Parameter(torch.randn(1, channel, size, size)) | |||||
| def forward(self, input): | |||||
| batch = input.shape[0] | |||||
| out = self.input.repeat(batch, 1, 1, 1) | |||||
| return out | |||||
| class StyledConv(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channel, | |||||
| out_channel, | |||||
| kernel_size, | |||||
| style_dim, | |||||
| upsample=False, | |||||
| blur_kernel=[1, 3, 3, 1], | |||||
| demodulate=True, | |||||
| isconcat=True, | |||||
| ): | |||||
| super().__init__() | |||||
| self.conv = ModulatedConv2d( | |||||
| in_channel, | |||||
| out_channel, | |||||
| kernel_size, | |||||
| style_dim, | |||||
| upsample=upsample, | |||||
| blur_kernel=blur_kernel, | |||||
| demodulate=demodulate, | |||||
| ) | |||||
| self.noise = NoiseInjection(isconcat) | |||||
| # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) | |||||
| # self.activate = ScaledLeakyReLU(0.2) | |||||
| feat_multiplier = 2 if isconcat else 1 | |||||
| self.activate = FusedLeakyReLU(out_channel * feat_multiplier) | |||||
| def forward(self, input, style, noise=None): | |||||
| out = self.conv(input, style) | |||||
| out = self.noise(out, noise=noise) | |||||
| # out = out + self.bias | |||||
| out = self.activate(out) | |||||
| return out | |||||
| class ToRGB(nn.Module): | |||||
| def __init__(self, | |||||
| in_channel, | |||||
| style_dim, | |||||
| upsample=True, | |||||
| blur_kernel=[1, 3, 3, 1]): | |||||
| super().__init__() | |||||
| if upsample: | |||||
| self.upsample = Upsample(blur_kernel) | |||||
| self.conv = ModulatedConv2d( | |||||
| in_channel, 3, 1, style_dim, demodulate=False) | |||||
| self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) | |||||
| def forward(self, input, style, skip=None): | |||||
| out = self.conv(input, style) | |||||
| out = out + self.bias | |||||
| if skip is not None: | |||||
| skip = self.upsample(skip) | |||||
| out = out + skip | |||||
| return out | |||||
| class Generator(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| size, | |||||
| style_dim, | |||||
| n_mlp, | |||||
| channel_multiplier=2, | |||||
| blur_kernel=[1, 3, 3, 1], | |||||
| lr_mlp=0.01, | |||||
| isconcat=True, | |||||
| narrow=1, | |||||
| ): | |||||
| super().__init__() | |||||
| self.size = size | |||||
| self.n_mlp = n_mlp | |||||
| self.style_dim = style_dim | |||||
| self.feat_multiplier = 2 if isconcat else 1 | |||||
| layers = [PixelNorm()] | |||||
| for i in range(n_mlp): | |||||
| layers.append( | |||||
| EqualLinear( | |||||
| style_dim, | |||||
| style_dim, | |||||
| lr_mul=lr_mlp, | |||||
| activation='fused_lrelu')) | |||||
| self.style = nn.Sequential(*layers) | |||||
| self.channels = { | |||||
| 4: int(512 * narrow), | |||||
| 8: int(512 * narrow), | |||||
| 16: int(512 * narrow), | |||||
| 32: int(512 * narrow), | |||||
| 64: int(256 * channel_multiplier * narrow), | |||||
| 128: int(128 * channel_multiplier * narrow), | |||||
| 256: int(64 * channel_multiplier * narrow), | |||||
| 512: int(32 * channel_multiplier * narrow), | |||||
| 1024: int(16 * channel_multiplier * narrow), | |||||
| 2048: int(8 * channel_multiplier * narrow) | |||||
| } | |||||
| self.input = ConstantInput(self.channels[4]) | |||||
| self.conv1 = StyledConv( | |||||
| self.channels[4], | |||||
| self.channels[4], | |||||
| 3, | |||||
| style_dim, | |||||
| blur_kernel=blur_kernel, | |||||
| isconcat=isconcat) | |||||
| self.to_rgb1 = ToRGB( | |||||
| self.channels[4] * self.feat_multiplier, style_dim, upsample=False) | |||||
| self.log_size = int(math.log(size, 2)) | |||||
| self.convs = nn.ModuleList() | |||||
| self.upsamples = nn.ModuleList() | |||||
| self.to_rgbs = nn.ModuleList() | |||||
| in_channel = self.channels[4] | |||||
| for i in range(3, self.log_size + 1): | |||||
| out_channel = self.channels[2**i] | |||||
| self.convs.append( | |||||
| StyledConv( | |||||
| in_channel * self.feat_multiplier, | |||||
| out_channel, | |||||
| 3, | |||||
| style_dim, | |||||
| upsample=True, | |||||
| blur_kernel=blur_kernel, | |||||
| isconcat=isconcat, | |||||
| )) | |||||
| self.convs.append( | |||||
| StyledConv( | |||||
| out_channel * self.feat_multiplier, | |||||
| out_channel, | |||||
| 3, | |||||
| style_dim, | |||||
| blur_kernel=blur_kernel, | |||||
| isconcat=isconcat)) | |||||
| self.to_rgbs.append( | |||||
| ToRGB(out_channel * self.feat_multiplier, style_dim)) | |||||
| in_channel = out_channel | |||||
| self.n_latent = self.log_size * 2 - 2 | |||||
| def make_noise(self): | |||||
| device = self.input.input.device | |||||
| noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] | |||||
| for i in range(3, self.log_size + 1): | |||||
| for _ in range(2): | |||||
| noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) | |||||
| return noises | |||||
| def mean_latent(self, n_latent): | |||||
| latent_in = torch.randn( | |||||
| n_latent, self.style_dim, device=self.input.input.device) | |||||
| latent = self.style(latent_in).mean(0, keepdim=True) | |||||
| return latent | |||||
| def get_latent(self, input): | |||||
| return self.style(input) | |||||
| def forward( | |||||
| self, | |||||
| styles, | |||||
| return_latents=False, | |||||
| inject_index=None, | |||||
| truncation=1, | |||||
| truncation_latent=None, | |||||
| input_is_latent=False, | |||||
| noise=None, | |||||
| ): | |||||
| if not input_is_latent: | |||||
| styles = [self.style(s) for s in styles] | |||||
| if noise is None: | |||||
| ''' | |||||
| noise = [None] * (2 * (self.log_size - 2) + 1) | |||||
| ''' | |||||
| noise = [] | |||||
| batch = styles[0].shape[0] | |||||
| for i in range(self.n_mlp + 1): | |||||
| size = 2**(i + 2) | |||||
| noise.append( | |||||
| torch.randn( | |||||
| batch, | |||||
| self.channels[size], | |||||
| size, | |||||
| size, | |||||
| device=styles[0].device)) | |||||
| if truncation < 1: | |||||
| style_t = [] | |||||
| for style in styles: | |||||
| style_t.append(truncation_latent | |||||
| + truncation * (style - truncation_latent)) | |||||
| styles = style_t | |||||
| if len(styles) < 2: | |||||
| inject_index = self.n_latent | |||||
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||||
| else: | |||||
| if inject_index is None: | |||||
| inject_index = random.randint(1, self.n_latent - 1) | |||||
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||||
| latent2 = styles[1].unsqueeze(1).repeat( | |||||
| 1, self.n_latent - inject_index, 1) | |||||
| latent = torch.cat([latent, latent2], 1) | |||||
| out = self.input(latent) | |||||
| out = self.conv1(out, latent[:, 0], noise=noise[0]) | |||||
| skip = self.to_rgb1(out, latent[:, 1]) | |||||
| i = 1 | |||||
| for conv1, conv2, noise1, noise2, to_rgb in zip( | |||||
| self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], | |||||
| self.to_rgbs): | |||||
| out = conv1(out, latent[:, i], noise=noise1) | |||||
| out = conv2(out, latent[:, i + 1], noise=noise2) | |||||
| skip = to_rgb(out, latent[:, i + 2], skip) | |||||
| i += 2 | |||||
| image = skip | |||||
| if return_latents: | |||||
| return image, latent | |||||
| else: | |||||
| return image, None | |||||
| class ConvLayer(nn.Sequential): | |||||
| def __init__( | |||||
| self, | |||||
| in_channel, | |||||
| out_channel, | |||||
| kernel_size, | |||||
| downsample=False, | |||||
| blur_kernel=[1, 3, 3, 1], | |||||
| bias=True, | |||||
| activate=True, | |||||
| ): | |||||
| layers = [] | |||||
| if downsample: | |||||
| factor = 2 | |||||
| p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||||
| pad0 = (p + 1) // 2 | |||||
| pad1 = p // 2 | |||||
| layers.append(Blur(blur_kernel, pad=(pad0, pad1))) | |||||
| stride = 2 | |||||
| self.padding = 0 | |||||
| else: | |||||
| stride = 1 | |||||
| self.padding = kernel_size // 2 | |||||
| layers.append( | |||||
| EqualConv2d( | |||||
| in_channel, | |||||
| out_channel, | |||||
| kernel_size, | |||||
| padding=self.padding, | |||||
| stride=stride, | |||||
| bias=bias and not activate, | |||||
| )) | |||||
| if activate: | |||||
| if bias: | |||||
| layers.append(FusedLeakyReLU(out_channel)) | |||||
| else: | |||||
| layers.append(ScaledLeakyReLU(0.2)) | |||||
| super().__init__(*layers) | |||||
| class ResBlock(nn.Module): | |||||
| def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): | |||||
| super().__init__() | |||||
| self.conv1 = ConvLayer(in_channel, in_channel, 3) | |||||
| self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) | |||||
| self.skip = ConvLayer( | |||||
| in_channel, | |||||
| out_channel, | |||||
| 1, | |||||
| downsample=True, | |||||
| activate=False, | |||||
| bias=False) | |||||
| def forward(self, input): | |||||
| out = self.conv1(input) | |||||
| out = self.conv2(out) | |||||
| skip = self.skip(input) | |||||
| out = (out + skip) / math.sqrt(2) | |||||
| return out | |||||
| class FullGenerator(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| size, | |||||
| style_dim, | |||||
| n_mlp, | |||||
| channel_multiplier=2, | |||||
| blur_kernel=[1, 3, 3, 1], | |||||
| lr_mlp=0.01, | |||||
| isconcat=True, | |||||
| narrow=1, | |||||
| ): | |||||
| super().__init__() | |||||
| channels = { | |||||
| 4: int(512 * narrow), | |||||
| 8: int(512 * narrow), | |||||
| 16: int(512 * narrow), | |||||
| 32: int(512 * narrow), | |||||
| 64: int(256 * channel_multiplier * narrow), | |||||
| 128: int(128 * channel_multiplier * narrow), | |||||
| 256: int(64 * channel_multiplier * narrow), | |||||
| 512: int(32 * channel_multiplier * narrow), | |||||
| 1024: int(16 * channel_multiplier * narrow), | |||||
| 2048: int(8 * channel_multiplier * narrow) | |||||
| } | |||||
| self.log_size = int(math.log(size, 2)) | |||||
| self.generator = Generator( | |||||
| size, | |||||
| style_dim, | |||||
| n_mlp, | |||||
| channel_multiplier=channel_multiplier, | |||||
| blur_kernel=blur_kernel, | |||||
| lr_mlp=lr_mlp, | |||||
| isconcat=isconcat, | |||||
| narrow=narrow) | |||||
| conv = [ConvLayer(3, channels[size], 1)] | |||||
| self.ecd0 = nn.Sequential(*conv) | |||||
| in_channel = channels[size] | |||||
| self.names = ['ecd%d' % i for i in range(self.log_size - 1)] | |||||
| for i in range(self.log_size, 2, -1): | |||||
| out_channel = channels[2**(i - 1)] | |||||
| # conv = [ResBlock(in_channel, out_channel, blur_kernel)] | |||||
| conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] | |||||
| setattr(self, self.names[self.log_size - i + 1], | |||||
| nn.Sequential(*conv)) | |||||
| in_channel = out_channel | |||||
| self.final_linear = nn.Sequential( | |||||
| EqualLinear( | |||||
| channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) | |||||
| def forward( | |||||
| self, | |||||
| inputs, | |||||
| return_latents=False, | |||||
| inject_index=None, | |||||
| truncation=1, | |||||
| truncation_latent=None, | |||||
| input_is_latent=False, | |||||
| ): | |||||
| noise = [] | |||||
| for i in range(self.log_size - 1): | |||||
| ecd = getattr(self, self.names[i]) | |||||
| inputs = ecd(inputs) | |||||
| noise.append(inputs) | |||||
| inputs = inputs.view(inputs.shape[0], -1) | |||||
| outs = self.final_linear(inputs) | |||||
| noise = list( | |||||
| itertools.chain.from_iterable( | |||||
| itertools.repeat(x, 2) for x in noise))[::-1] | |||||
| outs = self.generator([outs], | |||||
| return_latents, | |||||
| inject_index, | |||||
| truncation, | |||||
| truncation_latent, | |||||
| input_is_latent, | |||||
| noise=noise[1:]) | |||||
| return outs | |||||
| class Discriminator(nn.Module): | |||||
| def __init__(self, | |||||
| size, | |||||
| channel_multiplier=2, | |||||
| blur_kernel=[1, 3, 3, 1], | |||||
| narrow=1): | |||||
| super().__init__() | |||||
| channels = { | |||||
| 4: int(512 * narrow), | |||||
| 8: int(512 * narrow), | |||||
| 16: int(512 * narrow), | |||||
| 32: int(512 * narrow), | |||||
| 64: int(256 * channel_multiplier * narrow), | |||||
| 128: int(128 * channel_multiplier * narrow), | |||||
| 256: int(64 * channel_multiplier * narrow), | |||||
| 512: int(32 * channel_multiplier * narrow), | |||||
| 1024: int(16 * channel_multiplier * narrow), | |||||
| 2048: int(8 * channel_multiplier * narrow) | |||||
| } | |||||
| convs = [ConvLayer(3, channels[size], 1)] | |||||
| log_size = int(math.log(size, 2)) | |||||
| in_channel = channels[size] | |||||
| for i in range(log_size, 2, -1): | |||||
| out_channel = channels[2**(i - 1)] | |||||
| convs.append(ResBlock(in_channel, out_channel, blur_kernel)) | |||||
| in_channel = out_channel | |||||
| self.convs = nn.Sequential(*convs) | |||||
| self.stddev_group = 4 | |||||
| self.stddev_feat = 1 | |||||
| self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) | |||||
| self.final_linear = nn.Sequential( | |||||
| EqualLinear( | |||||
| channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), | |||||
| EqualLinear(channels[4], 1), | |||||
| ) | |||||
| def forward(self, input): | |||||
| out = self.convs(input) | |||||
| batch, channel, height, width = out.shape | |||||
| group = min(batch, self.stddev_group) | |||||
| stddev = out.view(group, -1, self.stddev_feat, | |||||
| channel // self.stddev_feat, height, width) | |||||
| stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) | |||||
| stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) | |||||
| stddev = stddev.repeat(group, 1, height, width) | |||||
| out = torch.cat([out, stddev], 1) | |||||
| out = self.final_conv(out) | |||||
| out = out.view(batch, -1) | |||||
| out = self.final_linear(out) | |||||
| return out | |||||
| @@ -0,0 +1,205 @@ | |||||
| import math | |||||
| import os.path as osp | |||||
| from copy import deepcopy | |||||
| from typing import Any, Dict, List, Union | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from torch import autograd, nn | |||||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .gpen import Discriminator, FullGenerator | |||||
| from .losses.losses import IDLoss, L1Loss | |||||
| logger = get_logger() | |||||
| __all__ = ['ImagePortraitEnhancement'] | |||||
| @MODELS.register_module( | |||||
| Tasks.image_portrait_enhancement, module_name=Models.gpen) | |||||
| class ImagePortraitEnhancement(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the face enhancement model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.size = 512 | |||||
| self.style_dim = 512 | |||||
| self.n_mlp = 8 | |||||
| self.mean_path_length = 0 | |||||
| self.accum = 0.5**(32 / (10 * 1000)) | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self.l1_loss = L1Loss() | |||||
| self.id_loss = IDLoss(f'{model_dir}/arcface/model_ir_se50.pth', | |||||
| self._device) | |||||
| self.generator = FullGenerator( | |||||
| self.size, self.style_dim, self.n_mlp, | |||||
| isconcat=True).to(self._device) | |||||
| self.g_ema = FullGenerator( | |||||
| self.size, self.style_dim, self.n_mlp, | |||||
| isconcat=True).to(self._device) | |||||
| self.discriminator = Discriminator(self.size).to(self._device) | |||||
| if self.size == 512: | |||||
| self.load_pretrained(model_dir) | |||||
| def load_pretrained(self, model_dir): | |||||
| g_path = f'{model_dir}/{ModelFile.TORCH_MODEL_FILE}' | |||||
| g_dict = torch.load(g_path, map_location=torch.device('cpu')) | |||||
| self.generator.load_state_dict(g_dict) | |||||
| self.g_ema.load_state_dict(g_dict) | |||||
| d_path = f'{model_dir}/net_d.pt' | |||||
| d_dict = torch.load(d_path, map_location=torch.device('cpu')) | |||||
| self.discriminator.load_state_dict(d_dict) | |||||
| logger.info('load model done.') | |||||
| def accumulate(self): | |||||
| par1 = dict(self.g_ema.named_parameters()) | |||||
| par2 = dict(self.generator.named_parameters()) | |||||
| for k in par1.keys(): | |||||
| par1[k].data.mul_(self.accum).add_(1 - self.accum, par2[k].data) | |||||
| def requires_grad(self, model, flag=True): | |||||
| for p in model.parameters(): | |||||
| p.requires_grad = flag | |||||
| def d_logistic_loss(self, real_pred, fake_pred): | |||||
| real_loss = F.softplus(-real_pred) | |||||
| fake_loss = F.softplus(fake_pred) | |||||
| return real_loss.mean() + fake_loss.mean() | |||||
| def d_r1_loss(self, real_pred, real_img): | |||||
| grad_real, = autograd.grad( | |||||
| outputs=real_pred.sum(), inputs=real_img, create_graph=True) | |||||
| grad_penalty = grad_real.pow(2).view(grad_real.shape[0], | |||||
| -1).sum(1).mean() | |||||
| return grad_penalty | |||||
| def g_nonsaturating_loss(self, | |||||
| fake_pred, | |||||
| fake_img=None, | |||||
| real_img=None, | |||||
| input_img=None): | |||||
| loss = F.softplus(-fake_pred).mean() | |||||
| loss_l1 = self.l1_loss(fake_img, real_img) | |||||
| loss_id, __, __ = self.id_loss(fake_img, real_img, input_img) | |||||
| loss_id = 0 | |||||
| loss += 1.0 * loss_l1 + 1.0 * loss_id | |||||
| return loss | |||||
| def g_path_regularize(self, | |||||
| fake_img, | |||||
| latents, | |||||
| mean_path_length, | |||||
| decay=0.01): | |||||
| noise = torch.randn_like(fake_img) / math.sqrt( | |||||
| fake_img.shape[2] * fake_img.shape[3]) | |||||
| grad, = autograd.grad( | |||||
| outputs=(fake_img * noise).sum(), | |||||
| inputs=latents, | |||||
| create_graph=True) | |||||
| path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) | |||||
| path_mean = mean_path_length + decay * ( | |||||
| path_lengths.mean() - mean_path_length) | |||||
| path_penalty = (path_lengths - path_mean).pow(2).mean() | |||||
| return path_penalty, path_mean.detach(), path_lengths | |||||
| @torch.no_grad() | |||||
| def _evaluate_postprocess(self, src: Tensor, | |||||
| target: Tensor) -> Dict[str, list]: | |||||
| preds, _ = self.generator(src) | |||||
| preds = list(torch.split(preds, 1, 0)) | |||||
| targets = list(torch.split(target, 1, 0)) | |||||
| preds = [((pred.data * 0.5 + 0.5) * 255.).squeeze(0).type( | |||||
| torch.uint8).permute(1, 2, 0).cpu().numpy() for pred in preds] | |||||
| targets = [((target.data * 0.5 + 0.5) * 255.).squeeze(0).type( | |||||
| torch.uint8).permute(1, 2, 0).cpu().numpy() for target in targets] | |||||
| return {'pred': preds, 'target': targets} | |||||
| def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| self.requires_grad(self.generator, False) | |||||
| self.requires_grad(self.discriminator, True) | |||||
| preds, _ = self.generator(src) | |||||
| fake_pred = self.discriminator(preds) | |||||
| real_pred = self.discriminator(target) | |||||
| d_loss = self.d_logistic_loss(real_pred, fake_pred) | |||||
| return d_loss | |||||
| def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| src.requires_grad = True | |||||
| target.requires_grad = True | |||||
| real_pred = self.discriminator(target) | |||||
| r1_loss = self.d_r1_loss(real_pred, target) | |||||
| return r1_loss | |||||
| def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| self.requires_grad(self.generator, True) | |||||
| self.requires_grad(self.discriminator, False) | |||||
| preds, _ = self.generator(src) | |||||
| fake_pred = self.discriminator(preds) | |||||
| g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) | |||||
| return g_loss | |||||
| def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: | |||||
| fake_img, latents = self.generator(src, return_latents=True) | |||||
| path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( | |||||
| fake_img, latents, self.mean_path_length) | |||||
| return path_loss | |||||
| @torch.no_grad() | |||||
| def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: | |||||
| return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} | |||||
| def forward(self, input: Dict[str, | |||||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, Union[list, Tensor]]: results | |||||
| """ | |||||
| for key, value in input.items(): | |||||
| input[key] = input[key].to(self._device) | |||||
| if 'target' in input: | |||||
| return self._evaluate_postprocess(**input) | |||||
| else: | |||||
| return self._inference_forward(**input) | |||||
| @@ -0,0 +1,129 @@ | |||||
| from collections import namedtuple | |||||
| import torch | |||||
| from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d, | |||||
| Module, PReLU, ReLU, Sequential, Sigmoid) | |||||
| class Flatten(Module): | |||||
| def forward(self, input): | |||||
| return input.view(input.size(0), -1) | |||||
| def l2_norm(input, axis=1): | |||||
| norm = torch.norm(input, 2, axis, True) | |||||
| output = torch.div(input, norm) | |||||
| return output | |||||
| class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |||||
| """ A named tuple describing a ResNet block. """ | |||||
| def get_block(in_channel, depth, num_units, stride=2): | |||||
| return [Bottleneck(in_channel, depth, stride) | |||||
| ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |||||
| def get_blocks(num_layers): | |||||
| if num_layers == 50: | |||||
| blocks = [ | |||||
| get_block(in_channel=64, depth=64, num_units=3), | |||||
| get_block(in_channel=64, depth=128, num_units=4), | |||||
| get_block(in_channel=128, depth=256, num_units=14), | |||||
| get_block(in_channel=256, depth=512, num_units=3) | |||||
| ] | |||||
| elif num_layers == 100: | |||||
| blocks = [ | |||||
| get_block(in_channel=64, depth=64, num_units=3), | |||||
| get_block(in_channel=64, depth=128, num_units=13), | |||||
| get_block(in_channel=128, depth=256, num_units=30), | |||||
| get_block(in_channel=256, depth=512, num_units=3) | |||||
| ] | |||||
| elif num_layers == 152: | |||||
| blocks = [ | |||||
| get_block(in_channel=64, depth=64, num_units=3), | |||||
| get_block(in_channel=64, depth=128, num_units=8), | |||||
| get_block(in_channel=128, depth=256, num_units=36), | |||||
| get_block(in_channel=256, depth=512, num_units=3) | |||||
| ] | |||||
| else: | |||||
| raise ValueError( | |||||
| 'Invalid number of layers: {}. Must be one of [50, 100, 152]'. | |||||
| format(num_layers)) | |||||
| return blocks | |||||
| class SEModule(Module): | |||||
| def __init__(self, channels, reduction): | |||||
| super(SEModule, self).__init__() | |||||
| self.avg_pool = AdaptiveAvgPool2d(1) | |||||
| self.fc1 = Conv2d( | |||||
| channels, | |||||
| channels // reduction, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| bias=False) | |||||
| self.relu = ReLU(inplace=True) | |||||
| self.fc2 = Conv2d( | |||||
| channels // reduction, | |||||
| channels, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| bias=False) | |||||
| self.sigmoid = Sigmoid() | |||||
| def forward(self, x): | |||||
| module_input = x | |||||
| x = self.avg_pool(x) | |||||
| x = self.fc1(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2(x) | |||||
| x = self.sigmoid(x) | |||||
| return module_input * x | |||||
| class bottleneck_IR(Module): | |||||
| def __init__(self, in_channel, depth, stride): | |||||
| super(bottleneck_IR, self).__init__() | |||||
| if in_channel == depth: | |||||
| self.shortcut_layer = MaxPool2d(1, stride) | |||||
| else: | |||||
| self.shortcut_layer = Sequential( | |||||
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||||
| BatchNorm2d(depth)) | |||||
| self.res_layer = Sequential( | |||||
| BatchNorm2d(in_channel), | |||||
| Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||||
| PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||||
| BatchNorm2d(depth)) | |||||
| def forward(self, x): | |||||
| shortcut = self.shortcut_layer(x) | |||||
| res = self.res_layer(x) | |||||
| return res + shortcut | |||||
| class bottleneck_IR_SE(Module): | |||||
| def __init__(self, in_channel, depth, stride): | |||||
| super(bottleneck_IR_SE, self).__init__() | |||||
| if in_channel == depth: | |||||
| self.shortcut_layer = MaxPool2d(1, stride) | |||||
| else: | |||||
| self.shortcut_layer = Sequential( | |||||
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||||
| BatchNorm2d(depth)) | |||||
| self.res_layer = Sequential( | |||||
| BatchNorm2d(in_channel), | |||||
| Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||||
| PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||||
| BatchNorm2d(depth), SEModule(depth, 16)) | |||||
| def forward(self, x): | |||||
| shortcut = self.shortcut_layer(x) | |||||
| res = self.res_layer(x) | |||||
| return res + shortcut | |||||
| @@ -0,0 +1,90 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .model_irse import Backbone | |||||
| class L1Loss(nn.Module): | |||||
| """L1 (mean absolute error, MAE) loss. | |||||
| Args: | |||||
| loss_weight (float): Loss weight for L1 loss. Default: 1.0. | |||||
| reduction (str): Specifies the reduction to apply to the output. | |||||
| Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. | |||||
| """ | |||||
| def __init__(self, loss_weight=1.0, reduction='mean'): | |||||
| super(L1Loss, self).__init__() | |||||
| if reduction not in ['none', 'mean', 'sum']: | |||||
| raise ValueError( | |||||
| f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' | |||||
| ) | |||||
| self.loss_weight = loss_weight | |||||
| self.reduction = reduction | |||||
| def forward(self, pred, target, weight=None, **kwargs): | |||||
| """ | |||||
| Args: | |||||
| pred (Tensor): of shape (N, C, H, W). Predicted tensor. | |||||
| target (Tensor): of shape (N, C, H, W). Ground truth tensor. | |||||
| weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. | |||||
| """ | |||||
| return self.loss_weight * F.l1_loss( | |||||
| pred, target, reduction=self.reduction) | |||||
| class IDLoss(nn.Module): | |||||
| def __init__(self, model_path, device='cuda', ckpt_dict=None): | |||||
| super(IDLoss, self).__init__() | |||||
| print('Loading ResNet ArcFace') | |||||
| self.facenet = Backbone( | |||||
| input_size=112, num_layers=50, drop_ratio=0.6, | |||||
| mode='ir_se').to(device) | |||||
| if ckpt_dict is None: | |||||
| self.facenet.load_state_dict( | |||||
| torch.load(model_path, map_location=torch.device('cpu'))) | |||||
| else: | |||||
| self.facenet.load_state_dict(ckpt_dict) | |||||
| self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |||||
| self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | |||||
| self.facenet.eval() | |||||
| def extract_feats(self, x): | |||||
| _, _, h, w = x.shape | |||||
| assert h == w | |||||
| if h != 256: | |||||
| x = self.pool(x) | |||||
| x = x[:, :, 35:-33, 32:-36] # crop roi | |||||
| x = self.face_pool(x) | |||||
| x_feats = self.facenet(x) | |||||
| return x_feats | |||||
| @torch.no_grad() | |||||
| def forward(self, y_hat, y, x): | |||||
| n_samples = x.shape[0] | |||||
| x_feats = self.extract_feats(x) | |||||
| y_feats = self.extract_feats(y) # Otherwise use the feature from there | |||||
| y_hat_feats = self.extract_feats(y_hat) | |||||
| y_feats = y_feats.detach() | |||||
| loss = 0 | |||||
| sim_improvement = 0 | |||||
| id_logs = [] | |||||
| count = 0 | |||||
| for i in range(n_samples): | |||||
| diff_target = y_hat_feats[i].dot(y_feats[i]) | |||||
| diff_input = y_hat_feats[i].dot(x_feats[i]) | |||||
| diff_views = y_feats[i].dot(x_feats[i]) | |||||
| id_logs.append({ | |||||
| 'diff_target': float(diff_target), | |||||
| 'diff_input': float(diff_input), | |||||
| 'diff_views': float(diff_views) | |||||
| }) | |||||
| loss += 1 - diff_target | |||||
| id_diff = float(diff_target) - float(diff_views) | |||||
| sim_improvement += id_diff | |||||
| count += 1 | |||||
| return loss / count, sim_improvement / count, id_logs | |||||
| @@ -0,0 +1,92 @@ | |||||
| from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, | |||||
| Module, PReLU, Sequential) | |||||
| from .helpers import (Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, | |||||
| l2_norm) | |||||
| class Backbone(Module): | |||||
| def __init__(self, | |||||
| input_size, | |||||
| num_layers, | |||||
| mode='ir', | |||||
| drop_ratio=0.4, | |||||
| affine=True): | |||||
| super(Backbone, self).__init__() | |||||
| assert input_size in [112, 224], 'input_size should be 112 or 224' | |||||
| assert num_layers in [50, 100, | |||||
| 152], 'num_layers should be 50, 100 or 152' | |||||
| assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' | |||||
| blocks = get_blocks(num_layers) | |||||
| if mode == 'ir': | |||||
| unit_module = bottleneck_IR | |||||
| elif mode == 'ir_se': | |||||
| unit_module = bottleneck_IR_SE | |||||
| self.input_layer = Sequential( | |||||
| Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |||||
| PReLU(64)) | |||||
| if input_size == 112: | |||||
| self.output_layer = Sequential( | |||||
| BatchNorm2d(512), Dropout(drop_ratio), Flatten(), | |||||
| Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) | |||||
| else: | |||||
| self.output_layer = Sequential( | |||||
| BatchNorm2d(512), Dropout(drop_ratio), Flatten(), | |||||
| Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) | |||||
| modules = [] | |||||
| for block in blocks: | |||||
| for bottleneck in block: | |||||
| modules.append( | |||||
| unit_module(bottleneck.in_channel, bottleneck.depth, | |||||
| bottleneck.stride)) | |||||
| self.body = Sequential(*modules) | |||||
| def forward(self, x): | |||||
| x = self.input_layer(x) | |||||
| x = self.body(x) | |||||
| x = self.output_layer(x) | |||||
| return l2_norm(x) | |||||
| def IR_50(input_size): | |||||
| """Constructs a ir-50 model.""" | |||||
| model = Backbone( | |||||
| input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) | |||||
| return model | |||||
| def IR_101(input_size): | |||||
| """Constructs a ir-101 model.""" | |||||
| model = Backbone( | |||||
| input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) | |||||
| return model | |||||
| def IR_152(input_size): | |||||
| """Constructs a ir-152 model.""" | |||||
| model = Backbone( | |||||
| input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) | |||||
| return model | |||||
| def IR_SE_50(input_size): | |||||
| """Constructs a ir_se-50 model.""" | |||||
| model = Backbone( | |||||
| input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) | |||||
| return model | |||||
| def IR_SE_101(input_size): | |||||
| """Constructs a ir_se-101 model.""" | |||||
| model = Backbone( | |||||
| input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) | |||||
| return model | |||||
| def IR_SE_152(input_size): | |||||
| """Constructs a ir_se-152 model.""" | |||||
| model = Backbone( | |||||
| input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) | |||||
| return model | |||||
| @@ -0,0 +1,217 @@ | |||||
| import os | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.backends.cudnn as cudnn | |||||
| import torch.nn.functional as F | |||||
| from .models.retinaface import RetinaFace | |||||
| from .utils import PriorBox, decode, decode_landm, py_cpu_nms | |||||
| cfg_re50 = { | |||||
| 'name': 'Resnet50', | |||||
| 'min_sizes': [[16, 32], [64, 128], [256, 512]], | |||||
| 'steps': [8, 16, 32], | |||||
| 'variance': [0.1, 0.2], | |||||
| 'clip': False, | |||||
| 'pretrain': False, | |||||
| 'return_layers': { | |||||
| 'layer2': 1, | |||||
| 'layer3': 2, | |||||
| 'layer4': 3 | |||||
| }, | |||||
| 'in_channel': 256, | |||||
| 'out_channel': 256 | |||||
| } | |||||
| class RetinaFaceDetection(object): | |||||
| def __init__(self, model_path, device='cuda'): | |||||
| torch.set_grad_enabled(False) | |||||
| cudnn.benchmark = True | |||||
| self.model_path = model_path | |||||
| self.device = device | |||||
| self.cfg = cfg_re50 | |||||
| self.net = RetinaFace(cfg=self.cfg) | |||||
| self.load_model() | |||||
| self.net = self.net.to(device) | |||||
| self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) | |||||
| def check_keys(self, pretrained_state_dict): | |||||
| ckpt_keys = set(pretrained_state_dict.keys()) | |||||
| model_keys = set(self.net.state_dict().keys()) | |||||
| used_pretrained_keys = model_keys & ckpt_keys | |||||
| assert len( | |||||
| used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' | |||||
| return True | |||||
| def remove_prefix(self, state_dict, prefix): | |||||
| new_state_dict = dict() | |||||
| # remove unnecessary 'module.' | |||||
| for k, v in state_dict.items(): | |||||
| if k.startswith(prefix): | |||||
| new_state_dict[k[len(prefix):]] = v | |||||
| else: | |||||
| new_state_dict[k] = v | |||||
| return new_state_dict | |||||
| def load_model(self, load_to_cpu=False): | |||||
| pretrained_dict = torch.load( | |||||
| self.model_path, map_location=torch.device('cpu')) | |||||
| if 'state_dict' in pretrained_dict.keys(): | |||||
| pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], | |||||
| 'module.') | |||||
| else: | |||||
| pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') | |||||
| self.check_keys(pretrained_dict) | |||||
| self.net.load_state_dict(pretrained_dict, strict=False) | |||||
| self.net.eval() | |||||
| def detect(self, | |||||
| img_raw, | |||||
| resize=1, | |||||
| confidence_threshold=0.9, | |||||
| nms_threshold=0.4, | |||||
| top_k=5000, | |||||
| keep_top_k=750, | |||||
| save_image=False): | |||||
| img = np.float32(img_raw) | |||||
| im_height, im_width = img.shape[:2] | |||||
| ss = 1.0 | |||||
| # tricky | |||||
| if max(im_height, im_width) > 1500: | |||||
| ss = 1000.0 / max(im_height, im_width) | |||||
| img = cv2.resize(img, (0, 0), fx=ss, fy=ss) | |||||
| im_height, im_width = img.shape[:2] | |||||
| scale = torch.Tensor( | |||||
| [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) | |||||
| img -= (104, 117, 123) | |||||
| img = img.transpose(2, 0, 1) | |||||
| img = torch.from_numpy(img).unsqueeze(0) | |||||
| img = img.to(self.device) | |||||
| scale = scale.to(self.device) | |||||
| loc, conf, landms = self.net(img) # forward pass | |||||
| del img | |||||
| priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) | |||||
| priors = priorbox.forward() | |||||
| priors = priors.to(self.device) | |||||
| prior_data = priors.data | |||||
| boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |||||
| boxes = boxes * scale / resize | |||||
| boxes = boxes.cpu().numpy() | |||||
| scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |||||
| landms = decode_landm( | |||||
| landms.data.squeeze(0), prior_data, self.cfg['variance']) | |||||
| scale1 = torch.Tensor([ | |||||
| im_width, im_height, im_width, im_height, im_width, im_height, | |||||
| im_width, im_height, im_width, im_height | |||||
| ]) | |||||
| scale1 = scale1.to(self.device) | |||||
| landms = landms * scale1 / resize | |||||
| landms = landms.cpu().numpy() | |||||
| # ignore low scores | |||||
| inds = np.where(scores > confidence_threshold)[0] | |||||
| boxes = boxes[inds] | |||||
| landms = landms[inds] | |||||
| scores = scores[inds] | |||||
| # keep top-K before NMS | |||||
| order = scores.argsort()[::-1][:top_k] | |||||
| boxes = boxes[order] | |||||
| landms = landms[order] | |||||
| scores = scores[order] | |||||
| # do NMS | |||||
| dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||||
| np.float32, copy=False) | |||||
| keep = py_cpu_nms(dets, nms_threshold) | |||||
| dets = dets[keep, :] | |||||
| landms = landms[keep] | |||||
| # keep top-K faster NMS | |||||
| dets = dets[:keep_top_k, :] | |||||
| landms = landms[:keep_top_k, :] | |||||
| landms = landms.reshape((-1, 5, 2)) | |||||
| landms = landms.transpose((0, 2, 1)) | |||||
| landms = landms.reshape( | |||||
| -1, | |||||
| 10, | |||||
| ) | |||||
| return dets / ss, landms / ss | |||||
| def detect_tensor(self, | |||||
| img, | |||||
| resize=1, | |||||
| confidence_threshold=0.9, | |||||
| nms_threshold=0.4, | |||||
| top_k=5000, | |||||
| keep_top_k=750, | |||||
| save_image=False): | |||||
| im_height, im_width = img.shape[-2:] | |||||
| ss = 1000 / max(im_height, im_width) | |||||
| img = F.interpolate(img, scale_factor=ss) | |||||
| im_height, im_width = img.shape[-2:] | |||||
| scale = torch.Tensor([im_width, im_height, im_width, | |||||
| im_height]).to(self.device) | |||||
| img -= self.mean | |||||
| loc, conf, landms = self.net(img) # forward pass | |||||
| priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) | |||||
| priors = priorbox.forward() | |||||
| priors = priors.to(self.device) | |||||
| prior_data = priors.data | |||||
| boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |||||
| boxes = boxes * scale / resize | |||||
| boxes = boxes.cpu().numpy() | |||||
| scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |||||
| landms = decode_landm( | |||||
| landms.data.squeeze(0), prior_data, self.cfg['variance']) | |||||
| scale1 = torch.Tensor([ | |||||
| img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |||||
| img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |||||
| img.shape[3], img.shape[2] | |||||
| ]) | |||||
| scale1 = scale1.to(self.device) | |||||
| landms = landms * scale1 / resize | |||||
| landms = landms.cpu().numpy() | |||||
| # ignore low scores | |||||
| inds = np.where(scores > confidence_threshold)[0] | |||||
| boxes = boxes[inds] | |||||
| landms = landms[inds] | |||||
| scores = scores[inds] | |||||
| # keep top-K before NMS | |||||
| order = scores.argsort()[::-1][:top_k] | |||||
| boxes = boxes[order] | |||||
| landms = landms[order] | |||||
| scores = scores[order] | |||||
| # do NMS | |||||
| dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||||
| np.float32, copy=False) | |||||
| keep = py_cpu_nms(dets, nms_threshold) | |||||
| dets = dets[keep, :] | |||||
| landms = landms[keep] | |||||
| # keep top-K faster NMS | |||||
| dets = dets[:keep_top_k, :] | |||||
| landms = landms[:keep_top_k, :] | |||||
| landms = landms.reshape((-1, 5, 2)) | |||||
| landms = landms.transpose((0, 2, 1)) | |||||
| landms = landms.reshape( | |||||
| -1, | |||||
| 10, | |||||
| ) | |||||
| return dets / ss, landms / ss | |||||
| @@ -0,0 +1,148 @@ | |||||
| import time | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torchvision.models as models | |||||
| import torchvision.models._utils as _utils | |||||
| from torch.autograd import Variable | |||||
| def conv_bn(inp, oup, stride=1, leaky=0): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), | |||||
| nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||||
| def conv_bn_no_relu(inp, oup, stride): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||||
| nn.BatchNorm2d(oup), | |||||
| ) | |||||
| def conv_bn1X1(inp, oup, stride, leaky=0): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), | |||||
| nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||||
| def conv_dw(inp, oup, stride, leaky=0.1): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||||
| nn.BatchNorm2d(inp), | |||||
| nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||||
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||||
| nn.BatchNorm2d(oup), | |||||
| nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||||
| ) | |||||
| class SSH(nn.Module): | |||||
| def __init__(self, in_channel, out_channel): | |||||
| super(SSH, self).__init__() | |||||
| assert out_channel % 4 == 0 | |||||
| leaky = 0 | |||||
| if (out_channel <= 64): | |||||
| leaky = 0.1 | |||||
| self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) | |||||
| self.conv5X5_1 = conv_bn( | |||||
| in_channel, out_channel // 4, stride=1, leaky=leaky) | |||||
| self.conv5X5_2 = conv_bn_no_relu( | |||||
| out_channel // 4, out_channel // 4, stride=1) | |||||
| self.conv7X7_2 = conv_bn( | |||||
| out_channel // 4, out_channel // 4, stride=1, leaky=leaky) | |||||
| self.conv7x7_3 = conv_bn_no_relu( | |||||
| out_channel // 4, out_channel // 4, stride=1) | |||||
| def forward(self, input): | |||||
| conv3X3 = self.conv3X3(input) | |||||
| conv5X5_1 = self.conv5X5_1(input) | |||||
| conv5X5 = self.conv5X5_2(conv5X5_1) | |||||
| conv7X7_2 = self.conv7X7_2(conv5X5_1) | |||||
| conv7X7 = self.conv7x7_3(conv7X7_2) | |||||
| out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) | |||||
| out = F.relu(out) | |||||
| return out | |||||
| class FPN(nn.Module): | |||||
| def __init__(self, in_channels_list, out_channels): | |||||
| super(FPN, self).__init__() | |||||
| leaky = 0 | |||||
| if (out_channels <= 64): | |||||
| leaky = 0.1 | |||||
| self.output1 = conv_bn1X1( | |||||
| in_channels_list[0], out_channels, stride=1, leaky=leaky) | |||||
| self.output2 = conv_bn1X1( | |||||
| in_channels_list[1], out_channels, stride=1, leaky=leaky) | |||||
| self.output3 = conv_bn1X1( | |||||
| in_channels_list[2], out_channels, stride=1, leaky=leaky) | |||||
| self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) | |||||
| self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) | |||||
| def forward(self, input): | |||||
| # names = list(input.keys()) | |||||
| input = list(input.values()) | |||||
| output1 = self.output1(input[0]) | |||||
| output2 = self.output2(input[1]) | |||||
| output3 = self.output3(input[2]) | |||||
| up3 = F.interpolate( | |||||
| output3, size=[output2.size(2), output2.size(3)], mode='nearest') | |||||
| output2 = output2 + up3 | |||||
| output2 = self.merge2(output2) | |||||
| up2 = F.interpolate( | |||||
| output2, size=[output1.size(2), output1.size(3)], mode='nearest') | |||||
| output1 = output1 + up2 | |||||
| output1 = self.merge1(output1) | |||||
| out = [output1, output2, output3] | |||||
| return out | |||||
| class MobileNetV1(nn.Module): | |||||
| def __init__(self): | |||||
| super(MobileNetV1, self).__init__() | |||||
| self.stage1 = nn.Sequential( | |||||
| conv_bn(3, 8, 2, leaky=0.1), # 3 | |||||
| conv_dw(8, 16, 1), # 7 | |||||
| conv_dw(16, 32, 2), # 11 | |||||
| conv_dw(32, 32, 1), # 19 | |||||
| conv_dw(32, 64, 2), # 27 | |||||
| conv_dw(64, 64, 1), # 43 | |||||
| ) | |||||
| self.stage2 = nn.Sequential( | |||||
| conv_dw(64, 128, 2), # 43 + 16 = 59 | |||||
| conv_dw(128, 128, 1), # 59 + 32 = 91 | |||||
| conv_dw(128, 128, 1), # 91 + 32 = 123 | |||||
| conv_dw(128, 128, 1), # 123 + 32 = 155 | |||||
| conv_dw(128, 128, 1), # 155 + 32 = 187 | |||||
| conv_dw(128, 128, 1), # 187 + 32 = 219 | |||||
| ) | |||||
| self.stage3 = nn.Sequential( | |||||
| conv_dw(128, 256, 2), # 219 +3 2 = 241 | |||||
| conv_dw(256, 256, 1), # 241 + 64 = 301 | |||||
| ) | |||||
| self.avg = nn.AdaptiveAvgPool2d((1, 1)) | |||||
| self.fc = nn.Linear(256, 1000) | |||||
| def forward(self, x): | |||||
| x = self.stage1(x) | |||||
| x = self.stage2(x) | |||||
| x = self.stage3(x) | |||||
| x = self.avg(x) | |||||
| x = x.view(-1, 256) | |||||
| x = self.fc(x) | |||||
| return x | |||||
| @@ -0,0 +1,144 @@ | |||||
| from collections import OrderedDict | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torchvision.models as models | |||||
| import torchvision.models._utils as _utils | |||||
| import torchvision.models.detection.backbone_utils as backbone_utils | |||||
| from .net import FPN, SSH, MobileNetV1 | |||||
| class ClassHead(nn.Module): | |||||
| def __init__(self, inchannels=512, num_anchors=3): | |||||
| super(ClassHead, self).__init__() | |||||
| self.num_anchors = num_anchors | |||||
| self.conv1x1 = nn.Conv2d( | |||||
| inchannels, | |||||
| self.num_anchors * 2, | |||||
| kernel_size=(1, 1), | |||||
| stride=1, | |||||
| padding=0) | |||||
| def forward(self, x): | |||||
| out = self.conv1x1(x) | |||||
| out = out.permute(0, 2, 3, 1).contiguous() | |||||
| return out.view(out.shape[0], -1, 2) | |||||
| class BboxHead(nn.Module): | |||||
| def __init__(self, inchannels=512, num_anchors=3): | |||||
| super(BboxHead, self).__init__() | |||||
| self.conv1x1 = nn.Conv2d( | |||||
| inchannels, | |||||
| num_anchors * 4, | |||||
| kernel_size=(1, 1), | |||||
| stride=1, | |||||
| padding=0) | |||||
| def forward(self, x): | |||||
| out = self.conv1x1(x) | |||||
| out = out.permute(0, 2, 3, 1).contiguous() | |||||
| return out.view(out.shape[0], -1, 4) | |||||
| class LandmarkHead(nn.Module): | |||||
| def __init__(self, inchannels=512, num_anchors=3): | |||||
| super(LandmarkHead, self).__init__() | |||||
| self.conv1x1 = nn.Conv2d( | |||||
| inchannels, | |||||
| num_anchors * 10, | |||||
| kernel_size=(1, 1), | |||||
| stride=1, | |||||
| padding=0) | |||||
| def forward(self, x): | |||||
| out = self.conv1x1(x) | |||||
| out = out.permute(0, 2, 3, 1).contiguous() | |||||
| return out.view(out.shape[0], -1, 10) | |||||
| class RetinaFace(nn.Module): | |||||
| def __init__(self, cfg=None): | |||||
| """ | |||||
| :param cfg: Network related settings. | |||||
| """ | |||||
| super(RetinaFace, self).__init__() | |||||
| backbone = None | |||||
| if cfg['name'] == 'Resnet50': | |||||
| backbone = models.resnet50(pretrained=cfg['pretrain']) | |||||
| else: | |||||
| raise Exception('Invalid name') | |||||
| self.body = _utils.IntermediateLayerGetter(backbone, | |||||
| cfg['return_layers']) | |||||
| in_channels_stage2 = cfg['in_channel'] | |||||
| in_channels_list = [ | |||||
| in_channels_stage2 * 2, | |||||
| in_channels_stage2 * 4, | |||||
| in_channels_stage2 * 8, | |||||
| ] | |||||
| out_channels = cfg['out_channel'] | |||||
| self.fpn = FPN(in_channels_list, out_channels) | |||||
| self.ssh1 = SSH(out_channels, out_channels) | |||||
| self.ssh2 = SSH(out_channels, out_channels) | |||||
| self.ssh3 = SSH(out_channels, out_channels) | |||||
| self.ClassHead = self._make_class_head( | |||||
| fpn_num=3, inchannels=cfg['out_channel']) | |||||
| self.BboxHead = self._make_bbox_head( | |||||
| fpn_num=3, inchannels=cfg['out_channel']) | |||||
| self.LandmarkHead = self._make_landmark_head( | |||||
| fpn_num=3, inchannels=cfg['out_channel']) | |||||
| def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||||
| classhead = nn.ModuleList() | |||||
| for i in range(fpn_num): | |||||
| classhead.append(ClassHead(inchannels, anchor_num)) | |||||
| return classhead | |||||
| def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||||
| bboxhead = nn.ModuleList() | |||||
| for i in range(fpn_num): | |||||
| bboxhead.append(BboxHead(inchannels, anchor_num)) | |||||
| return bboxhead | |||||
| def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||||
| landmarkhead = nn.ModuleList() | |||||
| for i in range(fpn_num): | |||||
| landmarkhead.append(LandmarkHead(inchannels, anchor_num)) | |||||
| return landmarkhead | |||||
| def forward(self, inputs): | |||||
| out = self.body(inputs) | |||||
| # FPN | |||||
| fpn = self.fpn(out) | |||||
| # SSH | |||||
| feature1 = self.ssh1(fpn[0]) | |||||
| feature2 = self.ssh2(fpn[1]) | |||||
| feature3 = self.ssh3(fpn[2]) | |||||
| features = [feature1, feature2, feature3] | |||||
| bbox_regressions = torch.cat( | |||||
| [self.BboxHead[i](feature) for i, feature in enumerate(features)], | |||||
| dim=1) | |||||
| classifications = torch.cat( | |||||
| [self.ClassHead[i](feature) for i, feature in enumerate(features)], | |||||
| dim=1) | |||||
| ldm_regressions = torch.cat( | |||||
| [self.LandmarkHead[i](feat) for i, feat in enumerate(features)], | |||||
| dim=1) | |||||
| output = (bbox_regressions, F.softmax(classifications, | |||||
| dim=-1), ldm_regressions) | |||||
| return output | |||||
| @@ -0,0 +1,123 @@ | |||||
| # -------------------------------------------------------- | |||||
| # Modified from https://github.com/biubug6/Pytorch_Retinaface | |||||
| # -------------------------------------------------------- | |||||
| from itertools import product as product | |||||
| from math import ceil | |||||
| import numpy as np | |||||
| import torch | |||||
| class PriorBox(object): | |||||
| def __init__(self, cfg, image_size=None, phase='train'): | |||||
| super(PriorBox, self).__init__() | |||||
| self.min_sizes = cfg['min_sizes'] | |||||
| self.steps = cfg['steps'] | |||||
| self.clip = cfg['clip'] | |||||
| self.image_size = image_size | |||||
| self.feature_maps = [[ | |||||
| ceil(self.image_size[0] / step), | |||||
| ceil(self.image_size[1] / step) | |||||
| ] for step in self.steps] | |||||
| self.name = 's' | |||||
| def forward(self): | |||||
| anchors = [] | |||||
| for k, f in enumerate(self.feature_maps): | |||||
| min_sizes = self.min_sizes[k] | |||||
| for i, j in product(range(f[0]), range(f[1])): | |||||
| for min_size in min_sizes: | |||||
| s_kx = min_size / self.image_size[1] | |||||
| s_ky = min_size / self.image_size[0] | |||||
| dense_cx = [ | |||||
| x * self.steps[k] / self.image_size[1] | |||||
| for x in [j + 0.5] | |||||
| ] | |||||
| dense_cy = [ | |||||
| y * self.steps[k] / self.image_size[0] | |||||
| for y in [i + 0.5] | |||||
| ] | |||||
| for cy, cx in product(dense_cy, dense_cx): | |||||
| anchors += [cx, cy, s_kx, s_ky] | |||||
| # back to torch land | |||||
| output = torch.Tensor(anchors).view(-1, 4) | |||||
| if self.clip: | |||||
| output.clamp_(max=1, min=0) | |||||
| return output | |||||
| def py_cpu_nms(dets, thresh): | |||||
| """Pure Python NMS baseline.""" | |||||
| x1 = dets[:, 0] | |||||
| y1 = dets[:, 1] | |||||
| x2 = dets[:, 2] | |||||
| y2 = dets[:, 3] | |||||
| scores = dets[:, 4] | |||||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||||
| order = scores.argsort()[::-1] | |||||
| keep = [] | |||||
| while order.size > 0: | |||||
| i = order[0] | |||||
| keep.append(i) | |||||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||||
| inter = w * h | |||||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||||
| inds = np.where(ovr <= thresh)[0] | |||||
| order = order[inds + 1] | |||||
| return keep | |||||
| # Adapted from https://github.com/Hakuyume/chainer-ssd | |||||
| def decode(loc, priors, variances): | |||||
| """Decode locations from predictions using priors to undo | |||||
| the encoding we did for offset regression at train time. | |||||
| Args: | |||||
| loc (tensor): location predictions for loc layers, | |||||
| Shape: [num_priors,4] | |||||
| priors (tensor): Prior boxes in center-offset form. | |||||
| Shape: [num_priors,4]. | |||||
| variances: (list[float]) Variances of priorboxes | |||||
| Return: | |||||
| decoded bounding box predictions | |||||
| """ | |||||
| boxes = torch.cat( | |||||
| (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], | |||||
| priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) | |||||
| boxes[:, :2] -= boxes[:, 2:] / 2 | |||||
| boxes[:, 2:] += boxes[:, :2] | |||||
| return boxes | |||||
| def decode_landm(pre, priors, variances): | |||||
| """Decode landm from predictions using priors to undo | |||||
| the encoding we did for offset regression at train time. | |||||
| Args: | |||||
| pre (tensor): landm predictions for loc layers, | |||||
| Shape: [num_priors,10] | |||||
| priors (tensor): Prior boxes in center-offset form. | |||||
| Shape: [num_priors,4]. | |||||
| variances: (list[float]) Variances of priorboxes | |||||
| Return: | |||||
| decoded landm predictions | |||||
| """ | |||||
| a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] | |||||
| b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] | |||||
| c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] | |||||
| d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] | |||||
| e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] | |||||
| landms = torch.cat((a, b, c, d, e), dim=1) | |||||
| return landms | |||||
| @@ -137,6 +137,7 @@ TASK_OUTPUTS = { | |||||
| Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG], | Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], | Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], | |||||
| # image generation task result for a single image | # image generation task result for a single image | ||||
| # {"output_img": np.array with shape (h, w, 3)} | # {"output_img": np.array with shape (h, w, 3)} | ||||
| @@ -110,6 +110,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_gan_face-image-generation'), | 'damo/cv_gan_face-image-generation'), | ||||
| Tasks.image_super_resolution: (Pipelines.image_super_resolution, | Tasks.image_super_resolution: (Pipelines.image_super_resolution, | ||||
| 'damo/cv_rrdb_image-super-resolution'), | 'damo/cv_rrdb_image-super-resolution'), | ||||
| Tasks.image_portrait_enhancement: | |||||
| (Pipelines.image_portrait_enhancement, | |||||
| 'damo/cv_gpen_image-portrait-enhancement'), | |||||
| Tasks.product_retrieval_embedding: | Tasks.product_retrieval_embedding: | ||||
| (Pipelines.product_retrieval_embedding, | (Pipelines.product_retrieval_embedding, | ||||
| 'damo/cv_resnet50_product-bag-embedding-models'), | 'damo/cv_resnet50_product-bag-embedding-models'), | ||||
| @@ -11,14 +11,15 @@ if TYPE_CHECKING: | |||||
| from .face_detection_pipeline import FaceDetectionPipeline | from .face_detection_pipeline import FaceDetectionPipeline | ||||
| from .face_recognition_pipeline import FaceRecognitionPipeline | from .face_recognition_pipeline import FaceRecognitionPipeline | ||||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
| from .image_classification_pipeline import ImageClassificationPipeline | |||||
| from .image_cartoon_pipeline import ImageCartoonPipeline | from .image_cartoon_pipeline import ImageCartoonPipeline | ||||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | from .image_classification_pipeline import GeneralImageClassificationPipeline | ||||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | from .image_color_enhance_pipeline import ImageColorEnhancePipeline | ||||
| from .image_colorization_pipeline import ImageColorizationPipeline | from .image_colorization_pipeline import ImageColorizationPipeline | ||||
| from .image_classification_pipeline import ImageClassificationPipeline | |||||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | ||||
| from .image_matting_pipeline import ImageMattingPipeline | from .image_matting_pipeline import ImageMattingPipeline | ||||
| from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | |||||
| from .image_style_transfer_pipeline import ImageStyleTransferPipeline | from .image_style_transfer_pipeline import ImageStyleTransferPipeline | ||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
| from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline | from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline | ||||
| @@ -46,6 +47,8 @@ else: | |||||
| 'image_instance_segmentation_pipeline': | 'image_instance_segmentation_pipeline': | ||||
| ['ImageInstanceSegmentationPipeline'], | ['ImageInstanceSegmentationPipeline'], | ||||
| 'image_matting_pipeline': ['ImageMattingPipeline'], | 'image_matting_pipeline': ['ImageMattingPipeline'], | ||||
| 'image_portrait_enhancement_pipeline': | |||||
| ['ImagePortraitEnhancementPipeline'], | |||||
| 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], | 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], | ||||
| 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | ||||
| 'image_to_image_translation_pipeline': | 'image_to_image_translation_pipeline': | ||||
| @@ -0,0 +1,216 @@ | |||||
| import math | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from scipy.ndimage import gaussian_filter | |||||
| from scipy.spatial.distance import pdist, squareform | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.image_portrait_enhancement import gpen | |||||
| from modelscope.models.cv.image_portrait_enhancement.align_faces import ( | |||||
| get_reference_facial_points, warp_and_crop_face) | |||||
| from modelscope.models.cv.image_portrait_enhancement.eqface import fqa | |||||
| from modelscope.models.cv.image_portrait_enhancement.retinaface import \ | |||||
| detection | |||||
| from modelscope.models.cv.super_resolution import rrdbnet_arch | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage, load_image | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_portrait_enhancement, | |||||
| module_name=Pipelines.image_portrait_enhancement) | |||||
| class ImagePortraitEnhancementPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| if torch.cuda.is_available(): | |||||
| self.device = torch.device('cuda') | |||||
| else: | |||||
| self.device = torch.device('cpu') | |||||
| self.use_sr = True | |||||
| self.size = 512 | |||||
| self.n_mlp = 8 | |||||
| self.channel_multiplier = 2 | |||||
| self.narrow = 1 | |||||
| self.face_enhancer = gpen.FullGenerator( | |||||
| self.size, | |||||
| 512, | |||||
| self.n_mlp, | |||||
| self.channel_multiplier, | |||||
| narrow=self.narrow).to(self.device) | |||||
| gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | |||||
| self.face_enhancer.load_state_dict( | |||||
| torch.load(gpen_model_path), strict=True) | |||||
| logger.info('load face enhancer model done') | |||||
| self.threshold = 0.9 | |||||
| detector_model_path = f'{model}/face_detection/RetinaFace-R50.pth' | |||||
| self.face_detector = detection.RetinaFaceDetection( | |||||
| detector_model_path, self.device) | |||||
| logger.info('load face detector model done') | |||||
| self.num_feat = 32 | |||||
| self.num_block = 23 | |||||
| self.scale = 2 | |||||
| self.sr_model = rrdbnet_arch.RRDBNet( | |||||
| num_in_ch=3, | |||||
| num_out_ch=3, | |||||
| num_feat=self.num_feat, | |||||
| num_block=self.num_block, | |||||
| num_grow_ch=32, | |||||
| scale=self.scale).to(self.device) | |||||
| sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' | |||||
| self.sr_model.load_state_dict( | |||||
| torch.load(sr_model_path)['params_ema'], strict=True) | |||||
| logger.info('load sr model done') | |||||
| self.fqa_thres = 0.1 | |||||
| self.id_thres = 0.15 | |||||
| self.alpha = 1.0 | |||||
| backbone_model_path = f'{model}/face_quality/eqface_backbone.pth' | |||||
| fqa_model_path = f'{model}/face_quality/eqface_quality.pth' | |||||
| self.eqface = fqa.FQA(backbone_model_path, fqa_model_path, self.device) | |||||
| logger.info('load fqa model done') | |||||
| # the mask for pasting restored faces back | |||||
| self.mask = np.zeros((512, 512, 3), np.float32) | |||||
| cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, | |||||
| cv2.LINE_AA) | |||||
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) | |||||
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) | |||||
| def enhance_face(self, img): | |||||
| img = cv2.resize(img, (self.size, self.size)) | |||||
| img_t = self.img2tensor(img) | |||||
| self.face_enhancer.eval() | |||||
| with torch.no_grad(): | |||||
| out, __ = self.face_enhancer(img_t) | |||||
| del img_t | |||||
| out = self.tensor2img(out) | |||||
| return out | |||||
| def img2tensor(self, img, is_norm=True): | |||||
| img_t = torch.from_numpy(img).to(self.device) / 255. | |||||
| if is_norm: | |||||
| img_t = (img_t - 0.5) / 0.5 | |||||
| img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB | |||||
| return img_t | |||||
| def tensor2img(self, img_t, pmax=255.0, is_denorm=True, imtype=np.uint8): | |||||
| if is_denorm: | |||||
| img_t = img_t * 0.5 + 0.5 | |||||
| img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR | |||||
| img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax | |||||
| return img_np.astype(imtype) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img_sr = None | |||||
| if self.use_sr: | |||||
| self.sr_model.eval() | |||||
| with torch.no_grad(): | |||||
| img_t = self.img2tensor(img, is_norm=False) | |||||
| img_out = self.sr_model(img_t) | |||||
| img_sr = img_out.squeeze(0).permute(1, 2, 0).flip(2).cpu().clamp_( | |||||
| 0, 1).numpy() | |||||
| img_sr = (img_sr * 255.0).round().astype(np.uint8) | |||||
| img = cv2.resize(img, img_sr.shape[:2][::-1]) | |||||
| result = {'img': img, 'img_sr': img_sr} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| img, img_sr = input['img'], input['img_sr'] | |||||
| img, img_sr = img.cpu().numpy(), img_sr.cpu().numpy() | |||||
| facebs, landms = self.face_detector.detect(img) | |||||
| height, width = img.shape[:2] | |||||
| full_mask = np.zeros(img.shape, dtype=np.float32) | |||||
| full_img = np.zeros(img.shape, dtype=np.uint8) | |||||
| for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): | |||||
| if faceb[4] < self.threshold: | |||||
| continue | |||||
| # fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0]) | |||||
| facial5points = np.reshape(facial5points, (2, 5)) | |||||
| of, of_112, tfm_inv = warp_and_crop_face( | |||||
| img, facial5points, crop_size=(self.size, self.size)) | |||||
| # detect orig face quality | |||||
| fq_o, fea_o = self.eqface.get_face_quality(of_112) | |||||
| if fq_o < self.fqa_thres: | |||||
| continue | |||||
| # enhance the face | |||||
| ef = self.enhance_face(of) | |||||
| # detect enhanced face quality | |||||
| ss = self.size // 256 | |||||
| ef_112 = cv2.resize(ef[35 * ss:-33 * ss, 32 * ss:-36 * ss], | |||||
| (112, 112)) # crop roi | |||||
| fq_e, fea_e = self.eqface.get_face_quality(ef_112) | |||||
| dist = squareform(pdist([fea_o, fea_e], 'cosine')).mean() | |||||
| if dist > self.id_thres: | |||||
| continue | |||||
| # blending parameter | |||||
| fq = max(1., (fq_o - self.fqa_thres)) | |||||
| fq = (1 - 2 * dist) * (1.0 / (1 + math.exp(-(2 * fq - 1)))) | |||||
| # blend face | |||||
| ef = cv2.addWeighted(ef, fq * self.alpha, of, 1 - fq * self.alpha, | |||||
| 0.0) | |||||
| tmp_mask = self.mask | |||||
| tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) | |||||
| tmp_mask = cv2.warpAffine( | |||||
| tmp_mask, tfm_inv, (width, height), flags=3) | |||||
| tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) | |||||
| mask = np.clip(tmp_mask - full_mask, 0, 1) | |||||
| full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] | |||||
| full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] | |||||
| if self.use_sr and img_sr is not None: | |||||
| out_img = cv2.convertScaleAbs(img_sr * (1 - full_mask) | |||||
| + full_img * full_mask) | |||||
| else: | |||||
| out_img = cv2.convertScaleAbs(img * (1 - full_mask) | |||||
| + full_img * full_mask) | |||||
| return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -163,6 +163,32 @@ class ImageDenoisePreprocessor(Preprocessor): | |||||
| return data | return data | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.cv, | |||||
| module_name=Preprocessors.image_portrait_enhancement_preprocessor) | |||||
| class ImagePortraitEnhancementPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """ | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data Dict[str, Any] | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| return data | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.cv, | Fields.cv, | ||||
| module_name=Preprocessors.image_instance_segmentation_preprocessor) | module_name=Preprocessors.image_instance_segmentation_preprocessor) | ||||
| @@ -1,6 +1,7 @@ | |||||
| from .base import DummyTrainer | from .base import DummyTrainer | ||||
| from .builder import build_trainer | from .builder import build_trainer | ||||
| from .cv import ImageInstanceSegmentationTrainer | |||||
| from .cv import (ImageInstanceSegmentationTrainer, | |||||
| ImagePortraitEnhancementTrainer) | |||||
| from .multi_modal import CLIPTrainer | from .multi_modal import CLIPTrainer | ||||
| from .nlp import SequenceClassificationTrainer | from .nlp import SequenceClassificationTrainer | ||||
| from .trainer import EpochBasedTrainer | from .trainer import EpochBasedTrainer | ||||
| @@ -1,2 +1,3 @@ | |||||
| from .image_instance_segmentation_trainer import \ | from .image_instance_segmentation_trainer import \ | ||||
| ImageInstanceSegmentationTrainer | ImageInstanceSegmentationTrainer | ||||
| from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | |||||
| @@ -0,0 +1,148 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from collections.abc import Mapping | |||||
| import torch | |||||
| from torch import distributed as dist | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.optimizer.builder import build_optimizer | |||||
| from modelscope.trainers.trainer import EpochBasedTrainer | |||||
| from modelscope.utils.constant import ModeKeys | |||||
| from modelscope.utils.logger import get_logger | |||||
| @TRAINERS.register_module(module_name='gpen') | |||||
| class ImagePortraitEnhancementTrainer(EpochBasedTrainer): | |||||
| def train_step(self, model, inputs): | |||||
| """ Perform a training step on a batch of inputs. | |||||
| Subclass and override to inject custom behavior. | |||||
| Args: | |||||
| model (`TorchModel`): The model to train. | |||||
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |||||
| The inputs and targets of the model. | |||||
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |||||
| argument `labels`. Check your model's documentation for all accepted arguments. | |||||
| Return: | |||||
| `torch.Tensor`: The tensor with training loss on this batch. | |||||
| """ | |||||
| # EvaluationHook will do evaluate and change mode to val, return to train mode | |||||
| # TODO: find more pretty way to change mode | |||||
| self.d_reg_every = self.cfg.train.get('d_reg_every', 16) | |||||
| self.g_reg_every = self.cfg.train.get('g_reg_every', 4) | |||||
| self.path_regularize = self.cfg.train.get('path_regularize', 2) | |||||
| self.r1 = self.cfg.train.get('r1', 10) | |||||
| train_outputs = dict() | |||||
| self._mode = ModeKeys.TRAIN | |||||
| inputs = self.collate_fn(inputs) | |||||
| # call model forward but not __call__ to skip postprocess | |||||
| if isinstance(inputs, Mapping): | |||||
| d_loss = model._train_forward_d(**inputs) | |||||
| else: | |||||
| d_loss = model._train_forward_d(inputs) | |||||
| train_outputs['d_loss'] = d_loss | |||||
| model.discriminator.zero_grad() | |||||
| d_loss.backward() | |||||
| self.optimizer_d.step() | |||||
| if self._iter % self.d_reg_every == 0: | |||||
| if isinstance(inputs, Mapping): | |||||
| r1_loss = model._train_forward_d_r1(**inputs) | |||||
| else: | |||||
| r1_loss = model._train_forward_d_r1(inputs) | |||||
| train_outputs['r1_loss'] = r1_loss | |||||
| model.discriminator.zero_grad() | |||||
| (self.r1 / 2 * r1_loss * self.d_reg_every).backward() | |||||
| self.optimizer_d.step() | |||||
| if isinstance(inputs, Mapping): | |||||
| g_loss = model._train_forward_g(**inputs) | |||||
| else: | |||||
| g_loss = model._train_forward_g(inputs) | |||||
| train_outputs['g_loss'] = g_loss | |||||
| model.generator.zero_grad() | |||||
| g_loss.backward() | |||||
| self.optimizer.step() | |||||
| path_loss = 0 | |||||
| if self._iter % self.g_reg_every == 0: | |||||
| if isinstance(inputs, Mapping): | |||||
| path_loss = model._train_forward_g_path(**inputs) | |||||
| else: | |||||
| path_loss = model._train_forward_g_path(inputs) | |||||
| train_outputs['path_loss'] = path_loss | |||||
| model.generator.zero_grad() | |||||
| weighted_path_loss = self.path_regularize * self.g_reg_every * path_loss | |||||
| weighted_path_loss.backward() | |||||
| self.optimizer.step() | |||||
| model.accumulate() | |||||
| if not isinstance(train_outputs, dict): | |||||
| raise TypeError('"model.forward()" must return a dict') | |||||
| # add model output info to log | |||||
| if 'log_vars' not in train_outputs: | |||||
| default_keys_pattern = ['loss'] | |||||
| match_keys = set([]) | |||||
| for key_p in default_keys_pattern: | |||||
| match_keys.update( | |||||
| [key for key in train_outputs.keys() if key_p in key]) | |||||
| log_vars = {} | |||||
| for key in match_keys: | |||||
| value = train_outputs.get(key, None) | |||||
| if value is not None: | |||||
| if dist.is_available() and dist.is_initialized(): | |||||
| value = value.data.clone() | |||||
| dist.all_reduce(value.div_(dist.get_world_size())) | |||||
| log_vars.update({key: value.item()}) | |||||
| self.log_buffer.update(log_vars) | |||||
| else: | |||||
| self.log_buffer.update(train_outputs['log_vars']) | |||||
| self.train_outputs = train_outputs | |||||
| def create_optimizer_and_scheduler(self): | |||||
| """ Create optimizer and lr scheduler | |||||
| We provide a default implementation, if you want to customize your own optimizer | |||||
| and lr scheduler, you can either pass a tuple through trainer init function or | |||||
| subclass this class and override this method. | |||||
| """ | |||||
| optimizer, lr_scheduler = self.optimizers | |||||
| if optimizer is None: | |||||
| optimizer_cfg = self.cfg.train.get('optimizer', None) | |||||
| else: | |||||
| optimizer_cfg = None | |||||
| optimizer_d_cfg = self.cfg.train.get('optimizer_d', None) | |||||
| optim_options = {} | |||||
| if optimizer_cfg is not None: | |||||
| optim_options = optimizer_cfg.pop('options', {}) | |||||
| optimizer = build_optimizer( | |||||
| self.model.generator, cfg=optimizer_cfg) | |||||
| if optimizer_d_cfg is not None: | |||||
| optimizer_d = build_optimizer( | |||||
| self.model.discriminator, cfg=optimizer_d_cfg) | |||||
| lr_options = {} | |||||
| self.optimizer = optimizer | |||||
| self.lr_scheduler = lr_scheduler | |||||
| self.optimizer_d = optimizer_d | |||||
| return self.optimizer, self.lr_scheduler, optim_options, lr_options | |||||
| @@ -14,5 +14,5 @@ __all__ = [ | |||||
| 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | ||||
| 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | ||||
| 'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook', | 'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook', | ||||
| 'BestCkptSaverHook' | |||||
| 'BestCkptSaverHook', 'NoneOptimizerHook', 'NoneLrSchedulerHook' | |||||
| ] | ] | ||||
| @@ -115,3 +115,18 @@ class PlateauLrSchedulerHook(LrSchedulerHook): | |||||
| self.warmup_lr_scheduler.step(metrics=metrics) | self.warmup_lr_scheduler.step(metrics=metrics) | ||||
| else: | else: | ||||
| trainer.lr_scheduler.step(metrics=metrics) | trainer.lr_scheduler.step(metrics=metrics) | ||||
| @HOOKS.register_module() | |||||
| class NoneLrSchedulerHook(LrSchedulerHook): | |||||
| PRIORITY = Priority.LOW # should be after EvaluationHook | |||||
| def __init__(self, by_epoch=True, warmup=None) -> None: | |||||
| super().__init__(by_epoch=by_epoch, warmup=warmup) | |||||
| def before_run(self, trainer): | |||||
| return | |||||
| def after_train_epoch(self, trainer): | |||||
| return | |||||
| @@ -200,3 +200,19 @@ class ApexAMPOptimizerHook(OptimizerHook): | |||||
| trainer.optimizer.step() | trainer.optimizer.step() | ||||
| trainer.optimizer.zero_grad() | trainer.optimizer.zero_grad() | ||||
| @HOOKS.register_module() | |||||
| class NoneOptimizerHook(OptimizerHook): | |||||
| def __init__(self, cumulative_iters=1, grad_clip=None, loss_keys='loss'): | |||||
| super(NoneOptimizerHook, self).__init__( | |||||
| grad_clip=grad_clip, loss_keys=loss_keys) | |||||
| self.cumulative_iters = cumulative_iters | |||||
| def before_run(self, trainer): | |||||
| return | |||||
| def after_train_iter(self, trainer): | |||||
| return | |||||
| @@ -43,6 +43,7 @@ class CVTasks(object): | |||||
| image_colorization = 'image-colorization' | image_colorization = 'image-colorization' | ||||
| image_color_enhancement = 'image-color-enhancement' | image_color_enhancement = 'image-color-enhancement' | ||||
| image_denoising = 'image-denoising' | image_denoising = 'image-denoising' | ||||
| image_portrait_enhancement = 'image-portrait-enhancement' | |||||
| # image generation | # image generation | ||||
| image_to_image_translation = 'image-to-image-translation' | image_to_image_translation = 'image-to-image-translation' | ||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ImagePortraitEnhancementTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||||
| self.test_image = 'data/test/images/Solvay_conference_1927.png' | |||||
| def pipeline_inference(self, pipeline: Pipeline, test_image: str): | |||||
| result = pipeline(test_image) | |||||
| if result is not None: | |||||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||||
| print(f'Output written to {osp.abspath("result.png")}') | |||||
| else: | |||||
| raise Exception('Testing failed: invalid output') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| face_enhancement = pipeline( | |||||
| Tasks.image_portrait_enhancement, model=self.model_id) | |||||
| self.pipeline_inference(face_enhancement, self.test_image) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| face_enhancement = pipeline(Tasks.image_portrait_enhancement) | |||||
| self.pipeline_inference(face_enhancement, self.test_image) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,119 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from typing import Callable, List, Optional, Tuple, Union | |||||
| import cv2 | |||||
| import torch | |||||
| from torch.utils import data as data | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models.cv.image_portrait_enhancement import \ | |||||
| ImagePortraitEnhancement | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||||
| class PairedImageDataset(data.Dataset): | |||||
| def __init__(self, root, size=512): | |||||
| super(PairedImageDataset, self).__init__() | |||||
| self.size = size | |||||
| gt_dir = osp.join(root, 'gt') | |||||
| lq_dir = osp.join(root, 'lq') | |||||
| self.gt_filelist = os.listdir(gt_dir) | |||||
| self.gt_filelist = sorted( | |||||
| self.gt_filelist, key=lambda x: int(x[:-4])) | |||||
| self.gt_filelist = [ | |||||
| osp.join(gt_dir, f) for f in self.gt_filelist | |||||
| ] | |||||
| self.lq_filelist = os.listdir(lq_dir) | |||||
| self.lq_filelist = sorted( | |||||
| self.lq_filelist, key=lambda x: int(x[:-4])) | |||||
| self.lq_filelist = [ | |||||
| osp.join(lq_dir, f) for f in self.lq_filelist | |||||
| ] | |||||
| def _img_to_tensor(self, img): | |||||
| img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute( | |||||
| 2, 0, 1).type(torch.float32) / 255. | |||||
| return (img - 0.5) / 0.5 | |||||
| def __getitem__(self, index): | |||||
| lq = cv2.imread(self.lq_filelist[index]) | |||||
| gt = cv2.imread(self.gt_filelist[index]) | |||||
| lq = cv2.resize( | |||||
| lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||||
| gt = cv2.resize( | |||||
| gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||||
| return \ | |||||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||||
| def __len__(self): | |||||
| return len(self.gt_filelist) | |||||
| def to_torch_dataset(self, | |||||
| columns: Union[str, List[str]] = None, | |||||
| preprocessors: Union[Callable, | |||||
| List[Callable]] = None, | |||||
| **format_kwargs): | |||||
| # self.preprocessor = preprocessors | |||||
| return self | |||||
| self.dataset = PairedImageDataset( | |||||
| './data/test/images/face_enhancement/') | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=self.dataset, | |||||
| eval_dataset=self.dataset, | |||||
| device='gpu', | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(name='gpen', default_args=kwargs) | |||||
| trainer.train() | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | |||||
| tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(tmp_dir): | |||||
| os.makedirs(tmp_dir) | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = ImagePortraitEnhancement.from_pretrained(cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.dataset, | |||||
| eval_dataset=self.dataset, | |||||
| device='gpu', | |||||
| max_epochs=2, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(name='gpen', default_args=kwargs) | |||||
| trainer.train() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||