Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491966master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:403034182fa320130dae0d75b92e85e0850771378e674d65455c403a4958e29c | |||||
| size 170716 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ebd5dacad9b75ef80f87eb785d7818421dadb63257da0e91e123766c5913f855 | |||||
| size 149971 | |||||
| @@ -10,6 +10,7 @@ class Models(object): | |||||
| Model name should only contain model info but not task info. | Model name should only contain model info but not task info. | ||||
| """ | """ | ||||
| # vision models | # vision models | ||||
| nafnet = 'nafnet' | |||||
| csrnet = 'csrnet' | csrnet = 'csrnet' | ||||
| cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | ||||
| @@ -59,6 +60,7 @@ class Pipelines(object): | |||||
| """ | """ | ||||
| # vision tasks | # vision tasks | ||||
| image_matting = 'unet-image-matting' | image_matting = 'unet-image-matting' | ||||
| image_denoise = 'nafnet-image-denoise' | |||||
| person_image_cartoon = 'unet-person-image-cartoon' | person_image_cartoon = 'unet-person-image-cartoon' | ||||
| ocr_detection = 'resnet18-ocr-detection' | ocr_detection = 'resnet18-ocr-detection' | ||||
| action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
| @@ -132,6 +134,7 @@ class Preprocessors(object): | |||||
| # cv preprocessor | # cv preprocessor | ||||
| load_image = 'load-image' | load_image = 'load-image' | ||||
| image_denoie_preprocessor = 'image-denoise-preprocessor' | |||||
| image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | image_color_enhance_preprocessor = 'image-color-enhance-preprocessor' | ||||
| image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor' | ||||
| @@ -167,6 +170,9 @@ class Metrics(object): | |||||
| # accuracy | # accuracy | ||||
| accuracy = 'accuracy' | accuracy = 'accuracy' | ||||
| # metrics for image denoise task | |||||
| image_denoise_metric = 'image-denoise-metric' | |||||
| # metric for image instance segmentation task | # metric for image instance segmentation task | ||||
| image_ins_seg_coco_metric = 'image-ins-seg-coco-metric' | image_ins_seg_coco_metric = 'image-ins-seg-coco-metric' | ||||
| # metrics for sequence classification task | # metrics for sequence classification task | ||||
| @@ -1,6 +1,7 @@ | |||||
| from .base import Metric | from .base import Metric | ||||
| from .builder import METRICS, build_metric, task_default_metrics | from .builder import METRICS, build_metric, task_default_metrics | ||||
| from .image_color_enhance_metric import ImageColorEnhanceMetric | from .image_color_enhance_metric import ImageColorEnhanceMetric | ||||
| from .image_denoise_metric import ImageDenoiseMetric | |||||
| from .image_instance_segmentation_metric import \ | from .image_instance_segmentation_metric import \ | ||||
| ImageInstanceSegmentationCOCOMetric | ImageInstanceSegmentationCOCOMetric | ||||
| from .sequence_classification_metric import SequenceClassificationMetric | from .sequence_classification_metric import SequenceClassificationMetric | ||||
| @@ -22,6 +22,7 @@ task_default_metrics = { | |||||
| Tasks.sentence_similarity: [Metrics.seq_cls_metric], | Tasks.sentence_similarity: [Metrics.seq_cls_metric], | ||||
| Tasks.sentiment_classification: [Metrics.seq_cls_metric], | Tasks.sentiment_classification: [Metrics.seq_cls_metric], | ||||
| Tasks.text_generation: [Metrics.text_gen_metric], | Tasks.text_generation: [Metrics.text_gen_metric], | ||||
| Tasks.image_denoise: [Metrics.image_denoise_metric], | |||||
| Tasks.image_color_enhance: [Metrics.image_color_enhance_metric] | Tasks.image_color_enhance: [Metrics.image_color_enhance_metric] | ||||
| } | } | ||||
| @@ -0,0 +1,45 @@ | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| from skimage.metrics import peak_signal_noise_ratio, structural_similarity | |||||
| from modelscope.metainfo import Metrics | |||||
| from modelscope.utils.registry import default_group | |||||
| from modelscope.utils.tensor_utils import (torch_nested_detach, | |||||
| torch_nested_numpify) | |||||
| from .base import Metric | |||||
| from .builder import METRICS, MetricKeys | |||||
| @METRICS.register_module( | |||||
| group_key=default_group, module_name=Metrics.image_denoise_metric) | |||||
| class ImageDenoiseMetric(Metric): | |||||
| """The metric computation class for image denoise classes. | |||||
| """ | |||||
| pred_name = 'pred' | |||||
| label_name = 'target' | |||||
| def __init__(self): | |||||
| self.preds = [] | |||||
| self.labels = [] | |||||
| def add(self, outputs: Dict, inputs: Dict): | |||||
| ground_truths = outputs[ImageDenoiseMetric.label_name] | |||||
| eval_results = outputs[ImageDenoiseMetric.pred_name] | |||||
| self.preds.append( | |||||
| torch_nested_numpify(torch_nested_detach(eval_results))) | |||||
| self.labels.append( | |||||
| torch_nested_numpify(torch_nested_detach(ground_truths))) | |||||
| def evaluate(self): | |||||
| psnr_list, ssim_list = [], [] | |||||
| for (pred, label) in zip(self.preds, self.labels): | |||||
| psnr_list.append( | |||||
| peak_signal_noise_ratio(label[0], pred[0], data_range=255)) | |||||
| ssim_list.append( | |||||
| structural_similarity( | |||||
| label[0], pred[0], multichannel=True, data_range=255)) | |||||
| return { | |||||
| MetricKeys.PSNR: np.mean(psnr_list), | |||||
| MetricKeys.SSIM: np.mean(ssim_list) | |||||
| } | |||||
| @@ -22,6 +22,7 @@ except ModuleNotFoundError as e: | |||||
| try: | try: | ||||
| from .multi_modal import OfaForImageCaptioning | from .multi_modal import OfaForImageCaptioning | ||||
| from .cv import NAFNetForImageDenoise | |||||
| from .nlp import (BertForMaskedLM, BertForSequenceClassification, | from .nlp import (BertForMaskedLM, BertForSequenceClassification, | ||||
| SbertForNLI, SbertForSentenceSimilarity, | SbertForNLI, SbertForSentenceSimilarity, | ||||
| SbertForSentimentClassification, | SbertForSentimentClassification, | ||||
| @@ -1,2 +1,3 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .image_color_enhance.image_color_enhance import ImageColorEnhance | from .image_color_enhance.image_color_enhance import ImageColorEnhance | ||||
| from .image_denoise.nafnet_for_image_denoise import * # noqa F403 | |||||
| @@ -0,0 +1,233 @@ | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from .arch_util import LayerNorm2d | |||||
| class SimpleGate(nn.Module): | |||||
| def forward(self, x): | |||||
| x1, x2 = x.chunk(2, dim=1) | |||||
| return x1 * x2 | |||||
| class NAFBlock(nn.Module): | |||||
| def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): | |||||
| super().__init__() | |||||
| dw_channel = c * DW_Expand | |||||
| self.conv1 = nn.Conv2d( | |||||
| in_channels=c, | |||||
| out_channels=dw_channel, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| stride=1, | |||||
| groups=1, | |||||
| bias=True) | |||||
| self.conv2 = nn.Conv2d( | |||||
| in_channels=dw_channel, | |||||
| out_channels=dw_channel, | |||||
| kernel_size=3, | |||||
| padding=1, | |||||
| stride=1, | |||||
| groups=dw_channel, | |||||
| bias=True) | |||||
| self.conv3 = nn.Conv2d( | |||||
| in_channels=dw_channel // 2, | |||||
| out_channels=c, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| stride=1, | |||||
| groups=1, | |||||
| bias=True) | |||||
| # Simplified Channel Attention | |||||
| self.sca = nn.Sequential( | |||||
| nn.AdaptiveAvgPool2d(1), | |||||
| nn.Conv2d( | |||||
| in_channels=dw_channel // 2, | |||||
| out_channels=dw_channel // 2, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| stride=1, | |||||
| groups=1, | |||||
| bias=True), | |||||
| ) | |||||
| # SimpleGate | |||||
| self.sg = SimpleGate() | |||||
| ffn_channel = FFN_Expand * c | |||||
| self.conv4 = nn.Conv2d( | |||||
| in_channels=c, | |||||
| out_channels=ffn_channel, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| stride=1, | |||||
| groups=1, | |||||
| bias=True) | |||||
| self.conv5 = nn.Conv2d( | |||||
| in_channels=ffn_channel // 2, | |||||
| out_channels=c, | |||||
| kernel_size=1, | |||||
| padding=0, | |||||
| stride=1, | |||||
| groups=1, | |||||
| bias=True) | |||||
| self.norm1 = LayerNorm2d(c) | |||||
| self.norm2 = LayerNorm2d(c) | |||||
| self.dropout1 = nn.Dropout( | |||||
| drop_out_rate) if drop_out_rate > 0. else nn.Identity() | |||||
| self.dropout2 = nn.Dropout( | |||||
| drop_out_rate) if drop_out_rate > 0. else nn.Identity() | |||||
| self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) | |||||
| self.gamma = nn.Parameter( | |||||
| torch.zeros((1, c, 1, 1)), requires_grad=True) | |||||
| def forward(self, inp): | |||||
| x = inp | |||||
| x = self.norm1(x) | |||||
| x = self.conv1(x) | |||||
| x = self.conv2(x) | |||||
| x = self.sg(x) | |||||
| x = x * self.sca(x) | |||||
| x = self.conv3(x) | |||||
| x = self.dropout1(x) | |||||
| y = inp + x * self.beta | |||||
| x = self.conv4(self.norm2(y)) | |||||
| x = self.sg(x) | |||||
| x = self.conv5(x) | |||||
| x = self.dropout2(x) | |||||
| return y + x * self.gamma | |||||
| class NAFNet(nn.Module): | |||||
| def __init__(self, | |||||
| img_channel=3, | |||||
| width=16, | |||||
| middle_blk_num=1, | |||||
| enc_blk_nums=[], | |||||
| dec_blk_nums=[]): | |||||
| super().__init__() | |||||
| self.intro = nn.Conv2d( | |||||
| in_channels=img_channel, | |||||
| out_channels=width, | |||||
| kernel_size=3, | |||||
| padding=1, | |||||
| stride=1, | |||||
| groups=1, | |||||
| bias=True) | |||||
| self.ending = nn.Conv2d( | |||||
| in_channels=width, | |||||
| out_channels=img_channel, | |||||
| kernel_size=3, | |||||
| padding=1, | |||||
| stride=1, | |||||
| groups=1, | |||||
| bias=True) | |||||
| self.encoders = nn.ModuleList() | |||||
| self.decoders = nn.ModuleList() | |||||
| self.middle_blks = nn.ModuleList() | |||||
| self.ups = nn.ModuleList() | |||||
| self.downs = nn.ModuleList() | |||||
| chan = width | |||||
| for num in enc_blk_nums: | |||||
| self.encoders.append( | |||||
| nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) | |||||
| self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2)) | |||||
| chan = chan * 2 | |||||
| self.middle_blks = \ | |||||
| nn.Sequential( | |||||
| *[NAFBlock(chan) for _ in range(middle_blk_num)] | |||||
| ) | |||||
| for num in dec_blk_nums: | |||||
| self.ups.append( | |||||
| nn.Sequential( | |||||
| nn.Conv2d(chan, chan * 2, 1, bias=False), | |||||
| nn.PixelShuffle(2))) | |||||
| chan = chan // 2 | |||||
| self.decoders.append( | |||||
| nn.Sequential(*[NAFBlock(chan) for _ in range(num)])) | |||||
| self.padder_size = 2**len(self.encoders) | |||||
| def forward(self, inp): | |||||
| B, C, H, W = inp.shape | |||||
| inp = self.check_image_size(inp) | |||||
| x = self.intro(inp) | |||||
| encs = [] | |||||
| for encoder, down in zip(self.encoders, self.downs): | |||||
| x = encoder(x) | |||||
| encs.append(x) | |||||
| x = down(x) | |||||
| x = self.middle_blks(x) | |||||
| for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): | |||||
| x = up(x) | |||||
| x = x + enc_skip | |||||
| x = decoder(x) | |||||
| x = self.ending(x) | |||||
| x = x + inp | |||||
| return x[:, :, :H, :W] | |||||
| def check_image_size(self, x): | |||||
| _, _, h, w = x.size() | |||||
| mod_pad_h = (self.padder_size | |||||
| - h % self.padder_size) % self.padder_size | |||||
| mod_pad_w = (self.padder_size | |||||
| - w % self.padder_size) % self.padder_size | |||||
| x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) | |||||
| return x | |||||
| class PSNRLoss(nn.Module): | |||||
| def __init__(self, loss_weight=1.0, reduction='mean', toY=False): | |||||
| super(PSNRLoss, self).__init__() | |||||
| assert reduction == 'mean' | |||||
| self.loss_weight = loss_weight | |||||
| self.scale = 10 / np.log(10) | |||||
| self.toY = toY | |||||
| self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) | |||||
| self.first = True | |||||
| def forward(self, pred, target): | |||||
| assert len(pred.size()) == 4 | |||||
| if self.toY: | |||||
| if self.first: | |||||
| self.coef = self.coef.to(pred.device) | |||||
| self.first = False | |||||
| pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. | |||||
| target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. | |||||
| pred, target = pred / 255., target / 255. | |||||
| pass | |||||
| assert len(pred.size()) == 4 | |||||
| return self.loss_weight * self.scale * torch.log(( | |||||
| (pred - target)**2).mean(dim=(1, 2, 3)) + 1e-8).mean() | |||||
| @@ -0,0 +1,42 @@ | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| class LayerNormFunction(torch.autograd.Function): | |||||
| @staticmethod | |||||
| def forward(ctx, x, weight, bias, eps): | |||||
| ctx.eps = eps | |||||
| N, C, H, W = x.size() | |||||
| mu = x.mean(1, keepdim=True) | |||||
| var = (x - mu).pow(2).mean(1, keepdim=True) | |||||
| y = (x - mu) / (var + eps).sqrt() | |||||
| ctx.save_for_backward(y, var, weight) | |||||
| y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) | |||||
| return y | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| eps = ctx.eps | |||||
| N, C, H, W = grad_output.size() | |||||
| y, var, weight = ctx.saved_variables | |||||
| g = grad_output * weight.view(1, C, 1, 1) | |||||
| mean_g = g.mean(dim=1, keepdim=True) | |||||
| mean_gy = (g * y).mean(dim=1, keepdim=True) | |||||
| gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) | |||||
| return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum( | |||||
| dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None | |||||
| class LayerNorm2d(nn.Module): | |||||
| def __init__(self, channels, eps=1e-6): | |||||
| super(LayerNorm2d, self).__init__() | |||||
| self.register_parameter('weight', nn.Parameter(torch.ones(channels))) | |||||
| self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) | |||||
| self.eps = eps | |||||
| def forward(self, x): | |||||
| return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) | |||||
| @@ -0,0 +1,119 @@ | |||||
| import os | |||||
| from copy import deepcopy | |||||
| from typing import Any, Dict, Union | |||||
| import numpy as np | |||||
| import torch.cuda | |||||
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor | |||||
| from modelscope.models.base.base_torch_model import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .nafnet.NAFNet_arch import NAFNet, PSNRLoss | |||||
| logger = get_logger() | |||||
| __all__ = ['NAFNetForImageDenoise'] | |||||
| @MODELS.register_module(Tasks.image_denoise, module_name=Models.nafnet) | |||||
| class NAFNetForImageDenoise(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the image denoise model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| self.model_dir = model_dir | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.model_dir, ModelFile.CONFIGURATION)) | |||||
| model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| self.model = NAFNet(**self.config.model.network_g) | |||||
| self.loss = PSNRLoss() | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self.model = self.model.to(self._device) | |||||
| self.model = self._load_pretrained(self.model, model_path) | |||||
| if self.training: | |||||
| self.model.train() | |||||
| else: | |||||
| self.model.eval() | |||||
| def _load_pretrained(self, | |||||
| net, | |||||
| load_path, | |||||
| strict=True, | |||||
| param_key='params'): | |||||
| if isinstance(net, (DataParallel, DistributedDataParallel)): | |||||
| net = net.module | |||||
| load_net = torch.load( | |||||
| load_path, map_location=lambda storage, loc: storage) | |||||
| if param_key is not None: | |||||
| if param_key not in load_net and 'params' in load_net: | |||||
| param_key = 'params' | |||||
| logger.info( | |||||
| f'Loading: {param_key} does not exist, use params.') | |||||
| if param_key in load_net: | |||||
| load_net = load_net[param_key] | |||||
| logger.info( | |||||
| f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].' | |||||
| ) | |||||
| # remove unnecessary 'module.' | |||||
| for k, v in deepcopy(load_net).items(): | |||||
| if k.startswith('module.'): | |||||
| load_net[k[7:]] = v | |||||
| load_net.pop(k) | |||||
| net.load_state_dict(load_net, strict=strict) | |||||
| logger.info('load model done.') | |||||
| return net | |||||
| def _train_forward(self, input: Tensor, | |||||
| target: Tensor) -> Dict[str, Tensor]: | |||||
| preds = self.model(input) | |||||
| return {'loss': self.loss(preds, target)} | |||||
| def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: | |||||
| return {'outputs': self.model(input).clamp(0, 1)} | |||||
| def _evaluate_postprocess(self, input: Tensor, | |||||
| target: Tensor) -> Dict[str, list]: | |||||
| preds = self.model(input) | |||||
| preds = list(torch.split(preds, 1, 0)) | |||||
| targets = list(torch.split(target, 1, 0)) | |||||
| preds = [(pred.data * 255.).squeeze(0).permute( | |||||
| 1, 2, 0).cpu().numpy().astype(np.uint8) for pred in preds] | |||||
| targets = [(target.data * 255.).squeeze(0).permute( | |||||
| 1, 2, 0).cpu().numpy().astype(np.uint8) for target in targets] | |||||
| return {'pred': preds, 'target': targets} | |||||
| def forward(self, inputs: Dict[str, | |||||
| Tensor]) -> Dict[str, Union[list, Tensor]]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| inputs (Tensor): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, Tensor]: results | |||||
| """ | |||||
| for key, value in inputs.items(): | |||||
| inputs[key] = inputs[key].to(self._device) | |||||
| if self.training: | |||||
| return self._train_forward(**inputs) | |||||
| elif 'target' in inputs: | |||||
| return self._evaluate_postprocess(**inputs) | |||||
| else: | |||||
| return self._inference_forward(**inputs) | |||||
| @@ -0,0 +1,152 @@ | |||||
| # ------------------------------------------------------------------------ | |||||
| # Modified from BasicSR (https://github.com/xinntao/BasicSR) | |||||
| # Copyright 2018-2020 BasicSR Authors | |||||
| # ------------------------------------------------------------------------ | |||||
| import os | |||||
| from os import path as osp | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from .transforms import mod_crop | |||||
| def img2tensor(imgs, bgr2rgb=True, float32=True): | |||||
| """Numpy array to tensor. | |||||
| Args: | |||||
| imgs (list[ndarray] | ndarray): Input images. | |||||
| bgr2rgb (bool): Whether to change bgr to rgb. | |||||
| float32 (bool): Whether to change to float32. | |||||
| Returns: | |||||
| list[tensor] | tensor: Tensor images. If returned results only have | |||||
| one element, just return tensor. | |||||
| """ | |||||
| def _totensor(img, bgr2rgb, float32): | |||||
| if img.shape[2] == 3 and bgr2rgb: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||||
| img = torch.from_numpy(img.transpose(2, 0, 1)) | |||||
| if float32: | |||||
| img = img.float() | |||||
| return img | |||||
| if isinstance(imgs, list): | |||||
| return [_totensor(img, bgr2rgb, float32) for img in imgs] | |||||
| else: | |||||
| return _totensor(imgs, bgr2rgb, float32) | |||||
| def scandir(dir_path, keyword=None, recursive=False, full_path=False): | |||||
| """Scan a directory to find the interested files. | |||||
| Args: | |||||
| dir_path (str): Path of the directory. | |||||
| keyword (str | tuple(str), optional): File keyword that we are | |||||
| interested in. Default: None. | |||||
| recursive (bool, optional): If set to True, recursively scan the | |||||
| directory. Default: False. | |||||
| full_path (bool, optional): If set to True, include the dir_path. | |||||
| Default: False. | |||||
| Returns: | |||||
| A generator for all the interested files with relative pathes. | |||||
| """ | |||||
| if (keyword is not None) and not isinstance(keyword, (str, tuple)): | |||||
| raise TypeError('"suffix" must be a string or tuple of strings') | |||||
| root = dir_path | |||||
| def _scandir(dir_path, keyword, recursive): | |||||
| for entry in os.scandir(dir_path): | |||||
| if not entry.name.startswith('.') and entry.is_file(): | |||||
| if full_path: | |||||
| return_path = entry.path | |||||
| else: | |||||
| return_path = osp.relpath(entry.path, root) | |||||
| if keyword is None: | |||||
| yield return_path | |||||
| elif keyword in return_path: | |||||
| yield return_path | |||||
| else: | |||||
| if recursive: | |||||
| yield from _scandir( | |||||
| entry.path, keyword=keyword, recursive=recursive) | |||||
| else: | |||||
| continue | |||||
| return _scandir(dir_path, keyword=keyword, recursive=recursive) | |||||
| def padding(img_lq, img_gt, gt_size): | |||||
| h, w, _ = img_lq.shape | |||||
| h_pad = max(0, gt_size - h) | |||||
| w_pad = max(0, gt_size - w) | |||||
| if h_pad == 0 and w_pad == 0: | |||||
| return img_lq, img_gt | |||||
| img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) | |||||
| img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) | |||||
| return img_lq, img_gt | |||||
| def read_img_seq(path, require_mod_crop=False, scale=1): | |||||
| """Read a sequence of images from a given folder path. | |||||
| Args: | |||||
| path (list[str] | str): List of image paths or image folder path. | |||||
| require_mod_crop (bool): Require mod crop for each image. | |||||
| Default: False. | |||||
| scale (int): Scale factor for mod_crop. Default: 1. | |||||
| Returns: | |||||
| Tensor: size (t, c, h, w), RGB, [0, 1]. | |||||
| """ | |||||
| if isinstance(path, list): | |||||
| img_paths = path | |||||
| else: | |||||
| img_paths = sorted(list(scandir(path, full_path=True))) | |||||
| imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] | |||||
| if require_mod_crop: | |||||
| imgs = [mod_crop(img, scale) for img in imgs] | |||||
| imgs = img2tensor(imgs, bgr2rgb=True, float32=True) | |||||
| imgs = torch.stack(imgs, dim=0) | |||||
| return imgs | |||||
| def paired_paths_from_folder(folders, keys, filename_tmpl): | |||||
| """Generate paired paths from folders. | |||||
| Args: | |||||
| folders (list[str]): A list of folder path. The order of list should | |||||
| be [input_folder, gt_folder]. | |||||
| keys (list[str]): A list of keys identifying folders. The order should | |||||
| be in consistent with folders, e.g., ['lq', 'gt']. | |||||
| filename_tmpl (str): Template for each filename. Note that the | |||||
| template excludes the file extension. Usually the filename_tmpl is | |||||
| for files in the input folder. | |||||
| Returns: | |||||
| list[str]: Returned path list. | |||||
| """ | |||||
| assert len(folders) == 2, ( | |||||
| 'The len of folders should be 2 with [input_folder, gt_folder]. ' | |||||
| f'But got {len(folders)}') | |||||
| assert len(keys) == 2, ( | |||||
| 'The len of keys should be 2 with [input_key, gt_key]. ' | |||||
| f'But got {len(keys)}') | |||||
| input_folder, gt_folder = folders | |||||
| input_key, gt_key = keys | |||||
| input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True)) | |||||
| gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True)) | |||||
| assert len(input_paths) == len(gt_paths), ( | |||||
| f'{input_key} and {gt_key} datasets have different number of images: ' | |||||
| f'{len(input_paths)}, {len(gt_paths)}.') | |||||
| paths = [] | |||||
| for idx in range(len(gt_paths)): | |||||
| gt_path = os.path.join(gt_folder, gt_paths[idx]) | |||||
| input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY')) | |||||
| paths.append( | |||||
| dict([(f'{input_key}_path', input_path), | |||||
| (f'{gt_key}_path', gt_path)])) | |||||
| return paths | |||||
| @@ -0,0 +1,78 @@ | |||||
| import os | |||||
| from typing import Callable, List, Optional, Tuple, Union | |||||
| import cv2 | |||||
| import numpy as np | |||||
| from torch.utils import data | |||||
| from .data_utils import img2tensor, padding, paired_paths_from_folder | |||||
| from .transforms import augment, paired_random_crop | |||||
| def default_loader(path): | |||||
| return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 | |||||
| class PairedImageDataset(data.Dataset): | |||||
| """Paired image dataset for image restoration. | |||||
| """ | |||||
| def __init__(self, opt, root, is_train): | |||||
| super(PairedImageDataset, self).__init__() | |||||
| self.opt = opt | |||||
| self.is_train = is_train | |||||
| self.gt_folder, self.lq_folder = os.path.join( | |||||
| root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq) | |||||
| if opt.filename_tmpl is not None: | |||||
| self.filename_tmpl = opt.filename_tmpl | |||||
| else: | |||||
| self.filename_tmpl = '{}' | |||||
| self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], | |||||
| ['lq', 'gt'], self.filename_tmpl) | |||||
| def __getitem__(self, index): | |||||
| scale = self.opt.scale | |||||
| # Load gt and lq images. Dimension order: HWC; channel order: BGR; | |||||
| # image range: [0, 1], float32. | |||||
| gt_path = self.paths[index]['gt_path'] | |||||
| img_gt = default_loader(gt_path) | |||||
| lq_path = self.paths[index]['lq_path'] | |||||
| img_lq = default_loader(lq_path) | |||||
| # augmentation for training | |||||
| # if self.is_train: | |||||
| gt_size = self.opt.gt_size | |||||
| # padding | |||||
| img_gt, img_lq = padding(img_gt, img_lq, gt_size) | |||||
| # random crop | |||||
| img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale) | |||||
| # flip, rotation | |||||
| img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, | |||||
| self.opt.use_rot) | |||||
| # BGR to RGB, HWC to CHW, numpy to tensor | |||||
| img_gt, img_lq = img2tensor([img_gt, img_lq], | |||||
| bgr2rgb=True, | |||||
| float32=True) | |||||
| return { | |||||
| 'input': img_lq, | |||||
| 'target': img_gt, | |||||
| 'input_path': lq_path, | |||||
| 'target_path': gt_path | |||||
| } | |||||
| def __len__(self): | |||||
| return len(self.paths) | |||||
| def to_torch_dataset( | |||||
| self, | |||||
| columns: Union[str, List[str]] = None, | |||||
| preprocessors: Union[Callable, List[Callable]] = None, | |||||
| **format_kwargs, | |||||
| ): | |||||
| return self | |||||
| @@ -0,0 +1,96 @@ | |||||
| # Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/data/transforms.py | |||||
| import random | |||||
| def mod_crop(img, scale): | |||||
| """Mod crop images, used during testing. | |||||
| Args: | |||||
| img (ndarray): Input image. | |||||
| scale (int): Scale factor. | |||||
| Returns: | |||||
| ndarray: Result image. | |||||
| """ | |||||
| img = img.copy() | |||||
| if img.ndim in (2, 3): | |||||
| h, w = img.shape[0], img.shape[1] | |||||
| h_remainder, w_remainder = h % scale, w % scale | |||||
| img = img[:h - h_remainder, :w - w_remainder, ...] | |||||
| else: | |||||
| raise ValueError(f'Wrong img ndim: {img.ndim}.') | |||||
| return img | |||||
| def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale): | |||||
| """Paired random crop. | |||||
| It crops lists of lq and gt images with corresponding locations. | |||||
| Args: | |||||
| img_gts (list[ndarray] | ndarray): GT images. | |||||
| img_lqs (list[ndarray] | ndarray): LQ images. | |||||
| gt_patch_size (int): GT patch size. | |||||
| scale (int): Scale factor. | |||||
| Returns: | |||||
| list[ndarray] | ndarray: GT images and LQ images. | |||||
| """ | |||||
| if not isinstance(img_gts, list): | |||||
| img_gts = [img_gts] | |||||
| if not isinstance(img_lqs, list): | |||||
| img_lqs = [img_lqs] | |||||
| h_lq, w_lq, _ = img_lqs[0].shape | |||||
| h_gt, w_gt, _ = img_gts[0].shape | |||||
| lq_patch_size = gt_patch_size // scale | |||||
| # randomly choose top and left coordinates for lq patch | |||||
| top = random.randint(0, h_lq - lq_patch_size) | |||||
| left = random.randint(0, w_lq - lq_patch_size) | |||||
| # crop lq patch | |||||
| img_lqs = [ | |||||
| v[top:top + lq_patch_size, left:left + lq_patch_size, ...] | |||||
| for v in img_lqs | |||||
| ] | |||||
| # crop corresponding gt patch | |||||
| top_gt, left_gt = int(top * scale), int(left * scale) | |||||
| img_gts = [ | |||||
| v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] | |||||
| for v in img_gts | |||||
| ] | |||||
| if len(img_gts) == 1: | |||||
| img_gts = img_gts[0] | |||||
| if len(img_lqs) == 1: | |||||
| img_lqs = img_lqs[0] | |||||
| return img_gts, img_lqs | |||||
| def augment(imgs, hflip=True, rotation=True, vflip=False): | |||||
| """Augment: horizontal flips | rotate | |||||
| All the images in the list use the same augmentation. | |||||
| """ | |||||
| hflip = hflip and random.random() < 0.5 | |||||
| if vflip or rotation: | |||||
| vflip = random.random() < 0.5 | |||||
| rot90 = rotation and random.random() < 0.5 | |||||
| def _augment(img): | |||||
| if hflip: # horizontal | |||||
| img = img[:, ::-1, :].copy() | |||||
| if vflip: # vertical | |||||
| img = img[::-1, :, :].copy() | |||||
| if rot90: | |||||
| img = img.transpose(1, 0, 2) | |||||
| return img | |||||
| if not isinstance(imgs, list): | |||||
| imgs = [imgs] | |||||
| imgs = [_augment(img) for img in imgs] | |||||
| if len(imgs) == 1: | |||||
| imgs = imgs[0] | |||||
| return imgs | |||||
| @@ -74,6 +74,7 @@ TASK_OUTPUTS = { | |||||
| Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_denoise: [OutputKeys.OUTPUT_IMG], | |||||
| Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], | Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], | Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], | ||||
| @@ -35,6 +35,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| ), # TODO: revise back after passing the pr | ), # TODO: revise back after passing the pr | ||||
| Tasks.image_matting: (Pipelines.image_matting, | Tasks.image_matting: (Pipelines.image_matting, | ||||
| 'damo/cv_unet_image-matting'), | 'damo/cv_unet_image-matting'), | ||||
| Tasks.image_denoise: (Pipelines.image_denoise, | |||||
| 'damo/cv_nafnet_image-denoise_sidd'), | |||||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | Tasks.text_classification: (Pipelines.sentiment_analysis, | ||||
| 'damo/bert-base-sst2'), | 'damo/bert-base-sst2'), | ||||
| Tasks.text_generation: (Pipelines.text_generation, | Tasks.text_generation: (Pipelines.text_generation, | ||||
| @@ -6,6 +6,7 @@ try: | |||||
| from .action_recognition_pipeline import ActionRecognitionPipeline | from .action_recognition_pipeline import ActionRecognitionPipeline | ||||
| from .animal_recog_pipeline import AnimalRecogPipeline | from .animal_recog_pipeline import AnimalRecogPipeline | ||||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | ||||
| from .image_denoise_pipeline import ImageDenoisePipeline | |||||
| from .image_color_enhance_pipeline import ImageColorEnhancePipeline | from .image_color_enhance_pipeline import ImageColorEnhancePipeline | ||||
| from .virtual_tryon_pipeline import VirtualTryonPipeline | from .virtual_tryon_pipeline import VirtualTryonPipeline | ||||
| from .image_colorization_pipeline import ImageColorizationPipeline | from .image_colorization_pipeline import ImageColorizationPipeline | ||||
| @@ -0,0 +1,111 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.cv import NAFNetForImageDenoise | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input | |||||
| from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from ..base import Pipeline | |||||
| from ..builder import PIPELINES | |||||
| logger = get_logger() | |||||
| __all__ = ['ImageDenoisePipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_denoise, module_name=Pipelines.image_denoise) | |||||
| class ImageDenoisePipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[NAFNetForImageDenoise, str], | |||||
| preprocessor: Optional[ImageDenoisePreprocessor] = None, | |||||
| **kwargs): | |||||
| """ | |||||
| use `model` and `preprocessor` to create a cv image denoise pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| model = model if isinstance( | |||||
| model, NAFNetForImageDenoise) else Model.from_pretrained(model) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.config = model.config | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self.model = model | |||||
| logger.info('load image denoise model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_img(input) | |||||
| test_transforms = transforms.Compose([transforms.ToTensor()]) | |||||
| img = test_transforms(img) | |||||
| result = {'img': img.unsqueeze(0).to(self._device)} | |||||
| return result | |||||
| def crop_process(self, input): | |||||
| output = torch.zeros_like(input) # [1, C, H, W] | |||||
| # determine crop_h and crop_w | |||||
| ih, iw = input.shape[-2:] | |||||
| crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1) | |||||
| overlap = 16 | |||||
| step_h, step_w = ih // crop_rows, iw // crop_cols | |||||
| for y in range(crop_rows): | |||||
| for x in range(crop_cols): | |||||
| crop_y = step_h * y | |||||
| crop_x = step_w * x | |||||
| crop_h = step_h if y < crop_rows - 1 else ih - crop_y | |||||
| crop_w = step_w if x < crop_cols - 1 else iw - crop_x | |||||
| crop_frames = input[:, :, | |||||
| max(0, crop_y - overlap | |||||
| ):min(crop_y + crop_h + overlap, ih), | |||||
| max(0, crop_x - overlap | |||||
| ):min(crop_x + crop_w | |||||
| + overlap, iw)].contiguous() | |||||
| h_start = overlap if max(0, crop_y - overlap) > 0 else 0 | |||||
| w_start = overlap if max(0, crop_x - overlap) > 0 else 0 | |||||
| h_end = h_start + crop_h if min(crop_y + crop_h | |||||
| + overlap, ih) < ih else ih | |||||
| w_end = w_start + crop_w if min(crop_x + crop_w | |||||
| + overlap, iw) < iw else iw | |||||
| output[:, :, crop_y:crop_y + crop_h, | |||||
| crop_x:crop_x + crop_w] = self.model._inference_forward( | |||||
| crop_frames)['outputs'][:, :, h_start:h_end, | |||||
| w_start:w_end] | |||||
| return output | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| def set_phase(model, is_train): | |||||
| if is_train: | |||||
| model.train() | |||||
| else: | |||||
| model.eval() | |||||
| is_train = False | |||||
| set_phase(self.model, is_train) | |||||
| with torch.no_grad(): | |||||
| output = self.crop_process(input['img']) # output Tensor | |||||
| return {'output_tensor': output} | |||||
| def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute( | |||||
| 1, 2, 0).numpy().astype('uint8') | |||||
| return {OutputKeys.OUTPUT_IMG: output_img} | |||||
| @@ -21,6 +21,7 @@ try: | |||||
| from .space.dialog_state_tracking_preprocessor import * # noqa F403 | from .space.dialog_state_tracking_preprocessor import * # noqa F403 | ||||
| from .image import ImageColorEnhanceFinetunePreprocessor | from .image import ImageColorEnhanceFinetunePreprocessor | ||||
| from .image import ImageInstanceSegmentationPreprocessor | from .image import ImageInstanceSegmentationPreprocessor | ||||
| from .image import ImageDenoisePreprocessor | |||||
| except ModuleNotFoundError as e: | except ModuleNotFoundError as e: | ||||
| if str(e) == "No module named 'tensorflow'": | if str(e) == "No module named 'tensorflow'": | ||||
| print(TENSORFLOW_IMPORT_ERROR.format('tts')) | print(TENSORFLOW_IMPORT_ERROR.format('tts')) | ||||
| @@ -138,6 +138,31 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor): | |||||
| return data | return data | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.cv, module_name=Preprocessors.image_denoie_preprocessor) | |||||
| class ImageDenoisePreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """ | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model_dir: str = model_dir | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """process the raw input data | |||||
| Args: | |||||
| data Dict[str, Any] | |||||
| Returns: | |||||
| Dict[str, Any]: the preprocessed data | |||||
| """ | |||||
| return data | |||||
| @PREPROCESSORS.register_module( | @PREPROCESSORS.register_module( | ||||
| Fields.cv, | Fields.cv, | ||||
| module_name=Preprocessors.image_instance_segmentation_preprocessor) | module_name=Preprocessors.image_instance_segmentation_preprocessor) | ||||
| @@ -24,6 +24,7 @@ class CVTasks(object): | |||||
| image_editing = 'image-editing' | image_editing = 'image-editing' | ||||
| image_generation = 'image-generation' | image_generation = 'image-generation' | ||||
| image_matting = 'image-matting' | image_matting = 'image-matting' | ||||
| image_denoise = 'image-denoise' | |||||
| ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
| action_recognition = 'action-recognition' | action_recognition = 'action-recognition' | ||||
| video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
| @@ -0,0 +1,59 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from PIL import Image | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import ImageDenoisePipeline, pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ImageDenoiseTest(unittest.TestCase): | |||||
| model_id = 'damo/cv_nafnet_image-denoise_sidd' | |||||
| demo_image_path = 'data/test/images/noisy-demo-1.png' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| pipeline = ImageDenoisePipeline(cache_path) | |||||
| denoise_img = pipeline( | |||||
| input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||||
| denoise_img = Image.fromarray(denoise_img) | |||||
| w, h = denoise_img.size | |||||
| print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| pipeline_ins = pipeline(task=Tasks.image_denoise, model=model) | |||||
| denoise_img = pipeline_ins( | |||||
| input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||||
| denoise_img = Image.fromarray(denoise_img) | |||||
| w, h = denoise_img.size | |||||
| print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_ins = pipeline(task=Tasks.image_denoise, model=self.model_id) | |||||
| denoise_img = pipeline_ins( | |||||
| input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||||
| denoise_img = Image.fromarray(denoise_img) | |||||
| w, h = denoise_img.size | |||||
| print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.image_denoise) | |||||
| denoise_img = pipeline_ins( | |||||
| input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] | |||||
| denoise_img = Image.fromarray(denoise_img) | |||||
| w, h = denoise_img.size | |||||
| print('pipeline: the shape of output_img is {}x{}'.format(h, w)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -0,0 +1,74 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import NAFNetForImageDenoise | |||||
| from modelscope.msdatasets.image_denoise_data.image_denoise_dataset import \ | |||||
| PairedImageDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| class ImageDenoiseTrainerTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| self.model_id = 'damo/cv_nafnet_image-denoise_sidd' | |||||
| self.cache_path = snapshot_download(self.model_id) | |||||
| self.config = Config.from_file( | |||||
| os.path.join(self.cache_path, ModelFile.CONFIGURATION)) | |||||
| self.dataset_train = PairedImageDataset( | |||||
| self.config.dataset, self.cache_path, is_train=True) | |||||
| self.dataset_val = PairedImageDataset( | |||||
| self.config.dataset, self.cache_path, is_train=False) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | |||||
| kwargs = dict( | |||||
| model=self.model_id, | |||||
| train_dataset=self.dataset_train, | |||||
| eval_dataset=self.dataset_val, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(2): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_trainer_with_model_and_args(self): | |||||
| model = NAFNetForImageDenoise.from_pretrained(self.cache_path) | |||||
| kwargs = dict( | |||||
| cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), | |||||
| model=model, | |||||
| train_dataset=self.dataset_train, | |||||
| eval_dataset=self.dataset_val, | |||||
| max_epochs=2, | |||||
| work_dir=self.tmp_dir) | |||||
| trainer = build_trainer(default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(2): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||