Browse Source

[to #43259593]添加image-color-enhance, pipeline and trainer

添加image-color-enhance, pipeline and trainer
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9483118
master
wenqi.oywq yingda.chen 3 years ago
parent
commit
da20fb66e9
26 changed files with 794 additions and 1 deletions
  1. +3
    -0
      data/test/images/image_color_enhance.png
  2. +3
    -0
      data/test/images/image_color_enhance/gt/1.png
  3. +3
    -0
      data/test/images/image_color_enhance/gt/2.png
  4. +3
    -0
      data/test/images/image_color_enhance/gt/3.png
  5. +3
    -0
      data/test/images/image_color_enhance/gt/4.png
  6. +3
    -0
      data/test/images/image_color_enhance/lq/1.png
  7. +3
    -0
      data/test/images/image_color_enhance/lq/2.png
  8. +3
    -0
      data/test/images/image_color_enhance/lq/3.png
  9. +3
    -0
      data/test/images/image_color_enhance/lq/4.png
  10. +5
    -0
      modelscope/metainfo.py
  11. +1
    -0
      modelscope/metrics/__init__.py
  12. +3
    -0
      modelscope/metrics/builder.py
  13. +258
    -0
      modelscope/metrics/image_color_enhance_metric.py
  14. +2
    -0
      modelscope/models/cv/__init__.py
  15. +0
    -0
      modelscope/models/cv/image_color_enhance/__init__.py
  16. +110
    -0
      modelscope/models/cv/image_color_enhance/csrnet.py
  17. +109
    -0
      modelscope/models/cv/image_color_enhance/image_color_enhance.py
  18. +6
    -0
      modelscope/outputs.py
  19. +2
    -0
      modelscope/pipelines/builder.py
  20. +1
    -0
      modelscope/pipelines/cv/__init__.py
  21. +74
    -0
      modelscope/pipelines/cv/image_color_enhance_pipeline.py
  22. +1
    -0
      modelscope/preprocessors/__init__.py
  23. +37
    -1
      modelscope/preprocessors/image.py
  24. +1
    -0
      modelscope/utils/constant.py
  25. +42
    -0
      tests/pipelines/test_image_color_enhance.py
  26. +115
    -0
      tests/trainers/test_image_color_enhance_trainer.py

+ 3
- 0
data/test/images/image_color_enhance.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b
size 950676

+ 3
- 0
data/test/images/image_color_enhance/gt/1.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d64493ea0643b30129eaedacb2db9ca233c2d9c0d69209ff6d464d3cae4b4a5b
size 950676

+ 3
- 0
data/test/images/image_color_enhance/gt/2.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0a4a8a60501976b2c5e753814a346519ef6faff052b53359cf44b4e597e62aaf
size 902214

+ 3
- 0
data/test/images/image_color_enhance/gt/3.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:326c5e3907926a4af6fec382050026d505d78aab8c5f2e0ecc85ac863abbb94c
size 856195

+ 3
- 0
data/test/images/image_color_enhance/gt/4.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:455f364c008be76a392085e7590b9050a628853a9df1e608a40c75a15bc41c5f
size 951993

+ 3
- 0
data/test/images/image_color_enhance/lq/1.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f806b26557317f856e7583fb128713579df3354016b368ef32791b283e3be051
size 932493

+ 3
- 0
data/test/images/image_color_enhance/lq/2.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0ec66811ec4f1ec8735b7f0eb897100f80939ba5dc150028fa91bfcd15b5164c
size 896481

+ 3
- 0
data/test/images/image_color_enhance/lq/3.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d9517b185b0cffc0c830270fd52551e145054daa00c704ed4132589b24ab46e9
size 828266

+ 3
- 0
data/test/images/image_color_enhance/lq/4.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a233195949ed1c3db9c9a182baf3d8f014620d28bab823aa4d4cc203e602bc6
size 927552

+ 5
- 0
modelscope/metainfo.py View File

@@ -10,6 +10,7 @@ class Models(object):
Model name should only contain model info but not task info. Model name should only contain model info but not task info.
""" """
# vision models # vision models
csrnet = 'csrnet'


# nlp models # nlp models
bert = 'bert' bert = 'bert'
@@ -60,6 +61,7 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition' action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog' animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
image_color_enhance = 'csrnet-image-color-enhance'
virtual_tryon = 'virtual_tryon' virtual_tryon = 'virtual_tryon'
image_colorization = 'unet-image-colorization' image_colorization = 'unet-image-colorization'
image_super_resolution = 'rrdb-image-super-resolution' image_super_resolution = 'rrdb-image-super-resolution'
@@ -121,6 +123,7 @@ class Preprocessors(object):


# cv preprocessor # cv preprocessor
load_image = 'load-image' load_image = 'load-image'
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'


# nlp preprocessor # nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer' sen_sim_tokenizer = 'sen-sim-tokenizer'
@@ -160,3 +163,5 @@ class Metrics(object):
token_cls_metric = 'token-cls-metric' token_cls_metric = 'token-cls-metric'
# metrics for text-generation task # metrics for text-generation task
text_gen_metric = 'text-gen-metric' text_gen_metric = 'text-gen-metric'
# metrics for image-color-enhance task
image_color_enhance_metric = 'image-color-enhance-metric'

+ 1
- 0
modelscope/metrics/__init__.py View File

@@ -1,4 +1,5 @@
from .base import Metric from .base import Metric
from .builder import METRICS, build_metric, task_default_metrics from .builder import METRICS, build_metric, task_default_metrics
from .image_color_enhance_metric import ImageColorEnhanceMetric
from .sequence_classification_metric import SequenceClassificationMetric from .sequence_classification_metric import SequenceClassificationMetric
from .text_generation_metric import TextGenerationMetric from .text_generation_metric import TextGenerationMetric

+ 3
- 0
modelscope/metrics/builder.py View File

@@ -13,12 +13,15 @@ class MetricKeys(object):
F1 = 'f1' F1 = 'f1'
PRECISION = 'precision' PRECISION = 'precision'
RECALL = 'recall' RECALL = 'recall'
PSNR = 'psnr'
SSIM = 'ssim'




task_default_metrics = { task_default_metrics = {
Tasks.sentence_similarity: [Metrics.seq_cls_metric], Tasks.sentence_similarity: [Metrics.seq_cls_metric],
Tasks.sentiment_classification: [Metrics.seq_cls_metric], Tasks.sentiment_classification: [Metrics.seq_cls_metric],
Tasks.text_generation: [Metrics.text_gen_metric], Tasks.text_generation: [Metrics.text_gen_metric],
Tasks.image_color_enhance: [Metrics.image_color_enhance_metric]
} }






+ 258
- 0
modelscope/metrics/image_color_enhance_metric.py View File

@@ -0,0 +1,258 @@
# The code is modified based on BasicSR metrics:
# https://github.com/XPixelGroup/BasicSR/tree/master/basicsr/metrics

from typing import Dict

import cv2
import numpy as np

from ..metainfo import Metrics
from ..utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.

The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.

It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.

Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.

Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img


def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.

If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.

Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.

Returns:
ndarray: reordered image.
"""

if input_order not in ['HWC', 'CHW']:
raise ValueError(
f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'"
)
if len(img.shape) == 2:
img = img[..., None]
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img


def to_y_channel(img):
"""Change to Y channel of YCbCr.

Args:
img (ndarray): Images with range [0, 255].

Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.


def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.

It is called by func:`calculate_ssim`.

Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.

Returns:
float: SSIM result.
"""

c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())

mu1 = cv2.filter2D(img, -1, window)[5:-5,
5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim_map = tmp1 / tmp2

return ssim_map.mean()


def calculate_psnr(img,
img2,
crop_border,
input_order='HWC',
test_y_channel=False,
**kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).

Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

Returns:
float: PSNR result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

if test_y_channel:
img = to_y_channel(img)
img2 = to_y_channel(img2)

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 10. * np.log10(255. * 255. / mse)


def calculate_ssim(img,
img2,
crop_border,
input_order='HWC',
test_y_channel=False,
**kwargs):
"""Calculate SSIM (structural similarity).

Ref:
Image quality assessment: From error visibility to structural similarity

The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.

For three-channel images, SSIM is calculated for each channel and then
averaged.

Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

Returns:
float: SSIM result.
"""

assert img.shape == img2.shape, (
f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"'
)
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

if test_y_channel:
img = to_y_channel(img)
img2 = to_y_channel(img2)

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()


@METRICS.register_module(
group_key=default_group, module_name=Metrics.image_color_enhance_metric)
class ImageColorEnhanceMetric(Metric):
"""The metric computation class for image color enhance classes.
"""

def __init__(self):
self.preds = []
self.targets = []

def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs['target']
eval_results = outputs['pred']
self.preds.extend(eval_results)
self.targets.extend(ground_truths)

def evaluate(self):
psnrs = [
calculate_psnr(pred, target, 2, test_y_channel=False)
for pred, target in zip(self.preds, self.targets)
]
ssims = [
calculate_ssim(pred, target, 2, test_y_channel=False)
for pred, target in zip(self.preds, self.targets)
]
return {
MetricKeys.PSNR: sum(psnrs) / len(psnrs),
MetricKeys.SSIM: sum(ssims) / len(ssims)
}

+ 2
- 0
modelscope/models/cv/__init__.py View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .image_color_enhance.image_color_enhance import ImageColorEnhance

+ 0
- 0
modelscope/models/cv/image_color_enhance/__init__.py View File


+ 110
- 0
modelscope/models/cv/image_color_enhance/csrnet.py View File

@@ -0,0 +1,110 @@
import functools
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class Condition(nn.Module):

def __init__(self, in_nc=3, nf=32):
super(Condition, self).__init__()
stride = 2
pad = 0
self.pad = nn.ZeroPad2d(1)
self.conv1 = nn.Conv2d(in_nc, nf, 7, stride, pad, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True)
self.conv3 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True)
self.act = nn.ReLU(inplace=True)

def forward(self, x):
conv1_out = self.act(self.conv1(self.pad(x)))
conv2_out = self.act(self.conv2(self.pad(conv1_out)))
conv3_out = self.act(self.conv3(self.pad(conv2_out)))
out = torch.mean(conv3_out, dim=[2, 3], keepdim=False)

return out


# 3layers with control
class CSRNet(nn.Module):

def __init__(self, in_nc=3, out_nc=3, base_nf=64, cond_nf=32):
super(CSRNet, self).__init__()

self.base_nf = base_nf
self.out_nc = out_nc

self.cond_net = Condition(in_nc=in_nc, nf=cond_nf)

self.cond_scale1 = nn.Linear(cond_nf, base_nf, bias=True)
self.cond_scale2 = nn.Linear(cond_nf, base_nf, bias=True)
self.cond_scale3 = nn.Linear(cond_nf, 3, bias=True)

self.cond_shift1 = nn.Linear(cond_nf, base_nf, bias=True)
self.cond_shift2 = nn.Linear(cond_nf, base_nf, bias=True)
self.cond_shift3 = nn.Linear(cond_nf, 3, bias=True)

self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True)
self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True)
self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True)

self.act = nn.ReLU(inplace=True)

def forward(self, x):
cond = self.cond_net(x)

scale1 = self.cond_scale1(cond)
shift1 = self.cond_shift1(cond)

scale2 = self.cond_scale2(cond)
shift2 = self.cond_shift2(cond)

scale3 = self.cond_scale3(cond)
shift3 = self.cond_shift3(cond)

out = self.conv1(x)
out = out * scale1.view(-1, self.base_nf, 1, 1) + shift1.view(
-1, self.base_nf, 1, 1) + out
out = self.act(out)

out = self.conv2(out)
out = out * scale2.view(-1, self.base_nf, 1, 1) + shift2.view(
-1, self.base_nf, 1, 1) + out
out = self.act(out)

out = self.conv3(out)
out = out * scale3.view(-1, self.out_nc, 1, 1) + shift3.view(
-1, self.out_nc, 1, 1) + out
return out


class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.

Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""

def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(
f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}'
)

self.loss_weight = loss_weight
self.reduction = reduction

def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * F.l1_loss(
pred, target, reduction=self.reduction)

+ 109
- 0
modelscope/models/cv/image_color_enhance/image_color_enhance.py View File

@@ -0,0 +1,109 @@
import os.path as osp
from copy import deepcopy
from typing import Dict, Union

import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .csrnet import CSRNet, L1Loss

logger = get_logger()

__all__ = ['ImageColorEnhance']


@MODELS.register_module(Tasks.image_color_enhance, module_name=Models.csrnet)
class ImageColorEnhance(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the image color enhance model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)

model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)

self.loss = L1Loss()
self.model = CSRNet()
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.model = self.model.to(self._device)

self.model = self.load_pretrained(self.model, model_path)

if self.training:
self.model.train()
else:
self.model.eval()

def load_pretrained(self, net, load_path, strict=True, param_key='params'):
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
load_net = torch.load(
load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info(
f'Loading: {param_key} does not exist, use params.')
if param_key in load_net:
load_net = load_net[param_key]
logger.info(
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
net.load_state_dict(load_net, strict=strict)
logger.info('load model done.')
return net

def _evaluate_postprocess(self, src: Tensor,
target: Tensor) -> Dict[str, list]:
preds = self.model(src)
preds = list(torch.split(preds, 1, 0))
targets = list(torch.split(target, 1, 0))

preds = [(pred.data * 255.).squeeze(0).type(torch.uint8).permute(
1, 2, 0).cpu().numpy() for pred in preds]
targets = [(target.data * 255.).squeeze(0).type(torch.uint8).permute(
1, 2, 0).cpu().numpy() for target in targets]

return {'pred': preds, 'target': targets}

def _train_forward(self, src: Tensor, target: Tensor) -> Dict[str, Tensor]:
preds = self.model(src)
return {'loss': self.loss(preds, target)}

def _inference_forward(self, src: Tensor) -> Dict[str, Tensor]:
return {'outputs': self.model(src).clamp(0, 1)}

def forward(self, input: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Union[list, Tensor]]: results
"""
for key, value in input.items():
input[key] = input[key].to(self._device)
if self.training:
return self._train_forward(**input)
elif 'target' in input:
return self._evaluate_postprocess(**input)
else:
return self._inference_forward(**input)

+ 6
- 0
modelscope/outputs.py View File

@@ -105,6 +105,12 @@ TASK_OUTPUTS = {
# } # }
Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING], Tasks.video_embedding: [OutputKeys.VIDEO_EMBEDDING],


# image_color_enhance result for a single sample
# {
# "output_img": np.ndarray with shape [height, width, 3], uint8
# }
Tasks.image_color_enhance: [OutputKeys.OUTPUT_IMG],

# ============ nlp tasks =================== # ============ nlp tasks ===================


# text classification result for single sample # text classification result for single sample


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -70,6 +70,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_to_image_synthesis: Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis, (Pipelines.text_to_image_synthesis,
'damo/cv_imagen_text-to-image-synthesis_tiny'), 'damo/cv_imagen_text-to-image-synthesis_tiny'),
Tasks.image_color_enhance: (Pipelines.image_color_enhance,
'damo/cv_csrnet_image-color-enhance-models'),
Tasks.virtual_tryon: (Pipelines.virtual_tryon, Tasks.virtual_tryon: (Pipelines.virtual_tryon,
'damo/cv_daflow_virtual-tryon_base'), 'damo/cv_daflow_virtual-tryon_base'),
Tasks.image_colorization: (Pipelines.image_colorization, Tasks.image_colorization: (Pipelines.image_colorization,


+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -6,6 +6,7 @@ try:
from .action_recognition_pipeline import ActionRecognitionPipeline from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline from .animal_recog_pipeline import AnimalRecogPipeline
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
from .virtual_tryon_pipeline import VirtualTryonPipeline from .virtual_tryon_pipeline import VirtualTryonPipeline
from .image_colorization_pipeline import ImageColorizationPipeline from .image_colorization_pipeline import ImageColorizationPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline


+ 74
- 0
modelscope/pipelines/cv/image_color_enhance_pipeline.py View File

@@ -0,0 +1,74 @@
from typing import Any, Dict, Optional, Union

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.models.cv.image_color_enhance.image_color_enhance import \
ImageColorEnhance
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input
from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor,
load_image)
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_color_enhance, module_name=Pipelines.image_color_enhance)
class ImageColorEnhancePipeline(Pipeline):

