Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9590794master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:fa8ab905e8374a0f94b4bfbfc81da14e762c71eaf64bae85bdd03b07cdf884c2 | |||
| size 859206 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:8cd14710143ba1a912e3ef574d0bf71c7e40bf9897522cba07ecae2567343064 | |||
| size 850603 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:d7f166ecb3a6913dbd05a1eb271399cbaa731d1074ac03184c13ae245ca66819 | |||
| size 800380 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:e95d11661485fc0e6f326398f953459dcb3e65b7f4a6c892611266067cf8fe3a | |||
| size 245773 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:03972400b20b3e6f1d056b359d9c9f12952653a67a73b36018504ce9ee9edf9d | |||
| size 254261 | |||
| @@ -16,6 +16,7 @@ class Models(object): | |||
| nafnet = 'nafnet' | |||
| csrnet = 'csrnet' | |||
| cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||
| gpen = 'gpen' | |||
| product_retrieval_embedding = 'product-retrieval-embedding' | |||
| # nlp models | |||
| @@ -91,6 +92,7 @@ class Pipelines(object): | |||
| image2image_translation = 'image-to-image-translation' | |||
| live_category = 'live-category' | |||
| video_category = 'video-category' | |||
| image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||
| image_to_image_generation = 'image-to-image-generation' | |||
| # nlp tasks | |||
| @@ -160,6 +162,7 @@ class Preprocessors(object): | |||
| image_denoie_preprocessor = 'image-denoise-preprocessor' | |||
| image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | |||
| image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | |||
| image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | |||
| # nlp preprocessor | |||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | |||
| @@ -207,3 +210,5 @@ class Metrics(object): | |||
| text_gen_metric = 'text-gen-metric' | |||
| # metrics for image-color-enhance task | |||
| image_color_enhance_metric = 'image-color-enhance-metric' | |||
| # metrics for image-portrait-enhancement task | |||
| image_portrait_enhancement_metric = 'image-portrait-enhancement-metric' | |||
| @@ -10,6 +10,7 @@ if TYPE_CHECKING: | |||
| from .image_denoise_metric import ImageDenoiseMetric | |||
| from .image_instance_segmentation_metric import \ | |||
| ImageInstanceSegmentationCOCOMetric | |||
| from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric | |||
| from .sequence_classification_metric import SequenceClassificationMetric | |||
| from .text_generation_metric import TextGenerationMetric | |||
| @@ -21,6 +22,8 @@ else: | |||
| 'image_denoise_metric': ['ImageDenoiseMetric'], | |||
| 'image_instance_segmentation_metric': | |||
| ['ImageInstanceSegmentationCOCOMetric'], | |||
| 'image_portrait_enhancement_metric': | |||
| ['ImagePortraitEnhancementMetric'], | |||
| 'sequence_classification_metric': ['SequenceClassificationMetric'], | |||
| 'text_generation_metric': ['TextGenerationMetric'], | |||
| } | |||
| @@ -23,7 +23,9 @@ task_default_metrics = { | |||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
| Tasks.text_generation: [Metrics.text_gen_metric], | |||
| Tasks.image_denoising: [Metrics.image_denoise_metric], | |||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric] | |||
| Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], | |||
| Tasks.image_portrait_enhancement: | |||
| [Metrics.image_portrait_enhancement_metric], | |||
| } | |||
| @@ -0,0 +1,47 @@ | |||
| from typing import Dict | |||
| import numpy as np | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| def calculate_psnr(img, img2): | |||
| assert img.shape == img2.shape, ( | |||
| f'Image shapes are different: {img.shape}, {img2.shape}.') | |||
| img = img.astype(np.float64) | |||
| img2 = img2.astype(np.float64) | |||
| mse = np.mean((img - img2)**2) | |||
| if mse == 0: | |||
| return float('inf') | |||
| return 10. * np.log10(255. * 255. / mse) | |||
| @METRICS.register_module( | |||
| group_key=default_group, | |||
| module_name=Metrics.image_portrait_enhancement_metric) | |||
| class ImagePortraitEnhancementMetric(Metric): | |||
| """The metric for image-portrait-enhancement task. | |||
| """ | |||
| def __init__(self): | |||
| self.preds = [] | |||
| self.targets = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| ground_truths = outputs['target'] | |||
| eval_results = outputs['pred'] | |||
| self.preds.extend(eval_results) | |||
| self.targets.extend(ground_truths) | |||
| def evaluate(self): | |||
| psnrs = [ | |||
| calculate_psnr(pred, target) | |||
| for pred, target in zip(self.preds, self.targets) | |||
| ] | |||
| return {MetricKeys.PSNR: sum(psnrs) / len(psnrs)} | |||
| @@ -3,6 +3,6 @@ from . import (action_recognition, animal_recognition, cartoon, | |||
| cmdssl_video_embedding, face_detection, face_generation, | |||
| image_classification, image_color_enhance, image_colorization, | |||
| image_denoise, image_instance_segmentation, | |||
| image_to_image_generation, image_to_image_translation, | |||
| object_detection, product_retrieval_embedding, super_resolution, | |||
| virual_tryon) | |||
| image_portrait_enhancement, image_to_image_generation, | |||
| image_to_image_translation, object_detection, | |||
| product_retrieval_embedding, super_resolution, virual_tryon) | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .image_portrait_enhancement import ImagePortraitEnhancement | |||
| else: | |||
| _import_structure = { | |||
| 'image_portrait_enhancement': ['ImagePortraitEnhancement'] | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,252 @@ | |||
| import cv2 | |||
| import numpy as np | |||
| from skimage import transform as trans | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| # reference facial points, a list of coordinates (x,y) | |||
| REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], | |||
| [65.53179932, 51.50139999], | |||
| [48.02519989, | |||
| 71.73660278], [33.54930115, 92.3655014], | |||
| [62.72990036, 92.20410156]] | |||
| DEFAULT_CROP_SIZE = (96, 112) | |||
| def _umeyama(src, dst, estimate_scale=True, scale=1.0): | |||
| """Estimate N-D similarity transformation with or without scaling. | |||
| Parameters | |||
| ---------- | |||
| src : (M, N) array | |||
| Source coordinates. | |||
| dst : (M, N) array | |||
| Destination coordinates. | |||
| estimate_scale : bool | |||
| Whether to estimate scaling factor. | |||
| Returns | |||
| ------- | |||
| T : (N + 1, N + 1) | |||
| The homogeneous similarity transformation matrix. The matrix contains | |||
| NaN values only if the problem is not well-conditioned. | |||
| References | |||
| ---------- | |||
| .. [1] "Least-squares estimation of transformation parameters between two | |||
| point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573` | |||
| """ | |||
| num = src.shape[0] | |||
| dim = src.shape[1] | |||
| # Compute mean of src and dst. | |||
| src_mean = src.mean(axis=0) | |||
| dst_mean = dst.mean(axis=0) | |||
| # Subtract mean from src and dst. | |||
| src_demean = src - src_mean | |||
| dst_demean = dst - dst_mean | |||
| # Eq. (38). | |||
| A = dst_demean.T @ src_demean / num | |||
| # Eq. (39). | |||
| d = np.ones((dim, ), dtype=np.double) | |||
| if np.linalg.det(A) < 0: | |||
| d[dim - 1] = -1 | |||
| T = np.eye(dim + 1, dtype=np.double) | |||
| U, S, V = np.linalg.svd(A) | |||
| # Eq. (40) and (43). | |||
| rank = np.linalg.matrix_rank(A) | |||
| if rank == 0: | |||
| return np.nan * T | |||
| elif rank == dim - 1: | |||
| if np.linalg.det(U) * np.linalg.det(V) > 0: | |||
| T[:dim, :dim] = U @ V | |||
| else: | |||
| s = d[dim - 1] | |||
| d[dim - 1] = -1 | |||
| T[:dim, :dim] = U @ np.diag(d) @ V | |||
| d[dim - 1] = s | |||
| else: | |||
| T[:dim, :dim] = U @ np.diag(d) @ V | |||
| if estimate_scale: | |||
| # Eq. (41) and (42). | |||
| scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d) | |||
| else: | |||
| scale = scale | |||
| T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T) | |||
| T[:dim, :dim] *= scale | |||
| return T, scale | |||
| class FaceWarpException(Exception): | |||
| def __str__(self): | |||
| return 'In File {}:{}'.format(__file__, super.__str__(self)) | |||
| def get_reference_facial_points(output_size=None, | |||
| inner_padding_factor=0.0, | |||
| outer_padding=(0, 0), | |||
| default_square=False): | |||
| ref_5pts = np.array(REFERENCE_FACIAL_POINTS) | |||
| ref_crop_size = np.array(DEFAULT_CROP_SIZE) | |||
| # 0) make the inner region a square | |||
| if default_square: | |||
| size_diff = max(ref_crop_size) - ref_crop_size | |||
| ref_5pts += size_diff / 2 | |||
| ref_crop_size += size_diff | |||
| if (output_size and output_size[0] == ref_crop_size[0] | |||
| and output_size[1] == ref_crop_size[1]): | |||
| return ref_5pts | |||
| if (inner_padding_factor == 0 and outer_padding == (0, 0)): | |||
| if output_size is None: | |||
| logger.info('No paddings to do: return default reference points') | |||
| return ref_5pts | |||
| else: | |||
| raise FaceWarpException( | |||
| 'No paddings to do, output_size must be None or {}'.format( | |||
| ref_crop_size)) | |||
| # check output size | |||
| if not (0 <= inner_padding_factor <= 1.0): | |||
| raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') | |||
| if ((inner_padding_factor > 0 or outer_padding[0] > 0 | |||
| or outer_padding[1] > 0) and output_size is None): | |||
| output_size = ref_crop_size * (1 + inner_padding_factor * 2).astype( | |||
| np.int32) | |||
| output_size += np.array(outer_padding) | |||
| logger.info('deduced from paddings, output_size = ', output_size) | |||
| if not (outer_padding[0] < output_size[0] | |||
| and outer_padding[1] < output_size[1]): | |||
| raise FaceWarpException('Not (outer_padding[0] < output_size[0]' | |||
| 'and outer_padding[1] < output_size[1])') | |||
| # 1) pad the inner region according inner_padding_factor | |||
| if inner_padding_factor > 0: | |||
| size_diff = ref_crop_size * inner_padding_factor * 2 | |||
| ref_5pts += size_diff / 2 | |||
| ref_crop_size += np.round(size_diff).astype(np.int32) | |||
| # 2) resize the padded inner region | |||
| size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 | |||
| if size_bf_outer_pad[0] * ref_crop_size[1] != size_bf_outer_pad[ | |||
| 1] * ref_crop_size[0]: | |||
| raise FaceWarpException( | |||
| 'Must have (output_size - outer_padding)' | |||
| '= some_scale * (crop_size * (1.0 + inner_padding_factor)') | |||
| scale_factor = size_bf_outer_pad[0].astype(np.float32) / ref_crop_size[0] | |||
| ref_5pts = ref_5pts * scale_factor | |||
| ref_crop_size = size_bf_outer_pad | |||
| # 3) add outer_padding to make output_size | |||
| reference_5point = ref_5pts + np.array(outer_padding) | |||
| ref_crop_size = output_size | |||
| return reference_5point | |||
| def get_affine_transform_matrix(src_pts, dst_pts): | |||
| tfm = np.float32([[1, 0, 0], [0, 1, 0]]) | |||
| n_pts = src_pts.shape[0] | |||
| ones = np.ones((n_pts, 1), src_pts.dtype) | |||
| src_pts_ = np.hstack([src_pts, ones]) | |||
| dst_pts_ = np.hstack([dst_pts, ones]) | |||
| A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) | |||
| if rank == 3: | |||
| tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], | |||
| [A[0, 1], A[1, 1], A[2, 1]]]) | |||
| elif rank == 2: | |||
| tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) | |||
| return tfm | |||
| def get_params(reference_pts, facial_pts, align_type): | |||
| ref_pts = np.float32(reference_pts) | |||
| ref_pts_shp = ref_pts.shape | |||
| if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: | |||
| raise FaceWarpException( | |||
| 'reference_pts.shape must be (K,2) or (2,K) and K>2') | |||
| if ref_pts_shp[0] == 2: | |||
| ref_pts = ref_pts.T | |||
| src_pts = np.float32(facial_pts) | |||
| src_pts_shp = src_pts.shape | |||
| if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: | |||
| raise FaceWarpException( | |||
| 'facial_pts.shape must be (K,2) or (2,K) and K>2') | |||
| if src_pts_shp[0] == 2: | |||
| src_pts = src_pts.T | |||
| if src_pts.shape != ref_pts.shape: | |||
| raise FaceWarpException( | |||
| 'facial_pts and reference_pts must have the same shape') | |||
| if align_type == 'cv2_affine': | |||
| tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) | |||
| tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3]) | |||
| elif align_type == 'affine': | |||
| tfm = get_affine_transform_matrix(src_pts, ref_pts) | |||
| tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) | |||
| else: | |||
| params, scale = _umeyama(src_pts, ref_pts) | |||
| tfm = params[:2, :] | |||
| params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0 / scale) | |||
| tfm_inv = params[:2, :] | |||
| return tfm, tfm_inv | |||
| def warp_and_crop_face(src_img, | |||
| facial_pts, | |||
| reference_pts=None, | |||
| crop_size=(96, 112), | |||
| align_type='smilarity'): # smilarity cv2_affine affine | |||
| reference_pts_112 = get_reference_facial_points((112, 112), 0.25, (0, 0), | |||
| True) | |||
| if reference_pts is None: | |||
| if crop_size[0] == 96 and crop_size[1] == 112: | |||
| reference_pts = REFERENCE_FACIAL_POINTS | |||
| else: | |||
| default_square = True # False | |||
| inner_padding_factor = 0.25 # 0 | |||
| outer_padding = (0, 0) | |||
| output_size = crop_size | |||
| reference_pts = get_reference_facial_points( | |||
| output_size, inner_padding_factor, outer_padding, | |||
| default_square) | |||
| tfm, tfm_inv = get_params(reference_pts, facial_pts, align_type) | |||
| tfm_112, tfm_inv_112 = get_params(reference_pts_112, facial_pts, | |||
| align_type) | |||
| if src_img is not None: | |||
| face_img = cv2.warpAffine( | |||
| src_img, tfm, (crop_size[0], crop_size[1]), flags=3) | |||
| face_img_112 = cv2.warpAffine(src_img, tfm_112, (112, 112), flags=3) | |||
| return face_img, face_img_112, tfm_inv | |||
| else: | |||
| return tfm, tfm_inv | |||
| @@ -0,0 +1,57 @@ | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from .model_resnet import FaceQuality, ResNet | |||
| class FQA(object): | |||
| def __init__(self, backbone_path, quality_path, device='cuda', size=112): | |||
| self.BACKBONE = ResNet(num_layers=100, feature_dim=512) | |||
| self.QUALITY = FaceQuality(512 * 7 * 7) | |||
| self.size = size | |||
| self.device = device | |||
| self.load_model(backbone_path, quality_path) | |||
| def load_model(self, backbone_path, quality_path): | |||
| checkpoint = torch.load(backbone_path, map_location='cpu') | |||
| self.load_state_dict(self.BACKBONE, checkpoint) | |||
| checkpoint = torch.load(quality_path, map_location='cpu') | |||
| self.load_state_dict(self.QUALITY, checkpoint) | |||
| self.BACKBONE.to(self.device) | |||
| self.QUALITY.to(self.device) | |||
| self.BACKBONE.eval() | |||
| self.QUALITY.eval() | |||
| def load_state_dict(self, model, state_dict): | |||
| all_keys = {k for k in state_dict.keys()} | |||
| for k in all_keys: | |||
| if k.startswith('module.'): | |||
| state_dict[k[7:]] = state_dict.pop(k) | |||
| model_dict = model.state_dict() | |||
| pretrained_dict = { | |||
| k: v | |||
| for k, v in state_dict.items() | |||
| if k in model_dict and v.size() == model_dict[k].size() | |||
| } | |||
| model_dict.update(pretrained_dict) | |||
| model.load_state_dict(model_dict) | |||
| def get_face_quality(self, img): | |||
| img = torch.from_numpy(img).permute(2, 0, | |||
| 1).unsqueeze(0).flip(1).cuda() | |||
| img = (img - 127.5) / 128.0 | |||
| # extract features & predict quality | |||
| with torch.no_grad(): | |||
| feature, fc = self.BACKBONE(img.to(self.device), True) | |||
| s = self.QUALITY(fc)[0] | |||
| return s.cpu().numpy()[0], feature.cpu().numpy()[0] | |||
| @@ -0,0 +1,130 @@ | |||
| import torch | |||
| from torch import nn | |||
| class BottleNeck_IR(nn.Module): | |||
| def __init__(self, in_channel, out_channel, stride, dim_match): | |||
| super(BottleNeck_IR, self).__init__() | |||
| self.res_layer = nn.Sequential( | |||
| nn.BatchNorm2d(in_channel), | |||
| nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), | |||
| nn.BatchNorm2d(out_channel), nn.PReLU(out_channel), | |||
| nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), | |||
| nn.BatchNorm2d(out_channel)) | |||
| if dim_match: | |||
| self.shortcut_layer = None | |||
| else: | |||
| self.shortcut_layer = nn.Sequential( | |||
| nn.Conv2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size=(1, 1), | |||
| stride=stride, | |||
| bias=False), nn.BatchNorm2d(out_channel)) | |||
| def forward(self, x): | |||
| shortcut = x | |||
| res = self.res_layer(x) | |||
| if self.shortcut_layer is not None: | |||
| shortcut = self.shortcut_layer(x) | |||
| return shortcut + res | |||
| channel_list = [64, 64, 128, 256, 512] | |||
| def get_layers(num_layers): | |||
| if num_layers == 34: | |||
| return [3, 4, 6, 3] | |||
| if num_layers == 50: | |||
| return [3, 4, 14, 3] | |||
| elif num_layers == 100: | |||
| return [3, 13, 30, 3] | |||
| elif num_layers == 152: | |||
| return [3, 8, 36, 3] | |||
| class ResNet(nn.Module): | |||
| def __init__(self, | |||
| num_layers=100, | |||
| feature_dim=512, | |||
| drop_ratio=0.4, | |||
| channel_list=channel_list): | |||
| super(ResNet, self).__init__() | |||
| assert num_layers in [34, 50, 100, 152] | |||
| layers = get_layers(num_layers) | |||
| block = BottleNeck_IR | |||
| self.input_layer = nn.Sequential( | |||
| nn.Conv2d( | |||
| 3, channel_list[0], (3, 3), stride=1, padding=1, bias=False), | |||
| nn.BatchNorm2d(channel_list[0]), nn.PReLU(channel_list[0])) | |||
| self.layer1 = self._make_layer( | |||
| block, channel_list[0], channel_list[1], layers[0], stride=2) | |||
| self.layer2 = self._make_layer( | |||
| block, channel_list[1], channel_list[2], layers[1], stride=2) | |||
| self.layer3 = self._make_layer( | |||
| block, channel_list[2], channel_list[3], layers[2], stride=2) | |||
| self.layer4 = self._make_layer( | |||
| block, channel_list[3], channel_list[4], layers[3], stride=2) | |||
| self.output_layer = nn.Sequential( | |||
| nn.BatchNorm2d(512), nn.Dropout(drop_ratio), nn.Flatten()) | |||
| self.feature_layer = nn.Sequential( | |||
| nn.Linear(512 * 7 * 7, feature_dim), nn.BatchNorm1d(feature_dim)) | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |||
| nn.init.xavier_uniform_(m.weight) | |||
| if m.bias is not None: | |||
| nn.init.constant_(m.bias, 0.0) | |||
| elif isinstance(m, nn.BatchNorm2d) or isinstance( | |||
| m, nn.BatchNorm1d): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| def _make_layer(self, block, in_channel, out_channel, blocks, stride): | |||
| layers = [] | |||
| layers.append(block(in_channel, out_channel, stride, False)) | |||
| for i in range(1, blocks): | |||
| layers.append(block(out_channel, out_channel, 1, True)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x, fc=False): | |||
| x = self.input_layer(x) | |||
| x = self.layer1(x) | |||
| x = self.layer2(x) | |||
| x = self.layer3(x) | |||
| x = self.layer4(x) | |||
| x = self.output_layer(x) | |||
| feature = self.feature_layer(x) | |||
| if fc: | |||
| return feature, x | |||
| return feature | |||
| class FaceQuality(nn.Module): | |||
| def __init__(self, feature_dim): | |||
| super(FaceQuality, self).__init__() | |||
| self.qualtiy = nn.Sequential( | |||
| nn.Linear(feature_dim, 512, bias=False), nn.BatchNorm1d(512), | |||
| nn.ReLU(inplace=True), nn.Linear(512, 2, bias=False), | |||
| nn.Softmax(dim=1)) | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |||
| nn.init.xavier_uniform_(m.weight) | |||
| if m.bias is not None: | |||
| nn.init.constant_(m.bias, 0.0) | |||
| elif isinstance(m, nn.BatchNorm2d) or isinstance( | |||
| m, nn.BatchNorm1d): | |||
| nn.init.constant_(m.weight, 1) | |||
| nn.init.constant_(m.bias, 0) | |||
| def forward(self, x): | |||
| x = self.qualtiy(x) | |||
| return x[:, 0:1] | |||
| @@ -0,0 +1,813 @@ | |||
| import functools | |||
| import itertools | |||
| import math | |||
| import operator | |||
| import random | |||
| import torch | |||
| from torch import nn | |||
| from torch.autograd import Function | |||
| from torch.nn import functional as F | |||
| from modelscope.models.cv.face_generation.op import (FusedLeakyReLU, | |||
| fused_leaky_relu, | |||
| upfirdn2d) | |||
| class PixelNorm(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def forward(self, input): | |||
| return input * torch.rsqrt( | |||
| torch.mean(input**2, dim=1, keepdim=True) + 1e-8) | |||
| def make_kernel(k): | |||
| k = torch.tensor(k, dtype=torch.float32) | |||
| if k.ndim == 1: | |||
| k = k[None, :] * k[:, None] | |||
| k /= k.sum() | |||
| return k | |||
| class Upsample(nn.Module): | |||
| def __init__(self, kernel, factor=2): | |||
| super().__init__() | |||
| self.factor = factor | |||
| kernel = make_kernel(kernel) * (factor**2) | |||
| self.register_buffer('kernel', kernel) | |||
| p = kernel.shape[0] - factor | |||
| pad0 = (p + 1) // 2 + factor - 1 | |||
| pad1 = p // 2 | |||
| self.pad = (pad0, pad1) | |||
| def forward(self, input): | |||
| out = upfirdn2d( | |||
| input, self.kernel, up=self.factor, down=1, pad=self.pad) | |||
| return out | |||
| class Downsample(nn.Module): | |||
| def __init__(self, kernel, factor=2): | |||
| super().__init__() | |||
| self.factor = factor | |||
| kernel = make_kernel(kernel) | |||
| self.register_buffer('kernel', kernel) | |||
| p = kernel.shape[0] - factor | |||
| pad0 = (p + 1) // 2 | |||
| pad1 = p // 2 | |||
| self.pad = (pad0, pad1) | |||
| def forward(self, input): | |||
| out = upfirdn2d( | |||
| input, self.kernel, up=1, down=self.factor, pad=self.pad) | |||
| return out | |||
| class Blur(nn.Module): | |||
| def __init__(self, kernel, pad, upsample_factor=1): | |||
| super().__init__() | |||
| kernel = make_kernel(kernel) | |||
| if upsample_factor > 1: | |||
| kernel = kernel * (upsample_factor**2) | |||
| self.register_buffer('kernel', kernel) | |||
| self.pad = pad | |||
| def forward(self, input): | |||
| out = upfirdn2d(input, self.kernel, pad=self.pad) | |||
| return out | |||
| class EqualConv2d(nn.Module): | |||
| def __init__(self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True): | |||
| super().__init__() | |||
| self.weight = nn.Parameter( | |||
| torch.randn(out_channel, in_channel, kernel_size, kernel_size)) | |||
| self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |||
| self.stride = stride | |||
| self.padding = padding | |||
| if bias: | |||
| self.bias = nn.Parameter(torch.zeros(out_channel)) | |||
| else: | |||
| self.bias = None | |||
| def forward(self, input): | |||
| out = F.conv2d( | |||
| input, | |||
| self.weight * self.scale, | |||
| bias=self.bias, | |||
| stride=self.stride, | |||
| padding=self.padding, | |||
| ) | |||
| return out | |||
| def __repr__(self): | |||
| return ( | |||
| f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' | |||
| f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' | |||
| ) | |||
| class EqualLinear(nn.Module): | |||
| def __init__(self, | |||
| in_dim, | |||
| out_dim, | |||
| bias=True, | |||
| bias_init=0, | |||
| lr_mul=1, | |||
| activation=None): | |||
| super().__init__() | |||
| self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) | |||
| if bias: | |||
| self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) | |||
| else: | |||
| self.bias = None | |||
| self.activation = activation | |||
| self.scale = (1 / math.sqrt(in_dim)) * lr_mul | |||
| self.lr_mul = lr_mul | |||
| def forward(self, input): | |||
| if self.activation: | |||
| out = F.linear(input, self.weight * self.scale) | |||
| out = fused_leaky_relu(out, self.bias * self.lr_mul) | |||
| else: | |||
| out = F.linear( | |||
| input, self.weight * self.scale, bias=self.bias * self.lr_mul) | |||
| return out | |||
| def __repr__(self): | |||
| return ( | |||
| f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' | |||
| ) | |||
| class ScaledLeakyReLU(nn.Module): | |||
| def __init__(self, negative_slope=0.2): | |||
| super().__init__() | |||
| self.negative_slope = negative_slope | |||
| def forward(self, input): | |||
| out = F.leaky_relu(input, negative_slope=self.negative_slope) | |||
| return out * math.sqrt(2) | |||
| class ModulatedConv2d(nn.Module): | |||
| def __init__( | |||
| self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| style_dim, | |||
| demodulate=True, | |||
| upsample=False, | |||
| downsample=False, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| ): | |||
| super().__init__() | |||
| self.eps = 1e-8 | |||
| self.kernel_size = kernel_size | |||
| self.in_channel = in_channel | |||
| self.out_channel = out_channel | |||
| self.upsample = upsample | |||
| self.downsample = downsample | |||
| if upsample: | |||
| factor = 2 | |||
| p = (len(blur_kernel) - factor) - (kernel_size - 1) | |||
| pad0 = (p + 1) // 2 + factor - 1 | |||
| pad1 = p // 2 + 1 | |||
| self.blur = Blur( | |||
| blur_kernel, pad=(pad0, pad1), upsample_factor=factor) | |||
| if downsample: | |||
| factor = 2 | |||
| p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||
| pad0 = (p + 1) // 2 | |||
| pad1 = p // 2 | |||
| self.blur = Blur(blur_kernel, pad=(pad0, pad1)) | |||
| fan_in = in_channel * kernel_size**2 | |||
| self.scale = 1 / math.sqrt(fan_in) | |||
| self.padding = kernel_size // 2 | |||
| self.weight = nn.Parameter( | |||
| torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) | |||
| self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) | |||
| self.demodulate = demodulate | |||
| def __repr__(self): | |||
| return ( | |||
| f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' | |||
| f'upsample={self.upsample}, downsample={self.downsample})') | |||
| def forward(self, input, style): | |||
| batch, in_channel, height, width = input.shape | |||
| style = self.modulation(style).view(batch, 1, in_channel, 1, 1) | |||
| weight = self.scale * self.weight * style | |||
| if self.demodulate: | |||
| demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) | |||
| weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) | |||
| weight = weight.view(batch * self.out_channel, in_channel, | |||
| self.kernel_size, self.kernel_size) | |||
| if self.upsample: | |||
| input = input.view(1, batch * in_channel, height, width) | |||
| weight = weight.view(batch, self.out_channel, in_channel, | |||
| self.kernel_size, self.kernel_size) | |||
| weight = weight.transpose(1, 2).reshape(batch * in_channel, | |||
| self.out_channel, | |||
| self.kernel_size, | |||
| self.kernel_size) | |||
| out = F.conv_transpose2d( | |||
| input, weight, padding=0, stride=2, groups=batch) | |||
| _, _, height, width = out.shape | |||
| out = out.view(batch, self.out_channel, height, width) | |||
| out = self.blur(out) | |||
| elif self.downsample: | |||
| input = self.blur(input) | |||
| _, _, height, width = input.shape | |||
| input = input.view(1, batch * in_channel, height, width) | |||
| out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) | |||
| _, _, height, width = out.shape | |||
| out = out.view(batch, self.out_channel, height, width) | |||
| else: | |||
| input = input.view(1, batch * in_channel, height, width) | |||
| out = F.conv2d(input, weight, padding=self.padding, groups=batch) | |||
| _, _, height, width = out.shape | |||
| out = out.view(batch, self.out_channel, height, width) | |||
| return out | |||
| class NoiseInjection(nn.Module): | |||
| def __init__(self, isconcat=True): | |||
| super().__init__() | |||
| self.isconcat = isconcat | |||
| self.weight = nn.Parameter(torch.zeros(1)) | |||
| def forward(self, image, noise=None): | |||
| if noise is None: | |||
| batch, channel, height, width = image.shape | |||
| noise = image.new_empty(batch, channel, height, width).normal_() | |||
| if self.isconcat: | |||
| return torch.cat((image, self.weight * noise), dim=1) | |||
| else: | |||
| return image + self.weight * noise | |||
| class ConstantInput(nn.Module): | |||
| def __init__(self, channel, size=4): | |||
| super().__init__() | |||
| self.input = nn.Parameter(torch.randn(1, channel, size, size)) | |||
| def forward(self, input): | |||
| batch = input.shape[0] | |||
| out = self.input.repeat(batch, 1, 1, 1) | |||
| return out | |||
| class StyledConv(nn.Module): | |||
| def __init__( | |||
| self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| style_dim, | |||
| upsample=False, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| demodulate=True, | |||
| isconcat=True, | |||
| ): | |||
| super().__init__() | |||
| self.conv = ModulatedConv2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| style_dim, | |||
| upsample=upsample, | |||
| blur_kernel=blur_kernel, | |||
| demodulate=demodulate, | |||
| ) | |||
| self.noise = NoiseInjection(isconcat) | |||
| # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) | |||
| # self.activate = ScaledLeakyReLU(0.2) | |||
| feat_multiplier = 2 if isconcat else 1 | |||
| self.activate = FusedLeakyReLU(out_channel * feat_multiplier) | |||
| def forward(self, input, style, noise=None): | |||
| out = self.conv(input, style) | |||
| out = self.noise(out, noise=noise) | |||
| # out = out + self.bias | |||
| out = self.activate(out) | |||
| return out | |||
| class ToRGB(nn.Module): | |||
| def __init__(self, | |||
| in_channel, | |||
| style_dim, | |||
| upsample=True, | |||
| blur_kernel=[1, 3, 3, 1]): | |||
| super().__init__() | |||
| if upsample: | |||
| self.upsample = Upsample(blur_kernel) | |||
| self.conv = ModulatedConv2d( | |||
| in_channel, 3, 1, style_dim, demodulate=False) | |||
| self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) | |||
| def forward(self, input, style, skip=None): | |||
| out = self.conv(input, style) | |||
| out = out + self.bias | |||
| if skip is not None: | |||
| skip = self.upsample(skip) | |||
| out = out + skip | |||
| return out | |||
| class Generator(nn.Module): | |||
| def __init__( | |||
| self, | |||
| size, | |||
| style_dim, | |||
| n_mlp, | |||
| channel_multiplier=2, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| lr_mlp=0.01, | |||
| isconcat=True, | |||
| narrow=1, | |||
| ): | |||
| super().__init__() | |||
| self.size = size | |||
| self.n_mlp = n_mlp | |||
| self.style_dim = style_dim | |||
| self.feat_multiplier = 2 if isconcat else 1 | |||
| layers = [PixelNorm()] | |||
| for i in range(n_mlp): | |||
| layers.append( | |||
| EqualLinear( | |||
| style_dim, | |||
| style_dim, | |||
| lr_mul=lr_mlp, | |||
| activation='fused_lrelu')) | |||
| self.style = nn.Sequential(*layers) | |||
| self.channels = { | |||
| 4: int(512 * narrow), | |||
| 8: int(512 * narrow), | |||
| 16: int(512 * narrow), | |||
| 32: int(512 * narrow), | |||
| 64: int(256 * channel_multiplier * narrow), | |||
| 128: int(128 * channel_multiplier * narrow), | |||
| 256: int(64 * channel_multiplier * narrow), | |||
| 512: int(32 * channel_multiplier * narrow), | |||
| 1024: int(16 * channel_multiplier * narrow), | |||
| 2048: int(8 * channel_multiplier * narrow) | |||
| } | |||
| self.input = ConstantInput(self.channels[4]) | |||
| self.conv1 = StyledConv( | |||
| self.channels[4], | |||
| self.channels[4], | |||
| 3, | |||
| style_dim, | |||
| blur_kernel=blur_kernel, | |||
| isconcat=isconcat) | |||
| self.to_rgb1 = ToRGB( | |||
| self.channels[4] * self.feat_multiplier, style_dim, upsample=False) | |||
| self.log_size = int(math.log(size, 2)) | |||
| self.convs = nn.ModuleList() | |||
| self.upsamples = nn.ModuleList() | |||
| self.to_rgbs = nn.ModuleList() | |||
| in_channel = self.channels[4] | |||
| for i in range(3, self.log_size + 1): | |||
| out_channel = self.channels[2**i] | |||
| self.convs.append( | |||
| StyledConv( | |||
| in_channel * self.feat_multiplier, | |||
| out_channel, | |||
| 3, | |||
| style_dim, | |||
| upsample=True, | |||
| blur_kernel=blur_kernel, | |||
| isconcat=isconcat, | |||
| )) | |||
| self.convs.append( | |||
| StyledConv( | |||
| out_channel * self.feat_multiplier, | |||
| out_channel, | |||
| 3, | |||
| style_dim, | |||
| blur_kernel=blur_kernel, | |||
| isconcat=isconcat)) | |||
| self.to_rgbs.append( | |||
| ToRGB(out_channel * self.feat_multiplier, style_dim)) | |||
| in_channel = out_channel | |||
| self.n_latent = self.log_size * 2 - 2 | |||
| def make_noise(self): | |||
| device = self.input.input.device | |||
| noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] | |||
| for i in range(3, self.log_size + 1): | |||
| for _ in range(2): | |||
| noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) | |||
| return noises | |||
| def mean_latent(self, n_latent): | |||
| latent_in = torch.randn( | |||
| n_latent, self.style_dim, device=self.input.input.device) | |||
| latent = self.style(latent_in).mean(0, keepdim=True) | |||
| return latent | |||
| def get_latent(self, input): | |||
| return self.style(input) | |||
| def forward( | |||
| self, | |||
| styles, | |||
| return_latents=False, | |||
| inject_index=None, | |||
| truncation=1, | |||
| truncation_latent=None, | |||
| input_is_latent=False, | |||
| noise=None, | |||
| ): | |||
| if not input_is_latent: | |||
| styles = [self.style(s) for s in styles] | |||
| if noise is None: | |||
| ''' | |||
| noise = [None] * (2 * (self.log_size - 2) + 1) | |||
| ''' | |||
| noise = [] | |||
| batch = styles[0].shape[0] | |||
| for i in range(self.n_mlp + 1): | |||
| size = 2**(i + 2) | |||
| noise.append( | |||
| torch.randn( | |||
| batch, | |||
| self.channels[size], | |||
| size, | |||
| size, | |||
| device=styles[0].device)) | |||
| if truncation < 1: | |||
| style_t = [] | |||
| for style in styles: | |||
| style_t.append(truncation_latent | |||
| + truncation * (style - truncation_latent)) | |||
| styles = style_t | |||
| if len(styles) < 2: | |||
| inject_index = self.n_latent | |||
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||
| else: | |||
| if inject_index is None: | |||
| inject_index = random.randint(1, self.n_latent - 1) | |||
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||
| latent2 = styles[1].unsqueeze(1).repeat( | |||
| 1, self.n_latent - inject_index, 1) | |||
| latent = torch.cat([latent, latent2], 1) | |||
| out = self.input(latent) | |||
| out = self.conv1(out, latent[:, 0], noise=noise[0]) | |||
| skip = self.to_rgb1(out, latent[:, 1]) | |||
| i = 1 | |||
| for conv1, conv2, noise1, noise2, to_rgb in zip( | |||
| self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], | |||
| self.to_rgbs): | |||
| out = conv1(out, latent[:, i], noise=noise1) | |||
| out = conv2(out, latent[:, i + 1], noise=noise2) | |||
| skip = to_rgb(out, latent[:, i + 2], skip) | |||
| i += 2 | |||
| image = skip | |||
| if return_latents: | |||
| return image, latent | |||
| else: | |||
| return image, None | |||
| class ConvLayer(nn.Sequential): | |||
| def __init__( | |||
| self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| downsample=False, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| bias=True, | |||
| activate=True, | |||
| ): | |||
| layers = [] | |||
| if downsample: | |||
| factor = 2 | |||
| p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||
| pad0 = (p + 1) // 2 | |||
| pad1 = p // 2 | |||
| layers.append(Blur(blur_kernel, pad=(pad0, pad1))) | |||
| stride = 2 | |||
| self.padding = 0 | |||
| else: | |||
| stride = 1 | |||
| self.padding = kernel_size // 2 | |||
| layers.append( | |||
| EqualConv2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| padding=self.padding, | |||
| stride=stride, | |||
| bias=bias and not activate, | |||
| )) | |||
| if activate: | |||
| if bias: | |||
| layers.append(FusedLeakyReLU(out_channel)) | |||
| else: | |||
| layers.append(ScaledLeakyReLU(0.2)) | |||
| super().__init__(*layers) | |||
| class ResBlock(nn.Module): | |||
| def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): | |||
| super().__init__() | |||
| self.conv1 = ConvLayer(in_channel, in_channel, 3) | |||
| self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) | |||
| self.skip = ConvLayer( | |||
| in_channel, | |||
| out_channel, | |||
| 1, | |||
| downsample=True, | |||
| activate=False, | |||
| bias=False) | |||
| def forward(self, input): | |||
| out = self.conv1(input) | |||
| out = self.conv2(out) | |||
| skip = self.skip(input) | |||
| out = (out + skip) / math.sqrt(2) | |||
| return out | |||
| class FullGenerator(nn.Module): | |||
| def __init__( | |||
| self, | |||
| size, | |||
| style_dim, | |||
| n_mlp, | |||
| channel_multiplier=2, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| lr_mlp=0.01, | |||
| isconcat=True, | |||
| narrow=1, | |||
| ): | |||
| super().__init__() | |||
| channels = { | |||
| 4: int(512 * narrow), | |||
| 8: int(512 * narrow), | |||
| 16: int(512 * narrow), | |||
| 32: int(512 * narrow), | |||
| 64: int(256 * channel_multiplier * narrow), | |||
| 128: int(128 * channel_multiplier * narrow), | |||
| 256: int(64 * channel_multiplier * narrow), | |||
| 512: int(32 * channel_multiplier * narrow), | |||
| 1024: int(16 * channel_multiplier * narrow), | |||
| 2048: int(8 * channel_multiplier * narrow) | |||
| } | |||
| self.log_size = int(math.log(size, 2)) | |||
| self.generator = Generator( | |||
| size, | |||
| style_dim, | |||
| n_mlp, | |||
| channel_multiplier=channel_multiplier, | |||
| blur_kernel=blur_kernel, | |||
| lr_mlp=lr_mlp, | |||
| isconcat=isconcat, | |||
| narrow=narrow) | |||
| conv = [ConvLayer(3, channels[size], 1)] | |||
| self.ecd0 = nn.Sequential(*conv) | |||
| in_channel = channels[size] | |||
| self.names = ['ecd%d' % i for i in range(self.log_size - 1)] | |||
| for i in range(self.log_size, 2, -1): | |||
| out_channel = channels[2**(i - 1)] | |||
| # conv = [ResBlock(in_channel, out_channel, blur_kernel)] | |||
| conv = [ConvLayer(in_channel, out_channel, 3, downsample=True)] | |||
| setattr(self, self.names[self.log_size - i + 1], | |||
| nn.Sequential(*conv)) | |||
| in_channel = out_channel | |||
| self.final_linear = nn.Sequential( | |||
| EqualLinear( | |||
| channels[4] * 4 * 4, style_dim, activation='fused_lrelu')) | |||
| def forward( | |||
| self, | |||
| inputs, | |||
| return_latents=False, | |||
| inject_index=None, | |||
| truncation=1, | |||
| truncation_latent=None, | |||
| input_is_latent=False, | |||
| ): | |||
| noise = [] | |||
| for i in range(self.log_size - 1): | |||
| ecd = getattr(self, self.names[i]) | |||
| inputs = ecd(inputs) | |||
| noise.append(inputs) | |||
| inputs = inputs.view(inputs.shape[0], -1) | |||
| outs = self.final_linear(inputs) | |||
| noise = list( | |||
| itertools.chain.from_iterable( | |||
| itertools.repeat(x, 2) for x in noise))[::-1] | |||
| outs = self.generator([outs], | |||
| return_latents, | |||
| inject_index, | |||
| truncation, | |||
| truncation_latent, | |||
| input_is_latent, | |||
| noise=noise[1:]) | |||
| return outs | |||
| class Discriminator(nn.Module): | |||
| def __init__(self, | |||
| size, | |||
| channel_multiplier=2, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| narrow=1): | |||
| super().__init__() | |||
| channels = { | |||
| 4: int(512 * narrow), | |||
| 8: int(512 * narrow), | |||
| 16: int(512 * narrow), | |||
| 32: int(512 * narrow), | |||
| 64: int(256 * channel_multiplier * narrow), | |||
| 128: int(128 * channel_multiplier * narrow), | |||
| 256: int(64 * channel_multiplier * narrow), | |||
| 512: int(32 * channel_multiplier * narrow), | |||
| 1024: int(16 * channel_multiplier * narrow), | |||
| 2048: int(8 * channel_multiplier * narrow) | |||
| } | |||
| convs = [ConvLayer(3, channels[size], 1)] | |||
| log_size = int(math.log(size, 2)) | |||
| in_channel = channels[size] | |||
| for i in range(log_size, 2, -1): | |||
| out_channel = channels[2**(i - 1)] | |||
| convs.append(ResBlock(in_channel, out_channel, blur_kernel)) | |||
| in_channel = out_channel | |||
| self.convs = nn.Sequential(*convs) | |||
| self.stddev_group = 4 | |||
| self.stddev_feat = 1 | |||
| self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) | |||
| self.final_linear = nn.Sequential( | |||
| EqualLinear( | |||
| channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), | |||
| EqualLinear(channels[4], 1), | |||
| ) | |||
| def forward(self, input): | |||
| out = self.convs(input) | |||
| batch, channel, height, width = out.shape | |||
| group = min(batch, self.stddev_group) | |||
| stddev = out.view(group, -1, self.stddev_feat, | |||
| channel // self.stddev_feat, height, width) | |||
| stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) | |||
| stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) | |||
| stddev = stddev.repeat(group, 1, height, width) | |||
| out = torch.cat([out, stddev], 1) | |||
| out = self.final_conv(out) | |||
| out = out.view(batch, -1) | |||
| out = self.final_linear(out) | |||
| return out | |||
| @@ -0,0 +1,205 @@ | |||
| import math | |||
| import os.path as osp | |||
| from copy import deepcopy | |||
| from typing import Any, Dict, List, Union | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import autograd, nn | |||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .gpen import Discriminator, FullGenerator | |||
| from .losses.losses import IDLoss, L1Loss | |||
| logger = get_logger() | |||
| __all__ = ['ImagePortraitEnhancement'] | |||
| @MODELS.register_module( | |||
| Tasks.image_portrait_enhancement, module_name=Models.gpen) | |||
| class ImagePortraitEnhancement(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the face enhancement model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.size = 512 | |||
| self.style_dim = 512 | |||
| self.n_mlp = 8 | |||
| self.mean_path_length = 0 | |||
| self.accum = 0.5**(32 / (10 * 1000)) | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| self.l1_loss = L1Loss() | |||
| self.id_loss = IDLoss(f'{model_dir}/arcface/model_ir_se50.pth', | |||
| self._device) | |||
| self.generator = FullGenerator( | |||
| self.size, self.style_dim, self.n_mlp, | |||
| isconcat=True).to(self._device) | |||
| self.g_ema = FullGenerator( | |||
| self.size, self.style_dim, self.n_mlp, | |||
| isconcat=True).to(self._device) | |||
| self.discriminator = Discriminator(self.size).to(self._device) | |||
| if self.size == 512: | |||
| self.load_pretrained(model_dir) | |||
| def load_pretrained(self, model_dir): | |||
| g_path = f'{model_dir}/{ModelFile.TORCH_MODEL_FILE}' | |||
| g_dict = torch.load(g_path, map_location=torch.device('cpu')) | |||
| self.generator.load_state_dict(g_dict) | |||
| self.g_ema.load_state_dict(g_dict) | |||
| d_path = f'{model_dir}/net_d.pt' | |||
| d_dict = torch.load(d_path, map_location=torch.device('cpu')) | |||
| self.discriminator.load_state_dict(d_dict) | |||
| logger.info('load model done.') | |||
| def accumulate(self): | |||
| par1 = dict(self.g_ema.named_parameters()) | |||
| par2 = dict(self.generator.named_parameters()) | |||
| for k in par1.keys(): | |||
| par1[k].data.mul_(self.accum).add_(1 - self.accum, par2[k].data) | |||
| def requires_grad(self, model, flag=True): | |||
| for p in model.parameters(): | |||
| p.requires_grad = flag | |||
| def d_logistic_loss(self, real_pred, fake_pred): | |||
| real_loss = F.softplus(-real_pred) | |||
| fake_loss = F.softplus(fake_pred) | |||
| return real_loss.mean() + fake_loss.mean() | |||
| def d_r1_loss(self, real_pred, real_img): | |||
| grad_real, = autograd.grad( | |||
| outputs=real_pred.sum(), inputs=real_img, create_graph=True) | |||
| grad_penalty = grad_real.pow(2).view(grad_real.shape[0], | |||
| -1).sum(1).mean() | |||
| return grad_penalty | |||
| def g_nonsaturating_loss(self, | |||
| fake_pred, | |||
| fake_img=None, | |||
| real_img=None, | |||
| input_img=None): | |||
| loss = F.softplus(-fake_pred).mean() | |||
| loss_l1 = self.l1_loss(fake_img, real_img) | |||
| loss_id, __, __ = self.id_loss(fake_img, real_img, input_img) | |||
| loss_id = 0 | |||
| loss += 1.0 * loss_l1 + 1.0 * loss_id | |||
| return loss | |||
| def g_path_regularize(self, | |||
| fake_img, | |||
| latents, | |||
| mean_path_length, | |||
| decay=0.01): | |||
| noise = torch.randn_like(fake_img) / math.sqrt( | |||
| fake_img.shape[2] * fake_img.shape[3]) | |||
| grad, = autograd.grad( | |||
| outputs=(fake_img * noise).sum(), | |||
| inputs=latents, | |||
| create_graph=True) | |||
| path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) | |||
| path_mean = mean_path_length + decay * ( | |||
| path_lengths.mean() - mean_path_length) | |||
| path_penalty = (path_lengths - path_mean).pow(2).mean() | |||
| return path_penalty, path_mean.detach(), path_lengths | |||
| @torch.no_grad() | |||
| def _evaluate_postprocess(self, src: Tensor, | |||
| target: Tensor) -> Dict[str, list]: | |||
| preds, _ = self.generator(src) | |||
| preds = list(torch.split(preds, 1, 0)) | |||
| targets = list(torch.split(target, 1, 0)) | |||
| preds = [((pred.data * 0.5 + 0.5) * 255.).squeeze(0).type( | |||
| torch.uint8).permute(1, 2, 0).cpu().numpy() for pred in preds] | |||
| targets = [((target.data * 0.5 + 0.5) * 255.).squeeze(0).type( | |||
| torch.uint8).permute(1, 2, 0).cpu().numpy() for target in targets] | |||
| return {'pred': preds, 'target': targets} | |||
| def _train_forward_d(self, src: Tensor, target: Tensor) -> Tensor: | |||
| self.requires_grad(self.generator, False) | |||
| self.requires_grad(self.discriminator, True) | |||
| preds, _ = self.generator(src) | |||
| fake_pred = self.discriminator(preds) | |||
| real_pred = self.discriminator(target) | |||
| d_loss = self.d_logistic_loss(real_pred, fake_pred) | |||
| return d_loss | |||
| def _train_forward_d_r1(self, src: Tensor, target: Tensor) -> Tensor: | |||
| src.requires_grad = True | |||
| target.requires_grad = True | |||
| real_pred = self.discriminator(target) | |||
| r1_loss = self.d_r1_loss(real_pred, target) | |||
| return r1_loss | |||
| def _train_forward_g(self, src: Tensor, target: Tensor) -> Tensor: | |||
| self.requires_grad(self.generator, True) | |||
| self.requires_grad(self.discriminator, False) | |||
| preds, _ = self.generator(src) | |||
| fake_pred = self.discriminator(preds) | |||
| g_loss = self.g_nonsaturating_loss(fake_pred, preds, target, src) | |||
| return g_loss | |||
| def _train_forward_g_path(self, src: Tensor, target: Tensor) -> Tensor: | |||
| fake_img, latents = self.generator(src, return_latents=True) | |||
| path_loss, self.mean_path_length, path_lengths = self.g_path_regularize( | |||
| fake_img, latents, self.mean_path_length) | |||
| return path_loss | |||
| @torch.no_grad() | |||
| def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: | |||
| return {'outputs': (self.generator(src)[0] * 0.5 + 0.5).clamp(0, 1)} | |||
| def forward(self, input: Dict[str, | |||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Union[list, Tensor]]: results | |||
| """ | |||
| for key, value in input.items(): | |||
| input[key] = input[key].to(self._device) | |||
| if 'target' in input: | |||
| return self._evaluate_postprocess(**input) | |||
| else: | |||
| return self._inference_forward(**input) | |||
| @@ -0,0 +1,129 @@ | |||
| from collections import namedtuple | |||
| import torch | |||
| from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d, | |||
| Module, PReLU, ReLU, Sequential, Sigmoid) | |||
| class Flatten(Module): | |||
| def forward(self, input): | |||
| return input.view(input.size(0), -1) | |||
| def l2_norm(input, axis=1): | |||
| norm = torch.norm(input, 2, axis, True) | |||
| output = torch.div(input, norm) | |||
| return output | |||
| class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): | |||
| """ A named tuple describing a ResNet block. """ | |||
| def get_block(in_channel, depth, num_units, stride=2): | |||
| return [Bottleneck(in_channel, depth, stride) | |||
| ] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |||
| def get_blocks(num_layers): | |||
| if num_layers == 50: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=4), | |||
| get_block(in_channel=128, depth=256, num_units=14), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 100: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=13), | |||
| get_block(in_channel=128, depth=256, num_units=30), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| elif num_layers == 152: | |||
| blocks = [ | |||
| get_block(in_channel=64, depth=64, num_units=3), | |||
| get_block(in_channel=64, depth=128, num_units=8), | |||
| get_block(in_channel=128, depth=256, num_units=36), | |||
| get_block(in_channel=256, depth=512, num_units=3) | |||
| ] | |||
| else: | |||
| raise ValueError( | |||
| 'Invalid number of layers: {}. Must be one of [50, 100, 152]'. | |||
| format(num_layers)) | |||
| return blocks | |||
| class SEModule(Module): | |||
| def __init__(self, channels, reduction): | |||
| super(SEModule, self).__init__() | |||
| self.avg_pool = AdaptiveAvgPool2d(1) | |||
| self.fc1 = Conv2d( | |||
| channels, | |||
| channels // reduction, | |||
| kernel_size=1, | |||
| padding=0, | |||
| bias=False) | |||
| self.relu = ReLU(inplace=True) | |||
| self.fc2 = Conv2d( | |||
| channels // reduction, | |||
| channels, | |||
| kernel_size=1, | |||
| padding=0, | |||
| bias=False) | |||
| self.sigmoid = Sigmoid() | |||
| def forward(self, x): | |||
| module_input = x | |||
| x = self.avg_pool(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.sigmoid(x) | |||
| return module_input * x | |||
| class bottleneck_IR(Module): | |||
| def __init__(self, in_channel, depth, stride): | |||
| super(bottleneck_IR, self).__init__() | |||
| if in_channel == depth: | |||
| self.shortcut_layer = MaxPool2d(1, stride) | |||
| else: | |||
| self.shortcut_layer = Sequential( | |||
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
| BatchNorm2d(depth)) | |||
| self.res_layer = Sequential( | |||
| BatchNorm2d(in_channel), | |||
| Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||
| PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||
| BatchNorm2d(depth)) | |||
| def forward(self, x): | |||
| shortcut = self.shortcut_layer(x) | |||
| res = self.res_layer(x) | |||
| return res + shortcut | |||
| class bottleneck_IR_SE(Module): | |||
| def __init__(self, in_channel, depth, stride): | |||
| super(bottleneck_IR_SE, self).__init__() | |||
| if in_channel == depth: | |||
| self.shortcut_layer = MaxPool2d(1, stride) | |||
| else: | |||
| self.shortcut_layer = Sequential( | |||
| Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||
| BatchNorm2d(depth)) | |||
| self.res_layer = Sequential( | |||
| BatchNorm2d(in_channel), | |||
| Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||
| PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||
| BatchNorm2d(depth), SEModule(depth, 16)) | |||
| def forward(self, x): | |||
| shortcut = self.shortcut_layer(x) | |||
| res = self.res_layer(x) | |||
| return res + shortcut | |||
| @@ -0,0 +1,90 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from .model_irse import Backbone | |||
| class L1Loss(nn.Module): | |||
| """L1 (mean absolute error, MAE) loss. | |||
| Args: | |||
| loss_weight (float): Loss weight for L1 loss. Default: 1.0. | |||
| reduction (str): Specifies the reduction to apply to the output. | |||
| Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. | |||
| """ | |||
| def __init__(self, loss_weight=1.0, reduction='mean'): | |||
| super(L1Loss, self).__init__() | |||
| if reduction not in ['none', 'mean', 'sum']: | |||
| raise ValueError( | |||
| f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' | |||
| ) | |||
| self.loss_weight = loss_weight | |||
| self.reduction = reduction | |||
| def forward(self, pred, target, weight=None, **kwargs): | |||
| """ | |||
| Args: | |||
| pred (Tensor): of shape (N, C, H, W). Predicted tensor. | |||
| target (Tensor): of shape (N, C, H, W). Ground truth tensor. | |||
| weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. | |||
| """ | |||
| return self.loss_weight * F.l1_loss( | |||
| pred, target, reduction=self.reduction) | |||
| class IDLoss(nn.Module): | |||
| def __init__(self, model_path, device='cuda', ckpt_dict=None): | |||
| super(IDLoss, self).__init__() | |||
| print('Loading ResNet ArcFace') | |||
| self.facenet = Backbone( | |||
| input_size=112, num_layers=50, drop_ratio=0.6, | |||
| mode='ir_se').to(device) | |||
| if ckpt_dict is None: | |||
| self.facenet.load_state_dict( | |||
| torch.load(model_path, map_location=torch.device('cpu'))) | |||
| else: | |||
| self.facenet.load_state_dict(ckpt_dict) | |||
| self.pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |||
| self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) | |||
| self.facenet.eval() | |||
| def extract_feats(self, x): | |||
| _, _, h, w = x.shape | |||
| assert h == w | |||
| if h != 256: | |||
| x = self.pool(x) | |||
| x = x[:, :, 35:-33, 32:-36] # crop roi | |||
| x = self.face_pool(x) | |||
| x_feats = self.facenet(x) | |||
| return x_feats | |||
| @torch.no_grad() | |||
| def forward(self, y_hat, y, x): | |||
| n_samples = x.shape[0] | |||
| x_feats = self.extract_feats(x) | |||
| y_feats = self.extract_feats(y) # Otherwise use the feature from there | |||
| y_hat_feats = self.extract_feats(y_hat) | |||
| y_feats = y_feats.detach() | |||
| loss = 0 | |||
| sim_improvement = 0 | |||
| id_logs = [] | |||
| count = 0 | |||
| for i in range(n_samples): | |||
| diff_target = y_hat_feats[i].dot(y_feats[i]) | |||
| diff_input = y_hat_feats[i].dot(x_feats[i]) | |||
| diff_views = y_feats[i].dot(x_feats[i]) | |||
| id_logs.append({ | |||
| 'diff_target': float(diff_target), | |||
| 'diff_input': float(diff_input), | |||
| 'diff_views': float(diff_views) | |||
| }) | |||
| loss += 1 - diff_target | |||
| id_diff = float(diff_target) - float(diff_views) | |||
| sim_improvement += id_diff | |||
| count += 1 | |||
| return loss / count, sim_improvement / count, id_logs | |||
| @@ -0,0 +1,92 @@ | |||
| from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, | |||
| Module, PReLU, Sequential) | |||
| from .helpers import (Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks, | |||
| l2_norm) | |||
| class Backbone(Module): | |||
| def __init__(self, | |||
| input_size, | |||
| num_layers, | |||
| mode='ir', | |||
| drop_ratio=0.4, | |||
| affine=True): | |||
| super(Backbone, self).__init__() | |||
| assert input_size in [112, 224], 'input_size should be 112 or 224' | |||
| assert num_layers in [50, 100, | |||
| 152], 'num_layers should be 50, 100 or 152' | |||
| assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' | |||
| blocks = get_blocks(num_layers) | |||
| if mode == 'ir': | |||
| unit_module = bottleneck_IR | |||
| elif mode == 'ir_se': | |||
| unit_module = bottleneck_IR_SE | |||
| self.input_layer = Sequential( | |||
| Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |||
| PReLU(64)) | |||
| if input_size == 112: | |||
| self.output_layer = Sequential( | |||
| BatchNorm2d(512), Dropout(drop_ratio), Flatten(), | |||
| Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine)) | |||
| else: | |||
| self.output_layer = Sequential( | |||
| BatchNorm2d(512), Dropout(drop_ratio), Flatten(), | |||
| Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine)) | |||
| modules = [] | |||
| for block in blocks: | |||
| for bottleneck in block: | |||
| modules.append( | |||
| unit_module(bottleneck.in_channel, bottleneck.depth, | |||
| bottleneck.stride)) | |||
| self.body = Sequential(*modules) | |||
| def forward(self, x): | |||
| x = self.input_layer(x) | |||
| x = self.body(x) | |||
| x = self.output_layer(x) | |||
| return l2_norm(x) | |||
| def IR_50(input_size): | |||
| """Constructs a ir-50 model.""" | |||
| model = Backbone( | |||
| input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) | |||
| return model | |||
| def IR_101(input_size): | |||
| """Constructs a ir-101 model.""" | |||
| model = Backbone( | |||
| input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) | |||
| return model | |||
| def IR_152(input_size): | |||
| """Constructs a ir-152 model.""" | |||
| model = Backbone( | |||
| input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) | |||
| return model | |||
| def IR_SE_50(input_size): | |||
| """Constructs a ir_se-50 model.""" | |||
| model = Backbone( | |||
| input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) | |||
| return model | |||
| def IR_SE_101(input_size): | |||
| """Constructs a ir_se-101 model.""" | |||
| model = Backbone( | |||
| input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) | |||
| return model | |||
| def IR_SE_152(input_size): | |||
| """Constructs a ir_se-152 model.""" | |||
| model = Backbone( | |||
| input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) | |||
| return model | |||
| @@ -0,0 +1,217 @@ | |||
| import os | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| import torch.backends.cudnn as cudnn | |||
| import torch.nn.functional as F | |||
| from .models.retinaface import RetinaFace | |||
| from .utils import PriorBox, decode, decode_landm, py_cpu_nms | |||
| cfg_re50 = { | |||
| 'name': 'Resnet50', | |||
| 'min_sizes': [[16, 32], [64, 128], [256, 512]], | |||
| 'steps': [8, 16, 32], | |||
| 'variance': [0.1, 0.2], | |||
| 'clip': False, | |||
| 'pretrain': False, | |||
| 'return_layers': { | |||
| 'layer2': 1, | |||
| 'layer3': 2, | |||
| 'layer4': 3 | |||
| }, | |||
| 'in_channel': 256, | |||
| 'out_channel': 256 | |||
| } | |||
| class RetinaFaceDetection(object): | |||
| def __init__(self, model_path, device='cuda'): | |||
| torch.set_grad_enabled(False) | |||
| cudnn.benchmark = True | |||
| self.model_path = model_path | |||
| self.device = device | |||
| self.cfg = cfg_re50 | |||
| self.net = RetinaFace(cfg=self.cfg) | |||
| self.load_model() | |||
| self.net = self.net.to(device) | |||
| self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device) | |||
| def check_keys(self, pretrained_state_dict): | |||
| ckpt_keys = set(pretrained_state_dict.keys()) | |||
| model_keys = set(self.net.state_dict().keys()) | |||
| used_pretrained_keys = model_keys & ckpt_keys | |||
| assert len( | |||
| used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' | |||
| return True | |||
| def remove_prefix(self, state_dict, prefix): | |||
| new_state_dict = dict() | |||
| # remove unnecessary 'module.' | |||
| for k, v in state_dict.items(): | |||
| if k.startswith(prefix): | |||
| new_state_dict[k[len(prefix):]] = v | |||
| else: | |||
| new_state_dict[k] = v | |||
| return new_state_dict | |||
| def load_model(self, load_to_cpu=False): | |||
| pretrained_dict = torch.load( | |||
| self.model_path, map_location=torch.device('cpu')) | |||
| if 'state_dict' in pretrained_dict.keys(): | |||
| pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], | |||
| 'module.') | |||
| else: | |||
| pretrained_dict = self.remove_prefix(pretrained_dict, 'module.') | |||
| self.check_keys(pretrained_dict) | |||
| self.net.load_state_dict(pretrained_dict, strict=False) | |||
| self.net.eval() | |||
| def detect(self, | |||
| img_raw, | |||
| resize=1, | |||
| confidence_threshold=0.9, | |||
| nms_threshold=0.4, | |||
| top_k=5000, | |||
| keep_top_k=750, | |||
| save_image=False): | |||
| img = np.float32(img_raw) | |||
| im_height, im_width = img.shape[:2] | |||
| ss = 1.0 | |||
| # tricky | |||
| if max(im_height, im_width) > 1500: | |||
| ss = 1000.0 / max(im_height, im_width) | |||
| img = cv2.resize(img, (0, 0), fx=ss, fy=ss) | |||
| im_height, im_width = img.shape[:2] | |||
| scale = torch.Tensor( | |||
| [img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) | |||
| img -= (104, 117, 123) | |||
| img = img.transpose(2, 0, 1) | |||
| img = torch.from_numpy(img).unsqueeze(0) | |||
| img = img.to(self.device) | |||
| scale = scale.to(self.device) | |||
| loc, conf, landms = self.net(img) # forward pass | |||
| del img | |||
| priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) | |||
| priors = priorbox.forward() | |||
| priors = priors.to(self.device) | |||
| prior_data = priors.data | |||
| boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |||
| boxes = boxes * scale / resize | |||
| boxes = boxes.cpu().numpy() | |||
| scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |||
| landms = decode_landm( | |||
| landms.data.squeeze(0), prior_data, self.cfg['variance']) | |||
| scale1 = torch.Tensor([ | |||
| im_width, im_height, im_width, im_height, im_width, im_height, | |||
| im_width, im_height, im_width, im_height | |||
| ]) | |||
| scale1 = scale1.to(self.device) | |||
| landms = landms * scale1 / resize | |||
| landms = landms.cpu().numpy() | |||
| # ignore low scores | |||
| inds = np.where(scores > confidence_threshold)[0] | |||
| boxes = boxes[inds] | |||
| landms = landms[inds] | |||
| scores = scores[inds] | |||
| # keep top-K before NMS | |||
| order = scores.argsort()[::-1][:top_k] | |||
| boxes = boxes[order] | |||
| landms = landms[order] | |||
| scores = scores[order] | |||
| # do NMS | |||
| dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||
| np.float32, copy=False) | |||
| keep = py_cpu_nms(dets, nms_threshold) | |||
| dets = dets[keep, :] | |||
| landms = landms[keep] | |||
| # keep top-K faster NMS | |||
| dets = dets[:keep_top_k, :] | |||
| landms = landms[:keep_top_k, :] | |||
| landms = landms.reshape((-1, 5, 2)) | |||
| landms = landms.transpose((0, 2, 1)) | |||
| landms = landms.reshape( | |||
| -1, | |||
| 10, | |||
| ) | |||
| return dets / ss, landms / ss | |||
| def detect_tensor(self, | |||
| img, | |||
| resize=1, | |||
| confidence_threshold=0.9, | |||
| nms_threshold=0.4, | |||
| top_k=5000, | |||
| keep_top_k=750, | |||
| save_image=False): | |||
| im_height, im_width = img.shape[-2:] | |||
| ss = 1000 / max(im_height, im_width) | |||
| img = F.interpolate(img, scale_factor=ss) | |||
| im_height, im_width = img.shape[-2:] | |||
| scale = torch.Tensor([im_width, im_height, im_width, | |||
| im_height]).to(self.device) | |||
| img -= self.mean | |||
| loc, conf, landms = self.net(img) # forward pass | |||
| priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) | |||
| priors = priorbox.forward() | |||
| priors = priors.to(self.device) | |||
| prior_data = priors.data | |||
| boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance']) | |||
| boxes = boxes * scale / resize | |||
| boxes = boxes.cpu().numpy() | |||
| scores = conf.squeeze(0).data.cpu().numpy()[:, 1] | |||
| landms = decode_landm( | |||
| landms.data.squeeze(0), prior_data, self.cfg['variance']) | |||
| scale1 = torch.Tensor([ | |||
| img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |||
| img.shape[3], img.shape[2], img.shape[3], img.shape[2], | |||
| img.shape[3], img.shape[2] | |||
| ]) | |||
| scale1 = scale1.to(self.device) | |||
| landms = landms * scale1 / resize | |||
| landms = landms.cpu().numpy() | |||
| # ignore low scores | |||
| inds = np.where(scores > confidence_threshold)[0] | |||
| boxes = boxes[inds] | |||
| landms = landms[inds] | |||
| scores = scores[inds] | |||
| # keep top-K before NMS | |||
| order = scores.argsort()[::-1][:top_k] | |||
| boxes = boxes[order] | |||
| landms = landms[order] | |||
| scores = scores[order] | |||
| # do NMS | |||
| dets = np.hstack((boxes, scores[:, np.newaxis])).astype( | |||
| np.float32, copy=False) | |||
| keep = py_cpu_nms(dets, nms_threshold) | |||
| dets = dets[keep, :] | |||
| landms = landms[keep] | |||
| # keep top-K faster NMS | |||
| dets = dets[:keep_top_k, :] | |||
| landms = landms[:keep_top_k, :] | |||
| landms = landms.reshape((-1, 5, 2)) | |||
| landms = landms.transpose((0, 2, 1)) | |||
| landms = landms.reshape( | |||
| -1, | |||
| 10, | |||
| ) | |||
| return dets / ss, landms / ss | |||
| @@ -0,0 +1,148 @@ | |||
| import time | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchvision.models as models | |||
| import torchvision.models._utils as _utils | |||
| from torch.autograd import Variable | |||
| def conv_bn(inp, oup, stride=1, leaky=0): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), | |||
| nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||
| def conv_bn_no_relu(inp, oup, stride): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, oup, 3, stride, 1, bias=False), | |||
| nn.BatchNorm2d(oup), | |||
| ) | |||
| def conv_bn1X1(inp, oup, stride, leaky=0): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), | |||
| nn.BatchNorm2d(oup), nn.LeakyReLU(negative_slope=leaky, inplace=True)) | |||
| def conv_dw(inp, oup, stride, leaky=0.1): | |||
| return nn.Sequential( | |||
| nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), | |||
| nn.BatchNorm2d(inp), | |||
| nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
| nn.Conv2d(inp, oup, 1, 1, 0, bias=False), | |||
| nn.BatchNorm2d(oup), | |||
| nn.LeakyReLU(negative_slope=leaky, inplace=True), | |||
| ) | |||
| class SSH(nn.Module): | |||
| def __init__(self, in_channel, out_channel): | |||
| super(SSH, self).__init__() | |||
| assert out_channel % 4 == 0 | |||
| leaky = 0 | |||
| if (out_channel <= 64): | |||
| leaky = 0.1 | |||
| self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) | |||
| self.conv5X5_1 = conv_bn( | |||
| in_channel, out_channel // 4, stride=1, leaky=leaky) | |||
| self.conv5X5_2 = conv_bn_no_relu( | |||
| out_channel // 4, out_channel // 4, stride=1) | |||
| self.conv7X7_2 = conv_bn( | |||
| out_channel // 4, out_channel // 4, stride=1, leaky=leaky) | |||
| self.conv7x7_3 = conv_bn_no_relu( | |||
| out_channel // 4, out_channel // 4, stride=1) | |||
| def forward(self, input): | |||
| conv3X3 = self.conv3X3(input) | |||
| conv5X5_1 = self.conv5X5_1(input) | |||
| conv5X5 = self.conv5X5_2(conv5X5_1) | |||
| conv7X7_2 = self.conv7X7_2(conv5X5_1) | |||
| conv7X7 = self.conv7x7_3(conv7X7_2) | |||
| out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) | |||
| out = F.relu(out) | |||
| return out | |||
| class FPN(nn.Module): | |||
| def __init__(self, in_channels_list, out_channels): | |||
| super(FPN, self).__init__() | |||
| leaky = 0 | |||
| if (out_channels <= 64): | |||
| leaky = 0.1 | |||
| self.output1 = conv_bn1X1( | |||
| in_channels_list[0], out_channels, stride=1, leaky=leaky) | |||
| self.output2 = conv_bn1X1( | |||
| in_channels_list[1], out_channels, stride=1, leaky=leaky) | |||
| self.output3 = conv_bn1X1( | |||
| in_channels_list[2], out_channels, stride=1, leaky=leaky) | |||
| self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) | |||
| self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) | |||
| def forward(self, input): | |||
| # names = list(input.keys()) | |||
| input = list(input.values()) | |||
| output1 = self.output1(input[0]) | |||
| output2 = self.output2(input[1]) | |||
| output3 = self.output3(input[2]) | |||
| up3 = F.interpolate( | |||
| output3, size=[output2.size(2), output2.size(3)], mode='nearest') | |||
| output2 = output2 + up3 | |||
| output2 = self.merge2(output2) | |||
| up2 = F.interpolate( | |||
| output2, size=[output1.size(2), output1.size(3)], mode='nearest') | |||
| output1 = output1 + up2 | |||
| output1 = self.merge1(output1) | |||
| out = [output1, output2, output3] | |||
| return out | |||
| class MobileNetV1(nn.Module): | |||
| def __init__(self): | |||
| super(MobileNetV1, self).__init__() | |||
| self.stage1 = nn.Sequential( | |||
| conv_bn(3, 8, 2, leaky=0.1), # 3 | |||
| conv_dw(8, 16, 1), # 7 | |||
| conv_dw(16, 32, 2), # 11 | |||
| conv_dw(32, 32, 1), # 19 | |||
| conv_dw(32, 64, 2), # 27 | |||
| conv_dw(64, 64, 1), # 43 | |||
| ) | |||
| self.stage2 = nn.Sequential( | |||
| conv_dw(64, 128, 2), # 43 + 16 = 59 | |||
| conv_dw(128, 128, 1), # 59 + 32 = 91 | |||
| conv_dw(128, 128, 1), # 91 + 32 = 123 | |||
| conv_dw(128, 128, 1), # 123 + 32 = 155 | |||
| conv_dw(128, 128, 1), # 155 + 32 = 187 | |||
| conv_dw(128, 128, 1), # 187 + 32 = 219 | |||
| ) | |||
| self.stage3 = nn.Sequential( | |||
| conv_dw(128, 256, 2), # 219 +3 2 = 241 | |||
| conv_dw(256, 256, 1), # 241 + 64 = 301 | |||
| ) | |||
| self.avg = nn.AdaptiveAvgPool2d((1, 1)) | |||
| self.fc = nn.Linear(256, 1000) | |||
| def forward(self, x): | |||
| x = self.stage1(x) | |||
| x = self.stage2(x) | |||
| x = self.stage3(x) | |||
| x = self.avg(x) | |||
| x = x.view(-1, 256) | |||
| x = self.fc(x) | |||
| return x | |||
| @@ -0,0 +1,144 @@ | |||
| from collections import OrderedDict | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torchvision.models as models | |||
| import torchvision.models._utils as _utils | |||
| import torchvision.models.detection.backbone_utils as backbone_utils | |||
| from .net import FPN, SSH, MobileNetV1 | |||
| class ClassHead(nn.Module): | |||
| def __init__(self, inchannels=512, num_anchors=3): | |||
| super(ClassHead, self).__init__() | |||
| self.num_anchors = num_anchors | |||
| self.conv1x1 = nn.Conv2d( | |||
| inchannels, | |||
| self.num_anchors * 2, | |||
| kernel_size=(1, 1), | |||
| stride=1, | |||
| padding=0) | |||
| def forward(self, x): | |||
| out = self.conv1x1(x) | |||
| out = out.permute(0, 2, 3, 1).contiguous() | |||
| return out.view(out.shape[0], -1, 2) | |||
| class BboxHead(nn.Module): | |||
| def __init__(self, inchannels=512, num_anchors=3): | |||
| super(BboxHead, self).__init__() | |||
| self.conv1x1 = nn.Conv2d( | |||
| inchannels, | |||
| num_anchors * 4, | |||
| kernel_size=(1, 1), | |||
| stride=1, | |||
| padding=0) | |||
| def forward(self, x): | |||
| out = self.conv1x1(x) | |||
| out = out.permute(0, 2, 3, 1).contiguous() | |||
| return out.view(out.shape[0], -1, 4) | |||
| class LandmarkHead(nn.Module): | |||
| def __init__(self, inchannels=512, num_anchors=3): | |||
| super(LandmarkHead, self).__init__() | |||
| self.conv1x1 = nn.Conv2d( | |||
| inchannels, | |||
| num_anchors * 10, | |||
| kernel_size=(1, 1), | |||
| stride=1, | |||
| padding=0) | |||
| def forward(self, x): | |||
| out = self.conv1x1(x) | |||
| out = out.permute(0, 2, 3, 1).contiguous() | |||
| return out.view(out.shape[0], -1, 10) | |||
| class RetinaFace(nn.Module): | |||
| def __init__(self, cfg=None): | |||
| """ | |||
| :param cfg: Network related settings. | |||
| """ | |||
| super(RetinaFace, self).__init__() | |||
| backbone = None | |||
| if cfg['name'] == 'Resnet50': | |||
| backbone = models.resnet50(pretrained=cfg['pretrain']) | |||
| else: | |||
| raise Exception('Invalid name') | |||
| self.body = _utils.IntermediateLayerGetter(backbone, | |||
| cfg['return_layers']) | |||
| in_channels_stage2 = cfg['in_channel'] | |||
| in_channels_list = [ | |||
| in_channels_stage2 * 2, | |||
| in_channels_stage2 * 4, | |||
| in_channels_stage2 * 8, | |||
| ] | |||
| out_channels = cfg['out_channel'] | |||
| self.fpn = FPN(in_channels_list, out_channels) | |||
| self.ssh1 = SSH(out_channels, out_channels) | |||
| self.ssh2 = SSH(out_channels, out_channels) | |||
| self.ssh3 = SSH(out_channels, out_channels) | |||
| self.ClassHead = self._make_class_head( | |||
| fpn_num=3, inchannels=cfg['out_channel']) | |||
| self.BboxHead = self._make_bbox_head( | |||
| fpn_num=3, inchannels=cfg['out_channel']) | |||
| self.LandmarkHead = self._make_landmark_head( | |||
| fpn_num=3, inchannels=cfg['out_channel']) | |||
| def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||
| classhead = nn.ModuleList() | |||
| for i in range(fpn_num): | |||
| classhead.append(ClassHead(inchannels, anchor_num)) | |||
| return classhead | |||
| def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||
| bboxhead = nn.ModuleList() | |||
| for i in range(fpn_num): | |||
| bboxhead.append(BboxHead(inchannels, anchor_num)) | |||
| return bboxhead | |||
| def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2): | |||
| landmarkhead = nn.ModuleList() | |||
| for i in range(fpn_num): | |||
| landmarkhead.append(LandmarkHead(inchannels, anchor_num)) | |||
| return landmarkhead | |||
| def forward(self, inputs): | |||
| out = self.body(inputs) | |||
| # FPN | |||
| fpn = self.fpn(out) | |||
| # SSH | |||
| feature1 = self.ssh1(fpn[0]) | |||
| feature2 = self.ssh2(fpn[1]) | |||
| feature3 = self.ssh3(fpn[2]) | |||
| features = [feature1, feature2, feature3] | |||
| bbox_regressions = torch.cat( | |||
| [self.BboxHead[i](feature) for i, feature in enumerate(features)], | |||
| dim=1) | |||
| classifications = torch.cat( | |||
| [self.ClassHead[i](feature) for i, feature in enumerate(features)], | |||
| dim=1) | |||
| ldm_regressions = torch.cat( | |||
| [self.LandmarkHead[i](feat) for i, feat in enumerate(features)], | |||
| dim=1) | |||
| output = (bbox_regressions, F.softmax(classifications, | |||
| dim=-1), ldm_regressions) | |||
| return output | |||
| @@ -0,0 +1,123 @@ | |||
| # -------------------------------------------------------- | |||
| # Modified from https://github.com/biubug6/Pytorch_Retinaface | |||
| # -------------------------------------------------------- | |||
| from itertools import product as product | |||
| from math import ceil | |||
| import numpy as np | |||
| import torch | |||
| class PriorBox(object): | |||
| def __init__(self, cfg, image_size=None, phase='train'): | |||
| super(PriorBox, self).__init__() | |||
| self.min_sizes = cfg['min_sizes'] | |||
| self.steps = cfg['steps'] | |||
| self.clip = cfg['clip'] | |||
| self.image_size = image_size | |||
| self.feature_maps = [[ | |||
| ceil(self.image_size[0] / step), | |||
| ceil(self.image_size[1] / step) | |||
| ] for step in self.steps] | |||
| self.name = 's' | |||
| def forward(self): | |||
| anchors = [] | |||
| for k, f in enumerate(self.feature_maps): | |||
| min_sizes = self.min_sizes[k] | |||
| for i, j in product(range(f[0]), range(f[1])): | |||
| for min_size in min_sizes: | |||
| s_kx = min_size / self.image_size[1] | |||
| s_ky = min_size / self.image_size[0] | |||
| dense_cx = [ | |||
| x * self.steps[k] / self.image_size[1] | |||
| for x in [j + 0.5] | |||
| ] | |||
| dense_cy = [ | |||
| y * self.steps[k] / self.image_size[0] | |||
| for y in [i + 0.5] | |||
| ] | |||
| for cy, cx in product(dense_cy, dense_cx): | |||
| anchors += [cx, cy, s_kx, s_ky] | |||
| # back to torch land | |||
| output = torch.Tensor(anchors).view(-1, 4) | |||
| if self.clip: | |||
| output.clamp_(max=1, min=0) | |||
| return output | |||
| def py_cpu_nms(dets, thresh): | |||
| """Pure Python NMS baseline.""" | |||
| x1 = dets[:, 0] | |||
| y1 = dets[:, 1] | |||
| x2 = dets[:, 2] | |||
| y2 = dets[:, 3] | |||
| scores = dets[:, 4] | |||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||
| order = scores.argsort()[::-1] | |||
| keep = [] | |||
| while order.size > 0: | |||
| i = order[0] | |||
| keep.append(i) | |||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||
| inter = w * h | |||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||
| inds = np.where(ovr <= thresh)[0] | |||
| order = order[inds + 1] | |||
| return keep | |||
| # Adapted from https://github.com/Hakuyume/chainer-ssd | |||
| def decode(loc, priors, variances): | |||
| """Decode locations from predictions using priors to undo | |||
| the encoding we did for offset regression at train time. | |||
| Args: | |||
| loc (tensor): location predictions for loc layers, | |||
| Shape: [num_priors,4] | |||
| priors (tensor): Prior boxes in center-offset form. | |||
| Shape: [num_priors,4]. | |||
| variances: (list[float]) Variances of priorboxes | |||
| Return: | |||
| decoded bounding box predictions | |||
| """ | |||
| boxes = torch.cat( | |||
| (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], | |||
| priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) | |||
| boxes[:, :2] -= boxes[:, 2:] / 2 | |||
| boxes[:, 2:] += boxes[:, :2] | |||
| return boxes | |||
| def decode_landm(pre, priors, variances): | |||
| """Decode landm from predictions using priors to undo | |||
| the encoding we did for offset regression at train time. | |||
| Args: | |||
| pre (tensor): landm predictions for loc layers, | |||
| Shape: [num_priors,10] | |||
| priors (tensor): Prior boxes in center-offset form. | |||
| Shape: [num_priors,4]. | |||
| variances: (list[float]) Variances of priorboxes | |||
| Return: | |||
| decoded landm predictions | |||
| """ | |||
| a = priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:] | |||
| b = priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:] | |||
| c = priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:] | |||
| d = priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:] | |||
| e = priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:] | |||
| landms = torch.cat((a, b, c, d, e), dim=1) | |||
| return landms | |||
| @@ -137,6 +137,7 @@ TASK_OUTPUTS = { | |||
| Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], | |||
| # image generation task result for a single image | |||
| # {"output_img": np.array with shape (h, w, 3)} | |||
| @@ -110,6 +110,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_gan_face-image-generation'), | |||
| Tasks.image_super_resolution: (Pipelines.image_super_resolution, | |||
| 'damo/cv_rrdb_image-super-resolution'), | |||
| Tasks.image_portrait_enhancement: | |||
| (Pipelines.image_portrait_enhancement, | |||
| 'damo/cv_gpen_image-portrait-enhancement'), | |||
| Tasks.product_retrieval_embedding: | |||
| (Pipelines.product_retrieval_embedding, | |||
| 'damo/cv_resnet50_product-bag-embedding-models'), | |||
| @@ -11,14 +11,15 @@ if TYPE_CHECKING: | |||
| from .face_detection_pipeline import FaceDetectionPipeline | |||
| from .face_recognition_pipeline import FaceRecognitionPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| from .image_classification_pipeline import ImageClassificationPipeline | |||
| from .image_cartoon_pipeline import ImageCartoonPipeline | |||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | |||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||
| from .image_colorization_pipeline import ImageColorizationPipeline | |||
| from .image_classification_pipeline import ImageClassificationPipeline | |||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | |||
| from .image_matting_pipeline import ImageMattingPipeline | |||
| from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | |||
| from .image_style_transfer_pipeline import ImageStyleTransferPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| from .image_to_image_generate_pipeline import Image2ImageGenerationePipeline | |||
| @@ -46,6 +47,8 @@ else: | |||
| 'image_instance_segmentation_pipeline': | |||
| ['ImageInstanceSegmentationPipeline'], | |||
| 'image_matting_pipeline': ['ImageMattingPipeline'], | |||
| 'image_portrait_enhancement_pipeline': | |||
| ['ImagePortraitEnhancementPipeline'], | |||
| 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], | |||
| 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
| 'image_to_image_translation_pipeline': | |||
| @@ -0,0 +1,216 @@ | |||
| import math | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from scipy.ndimage import gaussian_filter | |||
| from scipy.spatial.distance import pdist, squareform | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.image_portrait_enhancement import gpen | |||
| from modelscope.models.cv.image_portrait_enhancement.align_faces import ( | |||
| get_reference_facial_points, warp_and_crop_face) | |||
| from modelscope.models.cv.image_portrait_enhancement.eqface import fqa | |||
| from modelscope.models.cv.image_portrait_enhancement.retinaface import \ | |||
| detection | |||
| from modelscope.models.cv.super_resolution import rrdbnet_arch | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage, load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_portrait_enhancement, | |||
| module_name=Pipelines.image_portrait_enhancement) | |||
| class ImagePortraitEnhancementPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| if torch.cuda.is_available(): | |||
| self.device = torch.device('cuda') | |||
| else: | |||
| self.device = torch.device('cpu') | |||
| self.use_sr = True | |||
| self.size = 512 | |||
| self.n_mlp = 8 | |||
| self.channel_multiplier = 2 | |||
| self.narrow = 1 | |||
| self.face_enhancer = gpen.FullGenerator( | |||
| self.size, | |||
| 512, | |||
| self.n_mlp, | |||
| self.channel_multiplier, | |||
| narrow=self.narrow).to(self.device) | |||
| gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | |||
| self.face_enhancer.load_state_dict( | |||
| torch.load(gpen_model_path), strict=True) | |||
| logger.info('load face enhancer model done') | |||
| self.threshold = 0.9 | |||
| detector_model_path = f'{model}/face_detection/RetinaFace-R50.pth' | |||
| self.face_detector = detection.RetinaFaceDetection( | |||
| detector_model_path, self.device) | |||
| logger.info('load face detector model done') | |||
| self.num_feat = 32 | |||
| self.num_block = 23 | |||
| self.scale = 2 | |||
| self.sr_model = rrdbnet_arch.RRDBNet( | |||
| num_in_ch=3, | |||
| num_out_ch=3, | |||
| num_feat=self.num_feat, | |||
| num_block=self.num_block, | |||
| num_grow_ch=32, | |||
| scale=self.scale).to(self.device) | |||
| sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' | |||
| self.sr_model.load_state_dict( | |||
| torch.load(sr_model_path)['params_ema'], strict=True) | |||
| logger.info('load sr model done') | |||
| self.fqa_thres = 0.1 | |||
| self.id_thres = 0.15 | |||
| self.alpha = 1.0 | |||
| backbone_model_path = f'{model}/face_quality/eqface_backbone.pth' | |||
| fqa_model_path = f'{model}/face_quality/eqface_quality.pth' | |||
| self.eqface = fqa.FQA(backbone_model_path, fqa_model_path, self.device) | |||
| logger.info('load fqa model done') | |||
| # the mask for pasting restored faces back | |||
| self.mask = np.zeros((512, 512, 3), np.float32) | |||
| cv2.rectangle(self.mask, (26, 26), (486, 486), (1, 1, 1), -1, | |||
| cv2.LINE_AA) | |||
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) | |||
| self.mask = cv2.GaussianBlur(self.mask, (101, 101), 4) | |||
| def enhance_face(self, img): | |||
| img = cv2.resize(img, (self.size, self.size)) | |||
| img_t = self.img2tensor(img) | |||
| self.face_enhancer.eval() | |||
| with torch.no_grad(): | |||
| out, __ = self.face_enhancer(img_t) | |||
| del img_t | |||
| out = self.tensor2img(out) | |||
| return out | |||
| def img2tensor(self, img, is_norm=True): | |||
| img_t = torch.from_numpy(img).to(self.device) / 255. | |||
| if is_norm: | |||
| img_t = (img_t - 0.5) / 0.5 | |||
| img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB | |||
| return img_t | |||
| def tensor2img(self, img_t, pmax=255.0, is_denorm=True, imtype=np.uint8): | |||
| if is_denorm: | |||
| img_t = img_t * 0.5 + 0.5 | |||
| img_t = img_t.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR | |||
| img_np = np.clip(img_t.float().cpu().numpy(), 0, 1) * pmax | |||
| return img_np.astype(imtype) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img_sr = None | |||
| if self.use_sr: | |||
| self.sr_model.eval() | |||
| with torch.no_grad(): | |||
| img_t = self.img2tensor(img, is_norm=False) | |||
| img_out = self.sr_model(img_t) | |||
| img_sr = img_out.squeeze(0).permute(1, 2, 0).flip(2).cpu().clamp_( | |||
| 0, 1).numpy() | |||
| img_sr = (img_sr * 255.0).round().astype(np.uint8) | |||
| img = cv2.resize(img, img_sr.shape[:2][::-1]) | |||
| result = {'img': img, 'img_sr': img_sr} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| img, img_sr = input['img'], input['img_sr'] | |||
| img, img_sr = img.cpu().numpy(), img_sr.cpu().numpy() | |||
| facebs, landms = self.face_detector.detect(img) | |||
| height, width = img.shape[:2] | |||
| full_mask = np.zeros(img.shape, dtype=np.float32) | |||
| full_img = np.zeros(img.shape, dtype=np.uint8) | |||
| for i, (faceb, facial5points) in enumerate(zip(facebs, landms)): | |||
| if faceb[4] < self.threshold: | |||
| continue | |||
| # fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0]) | |||
| facial5points = np.reshape(facial5points, (2, 5)) | |||
| of, of_112, tfm_inv = warp_and_crop_face( | |||
| img, facial5points, crop_size=(self.size, self.size)) | |||
| # detect orig face quality | |||
| fq_o, fea_o = self.eqface.get_face_quality(of_112) | |||
| if fq_o < self.fqa_thres: | |||
| continue | |||
| # enhance the face | |||
| ef = self.enhance_face(of) | |||
| # detect enhanced face quality | |||
| ss = self.size // 256 | |||
| ef_112 = cv2.resize(ef[35 * ss:-33 * ss, 32 * ss:-36 * ss], | |||
| (112, 112)) # crop roi | |||
| fq_e, fea_e = self.eqface.get_face_quality(ef_112) | |||
| dist = squareform(pdist([fea_o, fea_e], 'cosine')).mean() | |||
| if dist > self.id_thres: | |||
| continue | |||
| # blending parameter | |||
| fq = max(1., (fq_o - self.fqa_thres)) | |||
| fq = (1 - 2 * dist) * (1.0 / (1 + math.exp(-(2 * fq - 1)))) | |||
| # blend face | |||
| ef = cv2.addWeighted(ef, fq * self.alpha, of, 1 - fq * self.alpha, | |||
| 0.0) | |||
| tmp_mask = self.mask | |||
| tmp_mask = cv2.resize(tmp_mask, ef.shape[:2]) | |||
| tmp_mask = cv2.warpAffine( | |||
| tmp_mask, tfm_inv, (width, height), flags=3) | |||
| tmp_img = cv2.warpAffine(ef, tfm_inv, (width, height), flags=3) | |||
| mask = np.clip(tmp_mask - full_mask, 0, 1) | |||
| full_mask[np.where(mask > 0)] = tmp_mask[np.where(mask > 0)] | |||
| full_img[np.where(mask > 0)] = tmp_img[np.where(mask > 0)] | |||
| if self.use_sr and img_sr is not None: | |||
| out_img = cv2.convertScaleAbs(img_sr * (1 - full_mask) | |||
| + full_img * full_mask) | |||
| else: | |||
| out_img = cv2.convertScaleAbs(img * (1 - full_mask) | |||
| + full_img * full_mask) | |||
| return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -163,6 +163,32 @@ class ImageDenoisePreprocessor(Preprocessor): | |||
| return data | |||
| @PREPROCESSORS.register_module( | |||
| Fields.cv, | |||
| module_name=Preprocessors.image_portrait_enhancement_preprocessor) | |||
| class ImagePortraitEnhancementPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """ | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data Dict[str, Any] | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| return data | |||
| @PREPROCESSORS.register_module( | |||
| Fields.cv, | |||
| module_name=Preprocessors.image_instance_segmentation_preprocessor) | |||
| @@ -1,6 +1,7 @@ | |||
| from .base import DummyTrainer | |||
| from .builder import build_trainer | |||
| from .cv import ImageInstanceSegmentationTrainer | |||
| from .cv import (ImageInstanceSegmentationTrainer, | |||
| ImagePortraitEnhancementTrainer) | |||
| from .multi_modal import CLIPTrainer | |||
| from .nlp import SequenceClassificationTrainer | |||
| from .trainer import EpochBasedTrainer | |||
| @@ -1,2 +1,3 @@ | |||
| from .image_instance_segmentation_trainer import \ | |||
| ImageInstanceSegmentationTrainer | |||
| from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | |||
| @@ -0,0 +1,148 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from collections.abc import Mapping | |||
| import torch | |||
| from torch import distributed as dist | |||
| from modelscope.trainers.builder import TRAINERS | |||
| from modelscope.trainers.optimizer.builder import build_optimizer | |||
| from modelscope.trainers.trainer import EpochBasedTrainer | |||
| from modelscope.utils.constant import ModeKeys | |||
| from modelscope.utils.logger import get_logger | |||
| @TRAINERS.register_module(module_name='gpen') | |||
| class ImagePortraitEnhancementTrainer(EpochBasedTrainer): | |||
| def train_step(self, model, inputs): | |||
| """ Perform a training step on a batch of inputs. | |||
| Subclass and override to inject custom behavior. | |||
| Args: | |||
| model (`TorchModel`): The model to train. | |||
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |||
| The inputs and targets of the model. | |||
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |||
| argument `labels`. Check your model's documentation for all accepted arguments. | |||
| Return: | |||
| `torch.Tensor`: The tensor with training loss on this batch. | |||
| """ | |||
| # EvaluationHook will do evaluate and change mode to val, return to train mode | |||
| # TODO: find more pretty way to change mode | |||
| self.d_reg_every = self.cfg.train.get('d_reg_every', 16) | |||
| self.g_reg_every = self.cfg.train.get('g_reg_every', 4) | |||
| self.path_regularize = self.cfg.train.get('path_regularize', 2) | |||
| self.r1 = self.cfg.train.get('r1', 10) | |||
| train_outputs = dict() | |||
| self._mode = ModeKeys.TRAIN | |||
| inputs = self.collate_fn(inputs) | |||
| # call model forward but not __call__ to skip postprocess | |||
| if isinstance(inputs, Mapping): | |||
| d_loss = model._train_forward_d(**inputs) | |||
| else: | |||
| d_loss = model._train_forward_d(inputs) | |||
| train_outputs['d_loss'] = d_loss | |||
| model.discriminator.zero_grad() | |||
| d_loss.backward() | |||
| self.optimizer_d.step() | |||
| if self._iter % self.d_reg_every == 0: | |||
| if isinstance(inputs, Mapping): | |||
| r1_loss = model._train_forward_d_r1(**inputs) | |||
| else: | |||
| r1_loss = model._train_forward_d_r1(inputs) | |||
| train_outputs['r1_loss'] = r1_loss | |||
| model.discriminator.zero_grad() | |||
| (self.r1 / 2 * r1_loss * self.d_reg_every).backward() | |||
| self.optimizer_d.step() | |||
| if isinstance(inputs, Mapping): | |||
| g_loss = model._train_forward_g(**inputs) | |||
| else: | |||
| g_loss = model._train_forward_g(inputs) | |||
| train_outputs['g_loss'] = g_loss | |||
| model.generator.zero_grad() | |||
| g_loss.backward() | |||
| self.optimizer.step() | |||
| path_loss = 0 | |||
| if self._iter % self.g_reg_every == 0: | |||
| if isinstance(inputs, Mapping): | |||
| path_loss = model._train_forward_g_path(**inputs) | |||
| else: | |||
| path_loss = model._train_forward_g_path(inputs) | |||
| train_outputs['path_loss'] = path_loss | |||
| model.generator.zero_grad() | |||
| weighted_path_loss = self.path_regularize * self.g_reg_every * path_loss | |||
| weighted_path_loss.backward() | |||
| self.optimizer.step() | |||
| model.accumulate() | |||
| if not isinstance(train_outputs, dict): | |||
| raise TypeError('"model.forward()" must return a dict') | |||
| # add model output info to log | |||
| if 'log_vars' not in train_outputs: | |||
| default_keys_pattern = ['loss'] | |||
| match_keys = set([]) | |||
| for key_p in default_keys_pattern: | |||
| match_keys.update( | |||
| [key for key in train_outputs.keys() if key_p in key]) | |||
| log_vars = {} | |||
| for key in match_keys: | |||
| value = train_outputs.get(key, None) | |||
| if value is not None: | |||
| if dist.is_available() and dist.is_initialized(): | |||
| value = value.data.clone() | |||
| dist.all_reduce(value.div_(dist.get_world_size())) | |||
| log_vars.update({key: value.item()}) | |||
| self.log_buffer.update(log_vars) | |||
| else: | |||
| self.log_buffer.update(train_outputs['log_vars']) | |||
| self.train_outputs = train_outputs | |||
| def create_optimizer_and_scheduler(self): | |||
| """ Create optimizer and lr scheduler | |||
| We provide a default implementation, if you want to customize your own optimizer | |||
| and lr scheduler, you can either pass a tuple through trainer init function or | |||
| subclass this class and override this method. | |||
| """ | |||
| optimizer, lr_scheduler = self.optimizers | |||
| if optimizer is None: | |||
| optimizer_cfg = self.cfg.train.get('optimizer', None) | |||
| else: | |||
| optimizer_cfg = None | |||
| optimizer_d_cfg = self.cfg.train.get('optimizer_d', None) | |||
| optim_options = {} | |||
| if optimizer_cfg is not None: | |||
| optim_options = optimizer_cfg.pop('options', {}) | |||
| optimizer = build_optimizer( | |||
| self.model.generator, cfg=optimizer_cfg) | |||
| if optimizer_d_cfg is not None: | |||
| optimizer_d = build_optimizer( | |||
| self.model.discriminator, cfg=optimizer_d_cfg) | |||
| lr_options = {} | |||
| self.optimizer = optimizer | |||
| self.lr_scheduler = lr_scheduler | |||
| self.optimizer_d = optimizer_d | |||
| return self.optimizer, self.lr_scheduler, optim_options, lr_options | |||
| @@ -14,5 +14,5 @@ __all__ = [ | |||
| 'Hook', 'HOOKS', 'CheckpointHook', 'EvaluationHook', 'LrSchedulerHook', | |||
| 'OptimizerHook', 'Priority', 'build_hook', 'TextLoggerHook', | |||
| 'IterTimerHook', 'TorchAMPOptimizerHook', 'ApexAMPOptimizerHook', | |||
| 'BestCkptSaverHook' | |||
| 'BestCkptSaverHook', 'NoneOptimizerHook', 'NoneLrSchedulerHook' | |||
| ] | |||
| @@ -115,3 +115,18 @@ class PlateauLrSchedulerHook(LrSchedulerHook): | |||
| self.warmup_lr_scheduler.step(metrics=metrics) | |||
| else: | |||
| trainer.lr_scheduler.step(metrics=metrics) | |||
| @HOOKS.register_module() | |||
| class NoneLrSchedulerHook(LrSchedulerHook): | |||
| PRIORITY = Priority.LOW # should be after EvaluationHook | |||
| def __init__(self, by_epoch=True, warmup=None) -> None: | |||
| super().__init__(by_epoch=by_epoch, warmup=warmup) | |||
| def before_run(self, trainer): | |||
| return | |||
| def after_train_epoch(self, trainer): | |||
| return | |||
| @@ -200,3 +200,19 @@ class ApexAMPOptimizerHook(OptimizerHook): | |||
| trainer.optimizer.step() | |||
| trainer.optimizer.zero_grad() | |||
| @HOOKS.register_module() | |||
| class NoneOptimizerHook(OptimizerHook): | |||
| def __init__(self, cumulative_iters=1, grad_clip=None, loss_keys='loss'): | |||
| super(NoneOptimizerHook, self).__init__( | |||
| grad_clip=grad_clip, loss_keys=loss_keys) | |||
| self.cumulative_iters = cumulative_iters | |||
| def before_run(self, trainer): | |||
| return | |||
| def after_train_iter(self, trainer): | |||
| return | |||
| @@ -43,6 +43,7 @@ class CVTasks(object): | |||
| image_colorization = 'image-colorization' | |||
| image_color_enhancement = 'image-color-enhancement' | |||
| image_denoising = 'image-denoising' | |||
| image_portrait_enhancement = 'image-portrait-enhancement' | |||
| # image generation | |||
| image_to_image_translation = 'image-to-image-translation' | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImagePortraitEnhancementTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||
| self.test_image = 'data/test/images/Solvay_conference_1927.png' | |||
| def pipeline_inference(self, pipeline: Pipeline, test_image: str): | |||
| result = pipeline(test_image) | |||
| if result is not None: | |||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| else: | |||
| raise Exception('Testing failed: invalid output') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| face_enhancement = pipeline( | |||
| Tasks.image_portrait_enhancement, model=self.model_id) | |||
| self.pipeline_inference(face_enhancement, self.test_image) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| face_enhancement = pipeline(Tasks.image_portrait_enhancement) | |||
| self.pipeline_inference(face_enhancement, self.test_image) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,119 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from typing import Callable, List, Optional, Tuple, Union | |||
| import cv2 | |||
| import torch | |||
| from torch.utils import data as data | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.cv.image_portrait_enhancement import \ | |||
| ImagePortraitEnhancement | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestImagePortraitEnhancementTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| self.model_id = 'damo/cv_gpen_image-portrait-enhancement' | |||
| class PairedImageDataset(data.Dataset): | |||
| def __init__(self, root, size=512): | |||
| super(PairedImageDataset, self).__init__() | |||
| self.size = size | |||
| gt_dir = osp.join(root, 'gt') | |||
| lq_dir = osp.join(root, 'lq') | |||
| self.gt_filelist = os.listdir(gt_dir) | |||
| self.gt_filelist = sorted( | |||
| self.gt_filelist, key=lambda x: int(x[:-4])) | |||
| self.gt_filelist = [ | |||
| osp.join(gt_dir, f) for f in self.gt_filelist | |||
| ] | |||
| self.lq_filelist = os.listdir(lq_dir) | |||
| self.lq_filelist = sorted( | |||
| self.lq_filelist, key=lambda x: int(x[:-4])) | |||
| self.lq_filelist = [ | |||
| osp.join(lq_dir, f) for f in self.lq_filelist | |||
| ] | |||
| def _img_to_tensor(self, img): | |||
| img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute( | |||
| 2, 0, 1).type(torch.float32) / 255. | |||
| return (img - 0.5) / 0.5 | |||
| def __getitem__(self, index): | |||
| lq = cv2.imread(self.lq_filelist[index]) | |||
| gt = cv2.imread(self.gt_filelist[index]) | |||
| lq = cv2.resize( | |||
| lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| gt = cv2.resize( | |||
| gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) | |||
| return \ | |||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
| def __len__(self): | |||
| return len(self.gt_filelist) | |||
| def to_torch_dataset(self, | |||
| columns: Union[str, List[str]] = None, | |||
| preprocessors: Union[Callable, | |||
| List[Callable]] = None, | |||
| **format_kwargs): | |||
| # self.preprocessor = preprocessors | |||
| return self | |||
| self.dataset = PairedImageDataset( | |||
| './data/test/images/face_enhancement/') | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| device='gpu', | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(name='gpen', default_args=kwargs) | |||
| trainer.train() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(tmp_dir): | |||
| os.makedirs(tmp_dir) | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = ImagePortraitEnhancement.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| device='gpu', | |||
| max_epochs=2, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(name='gpen', default_args=kwargs) | |||
| trainer.train() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||