Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10111615master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:46db348eae61448f1668ce282caec21375e96c3268d53da44aa67ec32cbf4fa5 | |||||
| size 2747938 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:709c1828ed2d56badf2f19a40194da9a5e5e6db2fb73ef55d047407f49bc7a15 | |||||
| size 27616 | |||||
| @@ -27,6 +27,7 @@ class Models(object): | |||||
| face_2d_keypoints = 'face-2d-keypoints' | face_2d_keypoints = 'face-2d-keypoints' | ||||
| panoptic_segmentation = 'swinL-panoptic-segmentation' | panoptic_segmentation = 'swinL-panoptic-segmentation' | ||||
| image_reid_person = 'passvitb' | image_reid_person = 'passvitb' | ||||
| image_inpainting = 'FFTInpainting' | |||||
| video_summarization = 'pgl-video-summarization' | video_summarization = 'pgl-video-summarization' | ||||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | swinL_semantic_segmentation = 'swinL-semantic-segmentation' | ||||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
| @@ -179,6 +180,7 @@ class Pipelines(object): | |||||
| video_summarization = 'googlenet_pgl_video_summarization' | video_summarization = 'googlenet_pgl_video_summarization' | ||||
| image_semantic_segmentation = 'image-semantic-segmentation' | image_semantic_segmentation = 'image-semantic-segmentation' | ||||
| image_reid_person = 'passvitb-image-reid-person' | image_reid_person = 'passvitb-image-reid-person' | ||||
| image_inpainting = 'fft-inpainting' | |||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| @@ -264,6 +266,7 @@ class Trainers(object): | |||||
| image_portrait_enhancement = 'image-portrait-enhancement' | image_portrait_enhancement = 'image-portrait-enhancement' | ||||
| video_summarization = 'video-summarization' | video_summarization = 'video-summarization' | ||||
| movie_scene_segmentation = 'movie-scene-segmentation' | movie_scene_segmentation = 'movie-scene-segmentation' | ||||
| image_inpainting = 'image-inpainting' | |||||
| # nlp trainers | # nlp trainers | ||||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | bert_sentiment_analysis = 'bert-sentiment-analysis' | ||||
| @@ -363,6 +366,8 @@ class Metrics(object): | |||||
| video_summarization_metric = 'video-summarization-metric' | video_summarization_metric = 'video-summarization-metric' | ||||
| # metric for movie-scene-segmentation task | # metric for movie-scene-segmentation task | ||||
| movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | movie_scene_segmentation_metric = 'movie-scene-segmentation-metric' | ||||
| # metric for inpainting task | |||||
| image_inpainting_metric = 'image-inpainting-metric' | |||||
| class Optimizers(object): | class Optimizers(object): | ||||
| @@ -17,6 +17,7 @@ if TYPE_CHECKING: | |||||
| from .token_classification_metric import TokenClassificationMetric | from .token_classification_metric import TokenClassificationMetric | ||||
| from .video_summarization_metric import VideoSummarizationMetric | from .video_summarization_metric import VideoSummarizationMetric | ||||
| from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric | ||||
| from .image_inpainting_metric import ImageInpaintingMetric | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -34,6 +35,7 @@ else: | |||||
| 'token_classification_metric': ['TokenClassificationMetric'], | 'token_classification_metric': ['TokenClassificationMetric'], | ||||
| 'video_summarization_metric': ['VideoSummarizationMetric'], | 'video_summarization_metric': ['VideoSummarizationMetric'], | ||||
| 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | 'movie_scene_segmentation_metric': ['MovieSceneSegmentationMetric'], | ||||
| 'image_inpainting_metric': ['ImageInpaintingMetric'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -18,6 +18,7 @@ class MetricKeys(object): | |||||
| SSIM = 'ssim' | SSIM = 'ssim' | ||||
| AVERAGE_LOSS = 'avg_loss' | AVERAGE_LOSS = 'avg_loss' | ||||
| FScore = 'fscore' | FScore = 'fscore' | ||||
| FID = 'fid' | |||||
| BLEU_1 = 'bleu-1' | BLEU_1 = 'bleu-1' | ||||
| BLEU_4 = 'bleu-4' | BLEU_4 = 'bleu-4' | ||||
| ROUGE_1 = 'rouge-1' | ROUGE_1 = 'rouge-1' | ||||
| @@ -39,6 +40,7 @@ task_default_metrics = { | |||||
| Tasks.image_captioning: [Metrics.text_gen_metric], | Tasks.image_captioning: [Metrics.text_gen_metric], | ||||
| Tasks.visual_question_answering: [Metrics.text_gen_metric], | Tasks.visual_question_answering: [Metrics.text_gen_metric], | ||||
| Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | Tasks.movie_scene_segmentation: [Metrics.movie_scene_segmentation_metric], | ||||
| Tasks.image_inpainting: [Metrics.image_inpainting_metric], | |||||
| } | } | ||||
| @@ -0,0 +1,210 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from scipy import linalg | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.models.cv.image_inpainting.modules.inception import InceptionV3 | |||||
| from modelscope.utils.registry import default_group | |||||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||||
| torch_nested_numpify) | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| def fid_calculate_activation_statistics(act): | |||||
| mu = np.mean(act, axis=0) | |||||
| sigma = np.cov(act, rowvar=False) | |||||
| return mu, sigma | |||||
| def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): | |||||
| mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) | |||||
| mu2, sigma2 = fid_calculate_activation_statistics(activations_target) | |||||
| diff = mu1 - mu2 | |||||
| # Product might be almost singular | |||||
| covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |||||
| if not np.isfinite(covmean).all(): | |||||
| offset = np.eye(sigma1.shape[0]) * eps | |||||
| covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |||||
| # Numerical error might give slight imaginary component | |||||
| if np.iscomplexobj(covmean): | |||||
| # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |||||
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): | |||||
| m = np.max(np.abs(covmean.imag)) | |||||
| raise ValueError('Imaginary component {}'.format(m)) | |||||
| covmean = covmean.real | |||||
| tr_covmean = np.trace(covmean) | |||||
| return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) | |||||
| - 2 * tr_covmean) | |||||
| class FIDScore(torch.nn.Module): | |||||
| def __init__(self, dims=2048, eps=1e-6): | |||||
| super().__init__() | |||||
| if getattr(FIDScore, '_MODEL', None) is None: | |||||
| block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] | |||||
| FIDScore._MODEL = InceptionV3([block_idx]).eval() | |||||
| self.model = FIDScore._MODEL | |||||
| self.eps = eps | |||||
| self.reset() | |||||
| def forward(self, pred_batch, target_batch, mask=None): | |||||
| activations_pred = self._get_activations(pred_batch) | |||||
| activations_target = self._get_activations(target_batch) | |||||
| self.activations_pred.append(activations_pred.detach().cpu()) | |||||
| self.activations_target.append(activations_target.detach().cpu()) | |||||
| def get_value(self): | |||||
| activations_pred, activations_target = (self.activations_pred, | |||||
| self.activations_target) | |||||
| activations_pred = torch.cat(activations_pred).cpu().numpy() | |||||
| activations_target = torch.cat(activations_target).cpu().numpy() | |||||
| total_distance = calculate_frechet_distance( | |||||
| activations_pred, activations_target, eps=self.eps) | |||||
| self.reset() | |||||
| return total_distance | |||||
| def reset(self): | |||||
| self.activations_pred = [] | |||||
| self.activations_target = [] | |||||
| def _get_activations(self, batch): | |||||
| activations = self.model(batch)[0] | |||||
| if activations.shape[2] != 1 or activations.shape[3] != 1: | |||||
| assert False, \ | |||||
| 'We should not have got here, because Inception always scales inputs to 299x299' | |||||
| activations = activations.squeeze(-1).squeeze(-1) | |||||
| return activations | |||||
| class SSIM(torch.nn.Module): | |||||
| """SSIM. Modified from: | |||||
| https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py | |||||
| """ | |||||
| def __init__(self, window_size=11, size_average=True): | |||||
| super().__init__() | |||||
| self.window_size = window_size | |||||
| self.size_average = size_average | |||||
| self.channel = 1 | |||||
| self.register_buffer('window', | |||||
| self._create_window(window_size, self.channel)) | |||||
| def forward(self, img1, img2): | |||||
| assert len(img1.shape) == 4 | |||||
| channel = img1.size()[1] | |||||
| if channel == self.channel and self.window.data.type( | |||||
| ) == img1.data.type(): | |||||
| window = self.window | |||||
| else: | |||||
| window = self._create_window(self.window_size, channel) | |||||
| window = window.type_as(img1) | |||||
| self.window = window | |||||
| self.channel = channel | |||||
| return self._ssim(img1, img2, window, self.window_size, channel, | |||||
| self.size_average) | |||||
| def _gaussian(self, window_size, sigma): | |||||
| gauss = torch.Tensor([ | |||||
| np.exp(-(x - (window_size // 2))**2 / float(2 * sigma**2)) | |||||
| for x in range(window_size) | |||||
| ]) | |||||
| return gauss / gauss.sum() | |||||
| def _create_window(self, window_size, channel): | |||||
| _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) | |||||
| _2D_window = _1D_window.mm( | |||||
| _1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |||||
| return _2D_window.expand(channel, 1, window_size, | |||||
| window_size).contiguous() | |||||
| def _ssim(self, | |||||
| img1, | |||||
| img2, | |||||
| window, | |||||
| window_size, | |||||
| channel, | |||||
| size_average=True): | |||||
| mu1 = F.conv2d( | |||||
| img1, window, padding=(window_size // 2), groups=channel) | |||||
| mu2 = F.conv2d( | |||||
| img2, window, padding=(window_size // 2), groups=channel) | |||||
| mu1_sq = mu1.pow(2) | |||||
| mu2_sq = mu2.pow(2) | |||||
| mu1_mu2 = mu1 * mu2 | |||||
| sigma1_sq = F.conv2d( | |||||
| img1 * img1, window, padding=(window_size // 2), | |||||
| groups=channel) - mu1_sq | |||||
| sigma2_sq = F.conv2d( | |||||
| img2 * img2, window, padding=(window_size // 2), | |||||
| groups=channel) - mu2_sq | |||||
| sigma12 = F.conv2d( | |||||
| img1 * img2, window, padding=(window_size // 2), | |||||
| groups=channel) - mu1_mu2 | |||||
| C1 = 0.01**2 | |||||
| C2 = 0.03**2 | |||||
| ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ | |||||
| ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |||||
| if size_average: | |||||
| return ssim_map.mean() | |||||
| return ssim_map.mean(1).mean(1).mean(1) | |||||
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |||||
| missing_keys, unexpected_keys, error_msgs): | |||||
| return | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name=Metrics.image_inpainting_metric) | |||||
| class ImageInpaintingMetric(Metric): | |||||
| """The metric computation class for image inpainting classes. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.preds = [] | |||||
| self.targets = [] | |||||
| self.SSIM = SSIM(window_size=11, size_average=False).eval() | |||||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
| self.FID = FIDScore().to(device) | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| pred = outputs['inpainted'] | |||||
| target = inputs['image'] | |||||
| self.preds.append(torch_nested_detach(pred)) | |||||
| self.targets.append(torch_nested_detach(target)) | |||||
| def evaluate(self): | |||||
| ssim_list = [] | |||||
| for (pred, target) in zip(self.preds, self.targets): | |||||
| ssim_list.append(self.SSIM(pred, target)) | |||||
| self.FID(pred, target) | |||||
| ssim_list = torch_nested_numpify(ssim_list) | |||||
| fid = self.FID.get_value() | |||||
| return {MetricKeys.SSIM: np.mean(ssim_list), MetricKeys.FID: fid} | |||||
| @@ -5,13 +5,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||||
| body_3d_keypoints, cartoon, cmdssl_video_embedding, | body_3d_keypoints, cartoon, cmdssl_video_embedding, | ||||
| crowd_counting, face_2d_keypoints, face_detection, | crowd_counting, face_2d_keypoints, face_detection, | ||||
| face_generation, image_classification, image_color_enhance, | face_generation, image_classification, image_color_enhance, | ||||
| image_colorization, image_denoise, image_instance_segmentation, | |||||
| image_panoptic_segmentation, image_portrait_enhancement, | |||||
| image_reid_person, image_semantic_segmentation, | |||||
| image_to_image_generation, image_to_image_translation, | |||||
| movie_scene_segmentation, object_detection, | |||||
| product_retrieval_embedding, realtime_object_detection, | |||||
| salient_detection, shop_segmentation, super_resolution, | |||||
| video_single_object_tracking, video_summarization, virual_tryon) | |||||
| image_colorization, image_denoise, image_inpainting, | |||||
| image_instance_segmentation, image_panoptic_segmentation, | |||||
| image_portrait_enhancement, image_reid_person, | |||||
| image_semantic_segmentation, image_to_image_generation, | |||||
| image_to_image_translation, movie_scene_segmentation, | |||||
| object_detection, product_retrieval_embedding, | |||||
| realtime_object_detection, salient_detection, shop_segmentation, | |||||
| super_resolution, video_single_object_tracking, | |||||
| video_summarization, virual_tryon) | |||||
| # yapf: enable | # yapf: enable | ||||
| @@ -1,3 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from typing import Any, Dict, Optional, Union | from typing import Any, Dict, Optional, Union | ||||
| @@ -1,10 +1,10 @@ | |||||
| # ------------------------------------------------------------------------------ | |||||
| # Copyright (c) Microsoft | |||||
| # Licensed under the MIT License. | |||||
| # Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||||
| # Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||||
| # https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||||
| # ------------------------------------------------------------------------------ | |||||
| """ | |||||
| Copyright (c) Microsoft | |||||
| Licensed under the MIT License. | |||||
| Written by Bin Xiao (Bin.Xiao@microsoft.com) | |||||
| Modified by Ke Sun (sunk@mail.ustc.edu.cn) | |||||
| https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py | |||||
| """ | |||||
| import functools | import functools | ||||
| import logging | import logging | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .model import FFTInpainting | |||||
| else: | |||||
| _import_structure = { | |||||
| 'model': ['FFTInpainting'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,75 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| from typing import Dict, Tuple | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .modules.adversarial import NonSaturatingWithR1 | |||||
| from .modules.ffc import FFCResNetGenerator | |||||
| from .modules.perceptual import ResNetPL | |||||
| from .modules.pix2pixhd import NLayerDiscriminator | |||||
| LOGGER = get_logger() | |||||
| class BaseInpaintingTrainingModule(nn.Module): | |||||
| def __init__(self, | |||||
| model_dir='', | |||||
| use_ddp=True, | |||||
| predict_only=False, | |||||
| visualize_each_iters=100, | |||||
| average_generator=False, | |||||
| generator_avg_beta=0.999, | |||||
| average_generator_start_step=30000, | |||||
| average_generator_period=10, | |||||
| store_discr_outputs_for_vis=False, | |||||
| **kwargs): | |||||
| super().__init__() | |||||
| LOGGER.info( | |||||
| f'BaseInpaintingTrainingModule init called, predict_only is {predict_only}' | |||||
| ) | |||||
| self.generator = FFCResNetGenerator() | |||||
| self.use_ddp = use_ddp | |||||
| if not predict_only: | |||||
| self.discriminator = NLayerDiscriminator() | |||||
| self.adversarial_loss = NonSaturatingWithR1( | |||||
| weight=10, | |||||
| gp_coef=0.001, | |||||
| mask_as_fake_target=True, | |||||
| allow_scale_mask=True) | |||||
| self.average_generator = average_generator | |||||
| self.generator_avg_beta = generator_avg_beta | |||||
| self.average_generator_start_step = average_generator_start_step | |||||
| self.average_generator_period = average_generator_period | |||||
| self.generator_average = None | |||||
| self.last_generator_averaging_step = -1 | |||||
| self.store_discr_outputs_for_vis = store_discr_outputs_for_vis | |||||
| self.loss_l1 = nn.L1Loss(reduction='none') | |||||
| self.loss_resnet_pl = ResNetPL(weight=30, weights_path=model_dir) | |||||
| self.visualize_each_iters = visualize_each_iters | |||||
| LOGGER.info('BaseInpaintingTrainingModule init done') | |||||
| def forward(self, batch: Dict[str, | |||||
| torch.Tensor]) -> Dict[str, torch.Tensor]: | |||||
| """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" | |||||
| raise NotImplementedError() | |||||
| def generator_loss(self, | |||||
| batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||||
| raise NotImplementedError() | |||||
| def discriminator_loss( | |||||
| self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||||
| raise NotImplementedError() | |||||
| @@ -0,0 +1,210 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import bisect | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .base import BaseInpaintingTrainingModule | |||||
| from .modules.feature_matching import feature_matching_loss, masked_l1_loss | |||||
| LOGGER = get_logger() | |||||
| def set_requires_grad(module, value): | |||||
| for param in module.parameters(): | |||||
| param.requires_grad = value | |||||
| def add_prefix_to_keys(dct, prefix): | |||||
| return {prefix + k: v for k, v in dct.items()} | |||||
| class LinearRamp: | |||||
| def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): | |||||
| self.start_value = start_value | |||||
| self.end_value = end_value | |||||
| self.start_iter = start_iter | |||||
| self.end_iter = end_iter | |||||
| def __call__(self, i): | |||||
| if i < self.start_iter: | |||||
| return self.start_value | |||||
| if i >= self.end_iter: | |||||
| return self.end_value | |||||
| part = (i - self.start_iter) / (self.end_iter - self.start_iter) | |||||
| return self.start_value * (1 - part) + self.end_value * part | |||||
| class LadderRamp: | |||||
| def __init__(self, start_iters, values): | |||||
| self.start_iters = start_iters | |||||
| self.values = values | |||||
| assert len(values) == len(start_iters) + 1, (len(values), | |||||
| len(start_iters)) | |||||
| def __call__(self, i): | |||||
| segment_i = bisect.bisect_right(self.start_iters, i) | |||||
| return self.values[segment_i] | |||||
| def get_ramp(kind='ladder', **kwargs): | |||||
| if kind == 'linear': | |||||
| return LinearRamp(**kwargs) | |||||
| if kind == 'ladder': | |||||
| return LadderRamp(**kwargs) | |||||
| raise ValueError(f'Unexpected ramp kind: {kind}') | |||||
| class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): | |||||
| def __init__(self, | |||||
| model_dir='', | |||||
| predict_only=False, | |||||
| concat_mask=True, | |||||
| rescale_scheduler_kwargs=None, | |||||
| image_to_discriminator='predicted_image', | |||||
| add_noise_kwargs=None, | |||||
| noise_fill_hole=False, | |||||
| const_area_crop_kwargs=None, | |||||
| distance_weighter_kwargs=None, | |||||
| distance_weighted_mask_for_discr=False, | |||||
| fake_fakes_proba=0, | |||||
| fake_fakes_generator_kwargs=None, | |||||
| **kwargs): | |||||
| super().__init__(model_dir=model_dir, predict_only=predict_only) | |||||
| self.concat_mask = concat_mask | |||||
| self.rescale_size_getter = get_ramp( | |||||
| **rescale_scheduler_kwargs | |||||
| ) if rescale_scheduler_kwargs is not None else None | |||||
| self.image_to_discriminator = image_to_discriminator | |||||
| self.add_noise_kwargs = add_noise_kwargs | |||||
| self.noise_fill_hole = noise_fill_hole | |||||
| self.const_area_crop_kwargs = const_area_crop_kwargs | |||||
| self.refine_mask_for_losses = None | |||||
| self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr | |||||
| self.feature_matching_weight = 100 | |||||
| self.losses_l1_weight_known = 10 | |||||
| self.losses_l1_weight_missing = 0 | |||||
| self.fake_fakes_proba = fake_fakes_proba | |||||
| def forward(self, batch): | |||||
| img = batch['image'] | |||||
| mask = batch['mask'] | |||||
| masked_img = img * (1 - mask) | |||||
| if self.concat_mask: | |||||
| masked_img = torch.cat([masked_img, mask], dim=1) | |||||
| batch['predicted_image'] = self.generator(masked_img) | |||||
| batch['inpainted'] = mask * batch['predicted_image'] + ( | |||||
| 1 - mask) * batch['image'] | |||||
| batch['mask_for_losses'] = mask | |||||
| return batch | |||||
| def generator_loss(self, batch): | |||||
| img = batch['image'] | |||||
| predicted_img = batch[self.image_to_discriminator] | |||||
| original_mask = batch['mask'] | |||||
| supervised_mask = batch['mask_for_losses'] | |||||
| # L1 | |||||
| l1_value = masked_l1_loss(predicted_img, img, supervised_mask, | |||||
| self.losses_l1_weight_known, | |||||
| self.losses_l1_weight_missing) | |||||
| total_loss = l1_value | |||||
| metrics = dict(gen_l1=l1_value) | |||||
| # discriminator | |||||
| # adversarial_loss calls backward by itself | |||||
| mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask | |||||
| self.adversarial_loss.pre_generator_step( | |||||
| real_batch=img, | |||||
| fake_batch=predicted_img, | |||||
| generator=self.generator, | |||||
| discriminator=self.discriminator) | |||||
| discr_real_pred, discr_real_features = self.discriminator(img) | |||||
| discr_fake_pred, discr_fake_features = self.discriminator( | |||||
| predicted_img) | |||||
| adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss( | |||||
| real_batch=img, | |||||
| fake_batch=predicted_img, | |||||
| discr_real_pred=discr_real_pred, | |||||
| discr_fake_pred=discr_fake_pred, | |||||
| mask=mask_for_discr) | |||||
| total_loss = total_loss + adv_gen_loss | |||||
| metrics['gen_adv'] = adv_gen_loss | |||||
| metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) | |||||
| # feature matching | |||||
| if self.feature_matching_weight > 0: | |||||
| need_mask_in_fm = False | |||||
| mask_for_fm = supervised_mask if need_mask_in_fm else None | |||||
| fm_value = feature_matching_loss( | |||||
| discr_fake_features, discr_real_features, | |||||
| mask=mask_for_fm) * self.feature_matching_weight | |||||
| total_loss = total_loss + fm_value | |||||
| metrics['gen_fm'] = fm_value | |||||
| if self.loss_resnet_pl is not None: | |||||
| resnet_pl_value = self.loss_resnet_pl(predicted_img, img) | |||||
| total_loss = total_loss + resnet_pl_value | |||||
| metrics['gen_resnet_pl'] = resnet_pl_value | |||||
| return total_loss, metrics | |||||
| def discriminator_loss(self, batch): | |||||
| total_loss = 0 | |||||
| metrics = {} | |||||
| predicted_img = batch[self.image_to_discriminator].detach() | |||||
| self.adversarial_loss.pre_discriminator_step( | |||||
| real_batch=batch['image'], | |||||
| fake_batch=predicted_img, | |||||
| generator=self.generator, | |||||
| discriminator=self.discriminator) | |||||
| discr_real_pred, discr_real_features = self.discriminator( | |||||
| batch['image']) | |||||
| discr_fake_pred, discr_fake_features = self.discriminator( | |||||
| predicted_img) | |||||
| adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss( | |||||
| real_batch=batch['image'], | |||||
| fake_batch=predicted_img, | |||||
| discr_real_pred=discr_real_pred, | |||||
| discr_fake_pred=discr_fake_pred, | |||||
| mask=batch['mask']) | |||||
| total_loss = (total_loss + adv_discr_loss) * 0.1 | |||||
| metrics['discr_adv'] = adv_discr_loss | |||||
| metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) | |||||
| return total_loss, metrics | |||||
| def _do_step(self, batch, optimizer_idx=None): | |||||
| if optimizer_idx == 0: # step for generator | |||||
| set_requires_grad(self.generator, True) | |||||
| set_requires_grad(self.discriminator, False) | |||||
| elif optimizer_idx == 1: # step for discriminator | |||||
| set_requires_grad(self.generator, False) | |||||
| set_requires_grad(self.discriminator, True) | |||||
| batch = self(batch) | |||||
| total_loss = 0 | |||||
| if optimizer_idx is None or optimizer_idx == 0: # step for generator | |||||
| total_loss, metrics = self.generator_loss(batch) | |||||
| elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator | |||||
| total_loss, metrics = self.discriminator_loss(batch) | |||||
| result = dict(loss=total_loss) | |||||
| return result | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base.base_torch_model import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| LOGGER = get_logger() | |||||
| @MODELS.register_module( | |||||
| Tasks.image_inpainting, module_name=Models.image_inpainting) | |||||
| class FFTInpainting(TorchModel): | |||||
| def __init__(self, model_dir: str, **kwargs): | |||||
| super().__init__(model_dir, **kwargs) | |||||
| from .default import DefaultInpaintingTrainingModule | |||||
| pretrained = kwargs.get('pretrained', True) | |||||
| predict_only = kwargs.get('predict_only', False) | |||||
| net = DefaultInpaintingTrainingModule( | |||||
| model_dir=model_dir, predict_only=predict_only) | |||||
| if pretrained: | |||||
| path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| LOGGER.info(f'loading pretrained model from {path}') | |||||
| state = torch.load(path, map_location='cpu') | |||||
| net.load_state_dict(state, strict=False) | |||||
| self.model = net | |||||
| def forward(self, inputs): | |||||
| return self.model(inputs) | |||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .base import ModelBuilder | |||||
| @@ -0,0 +1,380 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import os | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from torch.nn.modules import BatchNorm2d | |||||
| from . import resnet | |||||
| NUM_CLASS = 150 | |||||
| # Model Builder | |||||
| class ModelBuilder: | |||||
| # custom weights initialization | |||||
| @staticmethod | |||||
| def weights_init(m): | |||||
| classname = m.__class__.__name__ | |||||
| if classname.find('Conv') != -1: | |||||
| nn.init.kaiming_normal_(m.weight.data) | |||||
| elif classname.find('BatchNorm') != -1: | |||||
| m.weight.data.fill_(1.) | |||||
| m.bias.data.fill_(1e-4) | |||||
| @staticmethod | |||||
| def build_encoder(arch='resnet50dilated', | |||||
| fc_dim=512, | |||||
| weights='', | |||||
| model_dir=''): | |||||
| pretrained = True if len(weights) == 0 else False | |||||
| arch = arch.lower() | |||||
| if arch == 'resnet50dilated': | |||||
| orig_resnet = resnet.__dict__['resnet50']( | |||||
| pretrained=pretrained, model_dir=model_dir) | |||||
| net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) | |||||
| elif arch == 'resnet50': | |||||
| orig_resnet = resnet.__dict__['resnet50']( | |||||
| pretrained=pretrained, model_dir=model_dir) | |||||
| net_encoder = Resnet(orig_resnet) | |||||
| else: | |||||
| raise Exception('Architecture undefined!') | |||||
| # encoders are usually pretrained | |||||
| # net_encoder.apply(ModelBuilder.weights_init) | |||||
| if len(weights) > 0: | |||||
| print('Loading weights for net_encoder') | |||||
| net_encoder.load_state_dict( | |||||
| torch.load(weights, map_location=lambda storage, loc: storage), | |||||
| strict=False) | |||||
| return net_encoder | |||||
| @staticmethod | |||||
| def build_decoder(arch='ppm_deepsup', | |||||
| fc_dim=512, | |||||
| num_class=NUM_CLASS, | |||||
| weights='', | |||||
| use_softmax=False, | |||||
| drop_last_conv=False): | |||||
| arch = arch.lower() | |||||
| if arch == 'ppm_deepsup': | |||||
| net_decoder = PPMDeepsup( | |||||
| num_class=num_class, | |||||
| fc_dim=fc_dim, | |||||
| use_softmax=use_softmax, | |||||
| drop_last_conv=drop_last_conv) | |||||
| elif arch == 'c1_deepsup': | |||||
| net_decoder = C1DeepSup( | |||||
| num_class=num_class, | |||||
| fc_dim=fc_dim, | |||||
| use_softmax=use_softmax, | |||||
| drop_last_conv=drop_last_conv) | |||||
| else: | |||||
| raise Exception('Architecture undefined!') | |||||
| net_decoder.apply(ModelBuilder.weights_init) | |||||
| if len(weights) > 0: | |||||
| print('Loading weights for net_decoder') | |||||
| net_decoder.load_state_dict( | |||||
| torch.load(weights, map_location=lambda storage, loc: storage), | |||||
| strict=False) | |||||
| return net_decoder | |||||
| @staticmethod | |||||
| def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, | |||||
| drop_last_conv, *arts, **kwargs): | |||||
| path = os.path.join( | |||||
| weights_path, 'ade20k', | |||||
| f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') | |||||
| return ModelBuilder.build_decoder( | |||||
| arch=arch_decoder, | |||||
| fc_dim=fc_dim, | |||||
| weights=path, | |||||
| use_softmax=True, | |||||
| drop_last_conv=drop_last_conv) | |||||
| @staticmethod | |||||
| def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, | |||||
| segmentation, *arts, **kwargs): | |||||
| if segmentation: | |||||
| path = os.path.join( | |||||
| weights_path, 'ade20k', | |||||
| f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') | |||||
| else: | |||||
| path = '' | |||||
| return ModelBuilder.build_encoder( | |||||
| arch=arch_encoder, | |||||
| fc_dim=fc_dim, | |||||
| weights=path, | |||||
| model_dir=weights_path) | |||||
| def conv3x3_bn_relu(in_planes, out_planes, stride=1): | |||||
| return nn.Sequential( | |||||
| nn.Conv2d( | |||||
| in_planes, | |||||
| out_planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| bias=False), | |||||
| BatchNorm2d(out_planes), | |||||
| nn.ReLU(inplace=True), | |||||
| ) | |||||
| # pyramid pooling, deep supervision | |||||
| class PPMDeepsup(nn.Module): | |||||
| def __init__(self, | |||||
| num_class=NUM_CLASS, | |||||
| fc_dim=4096, | |||||
| use_softmax=False, | |||||
| pool_scales=(1, 2, 3, 6), | |||||
| drop_last_conv=False): | |||||
| super().__init__() | |||||
| self.use_softmax = use_softmax | |||||
| self.drop_last_conv = drop_last_conv | |||||
| self.ppm = [] | |||||
| for scale in pool_scales: | |||||
| self.ppm.append( | |||||
| nn.Sequential( | |||||
| nn.AdaptiveAvgPool2d(scale), | |||||
| nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |||||
| BatchNorm2d(512), nn.ReLU(inplace=True))) | |||||
| self.ppm = nn.ModuleList(self.ppm) | |||||
| self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |||||
| self.conv_last = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| fc_dim + len(pool_scales) * 512, | |||||
| 512, | |||||
| kernel_size=3, | |||||
| padding=1, | |||||
| bias=False), BatchNorm2d(512), nn.ReLU(inplace=True), | |||||
| nn.Dropout2d(0.1), nn.Conv2d(512, num_class, kernel_size=1)) | |||||
| self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||||
| self.dropout_deepsup = nn.Dropout2d(0.1) | |||||
| def forward(self, conv_out, segSize=None): | |||||
| conv5 = conv_out[-1] | |||||
| input_size = conv5.size() | |||||
| ppm_out = [conv5] | |||||
| for pool_scale in self.ppm: | |||||
| ppm_out.append( | |||||
| nn.functional.interpolate( | |||||
| pool_scale(conv5), (input_size[2], input_size[3]), | |||||
| mode='bilinear', | |||||
| align_corners=False)) | |||||
| ppm_out = torch.cat(ppm_out, 1) | |||||
| if self.drop_last_conv: | |||||
| return ppm_out | |||||
| else: | |||||
| x = self.conv_last(ppm_out) | |||||
| if self.use_softmax: # is True during inference | |||||
| x = nn.functional.interpolate( | |||||
| x, size=segSize, mode='bilinear', align_corners=False) | |||||
| x = nn.functional.softmax(x, dim=1) | |||||
| return x | |||||
| # deep sup | |||||
| conv4 = conv_out[-2] | |||||
| _ = self.cbr_deepsup(conv4) | |||||
| _ = self.dropout_deepsup(_) | |||||
| _ = self.conv_last_deepsup(_) | |||||
| x = nn.functional.log_softmax(x, dim=1) | |||||
| _ = nn.functional.log_softmax(_, dim=1) | |||||
| return (x, _) | |||||
| class Resnet(nn.Module): | |||||
| def __init__(self, orig_resnet): | |||||
| super(Resnet, self).__init__() | |||||
| # take pretrained resnet, except AvgPool and FC | |||||
| self.conv1 = orig_resnet.conv1 | |||||
| self.bn1 = orig_resnet.bn1 | |||||
| self.relu1 = orig_resnet.relu1 | |||||
| self.conv2 = orig_resnet.conv2 | |||||
| self.bn2 = orig_resnet.bn2 | |||||
| self.relu2 = orig_resnet.relu2 | |||||
| self.conv3 = orig_resnet.conv3 | |||||
| self.bn3 = orig_resnet.bn3 | |||||
| self.relu3 = orig_resnet.relu3 | |||||
| self.maxpool = orig_resnet.maxpool | |||||
| self.layer1 = orig_resnet.layer1 | |||||
| self.layer2 = orig_resnet.layer2 | |||||
| self.layer3 = orig_resnet.layer3 | |||||
| self.layer4 = orig_resnet.layer4 | |||||
| def forward(self, x, return_feature_maps=False): | |||||
| conv_out = [] | |||||
| x = self.relu1(self.bn1(self.conv1(x))) | |||||
| x = self.relu2(self.bn2(self.conv2(x))) | |||||
| x = self.relu3(self.bn3(self.conv3(x))) | |||||
| x = self.maxpool(x) | |||||
| x = self.layer1(x) | |||||
| conv_out.append(x) | |||||
| x = self.layer2(x) | |||||
| conv_out.append(x) | |||||
| x = self.layer3(x) | |||||
| conv_out.append(x) | |||||
| x = self.layer4(x) | |||||
| conv_out.append(x) | |||||
| if return_feature_maps: | |||||
| return conv_out | |||||
| return [x] | |||||
| # Resnet Dilated | |||||
| class ResnetDilated(nn.Module): | |||||
| def __init__(self, orig_resnet, dilate_scale=8): | |||||
| super().__init__() | |||||
| from functools import partial | |||||
| if dilate_scale == 8: | |||||
| orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) | |||||
| orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) | |||||
| elif dilate_scale == 16: | |||||
| orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) | |||||
| # take pretrained resnet, except AvgPool and FC | |||||
| self.conv1 = orig_resnet.conv1 | |||||
| self.bn1 = orig_resnet.bn1 | |||||
| self.relu1 = orig_resnet.relu1 | |||||
| self.conv2 = orig_resnet.conv2 | |||||
| self.bn2 = orig_resnet.bn2 | |||||
| self.relu2 = orig_resnet.relu2 | |||||
| self.conv3 = orig_resnet.conv3 | |||||
| self.bn3 = orig_resnet.bn3 | |||||
| self.relu3 = orig_resnet.relu3 | |||||
| self.maxpool = orig_resnet.maxpool | |||||
| self.layer1 = orig_resnet.layer1 | |||||
| self.layer2 = orig_resnet.layer2 | |||||
| self.layer3 = orig_resnet.layer3 | |||||
| self.layer4 = orig_resnet.layer4 | |||||
| def _nostride_dilate(self, m, dilate): | |||||
| classname = m.__class__.__name__ | |||||
| if classname.find('Conv') != -1: | |||||
| # the convolution with stride | |||||
| if m.stride == (2, 2): | |||||
| m.stride = (1, 1) | |||||
| if m.kernel_size == (3, 3): | |||||
| m.dilation = (dilate // 2, dilate // 2) | |||||
| m.padding = (dilate // 2, dilate // 2) | |||||
| # other convoluions | |||||
| else: | |||||
| if m.kernel_size == (3, 3): | |||||
| m.dilation = (dilate, dilate) | |||||
| m.padding = (dilate, dilate) | |||||
| def forward(self, x, return_feature_maps=False): | |||||
| conv_out = [] | |||||
| x = self.relu1(self.bn1(self.conv1(x))) | |||||
| x = self.relu2(self.bn2(self.conv2(x))) | |||||
| x = self.relu3(self.bn3(self.conv3(x))) | |||||
| x = self.maxpool(x) | |||||
| x = self.layer1(x) | |||||
| conv_out.append(x) | |||||
| x = self.layer2(x) | |||||
| conv_out.append(x) | |||||
| x = self.layer3(x) | |||||
| conv_out.append(x) | |||||
| x = self.layer4(x) | |||||
| conv_out.append(x) | |||||
| if return_feature_maps: | |||||
| return conv_out | |||||
| return [x] | |||||
| # last conv, deep supervision | |||||
| class C1DeepSup(nn.Module): | |||||
| def __init__(self, | |||||
| num_class=150, | |||||
| fc_dim=2048, | |||||
| use_softmax=False, | |||||
| drop_last_conv=False): | |||||
| super(C1DeepSup, self).__init__() | |||||
| self.use_softmax = use_softmax | |||||
| self.drop_last_conv = drop_last_conv | |||||
| self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |||||
| self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |||||
| # last conv | |||||
| self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||||
| self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||||
| def forward(self, conv_out, segSize=None): | |||||
| conv5 = conv_out[-1] | |||||
| x = self.cbr(conv5) | |||||
| if self.drop_last_conv: | |||||
| return x | |||||
| else: | |||||
| x = self.conv_last(x) | |||||
| if self.use_softmax: # is True during inference | |||||
| x = nn.functional.interpolate( | |||||
| x, size=segSize, mode='bilinear', align_corners=False) | |||||
| x = nn.functional.softmax(x, dim=1) | |||||
| return x | |||||
| # deep sup | |||||
| conv4 = conv_out[-2] | |||||
| _ = self.cbr_deepsup(conv4) | |||||
| _ = self.conv_last_deepsup(_) | |||||
| x = nn.functional.log_softmax(x, dim=1) | |||||
| _ = nn.functional.log_softmax(_, dim=1) | |||||
| return (x, _) | |||||
| # last conv | |||||
| class C1(nn.Module): | |||||
| def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): | |||||
| super(C1, self).__init__() | |||||
| self.use_softmax = use_softmax | |||||
| self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |||||
| # last conv | |||||
| self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |||||
| def forward(self, conv_out, segSize=None): | |||||
| conv5 = conv_out[-1] | |||||
| x = self.cbr(conv5) | |||||
| x = self.conv_last(x) | |||||
| if self.use_softmax: # is True during inference | |||||
| x = nn.functional.interpolate( | |||||
| x, size=segSize, mode='bilinear', align_corners=False) | |||||
| x = nn.functional.softmax(x, dim=1) | |||||
| else: | |||||
| x = nn.functional.log_softmax(x, dim=1) | |||||
| return x | |||||
| @@ -0,0 +1,183 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import math | |||||
| import os | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch.nn import BatchNorm2d | |||||
| __all__ = ['ResNet', 'resnet50'] | |||||
| def conv3x3(in_planes, out_planes, stride=1): | |||||
| '3x3 convolution with padding' | |||||
| return nn.Conv2d( | |||||
| in_planes, | |||||
| out_planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| bias=False) | |||||
| class BasicBlock(nn.Module): | |||||
| expansion = 1 | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
| super(BasicBlock, self).__init__() | |||||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||||
| self.bn1 = BatchNorm2d(planes) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.conv2 = conv3x3(planes, planes) | |||||
| self.bn2 = BatchNorm2d(planes) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| def forward(self, x): | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| if self.downsample is not None: | |||||
| residual = self.downsample(x) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
| super(Bottleneck, self).__init__() | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||||
| self.bn1 = BatchNorm2d(planes) | |||||
| self.conv2 = nn.Conv2d( | |||||
| planes, | |||||
| planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| bias=False) | |||||
| self.bn2 = BatchNorm2d(planes) | |||||
| self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |||||
| self.bn3 = BatchNorm2d(planes * 4) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| def forward(self, x): | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.downsample is not None: | |||||
| residual = self.downsample(x) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class ResNet(nn.Module): | |||||
| def __init__(self, block, layers, num_classes=1000): | |||||
| self.inplanes = 128 | |||||
| super(ResNet, self).__init__() | |||||
| self.conv1 = conv3x3(3, 64, stride=2) | |||||
| self.bn1 = BatchNorm2d(64) | |||||
| self.relu1 = nn.ReLU(inplace=True) | |||||
| self.conv2 = conv3x3(64, 64) | |||||
| self.bn2 = BatchNorm2d(64) | |||||
| self.relu2 = nn.ReLU(inplace=True) | |||||
| self.conv3 = conv3x3(64, 128) | |||||
| self.bn3 = BatchNorm2d(128) | |||||
| self.relu3 = nn.ReLU(inplace=True) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |||||
| self.avgpool = nn.AvgPool2d(7, stride=1) | |||||
| self.fc = nn.Linear(512 * block.expansion, num_classes) | |||||
| for m in self.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |||||
| elif isinstance(m, BatchNorm2d): | |||||
| m.weight.data.fill_(1) | |||||
| m.bias.data.zero_() | |||||
| def _make_layer(self, block, planes, blocks, stride=1): | |||||
| downsample = None | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||||
| downsample = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| self.inplanes, | |||||
| planes * block.expansion, | |||||
| kernel_size=1, | |||||
| stride=stride, | |||||
| bias=False), | |||||
| BatchNorm2d(planes * block.expansion), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append(block(self.inplanes, planes, stride, downsample)) | |||||
| self.inplanes = planes * block.expansion | |||||
| for i in range(1, blocks): | |||||
| layers.append(block(self.inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| x = self.relu1(self.bn1(self.conv1(x))) | |||||
| x = self.relu2(self.bn2(self.conv2(x))) | |||||
| x = self.relu3(self.bn3(self.conv3(x))) | |||||
| x = self.maxpool(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| x = self.layer4(x) | |||||
| x = self.avgpool(x) | |||||
| x = x.view(x.size(0), -1) | |||||
| x = self.fc(x) | |||||
| return x | |||||
| def resnet50(pretrained=False, model_dir='', **kwargs): | |||||
| """Constructs a ResNet-50 model. | |||||
| Args: | |||||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
| """ | |||||
| model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) | |||||
| if pretrained: | |||||
| cached_file = os.path.join(model_dir, 'resnet50-imagenet.pth') | |||||
| model.load_state_dict( | |||||
| torch.load(cached_file, map_location='cpu'), strict=False) | |||||
| return model | |||||
| @@ -0,0 +1,167 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| from typing import Dict, Optional, Tuple | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class BaseAdversarialLoss: | |||||
| def pre_generator_step(self, real_batch: torch.Tensor, | |||||
| fake_batch: torch.Tensor, generator: nn.Module, | |||||
| discriminator: nn.Module): | |||||
| """ | |||||
| Prepare for generator step | |||||
| :param real_batch: Tensor, a batch of real samples | |||||
| :param fake_batch: Tensor, a batch of samples produced by generator | |||||
| :param generator: | |||||
| :param discriminator: | |||||
| :return: None | |||||
| """ | |||||
| def pre_discriminator_step(self, real_batch: torch.Tensor, | |||||
| fake_batch: torch.Tensor, generator: nn.Module, | |||||
| discriminator: nn.Module): | |||||
| """ | |||||
| Prepare for discriminator step | |||||
| :param real_batch: Tensor, a batch of real samples | |||||
| :param fake_batch: Tensor, a batch of samples produced by generator | |||||
| :param generator: | |||||
| :param discriminator: | |||||
| :return: None | |||||
| """ | |||||
| def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||||
| discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None) \ | |||||
| -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||||
| """ | |||||
| Calculate generator loss | |||||
| :param real_batch: Tensor, a batch of real samples | |||||
| :param fake_batch: Tensor, a batch of samples produced by generator | |||||
| :param discr_real_pred: Tensor, discriminator output for real_batch | |||||
| :param discr_fake_pred: Tensor, discriminator output for fake_batch | |||||
| :param mask: Tensor, actual mask, which was at input of generator when making fake_batch | |||||
| :return: total generator loss along with some values that might be interesting to log | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||||
| discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None) \ | |||||
| -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||||
| """ | |||||
| Calculate discriminator loss and call .backward() on it | |||||
| :param real_batch: Tensor, a batch of real samples | |||||
| :param fake_batch: Tensor, a batch of samples produced by generator | |||||
| :param discr_real_pred: Tensor, discriminator output for real_batch | |||||
| :param discr_fake_pred: Tensor, discriminator output for fake_batch | |||||
| :param mask: Tensor, actual mask, which was at input of generator when making fake_batch | |||||
| :return: total discriminator loss along with some values that might be interesting to log | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def interpolate_mask(self, mask, shape): | |||||
| assert mask is not None | |||||
| assert self.allow_scale_mask or shape == mask.shape[-2:] | |||||
| if shape != mask.shape[-2:] and self.allow_scale_mask: | |||||
| if self.mask_scale_mode == 'maxpool': | |||||
| mask = F.adaptive_max_pool2d(mask, shape) | |||||
| else: | |||||
| mask = F.interpolate( | |||||
| mask, size=shape, mode=self.mask_scale_mode) | |||||
| return mask | |||||
| def make_r1_gp(discr_real_pred, real_batch): | |||||
| if torch.is_grad_enabled(): | |||||
| grad_real = torch.autograd.grad( | |||||
| outputs=discr_real_pred.sum(), | |||||
| inputs=real_batch, | |||||
| create_graph=True)[0] | |||||
| grad_penalty = (grad_real.view(grad_real.shape[0], | |||||
| -1).norm(2, dim=1)**2).mean() | |||||
| else: | |||||
| grad_penalty = 0 | |||||
| real_batch.requires_grad = False | |||||
| return grad_penalty | |||||
| class NonSaturatingWithR1(BaseAdversarialLoss): | |||||
| def __init__(self, | |||||
| gp_coef=5, | |||||
| weight=1, | |||||
| mask_as_fake_target=False, | |||||
| allow_scale_mask=False, | |||||
| mask_scale_mode='nearest', | |||||
| extra_mask_weight_for_gen=0, | |||||
| use_unmasked_for_gen=True, | |||||
| use_unmasked_for_discr=True): | |||||
| self.gp_coef = gp_coef | |||||
| self.weight = weight | |||||
| # use for discr => use for gen; | |||||
| # otherwise we teach only the discr to pay attention to very small difference | |||||
| assert use_unmasked_for_gen or (not use_unmasked_for_discr) | |||||
| # mask as target => use unmasked for discr: | |||||
| # if we don't care about unmasked regions at all | |||||
| # then it doesn't matter if the value of mask_as_fake_target is true or false | |||||
| assert use_unmasked_for_discr or (not mask_as_fake_target) | |||||
| self.use_unmasked_for_gen = use_unmasked_for_gen | |||||
| self.use_unmasked_for_discr = use_unmasked_for_discr | |||||
| self.mask_as_fake_target = mask_as_fake_target | |||||
| self.allow_scale_mask = allow_scale_mask | |||||
| self.mask_scale_mode = mask_scale_mode | |||||
| self.extra_mask_weight_for_gen = extra_mask_weight_for_gen | |||||
| def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||||
| discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||||
| mask=None) \ | |||||
| -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||||
| fake_loss = F.softplus(-discr_fake_pred) | |||||
| if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ | |||||
| not self.use_unmasked_for_gen: # == if masked region should be treated differently | |||||
| mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) | |||||
| if not self.use_unmasked_for_gen: | |||||
| fake_loss = fake_loss * mask | |||||
| else: | |||||
| pixel_weights = 1 + mask * self.extra_mask_weight_for_gen | |||||
| fake_loss = fake_loss * pixel_weights | |||||
| return fake_loss.mean() * self.weight, dict() | |||||
| def pre_discriminator_step(self, real_batch: torch.Tensor, | |||||
| fake_batch: torch.Tensor, generator: nn.Module, | |||||
| discriminator: nn.Module): | |||||
| real_batch.requires_grad = True | |||||
| def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, | |||||
| discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, | |||||
| mask=None) \ | |||||
| -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |||||
| real_loss = F.softplus(-discr_real_pred) | |||||
| grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef | |||||
| fake_loss = F.softplus(discr_fake_pred) | |||||
| if not self.use_unmasked_for_discr or self.mask_as_fake_target: | |||||
| # == if masked region should be treated differently | |||||
| mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) | |||||
| # use_unmasked_for_discr=False only makes sense for fakes; | |||||
| # for reals there is no difference beetween two regions | |||||
| fake_loss = fake_loss * mask | |||||
| if self.mask_as_fake_target: | |||||
| fake_loss = fake_loss + (1 | |||||
| - mask) * F.softplus(-discr_fake_pred) | |||||
| sum_discr_loss = real_loss + grad_penalty + fake_loss | |||||
| metrics = dict( | |||||
| discr_real_out=discr_real_pred.mean(), | |||||
| discr_fake_out=discr_fake_pred.mean(), | |||||
| discr_real_gp=grad_penalty) | |||||
| return sum_discr_loss.mean(), metrics | |||||
| @@ -0,0 +1,45 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| from typing import List | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| def masked_l2_loss(pred, target, mask, weight_known, weight_missing): | |||||
| per_pixel_l2 = F.mse_loss(pred, target, reduction='none') | |||||
| pixel_weights = mask * weight_missing + (1 - mask) * weight_known | |||||
| return (pixel_weights * per_pixel_l2).mean() | |||||
| def masked_l1_loss(pred, target, mask, weight_known, weight_missing): | |||||
| per_pixel_l1 = F.l1_loss(pred, target, reduction='none') | |||||
| pixel_weights = mask * weight_missing + (1 - mask) * weight_known | |||||
| return (pixel_weights * per_pixel_l1).mean() | |||||
| def feature_matching_loss(fake_features: List[torch.Tensor], | |||||
| target_features: List[torch.Tensor], | |||||
| mask=None): | |||||
| if mask is None: | |||||
| res = torch.stack([ | |||||
| F.mse_loss(fake_feat, target_feat) | |||||
| for fake_feat, target_feat in zip(fake_features, target_features) | |||||
| ]).mean() | |||||
| else: | |||||
| res = 0 | |||||
| norm = 0 | |||||
| for fake_feat, target_feat in zip(fake_features, target_features): | |||||
| cur_mask = F.interpolate( | |||||
| mask, | |||||
| size=fake_feat.shape[-2:], | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| error_weights = 1 - cur_mask | |||||
| cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() | |||||
| res = res + cur_val | |||||
| norm += 1 | |||||
| res = res / norm | |||||
| return res | |||||
| @@ -0,0 +1,588 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from kornia.geometry.transform import rotate | |||||
| def get_activation(kind='tanh'): | |||||
| if kind == 'tanh': | |||||
| return nn.Tanh() | |||||
| if kind == 'sigmoid': | |||||
| return nn.Sigmoid() | |||||
| if kind is False: | |||||
| return nn.Identity() | |||||
| raise ValueError(f'Unknown activation kind {kind}') | |||||
| class SELayer(nn.Module): | |||||
| def __init__(self, channel, reduction=16): | |||||
| super(SELayer, self).__init__() | |||||
| self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||||
| self.fc = nn.Sequential( | |||||
| nn.Linear(channel, channel // reduction, bias=False), | |||||
| nn.ReLU(inplace=True), | |||||
| nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid()) | |||||
| def forward(self, x): | |||||
| b, c, _, _ = x.size() | |||||
| y = self.avg_pool(x).view(b, c) | |||||
| y = self.fc(y).view(b, c, 1, 1) | |||||
| res = x * y.expand_as(x) | |||||
| return res | |||||
| class FourierUnit(nn.Module): | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| groups=1, | |||||
| spatial_scale_factor=None, | |||||
| spatial_scale_mode='bilinear', | |||||
| spectral_pos_encoding=False, | |||||
| use_se=False, | |||||
| se_kwargs=None, | |||||
| ffc3d=False, | |||||
| fft_norm='ortho'): | |||||
| # bn_layer not used | |||||
| super(FourierUnit, self).__init__() | |||||
| self.groups = groups | |||||
| self.conv_layer = torch.nn.Conv2d( | |||||
| in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), | |||||
| out_channels=out_channels * 2, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| padding=0, | |||||
| groups=self.groups, | |||||
| bias=False) | |||||
| self.bn = torch.nn.BatchNorm2d(out_channels * 2) | |||||
| self.relu = torch.nn.ReLU(inplace=True) | |||||
| # squeeze and excitation block | |||||
| self.use_se = use_se | |||||
| if use_se: | |||||
| if se_kwargs is None: | |||||
| se_kwargs = {} | |||||
| self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) | |||||
| self.spatial_scale_factor = spatial_scale_factor | |||||
| self.spatial_scale_mode = spatial_scale_mode | |||||
| self.spectral_pos_encoding = spectral_pos_encoding | |||||
| self.ffc3d = ffc3d | |||||
| self.fft_norm = fft_norm | |||||
| def forward(self, x): | |||||
| batch = x.shape[0] | |||||
| if self.spatial_scale_factor is not None: | |||||
| orig_size = x.shape[-2:] | |||||
| x = F.interpolate( | |||||
| x, | |||||
| scale_factor=self.spatial_scale_factor, | |||||
| mode=self.spatial_scale_mode, | |||||
| align_corners=False) | |||||
| # (batch, c, h, w/2+1, 2) | |||||
| fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) | |||||
| ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) | |||||
| ffted = torch.stack((ffted.real, ffted.imag), dim=-1) | |||||
| ffted = ffted.permute(0, 1, 4, 2, | |||||
| 3).contiguous() # (batch, c, 2, h, w/2+1) | |||||
| ffted = ffted.view(( | |||||
| batch, | |||||
| -1, | |||||
| ) + ffted.size()[3:]) | |||||
| if self.spectral_pos_encoding: | |||||
| height, width = ffted.shape[-2:] | |||||
| coords_vert = torch.linspace(0, 1, | |||||
| height)[None, None, :, None].expand( | |||||
| batch, 1, height, width).to(ffted) | |||||
| coords_hor = torch.linspace(0, 1, | |||||
| width)[None, None, None, :].expand( | |||||
| batch, 1, height, width).to(ffted) | |||||
| ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) | |||||
| if self.use_se: | |||||
| ffted = self.se(ffted) | |||||
| ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) | |||||
| ffted = self.relu(self.bn(ffted)) | |||||
| ffted = ffted.view(( | |||||
| batch, | |||||
| -1, | |||||
| 2, | |||||
| ) + ffted.size()[2:]).permute( | |||||
| 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) | |||||
| ffted = torch.complex(ffted[..., 0], ffted[..., 1]) | |||||
| ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] | |||||
| output = torch.fft.irfftn( | |||||
| ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) | |||||
| if self.spatial_scale_factor is not None: | |||||
| output = F.interpolate( | |||||
| output, | |||||
| size=orig_size, | |||||
| mode=self.spatial_scale_mode, | |||||
| align_corners=False) | |||||
| return output | |||||
| class SpectralTransform(nn.Module): | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| stride=1, | |||||
| groups=1, | |||||
| enable_lfu=True, | |||||
| **fu_kwargs): | |||||
| # bn_layer not used | |||||
| super(SpectralTransform, self).__init__() | |||||
| self.enable_lfu = enable_lfu | |||||
| if stride == 2: | |||||
| self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) | |||||
| else: | |||||
| self.downsample = nn.Identity() | |||||
| self.stride = stride | |||||
| self.conv1 = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| in_channels, | |||||
| out_channels // 2, | |||||
| kernel_size=1, | |||||
| groups=groups, | |||||
| bias=False), nn.BatchNorm2d(out_channels // 2), | |||||
| nn.ReLU(inplace=True)) | |||||
| self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, | |||||
| **fu_kwargs) | |||||
| if self.enable_lfu: | |||||
| self.lfu = FourierUnit(out_channels // 2, out_channels // 2, | |||||
| groups) | |||||
| self.conv2 = torch.nn.Conv2d( | |||||
| out_channels // 2, | |||||
| out_channels, | |||||
| kernel_size=1, | |||||
| groups=groups, | |||||
| bias=False) | |||||
| def forward(self, x): | |||||
| x = self.downsample(x) | |||||
| x = self.conv1(x) | |||||
| output = self.fu(x) | |||||
| if self.enable_lfu: | |||||
| n, c, h, w = x.shape | |||||
| split_no = 2 | |||||
| split_s = h // split_no | |||||
| xs = torch.cat( | |||||
| torch.split(x[:, :c // 4], split_s, dim=-2), | |||||
| dim=1).contiguous() | |||||
| xs = torch.cat( | |||||
| torch.split(xs, split_s, dim=-1), dim=1).contiguous() | |||||
| xs = self.lfu(xs) | |||||
| xs = xs.repeat(1, 1, split_no, split_no).contiguous() | |||||
| else: | |||||
| xs = 0 | |||||
| output = self.conv2(x + output + xs) | |||||
| return output | |||||
| class LearnableSpatialTransformWrapper(nn.Module): | |||||
| def __init__(self, | |||||
| impl, | |||||
| pad_coef=0.5, | |||||
| angle_init_range=80, | |||||
| train_angle=True): | |||||
| super().__init__() | |||||
| self.impl = impl | |||||
| self.angle = torch.rand(1) * angle_init_range | |||||
| if train_angle: | |||||
| self.angle = nn.Parameter(self.angle, requires_grad=True) | |||||
| self.pad_coef = pad_coef | |||||
| def forward(self, x): | |||||
| if torch.is_tensor(x): | |||||
| return self.inverse_transform(self.impl(self.transform(x)), x) | |||||
| elif isinstance(x, tuple): | |||||
| x_trans = tuple(self.transform(elem) for elem in x) | |||||
| y_trans = self.impl(x_trans) | |||||
| return tuple( | |||||
| self.inverse_transform(elem, orig_x) | |||||
| for elem, orig_x in zip(y_trans, x)) | |||||
| else: | |||||
| raise ValueError(f'Unexpected input type {type(x)}') | |||||
| def transform(self, x): | |||||
| height, width = x.shape[2:] | |||||
| pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) | |||||
| x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') | |||||
| x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) | |||||
| return x_padded_rotated | |||||
| def inverse_transform(self, y_padded_rotated, orig_x): | |||||
| height, width = orig_x.shape[2:] | |||||
| pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) | |||||
| y_padded = rotate( | |||||
| y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) | |||||
| y_height, y_width = y_padded.shape[2:] | |||||
| y = y_padded[:, :, pad_h:y_height - pad_h, pad_w:y_width - pad_w] | |||||
| return y | |||||
| class FFC(nn.Module): | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| ratio_gin, | |||||
| ratio_gout, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| groups=1, | |||||
| bias=False, | |||||
| enable_lfu=True, | |||||
| padding_type='reflect', | |||||
| gated=False, | |||||
| **spectral_kwargs): | |||||
| super(FFC, self).__init__() | |||||
| assert stride == 1 or stride == 2, 'Stride should be 1 or 2.' | |||||
| self.stride = stride | |||||
| in_cg = int(in_channels * ratio_gin) | |||||
| in_cl = in_channels - in_cg | |||||
| out_cg = int(out_channels * ratio_gout) | |||||
| out_cl = out_channels - out_cg | |||||
| self.ratio_gin = ratio_gin | |||||
| self.ratio_gout = ratio_gout | |||||
| self.global_in_num = in_cg | |||||
| module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d | |||||
| self.convl2l = module( | |||||
| in_cl, | |||||
| out_cl, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| padding_mode=padding_type) | |||||
| module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d | |||||
| self.convl2g = module( | |||||
| in_cl, | |||||
| out_cg, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| padding_mode=padding_type) | |||||
| module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d | |||||
| self.convg2l = module( | |||||
| in_cg, | |||||
| out_cl, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| padding_mode=padding_type) | |||||
| module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform | |||||
| self.convg2g = module(in_cg, out_cg, stride, | |||||
| 1 if groups == 1 else groups // 2, enable_lfu, | |||||
| **spectral_kwargs) | |||||
| self.gated = gated | |||||
| module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d | |||||
| self.gate = module(in_channels, 2, 1) | |||||
| def forward(self, x): | |||||
| x_l, x_g = x if type(x) is tuple else (x, 0) | |||||
| out_xl, out_xg = 0, 0 | |||||
| if self.gated: | |||||
| total_input_parts = [x_l] | |||||
| if torch.is_tensor(x_g): | |||||
| total_input_parts.append(x_g) | |||||
| total_input = torch.cat(total_input_parts, dim=1) | |||||
| gates = torch.sigmoid(self.gate(total_input)) | |||||
| g2l_gate, l2g_gate = gates.chunk(2, dim=1) | |||||
| else: | |||||
| g2l_gate, l2g_gate = 1, 1 | |||||
| if self.ratio_gout != 1: | |||||
| out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate | |||||
| if self.ratio_gout != 0: | |||||
| out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) | |||||
| return out_xl, out_xg | |||||
| class FFC_BN_ACT(nn.Module): | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| ratio_gin, | |||||
| ratio_gout, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| groups=1, | |||||
| bias=False, | |||||
| norm_layer=nn.BatchNorm2d, | |||||
| activation_layer=nn.Identity, | |||||
| padding_type='reflect', | |||||
| enable_lfu=True, | |||||
| **kwargs): | |||||
| super(FFC_BN_ACT, self).__init__() | |||||
| self.ffc = FFC( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| ratio_gin, | |||||
| ratio_gout, | |||||
| stride, | |||||
| padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| enable_lfu, | |||||
| padding_type=padding_type, | |||||
| **kwargs) | |||||
| lnorm = nn.Identity if ratio_gout == 1 else norm_layer | |||||
| gnorm = nn.Identity if ratio_gout == 0 else norm_layer | |||||
| global_channels = int(out_channels * ratio_gout) | |||||
| self.bn_l = lnorm(out_channels - global_channels) | |||||
| self.bn_g = gnorm(global_channels) | |||||
| lact = nn.Identity if ratio_gout == 1 else activation_layer | |||||
| gact = nn.Identity if ratio_gout == 0 else activation_layer | |||||
| self.act_l = lact(inplace=True) | |||||
| self.act_g = gact(inplace=True) | |||||
| def forward(self, x): | |||||
| x_l, x_g = self.ffc(x) | |||||
| x_l = self.act_l(self.bn_l(x_l)) | |||||
| x_g = self.act_g(self.bn_g(x_g)) | |||||
| return x_l, x_g | |||||
| class FFCResnetBlock(nn.Module): | |||||
| def __init__(self, | |||||
| dim, | |||||
| padding_type, | |||||
| norm_layer, | |||||
| activation_layer=nn.ReLU, | |||||
| dilation=1, | |||||
| spatial_transform_kwargs=None, | |||||
| inline=False, | |||||
| **conv_kwargs): | |||||
| super().__init__() | |||||
| self.conv1 = FFC_BN_ACT( | |||||
| dim, | |||||
| dim, | |||||
| kernel_size=3, | |||||
| padding=dilation, | |||||
| dilation=dilation, | |||||
| norm_layer=norm_layer, | |||||
| activation_layer=activation_layer, | |||||
| padding_type=padding_type, | |||||
| **conv_kwargs) | |||||
| self.conv2 = FFC_BN_ACT( | |||||
| dim, | |||||
| dim, | |||||
| kernel_size=3, | |||||
| padding=dilation, | |||||
| dilation=dilation, | |||||
| norm_layer=norm_layer, | |||||
| activation_layer=activation_layer, | |||||
| padding_type=padding_type, | |||||
| **conv_kwargs) | |||||
| if spatial_transform_kwargs is not None: | |||||
| self.conv1 = LearnableSpatialTransformWrapper( | |||||
| self.conv1, **spatial_transform_kwargs) | |||||
| self.conv2 = LearnableSpatialTransformWrapper( | |||||
| self.conv2, **spatial_transform_kwargs) | |||||
| self.inline = inline | |||||
| def forward(self, x): | |||||
| if self.inline: | |||||
| x_l, x_g = x[:, :-self.conv1.ffc. | |||||
| global_in_num], x[:, -self.conv1.ffc.global_in_num:] | |||||
| else: | |||||
| x_l, x_g = x if type(x) is tuple else (x, 0) | |||||
| id_l, id_g = x_l, x_g | |||||
| x_l, x_g = self.conv1((x_l, x_g)) | |||||
| x_l, x_g = self.conv2((x_l, x_g)) | |||||
| x_l, x_g = id_l + x_l, id_g + x_g | |||||
| out = x_l, x_g | |||||
| if self.inline: | |||||
| out = torch.cat(out, dim=1) | |||||
| return out | |||||
| class ConcatTupleLayer(nn.Module): | |||||
| def forward(self, x): | |||||
| assert isinstance(x, tuple) | |||||
| x_l, x_g = x | |||||
| assert torch.is_tensor(x_l) or torch.is_tensor(x_g) | |||||
| if not torch.is_tensor(x_g): | |||||
| return x_l | |||||
| return torch.cat(x, dim=1) | |||||
| class FFCResNetGenerator(nn.Module): | |||||
| def __init__(self, | |||||
| input_nc=4, | |||||
| output_nc=3, | |||||
| ngf=64, | |||||
| n_downsampling=3, | |||||
| n_blocks=18, | |||||
| norm_layer=nn.BatchNorm2d, | |||||
| padding_type='reflect', | |||||
| activation_layer=nn.ReLU, | |||||
| up_norm_layer=nn.BatchNorm2d, | |||||
| up_activation=nn.ReLU(True), | |||||
| init_conv_kwargs={ | |||||
| 'ratio_gin': 0, | |||||
| 'ratio_gout': 0, | |||||
| 'enable_lfu': False | |||||
| }, | |||||
| downsample_conv_kwargs={ | |||||
| 'ratio_gin': 0, | |||||
| 'ratio_gout': 0, | |||||
| 'enable_lfu': False | |||||
| }, | |||||
| resnet_conv_kwargs={ | |||||
| 'ratio_gin': 0.75, | |||||
| 'ratio_gout': 0.75, | |||||
| 'enable_lfu': False | |||||
| }, | |||||
| spatial_transform_layers=None, | |||||
| spatial_transform_kwargs={}, | |||||
| add_out_act='sigmoid', | |||||
| max_features=1024, | |||||
| out_ffc=False, | |||||
| out_ffc_kwargs={}): | |||||
| assert (n_blocks >= 0) | |||||
| super().__init__() | |||||
| model = [ | |||||
| nn.ReflectionPad2d(3), | |||||
| FFC_BN_ACT( | |||||
| input_nc, | |||||
| ngf, | |||||
| kernel_size=7, | |||||
| padding=0, | |||||
| norm_layer=norm_layer, | |||||
| activation_layer=activation_layer, | |||||
| **init_conv_kwargs) | |||||
| ] | |||||
| # downsample | |||||
| for i in range(n_downsampling): | |||||
| mult = 2**i | |||||
| if i == n_downsampling - 1: | |||||
| cur_conv_kwargs = dict(downsample_conv_kwargs) | |||||
| cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get( | |||||
| 'ratio_gin', 0) | |||||
| else: | |||||
| cur_conv_kwargs = downsample_conv_kwargs | |||||
| model += [ | |||||
| FFC_BN_ACT( | |||||
| min(max_features, ngf * mult), | |||||
| min(max_features, ngf * mult * 2), | |||||
| kernel_size=3, | |||||
| stride=2, | |||||
| padding=1, | |||||
| norm_layer=norm_layer, | |||||
| activation_layer=activation_layer, | |||||
| **cur_conv_kwargs) | |||||
| ] | |||||
| mult = 2**n_downsampling | |||||
| feats_num_bottleneck = min(max_features, ngf * mult) | |||||
| # resnet blocks | |||||
| for i in range(n_blocks): | |||||
| cur_resblock = FFCResnetBlock( | |||||
| feats_num_bottleneck, | |||||
| padding_type=padding_type, | |||||
| activation_layer=activation_layer, | |||||
| norm_layer=norm_layer, | |||||
| **resnet_conv_kwargs) | |||||
| if spatial_transform_layers is not None and i in spatial_transform_layers: | |||||
| cur_resblock = LearnableSpatialTransformWrapper( | |||||
| cur_resblock, **spatial_transform_kwargs) | |||||
| model += [cur_resblock] | |||||
| model += [ConcatTupleLayer()] | |||||
| # upsample | |||||
| for i in range(n_downsampling): | |||||
| mult = 2**(n_downsampling - i) | |||||
| model += [ | |||||
| nn.ConvTranspose2d( | |||||
| min(max_features, ngf * mult), | |||||
| min(max_features, int(ngf * mult / 2)), | |||||
| kernel_size=3, | |||||
| stride=2, | |||||
| padding=1, | |||||
| output_padding=1), | |||||
| up_norm_layer(min(max_features, int(ngf * mult / 2))), | |||||
| up_activation | |||||
| ] | |||||
| if out_ffc: | |||||
| model += [ | |||||
| FFCResnetBlock( | |||||
| ngf, | |||||
| padding_type=padding_type, | |||||
| activation_layer=activation_layer, | |||||
| norm_layer=norm_layer, | |||||
| inline=True, | |||||
| **out_ffc_kwargs) | |||||
| ] | |||||
| model += [ | |||||
| nn.ReflectionPad2d(3), | |||||
| nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) | |||||
| ] | |||||
| if add_out_act: | |||||
| model.append( | |||||
| get_activation('tanh' if add_out_act is True else add_out_act)) | |||||
| self.model = nn.Sequential(*model) | |||||
| def forward(self, input): | |||||
| return self.model(input) | |||||
| @@ -0,0 +1,324 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from torchvision import models | |||||
| from modelscope.utils.logger import get_logger | |||||
| try: | |||||
| from torchvision.models.utils import load_state_dict_from_url | |||||
| except ImportError: | |||||
| from torch.utils.model_zoo import load_url as load_state_dict_from_url | |||||
| # Inception weights ported to Pytorch from | |||||
| # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz | |||||
| FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/' \ | |||||
| 'fid_weights/pt_inception-2015-12-05-6726825d.pth' | |||||
| LOGGER = get_logger() | |||||
| class InceptionV3(nn.Module): | |||||
| """Pretrained InceptionV3 network returning feature maps""" | |||||
| # Index of default block of inception to return, | |||||
| # corresponds to output of final average pooling | |||||
| DEFAULT_BLOCK_INDEX = 3 | |||||
| # Maps feature dimensionality to their output blocks indices | |||||
| BLOCK_INDEX_BY_DIM = { | |||||
| 64: 0, # First max pooling features | |||||
| 192: 1, # Second max pooling featurs | |||||
| 768: 2, # Pre-aux classifier features | |||||
| 2048: 3 # Final average pooling features | |||||
| } | |||||
| def __init__(self, | |||||
| output_blocks=[DEFAULT_BLOCK_INDEX], | |||||
| resize_input=True, | |||||
| normalize_input=True, | |||||
| requires_grad=False, | |||||
| use_fid_inception=True): | |||||
| """Build pretrained InceptionV3 | |||||
| Parameters | |||||
| ---------- | |||||
| output_blocks : list of int | |||||
| Indices of blocks to return features of. Possible values are: | |||||
| - 0: corresponds to output of first max pooling | |||||
| - 1: corresponds to output of second max pooling | |||||
| - 2: corresponds to output which is fed to aux classifier | |||||
| - 3: corresponds to output of final average pooling | |||||
| resize_input : bool | |||||
| If true, bilinearly resizes input to width and height 299 before | |||||
| feeding input to model. As the network without fully connected | |||||
| layers is fully convolutional, it should be able to handle inputs | |||||
| of arbitrary size, so resizing might not be strictly needed | |||||
| normalize_input : bool | |||||
| If true, scales the input from range (0, 1) to the range the | |||||
| pretrained Inception network expects, namely (-1, 1) | |||||
| requires_grad : bool | |||||
| If true, parameters of the model require gradients. Possibly useful | |||||
| for finetuning the network | |||||
| use_fid_inception : bool | |||||
| If true, uses the pretrained Inception model used in Tensorflow's | |||||
| FID implementation. If false, uses the pretrained Inception model | |||||
| available in torchvision. The FID Inception model has different | |||||
| weights and a slightly different structure from torchvision's | |||||
| Inception model. If you want to compute FID scores, you are | |||||
| strongly advised to set this parameter to true to get comparable | |||||
| results. | |||||
| """ | |||||
| super(InceptionV3, self).__init__() | |||||
| self.resize_input = resize_input | |||||
| self.normalize_input = normalize_input | |||||
| self.output_blocks = sorted(output_blocks) | |||||
| self.last_needed_block = max(output_blocks) | |||||
| assert self.last_needed_block <= 3, \ | |||||
| 'Last possible output block index is 3' | |||||
| self.blocks = nn.ModuleList() | |||||
| if use_fid_inception: | |||||
| inception = fid_inception_v3() | |||||
| else: | |||||
| inception = models.inception_v3(pretrained=True) | |||||
| # Block 0: input to maxpool1 | |||||
| block0 = [ | |||||
| inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, | |||||
| inception.Conv2d_2b_3x3, | |||||
| nn.MaxPool2d(kernel_size=3, stride=2) | |||||
| ] | |||||
| self.blocks.append(nn.Sequential(*block0)) | |||||
| # Block 1: maxpool1 to maxpool2 | |||||
| if self.last_needed_block >= 1: | |||||
| block1 = [ | |||||
| inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, | |||||
| nn.MaxPool2d(kernel_size=3, stride=2) | |||||
| ] | |||||
| self.blocks.append(nn.Sequential(*block1)) | |||||
| # Block 2: maxpool2 to aux classifier | |||||
| if self.last_needed_block >= 2: | |||||
| block2 = [ | |||||
| inception.Mixed_5b, | |||||
| inception.Mixed_5c, | |||||
| inception.Mixed_5d, | |||||
| inception.Mixed_6a, | |||||
| inception.Mixed_6b, | |||||
| inception.Mixed_6c, | |||||
| inception.Mixed_6d, | |||||
| inception.Mixed_6e, | |||||
| ] | |||||
| self.blocks.append(nn.Sequential(*block2)) | |||||
| # Block 3: aux classifier to final avgpool | |||||
| if self.last_needed_block >= 3: | |||||
| block3 = [ | |||||
| inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, | |||||
| nn.AdaptiveAvgPool2d(output_size=(1, 1)) | |||||
| ] | |||||
| self.blocks.append(nn.Sequential(*block3)) | |||||
| for param in self.parameters(): | |||||
| param.requires_grad = requires_grad | |||||
| def forward(self, inp): | |||||
| """Get Inception feature maps | |||||
| Parameters | |||||
| ---------- | |||||
| inp : torch.autograd.Variable | |||||
| Input tensor of shape Bx3xHxW. Values are expected to be in | |||||
| range (0, 1) | |||||
| Returns | |||||
| ------- | |||||
| List of torch.autograd.Variable, corresponding to the selected output | |||||
| block, sorted ascending by index | |||||
| """ | |||||
| outp = [] | |||||
| x = inp | |||||
| if self.resize_input: | |||||
| x = F.interpolate( | |||||
| x, size=(299, 299), mode='bilinear', align_corners=False) | |||||
| if self.normalize_input: | |||||
| x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) | |||||
| for idx, block in enumerate(self.blocks): | |||||
| x = block(x) | |||||
| if idx in self.output_blocks: | |||||
| outp.append(x) | |||||
| if idx == self.last_needed_block: | |||||
| break | |||||
| return outp | |||||
| def fid_inception_v3(): | |||||
| """Build pretrained Inception model for FID computation | |||||
| The Inception model for FID computation uses a different set of weights | |||||
| and has a slightly different structure than torchvision's Inception. | |||||
| This method first constructs torchvision's Inception and then patches the | |||||
| necessary parts that are different in the FID Inception model. | |||||
| """ | |||||
| LOGGER.info('fid_inception_v3 called') | |||||
| inception = models.inception_v3( | |||||
| num_classes=1008, aux_logits=False, pretrained=False) | |||||
| LOGGER.info('models.inception_v3 done') | |||||
| inception.Mixed_5b = FIDInceptionA(192, pool_features=32) | |||||
| inception.Mixed_5c = FIDInceptionA(256, pool_features=64) | |||||
| inception.Mixed_5d = FIDInceptionA(288, pool_features=64) | |||||
| inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) | |||||
| inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) | |||||
| inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) | |||||
| inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) | |||||
| inception.Mixed_7b = FIDInceptionE_1(1280) | |||||
| inception.Mixed_7c = FIDInceptionE_2(2048) | |||||
| LOGGER.info('fid_inception_v3 patching done') | |||||
| state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) | |||||
| LOGGER.info('fid_inception_v3 weights downloaded') | |||||
| inception.load_state_dict(state_dict) | |||||
| LOGGER.info('fid_inception_v3 weights loaded into model') | |||||
| return inception | |||||
| class FIDInceptionA(models.inception.InceptionA): | |||||
| """InceptionA block patched for FID computation""" | |||||
| def __init__(self, in_channels, pool_features): | |||||
| super(FIDInceptionA, self).__init__(in_channels, pool_features) | |||||
| def forward(self, x): | |||||
| branch1x1 = self.branch1x1(x) | |||||
| branch5x5 = self.branch5x5_1(x) | |||||
| branch5x5 = self.branch5x5_2(branch5x5) | |||||
| branch3x3dbl = self.branch3x3dbl_1(x) | |||||
| branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
| branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |||||
| # Patch: Tensorflow's average pool does not use the padded zero's in | |||||
| # its average calculation | |||||
| branch_pool = F.avg_pool2d( | |||||
| x, kernel_size=3, stride=1, padding=1, count_include_pad=False) | |||||
| branch_pool = self.branch_pool(branch_pool) | |||||
| outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] | |||||
| return torch.cat(outputs, 1) | |||||
| class FIDInceptionC(models.inception.InceptionC): | |||||
| """InceptionC block patched for FID computation""" | |||||
| def __init__(self, in_channels, channels_7x7): | |||||
| super(FIDInceptionC, self).__init__(in_channels, channels_7x7) | |||||
| def forward(self, x): | |||||
| branch1x1 = self.branch1x1(x) | |||||
| branch7x7 = self.branch7x7_1(x) | |||||
| branch7x7 = self.branch7x7_2(branch7x7) | |||||
| branch7x7 = self.branch7x7_3(branch7x7) | |||||
| branch7x7dbl = self.branch7x7dbl_1(x) | |||||
| branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) | |||||
| branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) | |||||
| branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) | |||||
| branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) | |||||
| # Patch: Tensorflow's average pool does not use the padded zero's in | |||||
| # its average calculation | |||||
| branch_pool = F.avg_pool2d( | |||||
| x, kernel_size=3, stride=1, padding=1, count_include_pad=False) | |||||
| branch_pool = self.branch_pool(branch_pool) | |||||
| outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] | |||||
| return torch.cat(outputs, 1) | |||||
| class FIDInceptionE_1(models.inception.InceptionE): | |||||
| """First InceptionE block patched for FID computation""" | |||||
| def __init__(self, in_channels): | |||||
| super(FIDInceptionE_1, self).__init__(in_channels) | |||||
| def forward(self, x): | |||||
| branch1x1 = self.branch1x1(x) | |||||
| branch3x3 = self.branch3x3_1(x) | |||||
| branch3x3 = [ | |||||
| self.branch3x3_2a(branch3x3), | |||||
| self.branch3x3_2b(branch3x3), | |||||
| ] | |||||
| branch3x3 = torch.cat(branch3x3, 1) | |||||
| branch3x3dbl = self.branch3x3dbl_1(x) | |||||
| branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
| branch3x3dbl = [ | |||||
| self.branch3x3dbl_3a(branch3x3dbl), | |||||
| self.branch3x3dbl_3b(branch3x3dbl), | |||||
| ] | |||||
| branch3x3dbl = torch.cat(branch3x3dbl, 1) | |||||
| # Patch: Tensorflow's average pool does not use the padded zero's in | |||||
| # its average calculation | |||||
| branch_pool = F.avg_pool2d( | |||||
| x, kernel_size=3, stride=1, padding=1, count_include_pad=False) | |||||
| branch_pool = self.branch_pool(branch_pool) | |||||
| outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |||||
| return torch.cat(outputs, 1) | |||||
| class FIDInceptionE_2(models.inception.InceptionE): | |||||
| """Second InceptionE block patched for FID computation""" | |||||
| def __init__(self, in_channels): | |||||
| super(FIDInceptionE_2, self).__init__(in_channels) | |||||
| def forward(self, x): | |||||
| branch1x1 = self.branch1x1(x) | |||||
| branch3x3 = self.branch3x3_1(x) | |||||
| branch3x3 = [ | |||||
| self.branch3x3_2a(branch3x3), | |||||
| self.branch3x3_2b(branch3x3), | |||||
| ] | |||||
| branch3x3 = torch.cat(branch3x3, 1) | |||||
| branch3x3dbl = self.branch3x3dbl_1(x) | |||||
| branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |||||
| branch3x3dbl = [ | |||||
| self.branch3x3dbl_3a(branch3x3dbl), | |||||
| self.branch3x3dbl_3b(branch3x3dbl), | |||||
| ] | |||||
| branch3x3dbl = torch.cat(branch3x3dbl, 1) | |||||
| # Patch: The FID Inception model uses max pooling instead of average | |||||
| # pooling. This is likely an error in this specific Inception | |||||
| # implementation, as other Inception models use average pooling here | |||||
| # (which matches the description in the paper). | |||||
| branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) | |||||
| branch_pool = self.branch_pool(branch_pool) | |||||
| outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |||||
| return torch.cat(outputs, 1) | |||||
| @@ -0,0 +1,47 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torchvision | |||||
| from .ade20k import ModelBuilder | |||||
| IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] | |||||
| IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] | |||||
| class ResNetPL(nn.Module): | |||||
| def __init__(self, | |||||
| weight=1, | |||||
| weights_path=None, | |||||
| arch_encoder='resnet50dilated', | |||||
| segmentation=True): | |||||
| super().__init__() | |||||
| self.impl = ModelBuilder.get_encoder( | |||||
| weights_path=weights_path, | |||||
| arch_encoder=arch_encoder, | |||||
| arch_decoder='ppm_deepsup', | |||||
| fc_dim=2048, | |||||
| segmentation=segmentation) | |||||
| self.impl.eval() | |||||
| for w in self.impl.parameters(): | |||||
| w.requires_grad_(False) | |||||
| self.weight = weight | |||||
| def forward(self, pred, target): | |||||
| pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) | |||||
| target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) | |||||
| pred_feats = self.impl(pred, return_feature_maps=True) | |||||
| target_feats = self.impl(target, return_feature_maps=True) | |||||
| result = torch.stack([ | |||||
| F.mse_loss(cur_pred, cur_target) | |||||
| for cur_pred, cur_target in zip(pred_feats, target_feats) | |||||
| ]).sum() * self.weight | |||||
| return result | |||||
| @@ -0,0 +1,75 @@ | |||||
| """ | |||||
| The implementation is adopted from | |||||
| https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py | |||||
| """ | |||||
| import collections | |||||
| import functools | |||||
| import logging | |||||
| from collections import defaultdict | |||||
| from functools import partial | |||||
| import numpy as np | |||||
| import torch.nn as nn | |||||
| # Defines the PatchGAN discriminator with the specified arguments. | |||||
| class NLayerDiscriminator(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| input_nc=3, | |||||
| ndf=64, | |||||
| n_layers=4, | |||||
| norm_layer=nn.BatchNorm2d, | |||||
| ): | |||||
| super().__init__() | |||||
| self.n_layers = n_layers | |||||
| kw = 4 | |||||
| padw = int(np.ceil((kw - 1.0) / 2)) | |||||
| sequence = [[ | |||||
| nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), | |||||
| nn.LeakyReLU(0.2, True) | |||||
| ]] | |||||
| nf = ndf | |||||
| for n in range(1, n_layers): | |||||
| nf_prev = nf | |||||
| nf = min(nf * 2, 512) | |||||
| cur_model = [] | |||||
| cur_model += [ | |||||
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |||||
| norm_layer(nf), | |||||
| nn.LeakyReLU(0.2, True) | |||||
| ] | |||||
| sequence.append(cur_model) | |||||
| nf_prev = nf | |||||
| nf = min(nf * 2, 512) | |||||
| cur_model = [] | |||||
| cur_model += [ | |||||
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |||||
| norm_layer(nf), | |||||
| nn.LeakyReLU(0.2, True) | |||||
| ] | |||||
| sequence.append(cur_model) | |||||
| sequence += [[ | |||||
| nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw) | |||||
| ]] | |||||
| for n in range(len(sequence)): | |||||
| setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) | |||||
| def get_all_activations(self, x): | |||||
| res = [x] | |||||
| for n in range(self.n_layers + 2): | |||||
| model = getattr(self, 'model' + str(n)) | |||||
| res.append(model(res[-1])) | |||||
| return res[1:] | |||||
| def forward(self, x): | |||||
| act = self.get_all_activations(x) | |||||
| return act[-1], act[:-1] | |||||
| @@ -0,0 +1,393 @@ | |||||
| ''' | |||||
| Part of the implementation is borrowed and modified from LaMa, publicly available at | |||||
| https://github.com/saic-mdal/lama | |||||
| ''' | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from kornia.filters import gaussian_blur2d | |||||
| from kornia.geometry.transform import resize | |||||
| from kornia.morphology import erosion | |||||
| from torch.nn import functional as F | |||||
| from torch.optim import SGD, Adam | |||||
| from tqdm import tqdm | |||||
| from .modules.ffc import FFCResnetBlock | |||||
| def move_to_device(obj, device): | |||||
| if isinstance(obj, nn.Module): | |||||
| return obj.to(device) | |||||
| if torch.is_tensor(obj): | |||||
| return obj.to(device) | |||||
| if isinstance(obj, (tuple, list)): | |||||
| return [move_to_device(el, device) for el in obj] | |||||
| if isinstance(obj, dict): | |||||
| return {name: move_to_device(val, device) for name, val in obj.items()} | |||||
| raise ValueError(f'Unexpected type {type(obj)}') | |||||
| def ceil_modulo(x, mod): | |||||
| if x % mod == 0: | |||||
| return x | |||||
| return (x // mod + 1) * mod | |||||
| def pad_tensor_to_modulo(img, mod): | |||||
| batch_size, channels, height, width = img.shape | |||||
| out_height = ceil_modulo(height, mod) | |||||
| out_width = ceil_modulo(width, mod) | |||||
| return F.pad( | |||||
| img, | |||||
| pad=(0, out_width - width, 0, out_height - height), | |||||
| mode='reflect') | |||||
| def _pyrdown(im: torch.Tensor, downsize: tuple = None): | |||||
| """downscale the image""" | |||||
| if downsize is None: | |||||
| downsize = (im.shape[2] // 2, im.shape[3] // 2) | |||||
| assert im.shape[ | |||||
| 1] == 3, 'Expected shape for the input to be (n,3,height,width)' | |||||
| im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0)) | |||||
| im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False) | |||||
| return im | |||||
| def _pyrdown_mask(mask: torch.Tensor, | |||||
| downsize: tuple = None, | |||||
| eps: float = 1e-8, | |||||
| blur_mask: bool = True, | |||||
| round_up: bool = True): | |||||
| """downscale the mask tensor | |||||
| Parameters | |||||
| ---------- | |||||
| mask : torch.Tensor | |||||
| mask of size (B, 1, H, W) | |||||
| downsize : tuple, optional | |||||
| size to downscale to. If None, image is downscaled to half, by default None | |||||
| eps : float, optional | |||||
| threshold value for binarizing the mask, by default 1e-8 | |||||
| blur_mask : bool, optional | |||||
| if True, apply gaussian filter before downscaling, by default True | |||||
| round_up : bool, optional | |||||
| if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True | |||||
| Returns | |||||
| ------- | |||||
| torch.Tensor | |||||
| downscaled mask | |||||
| """ | |||||
| if downsize is None: | |||||
| downsize = (mask.shape[2] // 2, mask.shape[3] // 2) | |||||
| assert mask.shape[ | |||||
| 1] == 1, 'Expected shape for the input to be (n,1,height,width)' | |||||
| if blur_mask is True: | |||||
| mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0)) | |||||
| mask = F.interpolate( | |||||
| mask, size=downsize, mode='bilinear', align_corners=False) | |||||
| else: | |||||
| mask = F.interpolate( | |||||
| mask, size=downsize, mode='bilinear', align_corners=False) | |||||
| if round_up: | |||||
| mask[mask >= eps] = 1 | |||||
| mask[mask < eps] = 0 | |||||
| else: | |||||
| mask[mask >= 1.0 - eps] = 1 | |||||
| mask[mask < 1.0 - eps] = 0 | |||||
| return mask | |||||
| def _erode_mask(mask: torch.Tensor, | |||||
| ekernel: torch.Tensor = None, | |||||
| eps: float = 1e-8): | |||||
| """erode the mask, and set gray pixels to 0""" | |||||
| if ekernel is not None: | |||||
| mask = erosion(mask, ekernel) | |||||
| mask[mask >= 1.0 - eps] = 1 | |||||
| mask[mask < 1.0 - eps] = 0 | |||||
| return mask | |||||
| def _l1_loss(pred: torch.Tensor, | |||||
| pred_downscaled: torch.Tensor, | |||||
| ref: torch.Tensor, | |||||
| mask: torch.Tensor, | |||||
| mask_downscaled: torch.Tensor, | |||||
| image: torch.Tensor, | |||||
| on_pred: bool = True): | |||||
| """l1 loss on src pixels, and downscaled predictions if on_pred=True""" | |||||
| loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8])) | |||||
| if on_pred: | |||||
| loss += torch.mean( | |||||
| torch.abs(pred_downscaled[mask_downscaled >= 1e-8] | |||||
| - ref[mask_downscaled >= 1e-8])) | |||||
| return loss | |||||
| def _infer(image: torch.Tensor, | |||||
| mask: torch.Tensor, | |||||
| forward_front: nn.Module, | |||||
| forward_rears: nn.Module, | |||||
| ref_lower_res: torch.Tensor, | |||||
| orig_shape: tuple, | |||||
| devices: list, | |||||
| scale_ind: int, | |||||
| n_iters: int = 15, | |||||
| lr: float = 0.002): | |||||
| """Performs inference with refinement at a given scale. | |||||
| Parameters | |||||
| ---------- | |||||
| image : torch.Tensor | |||||
| input image to be inpainted, of size (1,3,H,W) | |||||
| mask : torch.Tensor | |||||
| input inpainting mask, of size (1,1,H,W) | |||||
| forward_front : nn.Module | |||||
| the front part of the inpainting network | |||||
| forward_rears : nn.Module | |||||
| the rear part of the inpainting network | |||||
| ref_lower_res : torch.Tensor | |||||
| the inpainting at previous scale, used as reference image | |||||
| orig_shape : tuple | |||||
| shape of the original input image before padding | |||||
| devices : list | |||||
| list of available devices | |||||
| scale_ind : int | |||||
| the scale index | |||||
| n_iters : int, optional | |||||
| number of iterations of refinement, by default 15 | |||||
| lr : float, optional | |||||
| learning rate, by default 0.002 | |||||
| Returns | |||||
| ------- | |||||
| torch.Tensor | |||||
| inpainted image | |||||
| """ | |||||
| masked_image = image * (1 - mask) | |||||
| masked_image = torch.cat([masked_image, mask], dim=1) | |||||
| mask = mask.repeat(1, 3, 1, 1) | |||||
| if ref_lower_res is not None: | |||||
| ref_lower_res = ref_lower_res.detach() | |||||
| with torch.no_grad(): | |||||
| z1, z2 = forward_front(masked_image) | |||||
| # Inference | |||||
| mask = mask.to(devices[-1]) | |||||
| ekernel = torch.from_numpy( | |||||
| cv2.getStructuringElement(cv2.MORPH_ELLIPSE, | |||||
| (15, 15)).astype(bool)).float() | |||||
| ekernel = ekernel.to(devices[-1]) | |||||
| image = image.to(devices[-1]) | |||||
| z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) | |||||
| z1.requires_grad, z2.requires_grad = True, True | |||||
| optimizer = Adam([z1, z2], lr=lr) | |||||
| pbar = tqdm(range(n_iters), leave=False) | |||||
| for idi in pbar: | |||||
| optimizer.zero_grad() | |||||
| input_feat = (z1, z2) | |||||
| for idd, forward_rear in enumerate(forward_rears): | |||||
| output_feat = forward_rear(input_feat) | |||||
| if idd < len(devices) - 1: | |||||
| midz1, midz2 = output_feat | |||||
| midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to( | |||||
| devices[idd + 1]) | |||||
| input_feat = (midz1, midz2) | |||||
| else: | |||||
| pred = output_feat | |||||
| if ref_lower_res is None: | |||||
| break | |||||
| losses = {} | |||||
| # scaled loss with downsampler | |||||
| pred_downscaled = _pyrdown(pred[:, :, :orig_shape[0], :orig_shape[1]]) | |||||
| mask_downscaled = _pyrdown_mask( | |||||
| mask[:, :1, :orig_shape[0], :orig_shape[1]], | |||||
| blur_mask=False, | |||||
| round_up=False) | |||||
| mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel) | |||||
| mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1) | |||||
| losses['ms_l1'] = _l1_loss( | |||||
| pred, | |||||
| pred_downscaled, | |||||
| ref_lower_res, | |||||
| mask, | |||||
| mask_downscaled, | |||||
| image, | |||||
| on_pred=True) | |||||
| loss = sum(losses.values()) | |||||
| pbar.set_description( | |||||
| 'Refining scale {} using scale {} ...current loss: {:.4f}'.format( | |||||
| scale_ind + 1, scale_ind, loss.item())) | |||||
| if idi < n_iters - 1: | |||||
| loss.backward() | |||||
| optimizer.step() | |||||
| del pred_downscaled | |||||
| del loss | |||||
| del pred | |||||
| # "pred" is the prediction after Plug-n-Play module | |||||
| inpainted = mask * pred + (1 - mask) * image | |||||
| inpainted = inpainted.detach().cpu() | |||||
| return inpainted | |||||
| def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int, | |||||
| px_budget: int): | |||||
| """Build the image mask pyramid | |||||
| Parameters | |||||
| ---------- | |||||
| batch : dict | |||||
| batch containing image, mask, etc | |||||
| min_side : int | |||||
| minimum side length to limit the number of scales of the pyramid | |||||
| max_scales : int | |||||
| maximum number of scales allowed | |||||
| px_budget : int | |||||
| the product H*W cannot exceed this budget, because of resource constraints | |||||
| Returns | |||||
| ------- | |||||
| tuple | |||||
| image-mask pyramid in the form of list of images and list of masks | |||||
| """ | |||||
| assert batch['image'].shape[ | |||||
| 0] == 1, 'refiner works on only batches of size 1!' | |||||
| h, w = batch['unpad_to_size'] | |||||
| h, w = h[0].item(), w[0].item() | |||||
| image = batch['image'][..., :h, :w] | |||||
| mask = batch['mask'][..., :h, :w] | |||||
| if h * w > px_budget: | |||||
| # resize | |||||
| ratio = np.sqrt(px_budget / float(h * w)) | |||||
| h_orig, w_orig = h, w | |||||
| h, w = int(h * ratio), int(w * ratio) | |||||
| print( | |||||
| f'Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...' | |||||
| ) | |||||
| image = resize( | |||||
| image, (h, w), interpolation='bilinear', align_corners=False) | |||||
| mask = resize( | |||||
| mask, (h, w), interpolation='bilinear', align_corners=False) | |||||
| mask[mask > 1e-8] = 1 | |||||
| breadth = min(h, w) | |||||
| n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))), | |||||
| max_scales) | |||||
| ls_images = [] | |||||
| ls_masks = [] | |||||
| ls_images.append(image) | |||||
| ls_masks.append(mask) | |||||
| for _ in range(n_scales - 1): | |||||
| image_p = _pyrdown(ls_images[-1]) | |||||
| mask_p = _pyrdown_mask(ls_masks[-1]) | |||||
| ls_images.append(image_p) | |||||
| ls_masks.append(mask_p) | |||||
| # reverse the lists because we want the lowest resolution image as index 0 | |||||
| return ls_images[::-1], ls_masks[::-1] | |||||
| def refine_predict(batch: dict, inpainter: nn.Module, gpu_ids: str, | |||||
| modulo: int, n_iters: int, lr: float, min_side: int, | |||||
| max_scales: int, px_budget: int): | |||||
| """Refines the inpainting of the network | |||||
| Parameters | |||||
| ---------- | |||||
| batch : dict | |||||
| image-mask batch, currently we assume the batchsize to be 1 | |||||
| inpainter : nn.Module | |||||
| the inpainting neural network | |||||
| gpu_ids : str | |||||
| the GPU ids of the machine to use. If only single GPU, use: "0," | |||||
| modulo : int | |||||
| pad the image to ensure dimension % modulo == 0 | |||||
| n_iters : int | |||||
| number of iterations of refinement for each scale | |||||
| lr : float | |||||
| learning rate | |||||
| min_side : int | |||||
| all sides of image on all scales should be >= min_side / sqrt(2) | |||||
| max_scales : int | |||||
| max number of downscaling scales for the image-mask pyramid | |||||
| px_budget : int | |||||
| pixels budget. Any image will be resized to satisfy height*width <= px_budget | |||||
| Returns | |||||
| ------- | |||||
| torch.Tensor | |||||
| inpainted image of size (1,3,H,W) | |||||
| """ | |||||
| inpainter = inpainter.model | |||||
| assert not inpainter.training | |||||
| assert not inpainter.add_noise_kwargs | |||||
| assert inpainter.concat_mask | |||||
| gpu_ids = [ | |||||
| f'cuda:{gpuid}' for gpuid in gpu_ids.replace(' ', '').split(',') | |||||
| if gpuid.isdigit() | |||||
| ] | |||||
| n_resnet_blocks = 0 | |||||
| first_resblock_ind = 0 | |||||
| found_first_resblock = False | |||||
| for idl in range(len(inpainter.generator.model)): | |||||
| if isinstance(inpainter.generator.model[idl], FFCResnetBlock): | |||||
| n_resnet_blocks += 1 | |||||
| found_first_resblock = True | |||||
| elif not found_first_resblock: | |||||
| first_resblock_ind += 1 | |||||
| resblocks_per_gpu = n_resnet_blocks // len(gpu_ids) | |||||
| devices = [torch.device(gpu_id) for gpu_id in gpu_ids] | |||||
| # split the model into front, and rear parts | |||||
| forward_front = inpainter.generator.model[0:first_resblock_ind] | |||||
| forward_front.to(devices[0]) | |||||
| forward_rears = [] | |||||
| for idd in range(len(gpu_ids)): | |||||
| if idd < len(gpu_ids) - 1: | |||||
| forward_rears.append( | |||||
| inpainter.generator.model[first_resblock_ind | |||||
| + resblocks_per_gpu | |||||
| * (idd):first_resblock_ind | |||||
| + resblocks_per_gpu * (idd + 1)]) | |||||
| else: | |||||
| forward_rears.append( | |||||
| inpainter.generator.model[first_resblock_ind | |||||
| + resblocks_per_gpu * (idd):]) | |||||
| forward_rears[idd].to(devices[idd]) | |||||
| ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales, | |||||
| px_budget) | |||||
| image_inpainted = None | |||||
| for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)): | |||||
| orig_shape = image.shape[2:] | |||||
| image = pad_tensor_to_modulo(image, modulo) | |||||
| mask = pad_tensor_to_modulo(mask, modulo) | |||||
| mask[mask >= 1e-8] = 1.0 | |||||
| mask[mask < 1e-8] = 0.0 | |||||
| image, mask = move_to_device(image, devices[0]), move_to_device( | |||||
| mask, devices[0]) | |||||
| if image_inpainted is not None: | |||||
| image_inpainted = move_to_device(image_inpainted, devices[-1]) | |||||
| image_inpainted = _infer(image, mask, forward_front, forward_rears, | |||||
| image_inpainted, orig_shape, devices, ids, | |||||
| n_iters, lr) | |||||
| image_inpainted = image_inpainted[:, :, :orig_shape[0], :orig_shape[1]] | |||||
| # detach everything to save resources | |||||
| image = image.detach().cpu() | |||||
| mask = mask.detach().cpu() | |||||
| return image_inpainted | |||||
| @@ -11,6 +11,7 @@ if TYPE_CHECKING: | |||||
| from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | ||||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | from .movie_scene_segmentation import MovieSceneSegmentationDataset | ||||
| from .video_summarization_dataset import VideoSummarizationDataset | from .video_summarization_dataset import VideoSummarizationDataset | ||||
| from .image_inpainting import ImageInpaintingDataset | |||||
| from .passage_ranking_dataset import PassageRankingDataset | from .passage_ranking_dataset import PassageRankingDataset | ||||
| else: | else: | ||||
| @@ -24,6 +25,7 @@ else: | |||||
| ['ImageInstanceSegmentationCocoDataset'], | ['ImageInstanceSegmentationCocoDataset'], | ||||
| 'video_summarization_dataset': ['VideoSummarizationDataset'], | 'video_summarization_dataset': ['VideoSummarizationDataset'], | ||||
| 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | ||||
| 'image_inpainting': ['ImageInpaintingDataset'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .image_inpainting_dataset import ImageInpaintingDataset | |||||
| @@ -0,0 +1,100 @@ | |||||
| """ | |||||
| The implementation is borrowed from LaMa, | |||||
| publicly available at https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import imgaug.augmenters as iaa | |||||
| from albumentations import DualIAATransform, to_tuple | |||||
| class IAAAffine2(DualIAATransform): | |||||
| """Place a regular grid of points on the input and randomly move the neighbourhood of these point around | |||||
| via affine transformations. | |||||
| Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} | |||||
| Args: | |||||
| p (float): probability of applying the transform. Default: 0.5. | |||||
| Targets: | |||||
| image, mask | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| scale=(0.7, 1.3), | |||||
| translate_percent=None, | |||||
| translate_px=None, | |||||
| rotate=0.0, | |||||
| shear=(-0.1, 0.1), | |||||
| order=1, | |||||
| cval=0, | |||||
| mode='reflect', | |||||
| always_apply=False, | |||||
| p=0.5, | |||||
| ): | |||||
| super(IAAAffine2, self).__init__(always_apply, p) | |||||
| self.scale = dict(x=scale, y=scale) | |||||
| self.translate_percent = to_tuple(translate_percent, 0) | |||||
| self.translate_px = to_tuple(translate_px, 0) | |||||
| self.rotate = to_tuple(rotate) | |||||
| self.shear = dict(x=shear, y=shear) | |||||
| self.order = order | |||||
| self.cval = cval | |||||
| self.mode = mode | |||||
| @property | |||||
| def processor(self): | |||||
| return iaa.Affine( | |||||
| self.scale, | |||||
| self.translate_percent, | |||||
| self.translate_px, | |||||
| self.rotate, | |||||
| self.shear, | |||||
| self.order, | |||||
| self.cval, | |||||
| self.mode, | |||||
| ) | |||||
| def get_transform_init_args_names(self): | |||||
| return ('scale', 'translate_percent', 'translate_px', 'rotate', | |||||
| 'shear', 'order', 'cval', 'mode') | |||||
| class IAAPerspective2(DualIAATransform): | |||||
| """Perform a random four point perspective transform of the input. | |||||
| Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} | |||||
| Args: | |||||
| scale ((float, float): standard deviation of the normal distributions. These are used to sample | |||||
| the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). | |||||
| p (float): probability of applying the transform. Default: 0.5. | |||||
| Targets: | |||||
| image, mask | |||||
| """ | |||||
| def __init__(self, | |||||
| scale=(0.05, 0.1), | |||||
| keep_size=True, | |||||
| always_apply=False, | |||||
| p=0.5, | |||||
| order=1, | |||||
| cval=0, | |||||
| mode='replicate'): | |||||
| super(IAAPerspective2, self).__init__(always_apply, p) | |||||
| self.scale = to_tuple(scale, 1.0) | |||||
| self.keep_size = keep_size | |||||
| self.cval = cval | |||||
| self.mode = mode | |||||
| @property | |||||
| def processor(self): | |||||
| return iaa.PerspectiveTransform( | |||||
| self.scale, | |||||
| keep_size=self.keep_size, | |||||
| mode=self.mode, | |||||
| cval=self.cval) | |||||
| def get_transform_init_args_names(self): | |||||
| return ('scale', 'keep_size') | |||||
| @@ -0,0 +1,337 @@ | |||||
| """ | |||||
| Part of the implementation is borrowed and modified from LaMa, | |||||
| publicly available at https://github.com/saic-mdal/lama | |||||
| """ | |||||
| import glob | |||||
| import os | |||||
| import os.path as osp | |||||
| from enum import Enum | |||||
| import albumentations as A | |||||
| import cv2 | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||||
| TorchTaskDataset | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .aug import IAAAffine2, IAAPerspective2 | |||||
| LOGGER = get_logger() | |||||
| class LinearRamp: | |||||
| def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): | |||||
| self.start_value = start_value | |||||
| self.end_value = end_value | |||||
| self.start_iter = start_iter | |||||
| self.end_iter = end_iter | |||||
| def __call__(self, i): | |||||
| if i < self.start_iter: | |||||
| return self.start_value | |||||
| if i >= self.end_iter: | |||||
| return self.end_value | |||||
| part = (i - self.start_iter) / (self.end_iter - self.start_iter) | |||||
| return self.start_value * (1 - part) + self.end_value * part | |||||
| class DrawMethod(Enum): | |||||
| LINE = 'line' | |||||
| CIRCLE = 'circle' | |||||
| SQUARE = 'square' | |||||
| def make_random_superres_mask(shape, | |||||
| min_step=2, | |||||
| max_step=4, | |||||
| min_width=1, | |||||
| max_width=3): | |||||
| height, width = shape | |||||
| mask = np.zeros((height, width), np.float32) | |||||
| step_x = np.random.randint(min_step, max_step + 1) | |||||
| width_x = np.random.randint(min_width, min(step_x, max_width + 1)) | |||||
| offset_x = np.random.randint(0, step_x) | |||||
| step_y = np.random.randint(min_step, max_step + 1) | |||||
| width_y = np.random.randint(min_width, min(step_y, max_width + 1)) | |||||
| offset_y = np.random.randint(0, step_y) | |||||
| for dy in range(width_y): | |||||
| mask[offset_y + dy::step_y] = 1 | |||||
| for dx in range(width_x): | |||||
| mask[:, offset_x + dx::step_x] = 1 | |||||
| return mask[None, ...] | |||||
| class RandomSuperresMaskGenerator: | |||||
| def __init__(self, **kwargs): | |||||
| self.kwargs = kwargs | |||||
| def __call__(self, img, iter_i=None): | |||||
| return make_random_superres_mask(img.shape[1:], **self.kwargs) | |||||
| def make_random_rectangle_mask(shape, | |||||
| margin=10, | |||||
| bbox_min_size=30, | |||||
| bbox_max_size=100, | |||||
| min_times=0, | |||||
| max_times=3): | |||||
| height, width = shape | |||||
| mask = np.zeros((height, width), np.float32) | |||||
| bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) | |||||
| times = np.random.randint(min_times, max_times + 1) | |||||
| for i in range(times): | |||||
| box_width = np.random.randint(bbox_min_size, bbox_max_size) | |||||
| box_height = np.random.randint(bbox_min_size, bbox_max_size) | |||||
| start_x = np.random.randint(margin, width - margin - box_width + 1) | |||||
| start_y = np.random.randint(margin, height - margin - box_height + 1) | |||||
| mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 | |||||
| return mask[None, ...] | |||||
| class RandomRectangleMaskGenerator: | |||||
| def __init__(self, | |||||
| margin=10, | |||||
| bbox_min_size=30, | |||||
| bbox_max_size=100, | |||||
| min_times=0, | |||||
| max_times=3, | |||||
| ramp_kwargs=None): | |||||
| self.margin = margin | |||||
| self.bbox_min_size = bbox_min_size | |||||
| self.bbox_max_size = bbox_max_size | |||||
| self.min_times = min_times | |||||
| self.max_times = max_times | |||||
| self.ramp = LinearRamp( | |||||
| **ramp_kwargs) if ramp_kwargs is not None else None | |||||
| def __call__(self, img, iter_i=None, raw_image=None): | |||||
| coef = self.ramp(iter_i) if (self.ramp is not None) and ( | |||||
| iter_i is not None) else 1 | |||||
| cur_bbox_max_size = int(self.bbox_min_size + 1 | |||||
| + (self.bbox_max_size - self.bbox_min_size) | |||||
| * coef) | |||||
| cur_max_times = int(self.min_times | |||||
| + (self.max_times - self.min_times) * coef) | |||||
| return make_random_rectangle_mask( | |||||
| img.shape[1:], | |||||
| margin=self.margin, | |||||
| bbox_min_size=self.bbox_min_size, | |||||
| bbox_max_size=cur_bbox_max_size, | |||||
| min_times=self.min_times, | |||||
| max_times=cur_max_times) | |||||
| def make_random_irregular_mask(shape, | |||||
| max_angle=4, | |||||
| max_len=60, | |||||
| max_width=20, | |||||
| min_times=0, | |||||
| max_times=10, | |||||
| draw_method=DrawMethod.LINE): | |||||
| draw_method = DrawMethod(draw_method) | |||||
| height, width = shape | |||||
| mask = np.zeros((height, width), np.float32) | |||||
| times = np.random.randint(min_times, max_times + 1) | |||||
| for i in range(times): | |||||
| start_x = np.random.randint(width) | |||||
| start_y = np.random.randint(height) | |||||
| for j in range(1 + np.random.randint(5)): | |||||
| angle = 0.01 + np.random.randint(max_angle) | |||||
| if i % 2 == 0: | |||||
| angle = 2 * 3.1415926 - angle | |||||
| length = 10 + np.random.randint(max_len) | |||||
| brush_w = 5 + np.random.randint(max_width) | |||||
| end_x = np.clip( | |||||
| (start_x + length * np.sin(angle)).astype(np.int32), 0, width) | |||||
| end_y = np.clip( | |||||
| (start_y + length * np.cos(angle)).astype(np.int32), 0, height) | |||||
| if draw_method == DrawMethod.LINE: | |||||
| cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, | |||||
| brush_w) | |||||
| elif draw_method == DrawMethod.CIRCLE: | |||||
| cv2.circle( | |||||
| mask, (start_x, start_y), | |||||
| radius=brush_w, | |||||
| color=1., | |||||
| thickness=-1) | |||||
| elif draw_method == DrawMethod.SQUARE: | |||||
| radius = brush_w // 2 | |||||
| mask[start_y - radius:start_y + radius, | |||||
| start_x - radius:start_x + radius] = 1 | |||||
| start_x, start_y = end_x, end_y | |||||
| return mask[None, ...] | |||||
| class RandomIrregularMaskGenerator: | |||||
| def __init__(self, | |||||
| max_angle=4, | |||||
| max_len=60, | |||||
| max_width=20, | |||||
| min_times=0, | |||||
| max_times=10, | |||||
| ramp_kwargs=None, | |||||
| draw_method=DrawMethod.LINE): | |||||
| self.max_angle = max_angle | |||||
| self.max_len = max_len | |||||
| self.max_width = max_width | |||||
| self.min_times = min_times | |||||
| self.max_times = max_times | |||||
| self.draw_method = draw_method | |||||
| self.ramp = LinearRamp( | |||||
| **ramp_kwargs) if ramp_kwargs is not None else None | |||||
| def __call__(self, img, iter_i=None, raw_image=None): | |||||
| coef = self.ramp(iter_i) if (self.ramp is not None) and ( | |||||
| iter_i is not None) else 1 | |||||
| cur_max_len = int(max(1, self.max_len * coef)) | |||||
| cur_max_width = int(max(1, self.max_width * coef)) | |||||
| cur_max_times = int(self.min_times + 1 | |||||
| + (self.max_times - self.min_times) * coef) | |||||
| return make_random_irregular_mask( | |||||
| img.shape[1:], | |||||
| max_angle=self.max_angle, | |||||
| max_len=cur_max_len, | |||||
| max_width=cur_max_width, | |||||
| min_times=self.min_times, | |||||
| max_times=cur_max_times, | |||||
| draw_method=self.draw_method) | |||||
| class MixedMaskGenerator: | |||||
| def __init__(self, | |||||
| irregular_proba=1 / 3, | |||||
| irregular_kwargs=None, | |||||
| box_proba=1 / 3, | |||||
| box_kwargs=None, | |||||
| segm_proba=1 / 3, | |||||
| segm_kwargs=None, | |||||
| squares_proba=0, | |||||
| squares_kwargs=None, | |||||
| superres_proba=0, | |||||
| superres_kwargs=None, | |||||
| outpainting_proba=0, | |||||
| outpainting_kwargs=None, | |||||
| invert_proba=0): | |||||
| self.probas = [] | |||||
| self.gens = [] | |||||
| if irregular_proba > 0: | |||||
| self.probas.append(irregular_proba) | |||||
| if irregular_kwargs is None: | |||||
| irregular_kwargs = {} | |||||
| else: | |||||
| irregular_kwargs = dict(irregular_kwargs) | |||||
| irregular_kwargs['draw_method'] = DrawMethod.LINE | |||||
| self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) | |||||
| if box_proba > 0: | |||||
| self.probas.append(box_proba) | |||||
| if box_kwargs is None: | |||||
| box_kwargs = {} | |||||
| self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) | |||||
| if squares_proba > 0: | |||||
| self.probas.append(squares_proba) | |||||
| if squares_kwargs is None: | |||||
| squares_kwargs = {} | |||||
| else: | |||||
| squares_kwargs = dict(squares_kwargs) | |||||
| squares_kwargs['draw_method'] = DrawMethod.SQUARE | |||||
| self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) | |||||
| if superres_proba > 0: | |||||
| self.probas.append(superres_proba) | |||||
| if superres_kwargs is None: | |||||
| superres_kwargs = {} | |||||
| self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) | |||||
| self.probas = np.array(self.probas, dtype='float32') | |||||
| self.probas /= self.probas.sum() | |||||
| self.invert_proba = invert_proba | |||||
| def __call__(self, img, iter_i=None, raw_image=None): | |||||
| kind = np.random.choice(len(self.probas), p=self.probas) | |||||
| gen = self.gens[kind] | |||||
| result = gen(img, iter_i=iter_i, raw_image=raw_image) | |||||
| if self.invert_proba > 0 and random.random() < self.invert_proba: | |||||
| result = 1 - result | |||||
| return result | |||||
| def get_transforms(test_mode, out_size): | |||||
| if not test_mode: | |||||
| transform = A.Compose([ | |||||
| IAAPerspective2(scale=(0.0, 0.06)), | |||||
| IAAAffine2(scale=(0.7, 1.3), rotate=(-40, 40), shear=(-0.1, 0.1)), | |||||
| A.PadIfNeeded(min_height=out_size, min_width=out_size), | |||||
| A.OpticalDistortion(), | |||||
| A.RandomCrop(height=out_size, width=out_size), | |||||
| A.HorizontalFlip(), | |||||
| A.CLAHE(), | |||||
| A.RandomBrightnessContrast( | |||||
| brightness_limit=0.2, contrast_limit=0.2), | |||||
| A.HueSaturationValue( | |||||
| hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), | |||||
| A.ToFloat() | |||||
| ]) | |||||
| else: | |||||
| transform = A.Compose([ | |||||
| A.PadIfNeeded(min_height=out_size, min_width=out_size), | |||||
| A.CenterCrop(height=out_size, width=out_size), | |||||
| A.ToFloat() | |||||
| ]) | |||||
| return transform | |||||
| @TASK_DATASETS.register_module( | |||||
| Tasks.image_inpainting, module_name=Models.image_inpainting) | |||||
| class ImageInpaintingDataset(TorchTaskDataset): | |||||
| def __init__(self, **kwargs): | |||||
| split_config = kwargs['split_config'] | |||||
| LOGGER.info(kwargs) | |||||
| mode = kwargs.get('test_mode', False) | |||||
| self.data_root = next(iter(split_config.values())) | |||||
| if not osp.exists(self.data_root): | |||||
| self.data_root = osp.dirname(self.data_root) | |||||
| assert osp.exists(self.data_root) | |||||
| mask_gen_kwargs = kwargs.get('mask_gen_kwargs', {}) | |||||
| out_size = kwargs.get('out_size', 256) | |||||
| self.mask_generator = MixedMaskGenerator(**mask_gen_kwargs) | |||||
| self.transform = get_transforms(mode, out_size) | |||||
| self.in_files = sorted( | |||||
| list( | |||||
| glob.glob( | |||||
| osp.join(self.data_root, '**', '*.jpg'), recursive=True)) | |||||
| + list( | |||||
| glob.glob( | |||||
| osp.join(self.data_root, '**', '*.png'), recursive=True))) | |||||
| self.iter_i = 0 | |||||
| def __len__(self): | |||||
| return len(self.in_files) | |||||
| def __getitem__(self, index): | |||||
| path = self.in_files[index] | |||||
| img = cv2.imread(path) | |||||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||||
| img = self.transform(image=img)['image'] | |||||
| img = np.transpose(img, (2, 0, 1)) | |||||
| # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks | |||||
| mask = self.mask_generator(img, iter_i=self.iter_i) | |||||
| self.iter_i += 1 | |||||
| return dict(image=img, mask=mask) | |||||
| @@ -177,6 +177,7 @@ TASK_OUTPUTS = { | |||||
| Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], | Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], | Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG], | Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_inpainting: [OutputKeys.OUTPUT_IMG], | |||||
| # image generation task result for a single image | # image generation task result for a single image | ||||
| # {"output_img": np.array with shape (h, w, 3)} | # {"output_img": np.array with shape (h, w, 3)} | ||||
| @@ -181,6 +181,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), | 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), | ||||
| Tasks.shop_segmentation: (Pipelines.shop_segmentation, | Tasks.shop_segmentation: (Pipelines.shop_segmentation, | ||||
| 'damo/cv_vitb16_segmentation_shop-seg'), | 'damo/cv_vitb16_segmentation_shop-seg'), | ||||
| Tasks.image_inpainting: (Pipelines.image_inpainting, | |||||
| 'damo/cv_fft_inpainting_lama'), | |||||
| Tasks.video_inpainting: (Pipelines.video_inpainting, | Tasks.video_inpainting: (Pipelines.video_inpainting, | ||||
| 'damo/cv_video-inpainting'), | 'damo/cv_video-inpainting'), | ||||
| Tasks.hand_static: (Pipelines.hand_static, | Tasks.hand_static: (Pipelines.hand_static, | ||||
| @@ -35,6 +35,7 @@ if TYPE_CHECKING: | |||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
| from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | ||||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | ||||
| from .image_inpainting_pipeline import ImageInpaintingPipeline | |||||
| from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | ||||
| from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline | from .realtime_object_detection_pipeline import RealtimeObjectDetectionPipeline | ||||
| from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
| @@ -99,6 +100,7 @@ else: | |||||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | 'live_category_pipeline': ['LiveCategoryPipeline'], | ||||
| 'image_to_image_generation_pipeline': | 'image_to_image_generation_pipeline': | ||||
| ['Image2ImageGenerationPipeline'], | ['Image2ImageGenerationPipeline'], | ||||
| 'image_inpainting_pipeline': ['ImageInpaintingPipeline'], | |||||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | ||||
| 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], | ||||
| 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | ||||
| @@ -0,0 +1,146 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch.utils.data._utils.collate import default_collate | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.image_inpainting import FFTInpainting | |||||
| from modelscope.models.cv.image_inpainting.refinement import refine_predict | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors.image import LoadImage | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_inpainting, module_name=Pipelines.image_inpainting) | |||||
| class ImageInpaintingPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: str, | |||||
| pad_out_to_modulo=8, | |||||
| refine=False, | |||||
| **kwargs): | |||||
| """ | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| assert isinstance(model, str), 'model must be a single str' | |||||
| super().__init__(model=model, auto_collate=False, **kwargs) | |||||
| self.refine = refine | |||||
| logger.info(f'loading model from dir {model}') | |||||
| self.infer_model = FFTInpainting(model, predict_only=True) | |||||
| if not self.refine: | |||||
| self.infer_model.to(self.device) | |||||
| self.infer_model.eval() | |||||
| logger.info(f'loading model done, refinement is set to {self.refine}') | |||||
| self.pad_out_to_modulo = pad_out_to_modulo | |||||
| def move_to_device(self, obj, device): | |||||
| if isinstance(obj, nn.Module): | |||||
| return obj.to(device) | |||||
| if torch.is_tensor(obj): | |||||
| return obj.to(device) | |||||
| if isinstance(obj, (tuple, list)): | |||||
| return [self.move_to_device(el, device) for el in obj] | |||||
| if isinstance(obj, dict): | |||||
| return { | |||||
| name: self.move_to_device(val, device) | |||||
| for name, val in obj.items() | |||||
| } | |||||
| raise ValueError(f'Unexpected type {type(obj)}') | |||||
| def transforms(self, img): | |||||
| if img.ndim == 3: | |||||
| img = np.transpose(img, (2, 0, 1)) | |||||
| out_img = img.astype('float32') / 255 | |||||
| return out_img | |||||
| def ceil_modulo(self, x, mod): | |||||
| if x % mod == 0: | |||||
| return x | |||||
| return (x // mod + 1) * mod | |||||
| def pad_img_to_modulo(self, img, mod): | |||||
| channels, height, width = img.shape | |||||
| out_height = self.ceil_modulo(height, mod) | |||||
| out_width = self.ceil_modulo(width, mod) | |||||
| return np.pad( | |||||
| img, ((0, 0), (0, out_height - height), (0, out_width - width)), | |||||
| mode='symmetric') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| if isinstance(input, str): | |||||
| image_name, mask_name = input.split('+') | |||||
| img = LoadImage.convert_to_ndarray(image_name) | |||||
| img = self.transforms(img) | |||||
| mask = np.array(LoadImage(mode='L')(mask_name)['img']) | |||||
| mask = self.transforms(mask) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = input.crop((0, 0, int(input.width / 2), input.height)) | |||||
| img = self.transforms(np.array(img)) | |||||
| mask = input.crop((int(input.width / 2), 0, input.width, | |||||
| input.height)).convert('L') | |||||
| mask = self.transforms(np.array(mask)) | |||||
| else: | |||||
| raise TypeError('input should be either str or PIL.Image') | |||||
| result = dict(image=img, mask=mask[None, ...]) | |||||
| if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: | |||||
| result['unpad_to_size'] = result['image'].shape[1:] | |||||
| result['image'] = self.pad_img_to_modulo(result['image'], | |||||
| self.pad_out_to_modulo) | |||||
| result['mask'] = self.pad_img_to_modulo(result['mask'], | |||||
| self.pad_out_to_modulo) | |||||
| # Since Pipeline use default torch.no_grad() for performing forward func. | |||||
| # We conduct inference here in case of doing training for refinement. | |||||
| result = self.perform_inference(result) | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return {OutputKeys.OUTPUT_IMG: input} | |||||
| def perform_inference(self, data): | |||||
| batch = default_collate([data]) | |||||
| if self.refine: | |||||
| assert 'unpad_to_size' in batch, 'Unpadded size is required for the refinement' | |||||
| assert 'cuda' in str(self.device), 'GPU is required for refinement' | |||||
| gpu_ids = str(self.device).split(':')[-1] | |||||
| cur_res = refine_predict( | |||||
| batch, | |||||
| self.infer_model, | |||||
| gpu_ids=gpu_ids, | |||||
| modulo=self.pad_out_to_modulo, | |||||
| n_iters=15, | |||||
| lr=0.002, | |||||
| min_side=512, | |||||
| max_scales=3, | |||||
| px_budget=900000) | |||||
| cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy() | |||||
| else: | |||||
| with torch.no_grad(): | |||||
| batch = self.move_to_device(batch, self.device) | |||||
| batch['mask'] = (batch['mask'] > 0) * 1 | |||||
| batch = self.infer_model(batch) | |||||
| cur_res = batch['inpainted'][0].permute( | |||||
| 1, 2, 0).detach().cpu().numpy() | |||||
| unpad_to_size = batch.get('unpad_to_size', None) | |||||
| if unpad_to_size is not None: | |||||
| orig_height, orig_width = unpad_to_size | |||||
| cur_res = cur_res[:orig_height, :orig_width] | |||||
| cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8') | |||||
| cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) | |||||
| return cur_res | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -9,7 +9,7 @@ if TYPE_CHECKING: | |||||
| from .builder import build_trainer | from .builder import build_trainer | ||||
| from .cv import (ImageInstanceSegmentationTrainer, | from .cv import (ImageInstanceSegmentationTrainer, | ||||
| ImagePortraitEnhancementTrainer, | ImagePortraitEnhancementTrainer, | ||||
| MovieSceneSegmentationTrainer) | |||||
| MovieSceneSegmentationTrainer, ImageInpaintingTrainer) | |||||
| from .multi_modal import CLIPTrainer | from .multi_modal import CLIPTrainer | ||||
| from .nlp import SequenceClassificationTrainer, PassageRankingTrainer | from .nlp import SequenceClassificationTrainer, PassageRankingTrainer | ||||
| from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer | ||||
| @@ -22,7 +22,8 @@ else: | |||||
| 'builder': ['build_trainer'], | 'builder': ['build_trainer'], | ||||
| 'cv': [ | 'cv': [ | ||||
| 'ImageInstanceSegmentationTrainer', | 'ImageInstanceSegmentationTrainer', | ||||
| 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer' | |||||
| 'ImagePortraitEnhancementTrainer', 'MovieSceneSegmentationTrainer', | |||||
| 'ImageInpaintingTrainer' | |||||
| ], | ], | ||||
| 'multi_modal': ['CLIPTrainer'], | 'multi_modal': ['CLIPTrainer'], | ||||
| 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], | 'nlp': ['SequenceClassificationTrainer', 'PassageRankingTrainer'], | ||||
| @@ -8,6 +8,7 @@ if TYPE_CHECKING: | |||||
| ImageInstanceSegmentationTrainer | ImageInstanceSegmentationTrainer | ||||
| from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | from .image_portrait_enhancement_trainer import ImagePortraitEnhancementTrainer | ||||
| from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | from .movie_scene_segmentation_trainer import MovieSceneSegmentationTrainer | ||||
| from .image_inpainting_trainer import ImageInpaintingTrainer | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| @@ -15,7 +16,8 @@ else: | |||||
| ['ImageInstanceSegmentationTrainer'], | ['ImageInstanceSegmentationTrainer'], | ||||
| 'image_portrait_enhancement_trainer': | 'image_portrait_enhancement_trainer': | ||||
| ['ImagePortraitEnhancementTrainer'], | ['ImagePortraitEnhancementTrainer'], | ||||
| 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'] | |||||
| 'movie_scene_segmentation_trainer': ['MovieSceneSegmentationTrainer'], | |||||
| 'image_inpainting_trainer': ['ImageInpaintingTrainer'] | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import time | |||||
| from collections.abc import Mapping | |||||
| from torch import distributed as dist | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.trainers.trainer import EpochBasedTrainer | |||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, | |||||
| ConfigKeys, Hubs, ModeKeys, ModelFile, | |||||
| Tasks, TrainerStages) | |||||
| from modelscope.utils.data_utils import to_device | |||||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||||
| @TRAINERS.register_module(module_name=Trainers.image_inpainting) | |||||
| class ImageInpaintingTrainer(EpochBasedTrainer): | |||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| def train(self, *args, **kwargs): | |||||
| super().train(*args, **kwargs) | |||||
| def evaluate(self, *args, **kwargs): | |||||
| metric_values = super().evaluate(*args, **kwargs) | |||||
| return metric_values | |||||
| def prediction_step(self, model, inputs): | |||||
| pass | |||||
| def train_loop(self, data_loader): | |||||
| """ Training loop used by `EpochBasedTrainer.train()` | |||||
| """ | |||||
| self.invoke_hook(TrainerStages.before_run) | |||||
| self._epoch = 0 | |||||
| self.model.train() | |||||
| for _ in range(self._epoch, self._max_epochs): | |||||
| self.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for i, data_batch in enumerate(data_loader): | |||||
| data_batch = to_device(data_batch, self.device) | |||||
| self.data_batch = data_batch | |||||
| self._inner_iter = i | |||||
| for idx in range(2): | |||||
| self.invoke_hook(TrainerStages.before_train_iter) | |||||
| self.train_step(self.model, data_batch, idx) | |||||
| self.invoke_hook(TrainerStages.after_train_iter) | |||||
| del self.data_batch | |||||
| self._iter += 1 | |||||
| self._mode = ModeKeys.TRAIN | |||||
| if i + 1 >= self.iters_per_epoch: | |||||
| break | |||||
| self.invoke_hook(TrainerStages.after_train_epoch) | |||||
| self._epoch += 1 | |||||
| self.invoke_hook(TrainerStages.after_run) | |||||
| def train_step(self, model, inputs, idx): | |||||
| """ Perform a training step on a batch of inputs. | |||||
| Subclass and override to inject custom behavior. | |||||
| Args: | |||||
| model (`TorchModel`): The model to train. | |||||
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |||||
| The inputs and targets of the model. | |||||
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |||||
| argument `labels`. Check your model's documentation for all accepted arguments. | |||||
| Return: | |||||
| `torch.Tensor`: The tensor with training loss on this batch. | |||||
| """ | |||||
| # EvaluationHook will do evaluate and change mode to val, return to train mode | |||||
| # TODO: find more pretty way to change mode | |||||
| model.train() | |||||
| self._mode = ModeKeys.TRAIN | |||||
| # call model forward but not __call__ to skip postprocess | |||||
| if isinstance(inputs, | |||||
| Mapping) and not func_receive_dict_inputs(model.forward): | |||||
| train_outputs = model.model._do_step(**inputs, optimizer_idx=idx) | |||||
| else: | |||||
| train_outputs = model.model._do_step(inputs, optimizer_idx=idx) | |||||
| if not isinstance(train_outputs, dict): | |||||
| raise TypeError('"model.forward()" must return a dict') | |||||
| # add model output info to log | |||||
| if 'log_vars' not in train_outputs: | |||||
| default_keys_pattern = ['loss'] | |||||
| match_keys = set([]) | |||||
| for key_p in default_keys_pattern: | |||||
| match_keys.update( | |||||
| [key for key in train_outputs.keys() if key_p in key]) | |||||
| log_vars = {} | |||||
| for key in match_keys: | |||||
| value = train_outputs.get(key, None) | |||||
| if value is not None: | |||||
| if dist.is_available() and dist.is_initialized(): | |||||
| value = value.data.clone() | |||||
| dist.all_reduce(value.div_(dist.get_world_size())) | |||||
| log_vars.update({key: value.item()}) | |||||
| self.log_buffer.update(log_vars) | |||||
| else: | |||||
| self.log_buffer.update(train_outputs['log_vars']) | |||||
| self.train_outputs = train_outputs | |||||
| @@ -47,6 +47,8 @@ class CVTasks(object): | |||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | product_segmentation = 'product-segmentation' | ||||
| crowd_counting = 'crowd-counting' | |||||
| # image editing | # image editing | ||||
| skin_retouching = 'skin-retouching' | skin_retouching = 'skin-retouching' | ||||
| image_super_resolution = 'image-super-resolution' | image_super_resolution = 'image-super-resolution' | ||||
| @@ -54,6 +56,7 @@ class CVTasks(object): | |||||
| image_color_enhancement = 'image-color-enhancement' | image_color_enhancement = 'image-color-enhancement' | ||||
| image_denoising = 'image-denoising' | image_denoising = 'image-denoising' | ||||
| image_portrait_enhancement = 'image-portrait-enhancement' | image_portrait_enhancement = 'image-portrait-enhancement' | ||||
| image_inpainting = 'image-inpainting' | |||||
| # image generation | # image generation | ||||
| image_to_image_translation = 'image-to-image-translation' | image_to_image_translation = 'image-to-image-translation' | ||||
| @@ -72,7 +75,6 @@ class CVTasks(object): | |||||
| video_category = 'video-category' | video_category = 'video-category' | ||||
| video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
| virtual_try_on = 'virtual-try-on' | virtual_try_on = 'virtual-try-on' | ||||
| crowd_counting = 'crowd-counting' | |||||
| movie_scene_segmentation = 'movie-scene-segmentation' | movie_scene_segmentation = 'movie-scene-segmentation' | ||||
| # video editing | # video editing | ||||
| @@ -7,6 +7,8 @@ ffmpeg-python>=0.2.0 | |||||
| ftfy | ftfy | ||||
| imageio>=2.9.0 | imageio>=2.9.0 | ||||
| imageio-ffmpeg>=0.4.2 | imageio-ffmpeg>=0.4.2 | ||||
| imgaug>=0.4.0 | |||||
| kornia>=0.5.0 | |||||
| lmdb | lmdb | ||||
| lpips | lpips | ||||
| ml_collections | ml_collections | ||||
| @@ -0,0 +1,77 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| import cv2 | |||||
| import torch | |||||
| from PIL import Image | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| class ImageInpaintingTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.input_location = 'data/test/images/image_inpainting/image_inpainting.png' | |||||
| self.input_mask_location = 'data/test/images/image_inpainting/image_inpainting_mask.png' | |||||
| self.model_id = 'damo/cv_fft_inpainting_lama' | |||||
| def save_result(self, result): | |||||
| vis_img = result[OutputKeys.OUTPUT_IMG] | |||||
| cv2.imwrite('result.png', vis_img) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_inpainting(self): | |||||
| inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) | |||||
| result = inpainting(self.input_location + '+' | |||||
| + self.input_mask_location) | |||||
| if result: | |||||
| self.save_result(result) | |||||
| else: | |||||
| raise ValueError('process error') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
| def test_inpainting_with_refinement(self): | |||||
| # if input image is HR, set refine=True is more better | |||||
| inpainting = pipeline( | |||||
| Tasks.image_inpainting, model=self.model_id, refine=True) | |||||
| result = inpainting(self.input_location + '+' | |||||
| + self.input_mask_location) | |||||
| if result: | |||||
| self.save_result(result) | |||||
| else: | |||||
| raise ValueError('process error') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_inpainting_with_image(self): | |||||
| inpainting = pipeline(Tasks.image_inpainting, model=self.model_id) | |||||
| img = Image.open(self.input_location).convert('RGB') | |||||
| mask = Image.open(self.input_mask_location).convert('RGB') | |||||
| img_new = Image.new('RGB', (img.width + mask.width, img.height)) | |||||
| img_new.paste(img, (0, 0)) | |||||
| img_new.paste(mask, (img.width, 0)) | |||||
| result = inpainting(img_new) | |||||
| if result: | |||||
| self.save_result(result) | |||||
| else: | |||||
| raise ValueError('process error') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_inpainting_with_default_task(self): | |||||
| inpainting = pipeline(Tasks.image_inpainting) | |||||
| result = inpainting(self.input_location + '+' | |||||
| + self.input_mask_location) | |||||
| if result: | |||||
| self.save_result(result) | |||||
| else: | |||||
| raise ValueError('process error') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -10,6 +10,7 @@ isolated: # test cases that may require excessive anmount of GPU memory, which | |||||
| - test_easycv_trainer.py | - test_easycv_trainer.py | ||||
| - test_segformer.py | - test_segformer.py | ||||
| - test_segmentation_pipeline.py | - test_segmentation_pipeline.py | ||||
| - test_image_inpainting.py | |||||
| envs: | envs: | ||||
| default: # default env, case not in other env will in default, pytorch. | default: # default env, case not in other env will in default, pytorch. | ||||
| @@ -0,0 +1,84 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.cv.image_inpainting import FFTInpainting | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config, ConfigDict | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| class ImageInpaintingTrainerTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| self.model_id = 'damo/cv_fft_inpainting_lama' | |||||
| self.cache_path = snapshot_download(self.model_id) | |||||
| cfg = Config.from_file( | |||||
| os.path.join(self.cache_path, ModelFile.CONFIGURATION)) | |||||
| train_data_cfg = ConfigDict( | |||||
| name='PlacesToydataset', | |||||
| split='train', | |||||
| mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, | |||||
| out_size=cfg.dataset.train_out_size, | |||||
| test_mode=False) | |||||
| test_data_cfg = ConfigDict( | |||||
| name='PlacesToydataset', | |||||
| split='test', | |||||
| mask_gen_kwargs=cfg.dataset.mask_gen_kwargs, | |||||
| out_size=cfg.dataset.val_out_size, | |||||
| test_mode=True) | |||||
| self.train_dataset = MsDataset.load( | |||||
| dataset_name=train_data_cfg.name, | |||||
| split=train_data_cfg.split, | |||||
| mask_gen_kwargs=train_data_cfg.mask_gen_kwargs, | |||||
| out_size=train_data_cfg.out_size, | |||||
| test_mode=train_data_cfg.test_mode) | |||||
| assert next( | |||||
| iter(self.train_dataset.config_kwargs['split_config'].values())) | |||||
| self.test_dataset = MsDataset.load( | |||||
| dataset_name=test_data_cfg.name, | |||||
| split=test_data_cfg.split, | |||||
| mask_gen_kwargs=test_data_cfg.mask_gen_kwargs, | |||||
| out_size=test_data_cfg.out_size, | |||||
| test_mode=test_data_cfg.test_mode) | |||||
| assert next( | |||||
| iter(self.test_dataset.config_kwargs['split_config'].values())) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.test_dataset) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.image_inpainting, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(trainer.work_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||