def __init__(self,
model: Union[ImageColorEnhance, str],
preprocessor: Optional[
ImageColorEnhanceFinetunePreprocessor] = None,
**kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
model = model if isinstance(
model, ImageColorEnhance) else Model.from_pretrained(model)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

test_transforms = transforms.Compose([transforms.ToTensor()])
img = test_transforms(img)
result = {'src': img.unsqueeze(0).to(self._device)}
return result

@torch.no_grad()
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
return super().forward(input)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
output_img = (inputs['outputs'].squeeze(0) * 255.).type(
torch.uint8).cpu().permute(1, 2, 0).numpy()
return {OutputKeys.OUTPUT_IMG: output_img}

+ 1
- 0
modelscope/preprocessors/__init__.py View File

@@ -19,6 +19,7 @@ try:
from .space.dialog_intent_prediction_preprocessor import * # noqa F403 from .space.dialog_intent_prediction_preprocessor import * # noqa F403
from .space.dialog_modeling_preprocessor import * # noqa F403 from .space.dialog_modeling_preprocessor import * # noqa F403
from .space.dialog_state_tracking_preprocessor import * # noqa F403 from .space.dialog_state_tracking_preprocessor import * # noqa F403
from .image import ImageColorEnhanceFinetunePreprocessor
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
if str(e) == "No module named 'tensorflow'": if str(e) == "No module named 'tensorflow'":
print(TENSORFLOW_IMPORT_ERROR.format('tts')) print(TENSORFLOW_IMPORT_ERROR.format('tts'))


+ 37
- 1
modelscope/preprocessors/image.py View File

@@ -1,12 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import io import io
from typing import Dict, Union
from typing import Any, Dict, Union


import torch
from PIL import Image, ImageOps from PIL import Image, ImageOps


from modelscope.fileio import File from modelscope.fileio import File
from modelscope.metainfo import Preprocessors from modelscope.metainfo import Preprocessors
from modelscope.utils.constant import Fields from modelscope.utils.constant import Fields
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS from .builder import PREPROCESSORS




@@ -66,3 +69,36 @@ def load_image(image_path_or_url: str) -> Image.Image:
""" """
loader = LoadImage() loader = LoadImage()
return loader(image_path_or_url)['img'] return loader(image_path_or_url)['img']


@PREPROCESSORS.register_module(
Fields.cv, module_name=Preprocessors.image_color_enhance_preprocessor)
class ImageColorEnhanceFinetunePreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)
self.model_dir: str = model_dir

