Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9440706 * init * merge mastermaster
| @@ -49,6 +49,7 @@ class Pipelines(object): | |||||
| action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
| animal_recognation = 'resnet101-animal_recog' | animal_recognation = 'resnet101-animal_recog' | ||||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | ||||
| image_super_resolution = 'rrdb-image-super-resolution' | |||||
| face_image_generation = 'gan-face-image-generation' | face_image_generation = 'gan-face-image-generation' | ||||
| style_transfer = 'AAMS-style-transfer' | style_transfer = 'AAMS-style-transfer' | ||||
| @@ -0,0 +1,226 @@ | |||||
| import collections.abc | |||||
| import math | |||||
| import warnings | |||||
| from itertools import repeat | |||||
| import torch | |||||
| import torchvision | |||||
| from torch import nn as nn | |||||
| from torch.nn import functional as F | |||||
| from torch.nn import init as init | |||||
| from torch.nn.modules.batchnorm import _BatchNorm | |||||
| @torch.no_grad() | |||||
| def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): | |||||
| """Initialize network weights. | |||||
| Args: | |||||
| module_list (list[nn.Module] | nn.Module): Modules to be initialized. | |||||
| scale (float): Scale initialized weights, especially for residual | |||||
| blocks. Default: 1. | |||||
| bias_fill (float): The value to fill bias. Default: 0 | |||||
| kwargs (dict): Other arguments for initialization function. | |||||
| """ | |||||
| if not isinstance(module_list, list): | |||||
| module_list = [module_list] | |||||
| for module in module_list: | |||||
| for m in module.modules(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| init.kaiming_normal_(m.weight, **kwargs) | |||||
| m.weight.data *= scale | |||||
| if m.bias is not None: | |||||
| m.bias.data.fill_(bias_fill) | |||||
| elif isinstance(m, nn.Linear): | |||||
| init.kaiming_normal_(m.weight, **kwargs) | |||||
| m.weight.data *= scale | |||||
| if m.bias is not None: | |||||
| m.bias.data.fill_(bias_fill) | |||||
| elif isinstance(m, _BatchNorm): | |||||
| init.constant_(m.weight, 1) | |||||
| if m.bias is not None: | |||||
| m.bias.data.fill_(bias_fill) | |||||
| def make_layer(basic_block, num_basic_block, **kwarg): | |||||
| """Make layers by stacking the same blocks. | |||||
| Args: | |||||
| basic_block (nn.module): nn.module class for basic block. | |||||
| num_basic_block (int): number of blocks. | |||||
| Returns: | |||||
| nn.Sequential: Stacked blocks in nn.Sequential. | |||||
| """ | |||||
| layers = [] | |||||
| for _ in range(num_basic_block): | |||||
| layers.append(basic_block(**kwarg)) | |||||
| return nn.Sequential(*layers) | |||||
| class ResidualBlockNoBN(nn.Module): | |||||
| """Residual block without BN. | |||||
| It has a style of: | |||||
| ---Conv-ReLU-Conv-+- | |||||
| |________________| | |||||
| Args: | |||||
| num_feat (int): Channel number of intermediate features. | |||||
| Default: 64. | |||||
| res_scale (float): Residual scale. Default: 1. | |||||
| pytorch_init (bool): If set to True, use pytorch default init, | |||||
| otherwise, use default_init_weights. Default: False. | |||||
| """ | |||||
| def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): | |||||
| super(ResidualBlockNoBN, self).__init__() | |||||
| self.res_scale = res_scale | |||||
| self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |||||
| self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| if not pytorch_init: | |||||
| default_init_weights([self.conv1, self.conv2], 0.1) | |||||
| def forward(self, x): | |||||
| identity = x | |||||
| out = self.conv2(self.relu(self.conv1(x))) | |||||
| return identity + out * self.res_scale | |||||
| class Upsample(nn.Sequential): | |||||
| """Upsample module. | |||||
| Args: | |||||
| scale (int): Scale factor. Supported scales: 2^n and 3. | |||||
| num_feat (int): Channel number of intermediate features. | |||||
| """ | |||||
| def __init__(self, scale, num_feat): | |||||
| m = [] | |||||
| if (scale & (scale - 1)) == 0: # scale = 2^n | |||||
| for _ in range(int(math.log(scale, 2))): | |||||
| m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) | |||||
| m.append(nn.PixelShuffle(2)) | |||||
| elif scale == 3: | |||||
| m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) | |||||
| m.append(nn.PixelShuffle(3)) | |||||
| else: | |||||
| raise ValueError( | |||||
| f'scale {scale} is not supported. Supported scales: 2^n and 3.' | |||||
| ) | |||||
| super(Upsample, self).__init__(*m) | |||||
| def flow_warp(x, | |||||
| flow, | |||||
| interp_mode='bilinear', | |||||
| padding_mode='zeros', | |||||
| align_corners=True): | |||||
| """Warp an image or feature map with optical flow. | |||||
| Args: | |||||
| x (Tensor): Tensor with size (n, c, h, w). | |||||
| flow (Tensor): Tensor with size (n, h, w, 2), normal value. | |||||
| interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. | |||||
| padding_mode (str): 'zeros' or 'border' or 'reflection'. | |||||
| Default: 'zeros'. | |||||
| align_corners (bool): Before pytorch 1.3, the default value is | |||||
| align_corners=True. After pytorch 1.3, the default value is | |||||
| align_corners=False. Here, we use the True as default. | |||||
| Returns: | |||||
| Tensor: Warped image or feature map. | |||||
| """ | |||||
| assert x.size()[-2:] == flow.size()[1:3] | |||||
| _, _, h, w = x.size() | |||||
| # create mesh grid | |||||
| grid_y, grid_x = torch.meshgrid( | |||||
| torch.arange(0, h).type_as(x), | |||||
| torch.arange(0, w).type_as(x)) | |||||
| grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 | |||||
| grid.requires_grad = False | |||||
| vgrid = grid + flow | |||||
| # scale grid to [-1,1] | |||||
| vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 | |||||
| vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 | |||||
| vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) | |||||
| output = F.grid_sample( | |||||
| x, | |||||
| vgrid_scaled, | |||||
| mode=interp_mode, | |||||
| padding_mode=padding_mode, | |||||
| align_corners=align_corners) | |||||
| # TODO, what if align_corners=False | |||||
| return output | |||||
| def resize_flow(flow, | |||||
| size_type, | |||||
| sizes, | |||||
| interp_mode='bilinear', | |||||
| align_corners=False): | |||||
| """Resize a flow according to ratio or shape. | |||||
| Args: | |||||
| flow (Tensor): Precomputed flow. shape [N, 2, H, W]. | |||||
| size_type (str): 'ratio' or 'shape'. | |||||
| sizes (list[int | float]): the ratio for resizing or the final output | |||||
| shape. | |||||
| 1) The order of ratio should be [ratio_h, ratio_w]. For | |||||
| downsampling, the ratio should be smaller than 1.0 (i.e., ratio | |||||
| < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., | |||||
| ratio > 1.0). | |||||
| 2) The order of output_size should be [out_h, out_w]. | |||||
| interp_mode (str): The mode of interpolation for resizing. | |||||
| Default: 'bilinear'. | |||||
| align_corners (bool): Whether align corners. Default: False. | |||||
| Returns: | |||||
| Tensor: Resized flow. | |||||
| """ | |||||
| _, _, flow_h, flow_w = flow.size() | |||||
| if size_type == 'ratio': | |||||
| output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) | |||||
| elif size_type == 'shape': | |||||
| output_h, output_w = sizes[0], sizes[1] | |||||
| else: | |||||
| raise ValueError( | |||||
| f'Size type should be ratio or shape, but got type {size_type}.') | |||||
| input_flow = flow.clone() | |||||
| ratio_h = output_h / flow_h | |||||
| ratio_w = output_w / flow_w | |||||
| input_flow[:, 0, :, :] *= ratio_w | |||||
| input_flow[:, 1, :, :] *= ratio_h | |||||
| resized_flow = F.interpolate( | |||||
| input=input_flow, | |||||
| size=(output_h, output_w), | |||||
| mode=interp_mode, | |||||
| align_corners=align_corners) | |||||
| return resized_flow | |||||
| # TODO: may write a cpp file | |||||
| def pixel_unshuffle(x, scale): | |||||
| """ Pixel unshuffle. | |||||
| Args: | |||||
| x (Tensor): Input feature with shape (b, c, hh, hw). | |||||
| scale (int): Downsample ratio. | |||||
| Returns: | |||||
| Tensor: the pixel unshuffled feature. | |||||
| """ | |||||
| b, c, hh, hw = x.size() | |||||
| out_channel = c * (scale**2) | |||||
| assert hh % scale == 0 and hw % scale == 0 | |||||
| h = hh // scale | |||||
| w = hw // scale | |||||
| x_view = x.view(b, c, h, scale, w, scale) | |||||
| return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) | |||||
| @@ -0,0 +1,129 @@ | |||||
| import torch | |||||
| from torch import nn as nn | |||||
| from torch.nn import functional as F | |||||
| from .arch_util import default_init_weights, make_layer, pixel_unshuffle | |||||
| class ResidualDenseBlock(nn.Module): | |||||
| """Residual Dense Block. | |||||
| Used in RRDB block in ESRGAN. | |||||
| Args: | |||||
| num_feat (int): Channel number of intermediate features. | |||||
| num_grow_ch (int): Channels for each growth. | |||||
| """ | |||||
| def __init__(self, num_feat=64, num_grow_ch=32): | |||||
| super(ResidualDenseBlock, self).__init__() | |||||
| self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) | |||||
| self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) | |||||
| self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, | |||||
| 1) | |||||
| self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, | |||||
| 1) | |||||
| self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) | |||||
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |||||
| # initialization | |||||
| default_init_weights( | |||||
| [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) | |||||
| def forward(self, x): | |||||
| x1 = self.lrelu(self.conv1(x)) | |||||
| x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |||||
| x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |||||
| x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |||||
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |||||
| # Emperically, we use 0.2 to scale the residual for better performance | |||||
| return x5 * 0.2 + x | |||||
| class RRDB(nn.Module): | |||||
| """Residual in Residual Dense Block. | |||||
| Used in RRDB-Net in ESRGAN. | |||||
| Args: | |||||
| num_feat (int): Channel number of intermediate features. | |||||
| num_grow_ch (int): Channels for each growth. | |||||
| """ | |||||
| def __init__(self, num_feat, num_grow_ch=32): | |||||
| super(RRDB, self).__init__() | |||||
| self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) | |||||
| self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) | |||||
| self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) | |||||
| def forward(self, x): | |||||
| out = self.rdb1(x) | |||||
| out = self.rdb2(out) | |||||
| out = self.rdb3(out) | |||||
| # Emperically, we use 0.2 to scale the residual for better performance | |||||
| return out * 0.2 + x | |||||
| class RRDBNet(nn.Module): | |||||
| """Networks consisting of Residual in Residual Dense Block, which is used | |||||
| in ESRGAN. | |||||
| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. | |||||
| We extend ESRGAN for scale x2 and scale x1. | |||||
| Note: This is one option for scale 1, scale 2 in RRDBNet. | |||||
| We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size | |||||
| and enlarge the channel size before feeding inputs into the main ESRGAN architecture. | |||||
| Args: | |||||
| num_in_ch (int): Channel number of inputs. | |||||
| num_out_ch (int): Channel number of outputs. | |||||
| num_feat (int): Channel number of intermediate features. | |||||
| Default: 64 | |||||
| num_block (int): Block number in the trunk network. Defaults: 23 | |||||
| num_grow_ch (int): Channels for each growth. Default: 32. | |||||
| """ | |||||
| def __init__(self, | |||||
| num_in_ch, | |||||
| num_out_ch, | |||||
| scale=4, | |||||
| num_feat=64, | |||||
| num_block=23, | |||||
| num_grow_ch=32): | |||||
| super(RRDBNet, self).__init__() | |||||
| self.scale = scale | |||||
| if scale == 2: | |||||
| num_in_ch = num_in_ch * 4 | |||||
| elif scale == 1: | |||||
| num_in_ch = num_in_ch * 16 | |||||
| self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | |||||
| self.body = make_layer( | |||||
| RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) | |||||
| self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
| # upsample | |||||
| self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
| self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
| self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||||
| self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) | |||||
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |||||
| def forward(self, x): | |||||
| if self.scale == 2: | |||||
| feat = pixel_unshuffle(x, scale=2) | |||||
| elif self.scale == 1: | |||||
| feat = pixel_unshuffle(x, scale=4) | |||||
| else: | |||||
| feat = x | |||||
| feat = self.conv_first(feat) | |||||
| body_feat = self.conv_body(self.body(feat)) | |||||
| feat = feat + body_feat | |||||
| # upsample | |||||
| feat = self.lrelu( | |||||
| self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) | |||||
| feat = self.lrelu( | |||||
| self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) | |||||
| out = self.conv_last(self.lrelu(self.conv_hr(feat))) | |||||
| return out | |||||
| @@ -70,6 +70,7 @@ TASK_OUTPUTS = { | |||||
| Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_restoration: [OutputKeys.OUTPUT_IMG], | |||||
| # action recognition result for single video | # action recognition result for single video | ||||
| # { | # { | ||||
| @@ -6,6 +6,7 @@ try: | |||||
| from .action_recognition_pipeline import ActionRecognitionPipeline | from .action_recognition_pipeline import ActionRecognitionPipeline | ||||
| from .animal_recog_pipeline import AnimalRecogPipeline | from .animal_recog_pipeline import AnimalRecogPipeline | ||||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | ||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'torch'": | if str(e) == "No module named 'torch'": | ||||
| @@ -0,0 +1,77 @@ | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.super_resolution import rrdbnet_arch | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input | |||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_restoration, module_name=Pipelines.image_super_resolution) | |||||
| class ImageSuperResolutionPipeline(Pipeline): | |||||
| def __init__(self, model: str): | |||||
| """ | |||||
| use `model` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| self.num_feat = 64 | |||||
| self.num_block = 23 | |||||
| self.scale = 4 | |||||
| self.sr_model = rrdbnet_arch.RRDBNet( | |||||
| num_in_ch=3, | |||||
| num_out_ch=3, | |||||
| num_feat=self.num_feat, | |||||
| num_block=self.num_block, | |||||
| num_grow_ch=32, | |||||
| scale=self.scale) | |||||
| model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' | |||||
| self.sr_model.load_state_dict(torch.load(model_path), strict=True) | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| if isinstance(input, str): | |||||
| img = np.array(load_image(input)) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = np.array(input.convert('RGB')) | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] # in rgb order | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) / 255. | |||||
| result = {'img': img} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| self.sr_model.eval() | |||||
| with torch.no_grad(): | |||||
| out = self.sr_model(input['img']) | |||||
| out = out.squeeze(0).permute(1, 2, 0).flip(2) | |||||
| out_img = np.clip(out.float().cpu().numpy(), 0, 1) * 255 | |||||
| return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -27,6 +27,7 @@ class CVTasks(object): | |||||
| ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
| action_recognition = 'action-recognition' | action_recognition = 'action-recognition' | ||||
| video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
| image_restoration = 'image-restoration' | |||||
| style_transfer = 'style-transfer' | style_transfer = 'style-transfer' | ||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ImageSuperResolutionTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_rrdb_image-super-resolution' | |||||
| self.img = 'data/test/images/dogs.jpg' | |||||
| def pipeline_inference(self, pipeline: Pipeline, img: str): | |||||
| result = pipeline(img) | |||||
| if result is not None: | |||||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||||
| print(f'Output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| super_resolution = pipeline( | |||||
| Tasks.image_restoration, model=self.model_id) | |||||
| self.pipeline_inference(super_resolution, self.img) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||