From da20fb66e90d96f911402b3bec3d1cf4bf8e292f Mon Sep 17 00:00:00 2001 From: "wenqi.oywq" Date: Mon, 25 Jul 2022 13:30:11 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#43259593]=E6=B7=BB=E5=8A=A0image-color-e?= =?UTF-8?q?nhance,=20pipeline=20and=20trainer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加image-color-enhance, pipeline and trainer Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9483118 --- data/test/images/image_color_enhance.png | 3 + data/test/images/image_color_enhance/gt/1.png | 3 + data/test/images/image_color_enhance/gt/2.png | 3 + data/test/images/image_color_enhance/gt/3.png | 3 + data/test/images/image_color_enhance/gt/4.png | 3 + data/test/images/image_color_enhance/lq/1.png | 3 + data/test/images/image_color_enhance/lq/2.png | 3 + data/test/images/image_color_enhance/lq/3.png | 3 + data/test/images/image_color_enhance/lq/4.png | 3 + modelscope/metainfo.py | 5 + modelscope/metrics/__init__.py | 1 + modelscope/metrics/builder.py | 3 + .../metrics/image_color_enhance_metric.py | 258 ++++++++++++++++++ modelscope/models/cv/__init__.py | 2 + .../models/cv/image_color_enhance/__init__.py | 0 .../models/cv/image_color_enhance/csrnet.py | 110 ++++++++ .../image_color_enhance.py | 109 ++++++++ modelscope/outputs.py | 6 + modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 1 + .../cv/image_color_enhance_pipeline.py | 74 +++++ modelscope/preprocessors/__init__.py | 1 + modelscope/preprocessors/image.py | 38 ++- modelscope/utils/constant.py | 1 + tests/pipelines/test_image_color_enhance.py | 42 +++ .../test_image_color_enhance_trainer.py | 115 ++++++++ 26 files changed, 794 insertions(+), 1 deletion(-) create mode 100644 data/test/images/image_color_enhance.png create mode 100644 data/test/images/image_color_enhance/gt/1.png create mode 100644 data/test/images/image_color_enhance/gt/2.png create mode 100644 data/test/images/image_color_enhance/gt/3.png create mode 100644 data/test/images/image_color_enhance/gt/4.png create mode 100644 data/test/images/image_color_enhance/lq/1.png create mode 100644 data/test/images/image_color_enhance/lq/2.png create mode 100644 data/test/images/image_color_enhance/lq/3.png create mode 100644 data/test/images/image_color_enhance/lq/4.png create mode 100644 modelscope/metrics/image_color_enhance_metric.py create mode 100644 modelscope/models/cv/image_color_enhance/__init__.py create mode 100644 modelscope/models/cv/image_color_enhance/csrnet.py create mode 100644 modelscope/models/cv/image_color_enhance/image_color_enhance.py create mode 100644 modelscope/pipelines/cv/image_color_enhance_pipeline.py create mode 100644 tests/pipelines/test_image_color_enhance.py create mode 100644 tests/trainers/test_image_color_enhance_trainer.py diff --git a/data/test/images/image_color_enhance.png b/data/test/images/image_color_enhance.png new file mode 100644 index 00000000..ffb4d188 --- /dev/null +++ b/data/test/images/image_color_enhance.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b +size 950676 diff --git a/data/test/images/image_color_enhance/gt/1.png b/data/test/images/image_color_enhance/gt/1.png new file mode 100644 index 00000000..ffb4d188 --- /dev/null +++ b/data/test/images/image_color_enhance/gt/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b +size 950676 diff --git a/data/test/images/image_color_enhance/gt/2.png b/data/test/images/image_color_enhance/gt/2.png new file mode 100644 index 00000000..a84f2543 --- /dev/null +++ b/data/test/images/image_color_enhance/gt/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a4a8a60501976b2c5e753814a346519ef6faff052b53359cf44b4e597e62aaf +size 902214 diff --git a/data/test/images/image_color_enhance/gt/3.png b/data/test/images/image_color_enhance/gt/3.png new file mode 100644 index 00000000..dc04f4bc --- /dev/null +++ b/data/test/images/image_color_enhance/gt/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:326c5e3907926a4af6fec382050026d505d78aab8c5f2e0ecc85ac863abbb94c +size 856195 diff --git a/data/test/images/image_color_enhance/gt/4.png b/data/test/images/image_color_enhance/gt/4.png new file mode 100644 index 00000000..4e888582 --- /dev/null +++ b/data/test/images/image_color_enhance/gt/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:455f364c008be76a392085e7590b9050a628853a9df1e608a40c75a15bc41c5f +size 951993 diff --git a/data/test/images/image_color_enhance/lq/1.png b/data/test/images/image_color_enhance/lq/1.png new file mode 100644 index 00000000..a9641037 --- /dev/null +++ b/data/test/images/image_color_enhance/lq/1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f806b26557317f856e7583fb128713579df3354016b368ef32791b283e3be051 +size 932493 diff --git a/data/test/images/image_color_enhance/lq/2.png b/data/test/images/image_color_enhance/lq/2.png new file mode 100644 index 00000000..79176bd1 --- /dev/null +++ b/data/test/images/image_color_enhance/lq/2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ec66811ec4f1ec8735b7f0eb897100f80939ba5dc150028fa91bfcd15b5164c +size 896481 diff --git a/data/test/images/image_color_enhance/lq/3.png b/data/test/images/image_color_enhance/lq/3.png new file mode 100644 index 00000000..93f52409 --- /dev/null +++ b/data/test/images/image_color_enhance/lq/3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9517b185b0cffc0c830270fd52551e145054daa00c704ed4132589b24ab46e9 +size 828266 diff --git a/data/test/images/image_color_enhance/lq/4.png b/data/test/images/image_color_enhance/lq/4.png new file mode 100644 index 00000000..6a1f659a --- /dev/null +++ b/data/test/images/image_color_enhance/lq/4.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a233195949ed1c3db9c9a182baf3d8f014620d28bab823aa4d4cc203e602bc6 +size 927552 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 511b786d..345a49a6 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -10,6 +10,7 @@ class Models(object): Model name should only contain model info but not task info. """ # vision models + csrnet = 'csrnet' # nlp models bert = 'bert' @@ -60,6 +61,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + image_color_enhance = 'csrnet-image-color-enhance' virtual_tryon = 'virtual_tryon' image_colorization = 'unet-image-colorization' image_super_resolution = 'rrdb-image-super-resolution' @@ -121,6 +123,7 @@ class Preprocessors(object): # cv preprocessor load_image = 'load-image' + image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' @@ -160,3 +163,5 @@ class Metrics(object): token_cls_metric = 'token-cls-metric' # metrics for text-generation task text_gen_metric = 'text-gen-metric' + # metrics for image-color-enhance task + image_color_enhance_metric = 'image-color-enhance-metric' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index 9a0ca94a..0c7dec95 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -1,4 +1,5 @@ from .base import Metric from .builder import METRICS, build_metric, task_default_metrics +from .image_color_enhance_metric import ImageColorEnhanceMetric from .sequence_classification_metric import SequenceClassificationMetric from .text_generation_metric import TextGenerationMetric diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index 2d60f9b9..6cb8be7d 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -13,12 +13,15 @@ class MetricKeys(object): F1 = 'f1' PRECISION = 'precision' RECALL = 'recall' + PSNR = 'psnr' + SSIM = 'ssim' task_default_metrics = { Tasks.sentence_similarity: [Metrics.seq_cls_metric], Tasks.sentiment_classification: [Metrics.seq_cls_metric], Tasks.text_generation: [Metrics.text_gen_metric], + Tasks.image_color_enhance: [Metrics.image_color_enhance_metric] } diff --git a/modelscope/metrics/image_color_enhance_metric.py b/modelscope/metrics/image_color_enhance_metric.py new file mode 100644 index 00000000..df6534c5 --- /dev/null +++ b/modelscope/metrics/image_color_enhance_metric.py @@ -0,0 +1,258 @@ +# The code is modified based on BasicSR metrics: +# https://github.com/XPixelGroup/BasicSR/tree/master/basicsr/metrics + +from typing import Dict + +import cv2 +import numpy as np + +from ..metainfo import Metrics +from ..utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" + ) + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, + 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) + tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ssim_map = tmp1 / tmp2 + + return ssim_map.mean() + + +def calculate_psnr(img, + img2, + crop_border, + input_order='HWC', + test_y_channel=False, + **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' + ) + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +def calculate_ssim(img, + img2, + crop_border, + input_order='HWC', + test_y_channel=False, + **kwargs): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' + ) + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img = to_y_channel(img) + img2 = to_y_channel(img2) + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.image_color_enhance_metric) +class ImageColorEnhanceMetric(Metric): + """The metric computation class for image color enhance classes. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs['target'] + eval_results = outputs['pred'] + self.preds.extend(eval_results) + self.targets.extend(ground_truths) + + def evaluate(self): + psnrs = [ + calculate_psnr(pred, target, 2, test_y_channel=False) + for pred, target in zip(self.preds, self.targets) + ] + ssims = [ + calculate_ssim(pred, target, 2, test_y_channel=False) + for pred, target in zip(self.preds, self.targets) + ] + return { + MetricKeys.PSNR: sum(psnrs) / len(psnrs), + MetricKeys.SSIM: sum(ssims) / len(ssims) + } diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index e69de29b..e15152c3 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .image_color_enhance.image_color_enhance import ImageColorEnhance diff --git a/modelscope/models/cv/image_color_enhance/__init__.py b/modelscope/models/cv/image_color_enhance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_color_enhance/csrnet.py b/modelscope/models/cv/image_color_enhance/csrnet.py new file mode 100644 index 00000000..782cd528 --- /dev/null +++ b/modelscope/models/cv/image_color_enhance/csrnet.py @@ -0,0 +1,110 @@ +import functools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Condition(nn.Module): + + def __init__(self, in_nc=3, nf=32): + super(Condition, self).__init__() + stride = 2 + pad = 0 + self.pad = nn.ZeroPad2d(1) + self.conv1 = nn.Conv2d(in_nc, nf, 7, stride, pad, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) + self.conv3 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + conv1_out = self.act(self.conv1(self.pad(x))) + conv2_out = self.act(self.conv2(self.pad(conv1_out))) + conv3_out = self.act(self.conv3(self.pad(conv2_out))) + out = torch.mean(conv3_out, dim=[2, 3], keepdim=False) + + return out + + +# 3layers with control +class CSRNet(nn.Module): + + def __init__(self, in_nc=3, out_nc=3, base_nf=64, cond_nf=32): + super(CSRNet, self).__init__() + + self.base_nf = base_nf + self.out_nc = out_nc + + self.cond_net = Condition(in_nc=in_nc, nf=cond_nf) + + self.cond_scale1 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_scale2 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_scale3 = nn.Linear(cond_nf, 3, bias=True) + + self.cond_shift1 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_shift2 = nn.Linear(cond_nf, base_nf, bias=True) + self.cond_shift3 = nn.Linear(cond_nf, 3, bias=True) + + self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True) + self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True) + self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True) + + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + cond = self.cond_net(x) + + scale1 = self.cond_scale1(cond) + shift1 = self.cond_shift1(cond) + + scale2 = self.cond_scale2(cond) + shift2 = self.cond_shift2(cond) + + scale3 = self.cond_scale3(cond) + shift3 = self.cond_shift3(cond) + + out = self.conv1(x) + out = out * scale1.view(-1, self.base_nf, 1, 1) + shift1.view( + -1, self.base_nf, 1, 1) + out + out = self.act(out) + + out = self.conv2(out) + out = out * scale2.view(-1, self.base_nf, 1, 1) + shift2.view( + -1, self.base_nf, 1, 1) + out + out = self.act(out) + + out = self.conv3(out) + out = out * scale3.view(-1, self.out_nc, 1, 1) + shift3.view( + -1, self.out_nc, 1, 1) + out + return out + + +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError( + f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}' + ) + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * F.l1_loss( + pred, target, reduction=self.reduction) diff --git a/modelscope/models/cv/image_color_enhance/image_color_enhance.py b/modelscope/models/cv/image_color_enhance/image_color_enhance.py new file mode 100644 index 00000000..d142e682 --- /dev/null +++ b/modelscope/models/cv/image_color_enhance/image_color_enhance.py @@ -0,0 +1,109 @@ +import os.path as osp +from copy import deepcopy +from typing import Dict, Union + +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .csrnet import CSRNet, L1Loss + +logger = get_logger() + +__all__ = ['ImageColorEnhance'] + + +@MODELS.register_module(Tasks.image_color_enhance, module_name=Models.csrnet) +class ImageColorEnhance(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image color enhance model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + + self.loss = L1Loss() + self.model = CSRNet() + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.model = self.model.to(self._device) + + self.model = self.load_pretrained(self.model, model_path) + + if self.training: + self.model.train() + else: + self.model.eval() + + def load_pretrained(self, net, load_path, strict=True, param_key='params'): + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + load_net = torch.load( + load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info( + f'Loading: {param_key} does not exist, use params.') + if param_key in load_net: + load_net = load_net[param_key] + logger.info( + f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' + ) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + net.load_state_dict(load_net, strict=strict) + logger.info('load model done.') + return net + + def _evaluate_postprocess(self, src: Tensor, + target: Tensor) -> Dict[str, list]: + preds = self.model(src) + preds = list(torch.split(preds, 1, 0)) + targets = list(torch.split(target, 1, 0)) + + preds = [(pred.data * 255.).squeeze(0).type(torch.uint8).permute( + 1, 2, 0).cpu().numpy() for pred in preds] + targets = [(target.data * 255.).squeeze(0).type(torch.uint8).permute( + 1, 2, 0).cpu().numpy() for target in targets] + + return {'pred': preds, 'target': targets} + + def _train_forward(self, src: Tensor, target: Tensor) -> Dict[str, Tensor]: + preds = self.model(src) + return {'loss': self.loss(preds, target)} + + def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]: + return {'outputs': self.model(src).clamp(0, 1)} + + def forward(self, input: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Union[list, Tensor]]: results + """ + for key, value in input.items(): + input[key] = input[key].to(self._device) + if self.training: + return self._train_forward(**input) + elif 'target' in input: + return self._evaluate_postprocess(**input) + else: + return self._inference_forward(**input) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 99463385..f3a40824 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -105,6 +105,12 @@ TASK_OUTPUTS = { # } Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], + # image_color_enhance result for a single sample + # { + # "output_img": np.ndarray with shape [height, width, 3], uint8 + # } + Tasks.image_color_enhance: [OutputKeys.OUTPUT_IMG], + # ============ nlp tasks =================== # text classification result for single sample diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 3c23e15e..072a5a00 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -70,6 +70,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_imagen_text-to-image-synthesis_tiny'), + Tasks.image_color_enhance: (Pipelines.image_color_enhance, + 'damo/cv_csrnet_image-color-enhance-models'), Tasks.virtual_tryon: (Pipelines.virtual_tryon, 'damo/cv_daflow_virtual-tryon_base'), Tasks.image_colorization: (Pipelines.image_colorization, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index c1c1acdb..006cb92c 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -6,6 +6,7 @@ try: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recog_pipeline import AnimalRecogPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline + from .image_color_enhance_pipeline import ImageColorEnhancePipeline from .virtual_tryon_pipeline import VirtualTryonPipeline from .image_colorization_pipeline import ImageColorizationPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline diff --git a/modelscope/pipelines/cv/image_color_enhance_pipeline.py b/modelscope/pipelines/cv/image_color_enhance_pipeline.py new file mode 100644 index 00000000..506488f3 --- /dev/null +++ b/modelscope/pipelines/cv/image_color_enhance_pipeline.py @@ -0,0 +1,74 @@ +from typing import Any, Dict, Optional, Union + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.models.cv.image_color_enhance.image_color_enhance import \ + ImageColorEnhance +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input +from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, + load_image) +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_color_enhance, module_name=Pipelines.image_color_enhance) +class ImageColorEnhancePipeline(Pipeline): + + def __init__(self, + model: Union[ImageColorEnhance, str], + preprocessor: Optional[ + ImageColorEnhanceFinetunePreprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + model = model if isinstance( + model, ImageColorEnhance) else Model.from_pretrained(model) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = load_image(input) + elif isinstance(input, PIL.Image.Image): + img = input.convert('RGB') + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + img = Image.fromarray(img.astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + test_transforms = transforms.Compose([transforms.ToTensor()]) + img = test_transforms(img) + result = {'src': img.unsqueeze(0).to(self._device)} + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + return super().forward(input) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_img = (inputs['outputs'].squeeze(0) * 255.).type( + torch.uint8).cpu().permute(1, 2, 0).numpy() + return {OutputKeys.OUTPUT_IMG: output_img} diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index a2e3ee42..38b67276 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -19,6 +19,7 @@ try: from .space.dialog_intent_prediction_preprocessor import * # noqa F403 from .space.dialog_modeling_preprocessor import * # noqa F403 from .space.dialog_state_tracking_preprocessor import * # noqa F403 + from .image import ImageColorEnhanceFinetunePreprocessor except ModuleNotFoundError as e: if str(e) == "No module named 'tensorflow'": print(TENSORFLOW_IMPORT_ERROR.format('tts')) diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index fcad3c0e..85afb5b8 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -1,12 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import io -from typing import Dict, Union +from typing import Any, Dict, Union +import torch from PIL import Image, ImageOps from modelscope.fileio import File from modelscope.metainfo import Preprocessors from modelscope.utils.constant import Fields +from modelscope.utils.type_assert import type_assert +from .base import Preprocessor from .builder import PREPROCESSORS @@ -66,3 +69,36 @@ def load_image(image_path_or_url: str) -> Image.Image: """ loader = LoadImage() return loader(image_path_or_url)['img'] + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.image_color_enhance_preprocessor) +class ImageColorEnhanceFinetunePreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + self.model_dir: str = model_dir + + @type_assert(object, object) + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + """process the raw input data + + Args: + data (tuple): [sentence1, sentence2] + sentence1 (str): a sentence + Example: + 'you are so handsome.' + sentence2 (str): a sentence + Example: + 'you are so beautiful.' + Returns: + Dict[str, Any]: the preprocessed data + """ + + return data diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 977160d9..4adb48f0 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -27,6 +27,7 @@ class CVTasks(object): ocr_detection = 'ocr-detection' action_recognition = 'action-recognition' video_embedding = 'video-embedding' + image_color_enhance = 'image-color-enhance' virtual_tryon = 'virtual-tryon' image_colorization = 'image-colorization' face_image_generation = 'face-image-generation' diff --git a/tests/pipelines/test_image_color_enhance.py b/tests/pipelines/test_image_color_enhance.py new file mode 100644 index 00000000..ae22d65e --- /dev/null +++ b/tests/pipelines/test_image_color_enhance.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ImageColorEnhanceTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_csrnet_image-color-enhance-models' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG][:, :, + [2, 1, 0]]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + img_color_enhance = pipeline( + Tasks.image_color_enhance, model=self.model_id) + self.pipeline_inference(img_color_enhance, + 'data/test/images/image_color_enhance.png') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + img_color_enhance = pipeline(Tasks.image_color_enhance) + self.pipeline_inference(img_color_enhance, + 'data/test/images/image_color_enhance.png') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_image_color_enhance_trainer.py b/tests/trainers/test_image_color_enhance_trainer.py new file mode 100644 index 00000000..d44b3cfd --- /dev/null +++ b/tests/trainers/test_image_color_enhance_trainer.py @@ -0,0 +1,115 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import shutil +import tempfile +import unittest +from typing import Callable, List, Optional, Tuple, Union + +import cv2 +import torch +from torch.utils import data as data + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models.cv.image_color_enhance.image_color_enhance import \ + ImageColorEnhance +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile +from modelscope.utils.test_utils import test_level + + +class TestImageColorEnhanceTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + self.model_id = 'damo/cv_csrnet_image-color-enhance-models' + + class PairedImageDataset(data.Dataset): + + def __init__(self, root): + super(PairedImageDataset, self).__init__() + gt_dir = osp.join(root, 'gt') + lq_dir = osp.join(root, 'lq') + self.gt_filelist = os.listdir(gt_dir) + self.gt_filelist = sorted( + self.gt_filelist, key=lambda x: int(x[:-4])) + self.gt_filelist = [ + osp.join(gt_dir, f) for f in self.gt_filelist + ] + self.lq_filelist = os.listdir(lq_dir) + self.lq_filelist = sorted( + self.lq_filelist, key=lambda x: int(x[:-4])) + self.lq_filelist = [ + osp.join(lq_dir, f) for f in self.lq_filelist + ] + + def _img_to_tensor(self, img): + return torch.from_numpy(img[:, :, [2, 1, 0]]).permute( + 2, 0, 1).type(torch.float32) / 255. + + def __getitem__(self, index): + lq = cv2.imread(self.lq_filelist[index]) + gt = cv2.imread(self.gt_filelist[index]) + lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC) + gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC) + return \ + {'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} + + def __len__(self): + return len(self.gt_filelist) + + def to_torch_dataset(self, + columns: Union[str, List[str]] = None, + preprocessors: Union[Callable, + List[Callable]] = None, + **format_kwargs): + return self + + self.dataset = PairedImageDataset( + './data/test/images/image_color_enhance/') + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer(self): + kwargs = dict( + model=self.model_id, + train_dataset=self.dataset, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(3): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + cache_path = snapshot_download(self.model_id) + model = ImageColorEnhance.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.dataset, + eval_dataset=self.dataset, + max_epochs=2, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(2): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + +if __name__ == '__main__': + unittest.main()