@type_assert(object, object)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""process the raw input data

Args:
data (tuple): [sentence1, sentence2]
sentence1 (str): a sentence
Example:
'you are so handsome.'
sentence2 (str): a sentence
Example:
'you are so beautiful.'
Returns:
Dict[str, Any]: the preprocessed data
"""

return data

+ 1
- 0
modelscope/utils/constant.py View File

@@ -27,6 +27,7 @@ class CVTasks(object):
ocr_detection = 'ocr-detection' ocr_detection = 'ocr-detection'
action_recognition = 'action-recognition' action_recognition = 'action-recognition'
video_embedding = 'video-embedding' video_embedding = 'video-embedding'
image_color_enhance = 'image-color-enhance'
virtual_tryon = 'virtual-tryon' virtual_tryon = 'virtual-tryon'
image_colorization = 'image-colorization' image_colorization = 'image-colorization'
face_image_generation = 'face-image-generation' face_image_generation = 'face-image-generation'


+ 42
- 0
tests/pipelines/test_image_color_enhance.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import unittest

import cv2

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ImageColorEnhanceTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_csrnet_image-color-enhance-models'

def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
if result is not None:
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG][:, :,
[2, 1, 0]])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
img_color_enhance = pipeline(
Tasks.image_color_enhance, model=self.model_id)
self.pipeline_inference(img_color_enhance,
'data/test/images/image_color_enhance.png')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
img_color_enhance = pipeline(Tasks.image_color_enhance)
self.pipeline_inference(img_color_enhance,
'data/test/images/image_color_enhance.png')


if __name__ == '__main__':
unittest.main()

+ 115
- 0
tests/trainers/test_image_color_enhance_trainer.py View File

@@ -0,0 +1,115 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import shutil
import tempfile
import unittest
from typing import Callable, List, Optional, Tuple, Union

import cv2
import torch
from torch.utils import data as data

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.cv.image_color_enhance.image_color_enhance import \
ImageColorEnhance
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestImageColorEnhanceTrainer(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

self.model_id = 'damo/cv_csrnet_image-color-enhance-models'

class PairedImageDataset(data.Dataset):

def __init__(self, root):
super(PairedImageDataset, self).__init__()
gt_dir = osp.join(root, 'gt')
lq_dir = osp.join(root, 'lq')
self.gt_filelist = os.listdir(gt_dir)
self.gt_filelist = sorted(
self.gt_filelist, key=lambda x: int(x[:-4]))
self.gt_filelist = [
osp.join(gt_dir, f) for f in self.gt_filelist
]
self.lq_filelist = os.listdir(lq_dir)
self.lq_filelist = sorted(
self.lq_filelist, key=lambda x: int(x[:-4]))
self.lq_filelist = [
osp.join(lq_dir, f) for f in self.lq_filelist
]

def _img_to_tensor(self, img):
return torch.from_numpy(img[:, :, [2, 1, 0]]).permute(
2, 0, 1).type(torch.float32) / 255.

def __getitem__(self, index):
lq = cv2.imread(self.lq_filelist[index])
gt = cv2.imread(self.gt_filelist[index])
lq = cv2.resize(lq, (256, 256), interpolation=cv2.INTER_CUBIC)
gt = cv2.resize(gt, (256, 256), interpolation=cv2.INTER_CUBIC)
return \
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)}

def __len__(self):
return len(self.gt_filelist)

def to_torch_dataset(self,
columns: Union[str, List[str]] = None,
preprocessors: Union[Callable,
List[Callable]] = None,
**format_kwargs):
return self

self.dataset = PairedImageDataset(
'./data/test/images/image_color_enhance/')

def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
cache_path = snapshot_download(self.model_id)
model = ImageColorEnhance.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
max_epochs=2,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save