Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9440706 * init * merge mastermaster
| @@ -49,6 +49,7 @@ class Pipelines(object): | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| image_super_resolution = 'rrdb-image-super-resolution' | |||
| face_image_generation = 'gan-face-image-generation' | |||
| style_transfer = 'AAMS-style-transfer' | |||
| @@ -0,0 +1,226 @@ | |||
| import collections.abc | |||
| import math | |||
| import warnings | |||
| from itertools import repeat | |||
| import torch | |||
| import torchvision | |||
| from torch import nn as nn | |||
| from torch.nn import functional as F | |||
| from torch.nn import init as init | |||
| from torch.nn.modules.batchnorm import _BatchNorm | |||
| @torch.no_grad() | |||
| def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): | |||
| """Initialize network weights. | |||
| Args: | |||
| module_list (list[nn.Module] | nn.Module): Modules to be initialized. | |||
| scale (float): Scale initialized weights, especially for residual | |||
| blocks. Default: 1. | |||
| bias_fill (float): The value to fill bias. Default: 0 | |||
| kwargs (dict): Other arguments for initialization function. | |||
| """ | |||
| if not isinstance(module_list, list): | |||
| module_list = [module_list] | |||
| for module in module_list: | |||
| for m in module.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| init.kaiming_normal_(m.weight, **kwargs) | |||
| m.weight.data *= scale | |||
| if m.bias is not None: | |||
| m.bias.data.fill_(bias_fill) | |||
| elif isinstance(m, nn.Linear): | |||
| init.kaiming_normal_(m.weight, **kwargs) | |||
| m.weight.data *= scale | |||
| if m.bias is not None: | |||
| m.bias.data.fill_(bias_fill) | |||
| elif isinstance(m, _BatchNorm): | |||
| init.constant_(m.weight, 1) | |||
| if m.bias is not None: | |||
| m.bias.data.fill_(bias_fill) | |||
| def make_layer(basic_block, num_basic_block, **kwarg): | |||
| """Make layers by stacking the same blocks. | |||
| Args: | |||
| basic_block (nn.module): nn.module class for basic block. | |||
| num_basic_block (int): number of blocks. | |||
| Returns: | |||
| nn.Sequential: Stacked blocks in nn.Sequential. | |||
| """ | |||
| layers = [] | |||
| for _ in range(num_basic_block): | |||
| layers.append(basic_block(**kwarg)) | |||
| return nn.Sequential(*layers) | |||
| class ResidualBlockNoBN(nn.Module): | |||
| """Residual block without BN. | |||
| It has a style of: | |||
| ---Conv-ReLU-Conv-+- | |||
| |________________| | |||
| Args: | |||
| num_feat (int): Channel number of intermediate features. | |||
| Default: 64. | |||
| res_scale (float): Residual scale. Default: 1. | |||
| pytorch_init (bool): If set to True, use pytorch default init, | |||
| otherwise, use default_init_weights. Default: False. | |||
| """ | |||
| def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): | |||
| super(ResidualBlockNoBN, self).__init__() | |||
| self.res_scale = res_scale | |||
| self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |||
| self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| if not pytorch_init: | |||
| default_init_weights([self.conv1, self.conv2], 0.1) | |||
| def forward(self, x): | |||
| identity = x | |||
| out = self.conv2(self.relu(self.conv1(x))) | |||
| return identity + out * self.res_scale | |||
| class Upsample(nn.Sequential): | |||
| """Upsample module. | |||
| Args: | |||
| scale (int): Scale factor. Supported scales: 2^n and 3. | |||
| num_feat (int): Channel number of intermediate features. | |||
| """ | |||
| def __init__(self, scale, num_feat): | |||
| m = [] | |||
| if (scale & (scale - 1)) == 0: # scale = 2^n | |||
| for _ in range(int(math.log(scale, 2))): | |||
| m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) | |||
| m.append(nn.PixelShuffle(2)) | |||
| elif scale == 3: | |||
| m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) | |||
| m.append(nn.PixelShuffle(3)) | |||
| else: | |||
| raise ValueError( | |||
| f'scale {scale} is not supported. Supported scales: 2^n and 3.' | |||
| ) | |||
| super(Upsample, self).__init__(*m) | |||
| def flow_warp(x, | |||
| flow, | |||
| interp_mode='bilinear', | |||
| padding_mode='zeros', | |||
| align_corners=True): | |||
| """Warp an image or feature map with optical flow. | |||
| Args: | |||
| x (Tensor): Tensor with size (n, c, h, w). | |||
| flow (Tensor): Tensor with size (n, h, w, 2), normal value. | |||
| interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. | |||
| padding_mode (str): 'zeros' or 'border' or 'reflection'. | |||
| Default: 'zeros'. | |||
| align_corners (bool): Before pytorch 1.3, the default value is | |||
| align_corners=True. After pytorch 1.3, the default value is | |||
| align_corners=False. Here, we use the True as default. | |||
| Returns: | |||
| Tensor: Warped image or feature map. | |||
| """ | |||
| assert x.size()[-2:] == flow.size()[1:3] | |||
| _, _, h, w = x.size() | |||
| # create mesh grid | |||
| grid_y, grid_x = torch.meshgrid( | |||
| torch.arange(0, h).type_as(x), | |||
| torch.arange(0, w).type_as(x)) | |||
| grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 | |||
| grid.requires_grad = False | |||
| vgrid = grid + flow | |||
| # scale grid to [-1,1] | |||
| vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 | |||
| vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 | |||
| vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) | |||
| output = F.grid_sample( | |||
| x, | |||
| vgrid_scaled, | |||
| mode=interp_mode, | |||
| padding_mode=padding_mode, | |||
| align_corners=align_corners) | |||
| # TODO, what if align_corners=False | |||
| return output | |||
| def resize_flow(flow, | |||
| size_type, | |||
| sizes, | |||
| interp_mode='bilinear', | |||
| align_corners=False): | |||
| """Resize a flow according to ratio or shape. | |||
| Args: | |||
| flow (Tensor): Precomputed flow. shape [N, 2, H, W]. | |||
| size_type (str): 'ratio' or 'shape'. | |||
| sizes (list[int | float]): the ratio for resizing or the final output | |||
| shape. | |||
| 1) The order of ratio should be [ratio_h, ratio_w]. For | |||
| downsampling, the ratio should be smaller than 1.0 (i.e., ratio | |||
| < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., | |||
| ratio > 1.0). | |||
| 2) The order of output_size should be [out_h, out_w]. | |||
| interp_mode (str): The mode of interpolation for resizing. | |||
| Default: 'bilinear'. | |||
| align_corners (bool): Whether align corners. Default: False. | |||
| Returns: | |||
| Tensor: Resized flow. | |||
| """ | |||
| _, _, flow_h, flow_w = flow.size() | |||
| if size_type == 'ratio': | |||
| output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) | |||
| elif size_type == 'shape': | |||
| output_h, output_w = sizes[0], sizes[1] | |||
| else: | |||
| raise ValueError( | |||
| f'Size type should be ratio or shape, but got type {size_type}.') | |||
| input_flow = flow.clone() | |||
| ratio_h = output_h / flow_h | |||
| ratio_w = output_w / flow_w | |||
| input_flow[:, 0, :, :] *= ratio_w | |||
| input_flow[:, 1, :, :] *= ratio_h | |||
| resized_flow = F.interpolate( | |||
| input=input_flow, | |||
| size=(output_h, output_w), | |||
| mode=interp_mode, | |||
| align_corners=align_corners) | |||
| return resized_flow | |||
| # TODO: may write a cpp file | |||
| def pixel_unshuffle(x, scale): | |||
| """ Pixel unshuffle. | |||
| Args: | |||
| x (Tensor): Input feature with shape (b, c, hh, hw). | |||
| scale (int): Downsample ratio. | |||
| Returns: | |||
| Tensor: the pixel unshuffled feature. | |||
| """ | |||
| b, c, hh, hw = x.size() | |||
| out_channel = c * (scale**2) | |||
| assert hh % scale == 0 and hw % scale == 0 | |||
| h = hh // scale | |||
| w = hw // scale | |||
| x_view = x.view(b, c, h, scale, w, scale) | |||
| return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) | |||
| @@ -0,0 +1,129 @@ | |||
| import torch | |||
| from torch import nn as nn | |||
| from torch.nn import functional as F | |||
| from .arch_util import default_init_weights, make_layer, pixel_unshuffle | |||
| class ResidualDenseBlock(nn.Module): | |||
| """Residual Dense Block. | |||
| Used in RRDB block in ESRGAN. | |||
| Args: | |||
| num_feat (int): Channel number of intermediate features. | |||
| num_grow_ch (int): Channels for each growth. | |||
| """ | |||
| def __init__(self, num_feat=64, num_grow_ch=32): | |||
| super(ResidualDenseBlock, self).__init__() | |||
| self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) | |||
| self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) | |||
| self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, | |||
| 1) | |||
| self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, | |||
| 1) | |||
| self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) | |||
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |||
| # initialization | |||
| default_init_weights( | |||
| [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) | |||
| def forward(self, x): | |||
| x1 = self.lrelu(self.conv1(x)) | |||
| x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) | |||
| x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) | |||
| x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) | |||
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |||
| # Emperically, we use 0.2 to scale the residual for better performance | |||
| return x5 * 0.2 + x | |||
| class RRDB(nn.Module): | |||
| """Residual in Residual Dense Block. | |||
| Used in RRDB-Net in ESRGAN. | |||
| Args: | |||
| num_feat (int): Channel number of intermediate features. | |||
| num_grow_ch (int): Channels for each growth. | |||
| """ | |||
| def __init__(self, num_feat, num_grow_ch=32): | |||
| super(RRDB, self).__init__() | |||
| self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) | |||
| self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) | |||
| self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) | |||
| def forward(self, x): | |||
| out = self.rdb1(x) | |||
| out = self.rdb2(out) | |||
| out = self.rdb3(out) | |||
| # Emperically, we use 0.2 to scale the residual for better performance | |||
| return out * 0.2 + x | |||
| class RRDBNet(nn.Module): | |||
| """Networks consisting of Residual in Residual Dense Block, which is used | |||
| in ESRGAN. | |||
| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. | |||
| We extend ESRGAN for scale x2 and scale x1. | |||
| Note: This is one option for scale 1, scale 2 in RRDBNet. | |||
| We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size | |||
| and enlarge the channel size before feeding inputs into the main ESRGAN architecture. | |||
| Args: | |||
| num_in_ch (int): Channel number of inputs. | |||
| num_out_ch (int): Channel number of outputs. | |||
| num_feat (int): Channel number of intermediate features. | |||
| Default: 64 | |||
| num_block (int): Block number in the trunk network. Defaults: 23 | |||
| num_grow_ch (int): Channels for each growth. Default: 32. | |||
| """ | |||
| def __init__(self, | |||
| num_in_ch, | |||
| num_out_ch, | |||
| scale=4, | |||
| num_feat=64, | |||
| num_block=23, | |||
| num_grow_ch=32): | |||
| super(RRDBNet, self).__init__() | |||
| self.scale = scale | |||
| if scale == 2: | |||
| num_in_ch = num_in_ch * 4 | |||
| elif scale == 1: | |||
| num_in_ch = num_in_ch * 16 | |||
| self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) | |||
| self.body = make_layer( | |||
| RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) | |||
| self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||
| # upsample | |||
| self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||
| self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||
| self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) | |||
| self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) | |||
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |||
| def forward(self, x): | |||
| if self.scale == 2: | |||
| feat = pixel_unshuffle(x, scale=2) | |||
| elif self.scale == 1: | |||
| feat = pixel_unshuffle(x, scale=4) | |||
| else: | |||
| feat = x | |||
| feat = self.conv_first(feat) | |||
| body_feat = self.conv_body(self.body(feat)) | |||
| feat = feat + body_feat | |||
| # upsample | |||
| feat = self.lrelu( | |||
| self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) | |||
| feat = self.lrelu( | |||
| self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) | |||
| out = self.conv_last(self.lrelu(self.conv_hr(feat))) | |||
| return out | |||
| @@ -70,6 +70,7 @@ TASK_OUTPUTS = { | |||
| Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_restoration: [OutputKeys.OUTPUT_IMG], | |||
| # action recognition result for single video | |||
| # { | |||
| @@ -6,6 +6,7 @@ try: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recog_pipeline import AnimalRecogPipeline | |||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'torch'": | |||
| @@ -0,0 +1,77 @@ | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.super_resolution import rrdbnet_arch | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_restoration, module_name=Pipelines.image_super_resolution) | |||
| class ImageSuperResolutionPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model) | |||
| self.num_feat = 64 | |||
| self.num_block = 23 | |||
| self.scale = 4 | |||
| self.sr_model = rrdbnet_arch.RRDBNet( | |||
| num_in_ch=3, | |||
| num_out_ch=3, | |||
| num_feat=self.num_feat, | |||
| num_block=self.num_block, | |||
| num_grow_ch=32, | |||
| scale=self.scale) | |||
| model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' | |||
| self.sr_model.load_state_dict(torch.load(model_path), strict=True) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) / 255. | |||
| result = {'img': img} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| self.sr_model.eval() | |||
| with torch.no_grad(): | |||
| out = self.sr_model(input['img']) | |||
| out = out.squeeze(0).permute(1, 2, 0).flip(2) | |||
| out_img = np.clip(out.float().cpu().numpy(), 0, 1) * 255 | |||
| return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -27,6 +27,7 @@ class CVTasks(object): | |||
| ocr_detection = 'ocr-detection' | |||
| action_recognition = 'action-recognition' | |||
| video_embedding = 'video-embedding' | |||
| image_restoration = 'image-restoration' | |||
| style_transfer = 'style-transfer' | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageSuperResolutionTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_rrdb_image-super-resolution' | |||
| self.img = 'data/test/images/dogs.jpg' | |||
| def pipeline_inference(self, pipeline: Pipeline, img: str): | |||
| result = pipeline(img) | |||
| if result is not None: | |||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| super_resolution = pipeline( | |||
| Tasks.image_restoration, model=self.model_id) | |||
| self.pipeline_inference(super_resolution, self.img) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||