添加image-color-enhance, pipeline and trainer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9483118
master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b | |||
| size 950676 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b | |||
| size 950676 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:0a4a8a60501976b2c5e753814a346519ef6faff052b53359cf44b4e597e62aaf | |||
| size 902214 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:326c5e3907926a4af6fec382050026d505d78aab8c5f2e0ecc85ac863abbb94c | |||
| size 856195 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:455f364c008be76a392085e7590b9050a628853a9df1e608a40c75a15bc41c5f | |||
| size 951993 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:f806b26557317f856e7583fb128713579df3354016b368ef32791b283e3be051 | |||
| size 932493 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:0ec66811ec4f1ec8735b7f0eb897100f80939ba5dc150028fa91bfcd15b5164c | |||
| size 896481 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:d9517b185b0cffc0c830270fd52551e145054daa00c704ed4132589b24ab46e9 | |||
| size 828266 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:5a233195949ed1c3db9c9a182baf3d8f014620d28bab823aa4d4cc203e602bc6 | |||
| size 927552 | |||
| @@ -10,6 +10,7 @@ class Models(object): | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| # vision models | |||
| csrnet = 'csrnet' | |||
| # nlp models | |||
| bert = 'bert' | |||
| @@ -60,6 +61,7 @@ class Pipelines(object): | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| image_color_enhance = 'csrnet-image-color-enhance' | |||
| virtual_tryon = 'virtual_tryon' | |||
| image_colorization = 'unet-image-colorization' | |||
| image_super_resolution = 'rrdb-image-super-resolution' | |||
| @@ -121,6 +123,7 @@ class Preprocessors(object): | |||
| # cv preprocessor | |||
| load_image = 'load-image' | |||
| image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | |||
| # nlp preprocessor | |||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | |||
| @@ -160,3 +163,5 @@ class Metrics(object): | |||
| token_cls_metric = 'token-cls-metric' | |||
| # metrics for text-generation task | |||
| text_gen_metric = 'text-gen-metric' | |||
| # metrics for image-color-enhance task | |||
| image_color_enhance_metric = 'image-color-enhance-metric' | |||
| @@ -1,4 +1,5 @@ | |||
| from .base import Metric | |||
| from .builder import METRICS, build_metric, task_default_metrics | |||
| from .image_color_enhance_metric import ImageColorEnhanceMetric | |||
| from .sequence_classification_metric import SequenceClassificationMetric | |||
| from .text_generation_metric import TextGenerationMetric | |||
| @@ -13,12 +13,15 @@ class MetricKeys(object): | |||
| F1 = 'f1' | |||
| PRECISION = 'precision' | |||
| RECALL = 'recall' | |||
| PSNR = 'psnr' | |||
| SSIM = 'ssim' | |||
| task_default_metrics = { | |||
| Tasks.sentence_similarity: [Metrics.seq_cls_metric], | |||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | |||
| Tasks.text_generation: [Metrics.text_gen_metric], | |||
| Tasks.image_color_enhance: [Metrics.image_color_enhance_metric] | |||
| } | |||
| @@ -0,0 +1,258 @@ | |||
| # The code is modified based on BasicSR metrics: | |||
| # https://github.com/XPixelGroup/BasicSR/tree/master/basicsr/metrics | |||
| from typing import Dict | |||
| import cv2 | |||
| import numpy as np | |||
| from ..metainfo import Metrics | |||
| from ..utils.registry import default_group | |||
| from .base import Metric | |||
| from .builder import METRICS, MetricKeys | |||
| def bgr2ycbcr(img, y_only=False): | |||
| """Convert a BGR image to YCbCr image. | |||
| The bgr version of rgb2ycbcr. | |||
| It implements the ITU-R BT.601 conversion for standard-definition | |||
| television. See more details in | |||
| https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. | |||
| It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. | |||
| In OpenCV, it implements a JPEG conversion. See more details in | |||
| https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. | |||
| Args: | |||
| img (ndarray): The input image. It accepts: | |||
| 1. np.uint8 type with range [0, 255]; | |||
| 2. np.float32 type with range [0, 1]. | |||
| y_only (bool): Whether to only return Y channel. Default: False. | |||
| Returns: | |||
| ndarray: The converted YCbCr image. The output image has the same type | |||
| and range as input image. | |||
| """ | |||
| img_type = img.dtype | |||
| img = _convert_input_type_range(img) | |||
| if y_only: | |||
| out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 | |||
| else: | |||
| out_img = np.matmul( | |||
| img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], | |||
| [65.481, -37.797, 112.0]]) + [16, 128, 128] | |||
| out_img = _convert_output_type_range(out_img, img_type) | |||
| return out_img | |||
| def reorder_image(img, input_order='HWC'): | |||
| """Reorder images to 'HWC' order. | |||
| If the input_order is (h, w), return (h, w, 1); | |||
| If the input_order is (c, h, w), return (h, w, c); | |||
| If the input_order is (h, w, c), return as it is. | |||
| Args: | |||
| img (ndarray): Input image. | |||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
| If the input image shape is (h, w), input_order will not have | |||
| effects. Default: 'HWC'. | |||
| Returns: | |||
| ndarray: reordered image. | |||
| """ | |||
| if input_order not in ['HWC', 'CHW']: | |||
| raise ValueError( | |||
| f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" | |||
| ) | |||
| if len(img.shape) == 2: | |||
| img = img[..., None] | |||
| if input_order == 'CHW': | |||
| img = img.transpose(1, 2, 0) | |||
| return img | |||
| def to_y_channel(img): | |||
| """Change to Y channel of YCbCr. | |||
| Args: | |||
| img (ndarray): Images with range [0, 255]. | |||
| Returns: | |||
| (ndarray): Images with range [0, 255] (float type) without round. | |||
| """ | |||
| img = img.astype(np.float32) / 255. | |||
| if img.ndim == 3 and img.shape[2] == 3: | |||
| img = bgr2ycbcr(img, y_only=True) | |||
| img = img[..., None] | |||
| return img * 255. | |||
| def _ssim(img, img2): | |||
| """Calculate SSIM (structural similarity) for one channel images. | |||
| It is called by func:`calculate_ssim`. | |||
| Args: | |||
| img (ndarray): Images with range [0, 255] with order 'HWC'. | |||
| img2 (ndarray): Images with range [0, 255] with order 'HWC'. | |||
| Returns: | |||
| float: SSIM result. | |||
| """ | |||
| c1 = (0.01 * 255)**2 | |||
| c2 = (0.03 * 255)**2 | |||
| kernel = cv2.getGaussianKernel(11, 1.5) | |||
| window = np.outer(kernel, kernel.transpose()) | |||
| mu1 = cv2.filter2D(img, -1, window)[5:-5, | |||
| 5:-5] # valid mode for window size 11 | |||
| mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] | |||
| mu1_sq = mu1**2 | |||
| mu2_sq = mu2**2 | |||
| mu1_mu2 = mu1 * mu2 | |||
| sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq | |||
| sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq | |||
| sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 | |||
| tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) | |||
| tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) | |||
| ssim_map = tmp1 / tmp2 | |||
| return ssim_map.mean() | |||
| def calculate_psnr(img, | |||
| img2, | |||
| crop_border, | |||
| input_order='HWC', | |||
| test_y_channel=False, | |||
| **kwargs): | |||
| """Calculate PSNR (Peak Signal-to-Noise Ratio). | |||
| Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio | |||
| Args: | |||
| img (ndarray): Images with range [0, 255]. | |||
| img2 (ndarray): Images with range [0, 255]. | |||
| crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. | |||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. | |||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||
| Returns: | |||
| float: PSNR result. | |||
| """ | |||
| assert img.shape == img2.shape, ( | |||
| f'Image shapes are different: {img.shape}, {img2.shape}.') | |||
| if input_order not in ['HWC', 'CHW']: | |||
| raise ValueError( | |||
| f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' | |||
| ) | |||
| img = reorder_image(img, input_order=input_order) | |||
| img2 = reorder_image(img2, input_order=input_order) | |||
| if crop_border != 0: | |||
| img = img[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| if test_y_channel: | |||
| img = to_y_channel(img) | |||
| img2 = to_y_channel(img2) | |||
| img = img.astype(np.float64) | |||
| img2 = img2.astype(np.float64) | |||
| mse = np.mean((img - img2)**2) | |||
| if mse == 0: | |||
| return float('inf') | |||
| return 10. * np.log10(255. * 255. / mse) | |||
| def calculate_ssim(img, | |||
| img2, | |||
| crop_border, | |||
| input_order='HWC', | |||
| test_y_channel=False, | |||
| **kwargs): | |||
| """Calculate SSIM (structural similarity). | |||
| Ref: | |||
| Image quality assessment: From error visibility to structural similarity | |||
| The results are the same as that of the official released MATLAB code in | |||
| https://ece.uwaterloo.ca/~z70wang/research/ssim/. | |||
| For three-channel images, SSIM is calculated for each channel and then | |||
| averaged. | |||
| Args: | |||
| img (ndarray): Images with range [0, 255]. | |||
| img2 (ndarray): Images with range [0, 255]. | |||
| crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. | |||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||
| Default: 'HWC'. | |||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||
| Returns: | |||
| float: SSIM result. | |||
| """ | |||
| assert img.shape == img2.shape, ( | |||
| f'Image shapes are different: {img.shape}, {img2.shape}.') | |||
| if input_order not in ['HWC', 'CHW']: | |||
| raise ValueError( | |||
| f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' | |||
| ) | |||
| img = reorder_image(img, input_order=input_order) | |||
| img2 = reorder_image(img2, input_order=input_order) | |||
| if crop_border != 0: | |||
| img = img[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||
| if test_y_channel: | |||
| img = to_y_channel(img) | |||
| img2 = to_y_channel(img2) | |||
| img = img.astype(np.float64) | |||
| img2 = img2.astype(np.float64) | |||
| ssims = [] | |||
| for i in range(img.shape[2]): | |||
| ssims.append(_ssim(img[..., i], img2[..., i])) | |||
| return np.array(ssims).mean() | |||
| @METRICS.register_module( | |||
| group_key=default_group, module_name=Metrics.image_color_enhance_metric) | |||
| class ImageColorEnhanceMetric(Metric): | |||
| """The metric computation class for image color enhance classes. | |||
| """ | |||
| def __init__(self): | |||
| self.preds = [] | |||
| self.targets = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| ground_truths = outputs['target'] | |||
| eval_results = outputs['pred'] | |||
| self.preds.extend(eval_results) | |||
| self.targets.extend(ground_truths) | |||
| def evaluate(self): | |||
| psnrs = [ | |||
| calculate_psnr(pred, target, 2, test_y_channel=False) | |||
| for pred, target in zip(self.preds, self.targets) | |||
| ] | |||
| ssims = [ | |||
| calculate_ssim(pred, target, 2, test_y_channel=False) | |||
| for pred, target in zip(self.preds, self.targets) | |||
| ] | |||
| return { | |||
| MetricKeys.PSNR: sum(psnrs) / len(psnrs), | |||
| MetricKeys.SSIM: sum(ssims) / len(ssims) | |||
| } | |||
| @@ -0,0 +1,2 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .image_color_enhance.image_color_enhance import ImageColorEnhance | |||
| @@ -0,0 +1,110 @@ | |||
| import functools | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Condition(nn.Module): | |||
| def __init__(self, in_nc=3, nf=32): | |||
| super(Condition, self).__init__() | |||
| stride = 2 | |||
| pad = 0 | |||
| self.pad = nn.ZeroPad2d(1) | |||
| self.conv1 = nn.Conv2d(in_nc, nf, 7, stride, pad, bias=True) | |||
| self.conv2 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) | |||
| self.conv3 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) | |||
| self.act = nn.ReLU(inplace=True) | |||
| def forward(self, x): | |||
| conv1_out = self.act(self.conv1(self.pad(x))) | |||
| conv2_out = self.act(self.conv2(self.pad(conv1_out))) | |||
| conv3_out = self.act(self.conv3(self.pad(conv2_out))) | |||
| out = torch.mean(conv3_out, dim=[2, 3], keepdim=False) | |||
| return out | |||
| # 3layers with control | |||
| class CSRNet(nn.Module): | |||
| def __init__(self, in_nc=3, out_nc=3, base_nf=64, cond_nf=32): | |||
| super(CSRNet, self).__init__() | |||
| self.base_nf = base_nf | |||
| self.out_nc = out_nc | |||
| self.cond_net = Condition(in_nc=in_nc, nf=cond_nf) | |||
| self.cond_scale1 = nn.Linear(cond_nf, base_nf, bias=True) | |||
| self.cond_scale2 = nn.Linear(cond_nf, base_nf, bias=True) | |||
| self.cond_scale3 = nn.Linear(cond_nf, 3, bias=True) | |||
| self.cond_shift1 = nn.Linear(cond_nf, base_nf, bias=True) | |||
| self.cond_shift2 = nn.Linear(cond_nf, base_nf, bias=True) | |||
| self.cond_shift3 = nn.Linear(cond_nf, 3, bias=True) | |||
| self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True) | |||
| self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True) | |||
| self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True) | |||
| self.act = nn.ReLU(inplace=True) | |||
| def forward(self, x): | |||
| cond = self.cond_net(x) | |||
| scale1 = self.cond_scale1(cond) | |||
| shift1 = self.cond_shift1(cond) | |||
| scale2 = self.cond_scale2(cond) | |||
| shift2 = self.cond_shift2(cond) | |||
| scale3 = self.cond_scale3(cond) | |||
| shift3 = self.cond_shift3(cond) | |||
| out = self.conv1(x) | |||
| out = out * scale1.view(-1, self.base_nf, 1, 1) + shift1.view( | |||
| -1, self.base_nf, 1, 1) + out | |||
| out = self.act(out) | |||
| out = self.conv2(out) | |||
| out = out * scale2.view(-1, self.base_nf, 1, 1) + shift2.view( | |||
| -1, self.base_nf, 1, 1) + out | |||
| out = self.act(out) | |||
| out = self.conv3(out) | |||
| out = out * scale3.view(-1, self.out_nc, 1, 1) + shift3.view( | |||
| -1, self.out_nc, 1, 1) + out | |||
| return out | |||
| class L1Loss(nn.Module): | |||
| """L1 (mean absolute error, MAE) loss. | |||
| Args: | |||
| loss_weight (float): Loss weight for L1 loss. Default: 1.0. | |||
| reduction (str): Specifies the reduction to apply to the output. | |||
| Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. | |||
| """ | |||
| def __init__(self, loss_weight=1.0, reduction='mean'): | |||
| super(L1Loss, self).__init__() | |||
| if reduction not in ['none', 'mean', 'sum']: | |||
| raise ValueError( | |||
| f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' | |||
| ) | |||
| self.loss_weight = loss_weight | |||
| self.reduction = reduction | |||
| def forward(self, pred, target, weight=None, **kwargs): | |||
| """ | |||
| Args: | |||
| pred (Tensor): of shape (N, C, H, W). Predicted tensor. | |||
| target (Tensor): of shape (N, C, H, W). Ground truth tensor. | |||
| weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. | |||
| """ | |||
| return self.loss_weight * F.l1_loss( | |||
| pred, target, reduction=self.reduction) | |||
| @@ -0,0 +1,109 @@ | |||
| import os.path as osp | |||
| from copy import deepcopy | |||
| from typing import Dict, Union | |||
| import torch | |||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .csrnet import CSRNet, L1Loss | |||
| logger = get_logger() | |||
| __all__ = ['ImageColorEnhance'] | |||
| @MODELS.register_module(Tasks.image_color_enhance, module_name=Models.csrnet) | |||
| class ImageColorEnhance(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the image color enhance model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| self.loss = L1Loss() | |||
| self.model = CSRNet() | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| self.model = self.model.to(self._device) | |||
| self.model = self.load_pretrained(self.model, model_path) | |||
| if self.training: | |||
| self.model.train() | |||
| else: | |||
| self.model.eval() | |||
| def load_pretrained(self, net, load_path, strict=True, param_key='params'): | |||
| if isinstance(net, (DataParallel, DistributedDataParallel)): | |||
| net = net.module | |||
| load_net = torch.load( | |||
| load_path, map_location=lambda storage, loc: storage) | |||
| if param_key is not None: | |||
| if param_key not in load_net and 'params' in load_net: | |||
| param_key = 'params' | |||
| logger.info( | |||
| f'Loading: {param_key} does not exist, use params.') | |||
| if param_key in load_net: | |||
| load_net = load_net[param_key] | |||
| logger.info( | |||
| f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' | |||
| ) | |||
| # remove unnecessary 'module.' | |||
| for k, v in deepcopy(load_net).items(): | |||
| if k.startswith('module.'): | |||
| load_net[k[7:]] = v | |||
| load_net.pop(k) | |||
| net.load_state_dict(load_net, strict=strict) | |||
| logger.info('load model done.') | |||
| return net | |||
| def _evaluate_postprocess(self, src: Tensor, | |||
| target: Tensor) -> Dict[str, list]: | |||
| preds = self.model(src) | |||
| preds = list(torch.split(preds, 1, 0)) | |||
| targets = list(torch.split(target, 1, 0)) | |||
| preds = [(pred.data * 255.).squeeze(0).type(torch.uint8).permute( | |||
| 1, 2, 0).cpu().numpy() for pred in preds] | |||
| targets = [(target.data * 255.).squeeze(0).type(torch.uint8).permute( | |||
| 1, 2, 0).cpu().numpy() for target in targets] | |||
| return {'pred': preds, 'target': targets} | |||
| def _train_forward(self, src: Tensor, target: Tensor) -> Dict[str, Tensor]: | |||
| preds = self.model(src) | |||
| return {'loss': self.loss(preds, target)} | |||
| def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: | |||
| return {'outputs': self.model(src).clamp(0, 1)} | |||
| def forward(self, input: Dict[str, | |||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Union[list, Tensor]]: results | |||
| """ | |||
| for key, value in input.items(): | |||
| input[key] = input[key].to(self._device) | |||
| if self.training: | |||
| return self._train_forward(**input) | |||
| elif 'target' in input: | |||
| return self._evaluate_postprocess(**input) | |||
| else: | |||
| return self._inference_forward(**input) | |||
| @@ -105,6 +105,12 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], | |||
| # image_color_enhance result for a single sample | |||
| # { | |||
| # "output_img": np.ndarray with shape [height, width, 3], uint8 | |||
| # } | |||
| Tasks.image_color_enhance: [OutputKeys.OUTPUT_IMG], | |||
| # ============ nlp tasks =================== | |||
| # text classification result for single sample | |||
| @@ -70,6 +70,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.text_to_image_synthesis: | |||
| (Pipelines.text_to_image_synthesis, | |||
| 'damo/cv_imagen_text-to-image-synthesis_tiny'), | |||
| Tasks.image_color_enhance: (Pipelines.image_color_enhance, | |||
| 'damo/cv_csrnet_image-color-enhance-models'), | |||
| Tasks.virtual_tryon: (Pipelines.virtual_tryon, | |||
| 'damo/cv_daflow_virtual-tryon_base'), | |||
| Tasks.image_colorization: (Pipelines.image_colorization, | |||
| @@ -6,6 +6,7 @@ try: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recog_pipeline import AnimalRecogPipeline | |||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||
| from .virtual_tryon_pipeline import VirtualTryonPipeline | |||
| from .image_colorization_pipeline import ImageColorizationPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| @@ -0,0 +1,74 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.cv.image_color_enhance.image_color_enhance import \ | |||
| ImageColorEnhance | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, | |||
| load_image) | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_color_enhance, module_name=Pipelines.image_color_enhance) | |||
| class ImageColorEnhancePipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[ImageColorEnhance, str], | |||
| preprocessor: Optional[ | |||
| ImageColorEnhanceFinetunePreprocessor] = None, | |||
| **kwargs): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| model = model if isinstance( | |||
| model, ImageColorEnhance) else Model.from_pretrained(model) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = load_image(input) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = input.convert('RGB') | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| test_transforms = transforms.Compose([transforms.ToTensor()]) | |||
| img = test_transforms(img) | |||
| result = {'src': img.unsqueeze(0).to(self._device)} | |||
| return result | |||
| @torch.no_grad() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| return super().forward(input) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| output_img = (inputs['outputs'].squeeze(0) * 255.).type( | |||
| torch.uint8).cpu().permute(1, 2, 0).numpy() | |||
| return {OutputKeys.OUTPUT_IMG: output_img} | |||
| @@ -19,6 +19,7 @@ try: | |||
| from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | |||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | |||
| from .space.dialog_state_tracking_preprocessor import * # noqa F403 | |||
| from .image import ImageColorEnhanceFinetunePreprocessor | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'tensorflow'": | |||
| print(TENSORFLOW_IMPORT_ERROR.format('tts')) | |||
| @@ -1,12 +1,15 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import io | |||
| from typing import Dict, Union | |||
| from typing import Any, Dict, Union | |||
| import torch | |||
| from PIL import Image, ImageOps | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| @@ -66,3 +69,36 @@ def load_image(image_path_or_url: str) -> Image.Image: | |||
| """ | |||
| loader = LoadImage() | |||
| return loader(image_path_or_url)['img'] | |||
| @PREPROCESSORS.register_module( | |||
| Fields.cv, module_name=Preprocessors.image_color_enhance_preprocessor) | |||
| class ImageColorEnhanceFinetunePreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| self.model_dir: str = model_dir | |||
| @type_assert(object, object) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (tuple): [sentence1, sentence2] | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| sentence2 (str): a sentence | |||
| Example: | |||
| 'you are so beautiful.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| return data | |||
| @@ -27,6 +27,7 @@ class CVTasks(object): | |||
| ocr_detection = 'ocr-detection' | |||
| action_recognition = 'action-recognition' | |||
| video_embedding = 'video-embedding' | |||
| image_color_enhance = 'image-color-enhance' | |||
| virtual_tryon = 'virtual-tryon' | |||
| image_colorization = 'image-colorization' | |||
| face_image_generation = 'face-image-generation' | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageColorEnhanceTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_csrnet_image-color-enhance-models' | |||
| def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||
| result = pipeline(input_location) | |||
| if result is not None: | |||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG][:, :, | |||
| [2, 1, 0]]) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| img_color_enhance = pipeline( | |||
| Tasks.image_color_enhance, model=self.model_id) | |||
| self.pipeline_inference(img_color_enhance, | |||
| 'data/test/images/image_color_enhance.png') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| img_color_enhance = pipeline(Tasks.image_color_enhance) | |||
| self.pipeline_inference(img_color_enhance, | |||
| 'data/test/images/image_color_enhance.png') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -0,0 +1,115 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from typing import Callable, List, Optional, Tuple, Union | |||
| import cv2 | |||
| import torch | |||
| from torch.utils import data as data | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.cv.image_color_enhance.image_color_enhance import \ | |||
| ImageColorEnhance | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestImageColorEnhanceTrainer(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| self.model_id = 'damo/cv_csrnet_image-color-enhance-models' | |||
| class PairedImageDataset(data.Dataset): | |||
| def __init__(self, root): | |||
| super(PairedImageDataset, self).__init__() | |||
| gt_dir = osp.join(root, 'gt') | |||
| lq_dir = osp.join(root, 'lq') | |||
| self.gt_filelist = os.listdir(gt_dir) | |||
| self.gt_filelist = sorted( | |||
| self.gt_filelist, key=lambda x: int(x[:-4])) | |||
| self.gt_filelist = [ | |||
| osp.join(gt_dir, f) for f in self.gt_filelist | |||
| ] | |||
| self.lq_filelist = os.listdir(lq_dir) | |||
| self.lq_filelist = sorted( | |||
| self.lq_filelist, key=lambda x: int(x[:-4])) | |||
| self.lq_filelist = [ | |||
| osp.join(lq_dir, f) for f in self.lq_filelist | |||
| ] | |||
| def _img_to_tensor(self, img): | |||
| return torch.from_numpy(img[:, :, [2, 1, 0]]).permute( | |||
| 2, 0, 1).type(torch.float32) / 255. | |||
| def __getitem__(self, index): | |||
| lq = cv2.imread(self.lq_filelist[index]) | |||
| gt = cv2.imread(self.gt_filelist[index]) | |||
| lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) | |||
| gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) | |||
| return \ | |||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||
| def __len__(self): | |||
| return len(self.gt_filelist) | |||
| def to_torch_dataset(self, | |||
| columns: Union[str, List[str]] = None, | |||
| preprocessors: Union[Callable, | |||
| List[Callable]] = None, | |||
| **format_kwargs): | |||
| return self | |||
| self.dataset = PairedImageDataset( | |||
| './data/test/images/image_color_enhance/') | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(3): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| model = ImageColorEnhance.from_pretrained(cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.dataset, | |||
| eval_dataset=self.dataset, | |||
| max_epochs=2, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(2): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||