添加image-color-enhance, pipeline and trainer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9483118
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b | |||||
| size 950676 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b | |||||
| size 950676 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:0a4a8a60501976b2c5e753814a346519ef6faff052b53359cf44b4e597e62aaf | |||||
| size 902214 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:326c5e3907926a4af6fec382050026d505d78aab8c5f2e0ecc85ac863abbb94c | |||||
| size 856195 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:455f364c008be76a392085e7590b9050a628853a9df1e608a40c75a15bc41c5f | |||||
| size 951993 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:f806b26557317f856e7583fb128713579df3354016b368ef32791b283e3be051 | |||||
| size 932493 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:0ec66811ec4f1ec8735b7f0eb897100f80939ba5dc150028fa91bfcd15b5164c | |||||
| size 896481 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:d9517b185b0cffc0c830270fd52551e145054daa00c704ed4132589b24ab46e9 | |||||
| size 828266 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:5a233195949ed1c3db9c9a182baf3d8f014620d28bab823aa4d4cc203e602bc6 | |||||
| size 927552 | |||||
| @@ -10,6 +10,7 @@ class Models(object): | |||||
| Model name should only contain model info but not task info. | Model name should only contain model info but not task info. | ||||
| """ | """ | ||||
| # vision models | # vision models | ||||
| csrnet = 'csrnet' | |||||
| # nlp models | # nlp models | ||||
| bert = 'bert' | bert = 'bert' | ||||
| @@ -60,6 +61,7 @@ class Pipelines(object): | |||||
| action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
| animal_recognation = 'resnet101-animal_recog' | animal_recognation = 'resnet101-animal_recog' | ||||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | ||||
| image_color_enhance = 'csrnet-image-color-enhance' | |||||
| virtual_tryon = 'virtual_tryon' | virtual_tryon = 'virtual_tryon' | ||||
| image_colorization = 'unet-image-colorization' | image_colorization = 'unet-image-colorization' | ||||
| image_super_resolution = 'rrdb-image-super-resolution' | image_super_resolution = 'rrdb-image-super-resolution' | ||||
| @@ -121,6 +123,7 @@ class Preprocessors(object): | |||||
| # cv preprocessor | # cv preprocessor | ||||
| load_image = 'load-image' | load_image = 'load-image' | ||||
| image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | |||||
| # nlp preprocessor | # nlp preprocessor | ||||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | sen_sim_tokenizer = 'sen-sim-tokenizer' | ||||
| @@ -160,3 +163,5 @@ class Metrics(object): | |||||
| token_cls_metric = 'token-cls-metric' | token_cls_metric = 'token-cls-metric' | ||||
| # metrics for text-generation task | # metrics for text-generation task | ||||
| text_gen_metric = 'text-gen-metric' | text_gen_metric = 'text-gen-metric' | ||||
| # metrics for image-color-enhance task | |||||
| image_color_enhance_metric = 'image-color-enhance-metric' | |||||
| @@ -1,4 +1,5 @@ | |||||
| from .base import Metric | from .base import Metric | ||||
| from .builder import METRICS, build_metric, task_default_metrics | from .builder import METRICS, build_metric, task_default_metrics | ||||
| from .image_color_enhance_metric import ImageColorEnhanceMetric | |||||
| from .sequence_classification_metric import SequenceClassificationMetric | from .sequence_classification_metric import SequenceClassificationMetric | ||||
| from .text_generation_metric import TextGenerationMetric | from .text_generation_metric import TextGenerationMetric | ||||
| @@ -13,12 +13,15 @@ class MetricKeys(object): | |||||
| F1 = 'f1' | F1 = 'f1' | ||||
| PRECISION = 'precision' | PRECISION = 'precision' | ||||
| RECALL = 'recall' | RECALL = 'recall' | ||||
| PSNR = 'psnr' | |||||
| SSIM = 'ssim' | |||||
| task_default_metrics = { | task_default_metrics = { | ||||
| Tasks.sentence_similarity: [Metrics.seq_cls_metric], | Tasks.sentence_similarity: [Metrics.seq_cls_metric], | ||||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
| Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
| Tasks.image_color_enhance: [Metrics.image_color_enhance_metric] | |||||
| } | } | ||||
| @@ -0,0 +1,258 @@ | |||||
| # The code is modified based on BasicSR metrics: | |||||
| # https://github.com/XPixelGroup/BasicSR/tree/master/basicsr/metrics | |||||
| from typing import Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from ..metainfo import Metrics | |||||
| from ..utils.registry import default_group | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| def bgr2ycbcr(img, y_only=False): | |||||
| """Convert a BGR image to YCbCr image. | |||||
| The bgr version of rgb2ycbcr. | |||||
| It implements the ITU-R BT.601 conversion for standard-definition | |||||
| television. See more details in | |||||
| https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. | |||||
| It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. | |||||
| In OpenCV, it implements a JPEG conversion. See more details in | |||||
| https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. | |||||
| Args: | |||||
| img (ndarray): The input image. It accepts: | |||||
| 1. np.uint8 type with range [0, 255]; | |||||
| 2. np.float32 type with range [0, 1]. | |||||
| y_only (bool): Whether to only return Y channel. Default: False. | |||||
| Returns: | |||||
| ndarray: The converted YCbCr image. The output image has the same type | |||||
| and range as input image. | |||||
| """ | |||||
| img_type = img.dtype | |||||
| img = _convert_input_type_range(img) | |||||
| if y_only: | |||||
| out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 | |||||
| else: | |||||
| out_img = np.matmul( | |||||
| img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], | |||||
| [65.481, -37.797, 112.0]]) + [16, 128, 128] | |||||
| out_img = _convert_output_type_range(out_img, img_type) | |||||
| return out_img | |||||
| def reorder_image(img, input_order='HWC'): | |||||
| """Reorder images to 'HWC' order. | |||||
| If the input_order is (h, w), return (h, w, 1); | |||||
| If the input_order is (c, h, w), return (h, w, c); | |||||
| If the input_order is (h, w, c), return as it is. | |||||
| Args: | |||||
| img (ndarray): Input image. | |||||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||||
| If the input image shape is (h, w), input_order will not have | |||||
| effects. Default: 'HWC'. | |||||
| Returns: | |||||
| ndarray: reordered image. | |||||
| """ | |||||
| if input_order not in ['HWC', 'CHW']: | |||||
| raise ValueError( | |||||
| f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" | |||||
| ) | |||||
| if len(img.shape) == 2: | |||||
| img = img[..., None] | |||||
| if input_order == 'CHW': | |||||
| img = img.transpose(1, 2, 0) | |||||
| return img | |||||
| def to_y_channel(img): | |||||
| """Change to Y channel of YCbCr. | |||||
| Args: | |||||
| img (ndarray): Images with range [0, 255]. | |||||
| Returns: | |||||
| (ndarray): Images with range [0, 255] (float type) without round. | |||||
| """ | |||||
| img = img.astype(np.float32) / 255. | |||||
| if img.ndim == 3 and img.shape[2] == 3: | |||||
| img = bgr2ycbcr(img, y_only=True) | |||||
| img = img[..., None] | |||||
| return img * 255. | |||||
| def _ssim(img, img2): | |||||
| """Calculate SSIM (structural similarity) for one channel images. | |||||
| It is called by func:`calculate_ssim`. | |||||
| Args: | |||||
| img (ndarray): Images with range [0, 255] with order 'HWC'. | |||||
| img2 (ndarray): Images with range [0, 255] with order 'HWC'. | |||||
| Returns: | |||||
| float: SSIM result. | |||||
| """ | |||||
| c1 = (0.01 * 255)**2 | |||||
| c2 = (0.03 * 255)**2 | |||||
| kernel = cv2.getGaussianKernel(11, 1.5) | |||||
| window = np.outer(kernel, kernel.transpose()) | |||||
| mu1 = cv2.filter2D(img, -1, window)[5:-5, | |||||
| 5:-5] # valid mode for window size 11 | |||||
| mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] | |||||
| mu1_sq = mu1**2 | |||||
| mu2_sq = mu2**2 | |||||
| mu1_mu2 = mu1 * mu2 | |||||
| sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq | |||||
| sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq | |||||
| sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 | |||||
| tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) | |||||
| tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) | |||||
| ssim_map = tmp1 / tmp2 | |||||
| return ssim_map.mean() | |||||
| def calculate_psnr(img, | |||||
| img2, | |||||
| crop_border, | |||||
| input_order='HWC', | |||||
| test_y_channel=False, | |||||
| **kwargs): | |||||
| """Calculate PSNR (Peak Signal-to-Noise Ratio). | |||||
| Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio | |||||
| Args: | |||||
| img (ndarray): Images with range [0, 255]. | |||||
| img2 (ndarray): Images with range [0, 255]. | |||||
| crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. | |||||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. | |||||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||||
| Returns: | |||||
| float: PSNR result. | |||||
| """ | |||||
| assert img.shape == img2.shape, ( | |||||
| f'Image shapes are different: {img.shape}, {img2.shape}.') | |||||
| if input_order not in ['HWC', 'CHW']: | |||||
| raise ValueError( | |||||
| f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' | |||||
| ) | |||||
| img = reorder_image(img, input_order=input_order) | |||||
| img2 = reorder_image(img2, input_order=input_order) | |||||
| if crop_border != 0: | |||||
| img = img[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| if test_y_channel: | |||||
| img = to_y_channel(img) | |||||
| img2 = to_y_channel(img2) | |||||
| img = img.astype(np.float64) | |||||
| img2 = img2.astype(np.float64) | |||||
| mse = np.mean((img - img2)**2) | |||||
| if mse == 0: | |||||
| return float('inf') | |||||
| return 10. * np.log10(255. * 255. / mse) | |||||
| def calculate_ssim(img, | |||||
| img2, | |||||
| crop_border, | |||||
| input_order='HWC', | |||||
| test_y_channel=False, | |||||
| **kwargs): | |||||
| """Calculate SSIM (structural similarity). | |||||
| Ref: | |||||
| Image quality assessment: From error visibility to structural similarity | |||||
| The results are the same as that of the official released MATLAB code in | |||||
| https://ece.uwaterloo.ca/~z70wang/research/ssim/. | |||||
| For three-channel images, SSIM is calculated for each channel and then | |||||
| averaged. | |||||
| Args: | |||||
| img (ndarray): Images with range [0, 255]. | |||||
| img2 (ndarray): Images with range [0, 255]. | |||||
| crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. | |||||
| input_order (str): Whether the input order is 'HWC' or 'CHW'. | |||||
| Default: 'HWC'. | |||||
| test_y_channel (bool): Test on Y channel of YCbCr. Default: False. | |||||
| Returns: | |||||
| float: SSIM result. | |||||
| """ | |||||
| assert img.shape == img2.shape, ( | |||||
| f'Image shapes are different: {img.shape}, {img2.shape}.') | |||||
| if input_order not in ['HWC', 'CHW']: | |||||
| raise ValueError( | |||||
| f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' | |||||
| ) | |||||
| img = reorder_image(img, input_order=input_order) | |||||
| img2 = reorder_image(img2, input_order=input_order) | |||||
| if crop_border != 0: | |||||
| img = img[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] | |||||
| if test_y_channel: | |||||
| img = to_y_channel(img) | |||||
| img2 = to_y_channel(img2) | |||||
| img = img.astype(np.float64) | |||||
| img2 = img2.astype(np.float64) | |||||
| ssims = [] | |||||
| for i in range(img.shape[2]): | |||||
| ssims.append(_ssim(img[..., i], img2[..., i])) | |||||
| return np.array(ssims).mean() | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name=Metrics.image_color_enhance_metric) | |||||
| class ImageColorEnhanceMetric(Metric): | |||||
| """The metric computation class for image color enhance classes. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.preds = [] | |||||
| self.targets = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| ground_truths = outputs['target'] | |||||
| eval_results = outputs['pred'] | |||||
| self.preds.extend(eval_results) | |||||
| self.targets.extend(ground_truths) | |||||
| def evaluate(self): | |||||
| psnrs = [ | |||||
| calculate_psnr(pred, target, 2, test_y_channel=False) | |||||
| for pred, target in zip(self.preds, self.targets) | |||||
| ] | |||||
| ssims = [ | |||||
| calculate_ssim(pred, target, 2, test_y_channel=False) | |||||
| for pred, target in zip(self.preds, self.targets) | |||||
| ] | |||||
| return { | |||||
| MetricKeys.PSNR: sum(psnrs) / len(psnrs), | |||||
| MetricKeys.SSIM: sum(ssims) / len(ssims) | |||||
| } | |||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .image_color_enhance.image_color_enhance import ImageColorEnhance | |||||
| @@ -0,0 +1,110 @@ | |||||
| import functools | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class Condition(nn.Module): | |||||
| def __init__(self, in_nc=3, nf=32): | |||||
| super(Condition, self).__init__() | |||||
| stride = 2 | |||||
| pad = 0 | |||||
| self.pad = nn.ZeroPad2d(1) | |||||
| self.conv1 = nn.Conv2d(in_nc, nf, 7, stride, pad, bias=True) | |||||
| self.conv2 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) | |||||
| self.conv3 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) | |||||
| self.act = nn.ReLU(inplace=True) | |||||
| def forward(self, x): | |||||
| conv1_out = self.act(self.conv1(self.pad(x))) | |||||
| conv2_out = self.act(self.conv2(self.pad(conv1_out))) | |||||
| conv3_out = self.act(self.conv3(self.pad(conv2_out))) | |||||
| out = torch.mean(conv3_out, dim=[2, 3], keepdim=False) | |||||
| return out | |||||
| # 3layers with control | |||||
| class CSRNet(nn.Module): | |||||
| def __init__(self, in_nc=3, out_nc=3, base_nf=64, cond_nf=32): | |||||
| super(CSRNet, self).__init__() | |||||
| self.base_nf = base_nf | |||||
| self.out_nc = out_nc | |||||
| self.cond_net = Condition(in_nc=in_nc, nf=cond_nf) | |||||
| self.cond_scale1 = nn.Linear(cond_nf, base_nf, bias=True) | |||||
| self.cond_scale2 = nn.Linear(cond_nf, base_nf, bias=True) | |||||
| self.cond_scale3 = nn.Linear(cond_nf, 3, bias=True) | |||||
| self.cond_shift1 = nn.Linear(cond_nf, base_nf, bias=True) | |||||
| self.cond_shift2 = nn.Linear(cond_nf, base_nf, bias=True) | |||||
| self.cond_shift3 = nn.Linear(cond_nf, 3, bias=True) | |||||
| self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True) | |||||
| self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True) | |||||
| self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True) | |||||
| self.act = nn.ReLU(inplace=True) | |||||
| def forward(self, x): | |||||
| cond = self.cond_net(x) | |||||
| scale1 = self.cond_scale1(cond) | |||||
| shift1 = self.cond_shift1(cond) | |||||
| scale2 = self.cond_scale2(cond) | |||||
| shift2 = self.cond_shift2(cond) | |||||
| scale3 = self.cond_scale3(cond) | |||||
| shift3 = self.cond_shift3(cond) | |||||
| out = self.conv1(x) | |||||
| out = out * scale1.view(-1, self.base_nf, 1, 1) + shift1.view( | |||||
| -1, self.base_nf, 1, 1) + out | |||||
| out = self.act(out) | |||||
| out = self.conv2(out) | |||||
| out = out * scale2.view(-1, self.base_nf, 1, 1) + shift2.view( | |||||
| -1, self.base_nf, 1, 1) + out | |||||
| out = self.act(out) | |||||
| out = self.conv3(out) | |||||
| out = out * scale3.view(-1, self.out_nc, 1, 1) + shift3.view( | |||||
| -1, self.out_nc, 1, 1) + out | |||||
| return out | |||||
| class L1Loss(nn.Module): | |||||
| """L1 (mean absolute error, MAE) loss. | |||||
| Args: | |||||
| loss_weight (float): Loss weight for L1 loss. Default: 1.0. | |||||
| reduction (str): Specifies the reduction to apply to the output. | |||||
| Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. | |||||
| """ | |||||
| def __init__(self, loss_weight=1.0, reduction='mean'): | |||||
| super(L1Loss, self).__init__() | |||||
| if reduction not in ['none', 'mean', 'sum']: | |||||
| raise ValueError( | |||||
| f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' | |||||
| ) | |||||
| self.loss_weight = loss_weight | |||||
| self.reduction = reduction | |||||
| def forward(self, pred, target, weight=None, **kwargs): | |||||
| """ | |||||
| Args: | |||||
| pred (Tensor): of shape (N, C, H, W). Predicted tensor. | |||||
| target (Tensor): of shape (N, C, H, W). Ground truth tensor. | |||||
| weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. | |||||
| """ | |||||
| return self.loss_weight * F.l1_loss( | |||||
| pred, target, reduction=self.reduction) | |||||
| @@ -0,0 +1,109 @@ | |||||
| import os.path as osp | |||||
| from copy import deepcopy | |||||
| from typing import Dict, Union | |||||
| import torch | |||||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .csrnet import CSRNet, L1Loss | |||||
| logger = get_logger() | |||||
| __all__ = ['ImageColorEnhance'] | |||||
| @MODELS.register_module(Tasks.image_color_enhance, module_name=Models.csrnet) | |||||
| class ImageColorEnhance(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the image color enhance model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| self.loss = L1Loss() | |||||
| self.model = CSRNet() | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self.model = self.model.to(self._device) | |||||
| self.model = self.load_pretrained(self.model, model_path) | |||||
| if self.training: | |||||
| self.model.train() | |||||
| else: | |||||
| self.model.eval() | |||||
| def load_pretrained(self, net, load_path, strict=True, param_key='params'): | |||||
| if isinstance(net, (DataParallel, DistributedDataParallel)): | |||||
| net = net.module | |||||
| load_net = torch.load( | |||||
| load_path, map_location=lambda storage, loc: storage) | |||||
| if param_key is not None: | |||||
| if param_key not in load_net and 'params' in load_net: | |||||
| param_key = 'params' | |||||
| logger.info( | |||||
| f'Loading: {param_key} does not exist, use params.') | |||||
| if param_key in load_net: | |||||
| load_net = load_net[param_key] | |||||
| logger.info( | |||||
| f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' | |||||
| ) | |||||
| # remove unnecessary 'module.' | |||||
| for k, v in deepcopy(load_net).items(): | |||||
| if k.startswith('module.'): | |||||
| load_net[k[7:]] = v | |||||
| load_net.pop(k) | |||||
| net.load_state_dict(load_net, strict=strict) | |||||
| logger.info('load model done.') | |||||
| return net | |||||
| def _evaluate_postprocess(self, src: Tensor, | |||||
| target: Tensor) -> Dict[str, list]: | |||||
| preds = self.model(src) | |||||
| preds = list(torch.split(preds, 1, 0)) | |||||
| targets = list(torch.split(target, 1, 0)) | |||||
| preds = [(pred.data * 255.).squeeze(0).type(torch.uint8).permute( | |||||
| 1, 2, 0).cpu().numpy() for pred in preds] | |||||
| targets = [(target.data * 255.).squeeze(0).type(torch.uint8).permute( | |||||
| 1, 2, 0).cpu().numpy() for target in targets] | |||||
| return {'pred': preds, 'target': targets} | |||||
| def _train_forward(self, src: Tensor, target: Tensor) -> Dict[str, Tensor]: | |||||
| preds = self.model(src) | |||||
| return {'loss': self.loss(preds, target)} | |||||
| def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: | |||||
| return {'outputs': self.model(src).clamp(0, 1)} | |||||
| def forward(self, input: Dict[str, | |||||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, Union[list, Tensor]]: results | |||||
| """ | |||||
| for key, value in input.items(): | |||||
| input[key] = input[key].to(self._device) | |||||
| if self.training: | |||||
| return self._train_forward(**input) | |||||
| elif 'target' in input: | |||||
| return self._evaluate_postprocess(**input) | |||||
| else: | |||||
| return self._inference_forward(**input) | |||||
| @@ -105,6 +105,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], | Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], | ||||
| # image_color_enhance result for a single sample | |||||
| # { | |||||
| # "output_img": np.ndarray with shape [height, width, 3], uint8 | |||||
| # } | |||||
| Tasks.image_color_enhance: [OutputKeys.OUTPUT_IMG], | |||||
| # ============ nlp tasks =================== | # ============ nlp tasks =================== | ||||
| # text classification result for single sample | # text classification result for single sample | ||||
| @@ -70,6 +70,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.text_to_image_synthesis: | Tasks.text_to_image_synthesis: | ||||
| (Pipelines.text_to_image_synthesis, | (Pipelines.text_to_image_synthesis, | ||||
| 'damo/cv_imagen_text-to-image-synthesis_tiny'), | 'damo/cv_imagen_text-to-image-synthesis_tiny'), | ||||
| Tasks.image_color_enhance: (Pipelines.image_color_enhance, | |||||
| 'damo/cv_csrnet_image-color-enhance-models'), | |||||
| Tasks.virtual_tryon: (Pipelines.virtual_tryon, | Tasks.virtual_tryon: (Pipelines.virtual_tryon, | ||||
| 'damo/cv_daflow_virtual-tryon_base'), | 'damo/cv_daflow_virtual-tryon_base'), | ||||
| Tasks.image_colorization: (Pipelines.image_colorization, | Tasks.image_colorization: (Pipelines.image_colorization, | ||||
| @@ -6,6 +6,7 @@ try: | |||||
| from .action_recognition_pipeline import ActionRecognitionPipeline | from .action_recognition_pipeline import ActionRecognitionPipeline | ||||
| from .animal_recog_pipeline import AnimalRecogPipeline | from .animal_recog_pipeline import AnimalRecogPipeline | ||||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | ||||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | |||||
| from .virtual_tryon_pipeline import VirtualTryonPipeline | from .virtual_tryon_pipeline import VirtualTryonPipeline | ||||
| from .image_colorization_pipeline import ImageColorizationPipeline | from .image_colorization_pipeline import ImageColorizationPipeline | ||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
| @@ -0,0 +1,74 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.cv.image_color_enhance.image_color_enhance import \ | |||||
| ImageColorEnhance | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input | |||||
| from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, | |||||
| load_image) | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_color_enhance, module_name=Pipelines.image_color_enhance) | |||||
| class ImageColorEnhancePipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[ImageColorEnhance, str], | |||||
| preprocessor: Optional[ | |||||
| ImageColorEnhanceFinetunePreprocessor] = None, | |||||
| **kwargs): | |||||
| """ | |||||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| model = model if isinstance( | |||||
| model, ImageColorEnhance) else Model.from_pretrained(model) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| if isinstance(input, str): | |||||
| img = load_image(input) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = input.convert('RGB') | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |||||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| test_transforms = transforms.Compose([transforms.ToTensor()]) | |||||
| img = test_transforms(img) | |||||
| result = {'src': img.unsqueeze(0).to(self._device)} | |||||
| return result | |||||
| @torch.no_grad() | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return super().forward(input) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| output_img = (inputs['outputs'].squeeze(0) * 255.).type( | |||||
| torch.uint8).cpu().permute(1, 2, 0).numpy() | |||||
| return {OutputKeys.OUTPUT_IMG: output_img} | |||||
| @@ -19,6 +19,7 @@ try: | |||||
| from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | from .space.dialog_intent_prediction_preprocessor import * # noqa F403 | ||||
| from .space.dialog_modeling_preprocessor import * # noqa F403 | from .space.dialog_modeling_preprocessor import * # noqa F403 | ||||
| from .space.dialog_state_tracking_preprocessor import * # noqa F403 | from .space.dialog_state_tracking_preprocessor import * # noqa F403 | ||||
| from .image import ImageColorEnhanceFinetunePreprocessor | |||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'tensorflow'": | if str(e) == "No module named 'tensorflow'": | ||||
| print(TENSORFLOW_IMPORT_ERROR.format('tts')) | print(TENSORFLOW_IMPORT_ERROR.format('tts')) | ||||
| @@ -1,12 +1,15 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import io | import io | ||||
| from typing import Dict, Union | |||||
| from typing import Any, Dict, Union | |||||
| import torch | |||||
| from PIL import Image, ImageOps | from PIL import Image, ImageOps | ||||
| from modelscope.fileio import File | from modelscope.fileio import File | ||||
| from modelscope.metainfo import Preprocessors | from modelscope.metainfo import Preprocessors | ||||
| from modelscope.utils.constant import Fields | from modelscope.utils.constant import Fields | ||||
| from modelscope.utils.type_assert import type_assert | |||||
| from .base import Preprocessor | |||||
| from .builder import PREPROCESSORS | from .builder import PREPROCESSORS | ||||
| @@ -66,3 +69,36 @@ def load_image(image_path_or_url: str) -> Image.Image: | |||||
| """ | """ | ||||
| loader = LoadImage() | loader = LoadImage() | ||||
| return loader(image_path_or_url)['img'] | return loader(image_path_or_url)['img'] | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.cv, module_name=Preprocessors.image_color_enhance_preprocessor) | |||||
| class ImageColorEnhanceFinetunePreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data from the `model_dir` path | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| @type_assert(object, object) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data (tuple): [sentence1, sentence2] | |||||
| sentence1 (str): a sentence | |||||
| Example: | |||||
| 'you are so handsome.' | |||||
| sentence2 (str): a sentence | |||||
| Example: | |||||
| 'you are so beautiful.' | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| return data | |||||
| @@ -27,6 +27,7 @@ class CVTasks(object): | |||||
| ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
| action_recognition = 'action-recognition' | action_recognition = 'action-recognition' | ||||
| video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
| image_color_enhance = 'image-color-enhance' | |||||
| virtual_tryon = 'virtual-tryon' | virtual_tryon = 'virtual-tryon' | ||||
| image_colorization = 'image-colorization' | image_colorization = 'image-colorization' | ||||
| face_image_generation = 'face-image-generation' | face_image_generation = 'face-image-generation' | ||||
| @@ -0,0 +1,42 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ImageColorEnhanceTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_csrnet_image-color-enhance-models' | |||||
| def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||||
| result = pipeline(input_location) | |||||
| if result is not None: | |||||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG][:, :, | |||||
| [2, 1, 0]]) | |||||
| print(f'Output written to {osp.abspath("result.png")}') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| img_color_enhance = pipeline( | |||||
| Tasks.image_color_enhance, model=self.model_id) | |||||
| self.pipeline_inference(img_color_enhance, | |||||
| 'data/test/images/image_color_enhance.png') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| img_color_enhance = pipeline(Tasks.image_color_enhance) | |||||
| self.pipeline_inference(img_color_enhance, | |||||
| 'data/test/images/image_color_enhance.png') | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,115 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import os.path as osp | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from typing import Callable, List, Optional, Tuple, Union | |||||
| import cv2 | |||||
| import torch | |||||
| from torch.utils import data as data | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models.cv.image_color_enhance.image_color_enhance import \ | |||||
| ImageColorEnhance | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestImageColorEnhanceTrainer(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| self.model_id = 'damo/cv_csrnet_image-color-enhance-models' | |||||
| class PairedImageDataset(data.Dataset): | |||||
| def __init__(self, root): | |||||
| super(PairedImageDataset, self).__init__() | |||||
| gt_dir = osp.join(root, 'gt') | |||||
| lq_dir = osp.join(root, 'lq') | |||||
| self.gt_filelist = os.listdir(gt_dir) | |||||
| self.gt_filelist = sorted( | |||||
| self.gt_filelist, key=lambda x: int(x[:-4])) | |||||
| self.gt_filelist = [ | |||||
| osp.join(gt_dir, f) for f in self.gt_filelist | |||||
| ] | |||||
| self.lq_filelist = os.listdir(lq_dir) | |||||
| self.lq_filelist = sorted( | |||||
| self.lq_filelist, key=lambda x: int(x[:-4])) | |||||
| self.lq_filelist = [ | |||||
| osp.join(lq_dir, f) for f in self.lq_filelist | |||||
| ] | |||||
| def _img_to_tensor(self, img): | |||||
| return torch.from_numpy(img[:, :, [2, 1, 0]]).permute( | |||||
| 2, 0, 1).type(torch.float32) / 255. | |||||
| def __getitem__(self, index): | |||||
| lq = cv2.imread(self.lq_filelist[index]) | |||||
| gt = cv2.imread(self.gt_filelist[index]) | |||||
| lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) | |||||
| gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) | |||||
| return \ | |||||
| {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} | |||||
| def __len__(self): | |||||
| return len(self.gt_filelist) | |||||
| def to_torch_dataset(self, | |||||
| columns: Union[str, List[str]] = None, | |||||
| preprocessors: Union[Callable, | |||||
| List[Callable]] = None, | |||||
| **format_kwargs): | |||||
| return self | |||||
| self.dataset = PairedImageDataset( | |||||
| './data/test/images/image_color_enhance/') | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=self.dataset, | |||||
| eval_dataset=self.dataset, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(3): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| model = ImageColorEnhance.from_pretrained(cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.dataset, | |||||
| eval_dataset=self.dataset, | |||||
| max_epochs=2, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(2): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||