diff --git a/modelscope/metrics/image_denoise_metric.py b/modelscope/metrics/image_denoise_metric.py index 94ec9dc7..c6df8df1 100644 --- a/modelscope/metrics/image_denoise_metric.py +++ b/modelscope/metrics/image_denoise_metric.py @@ -1,7 +1,9 @@ +# The code is modified based on BasicSR metrics: +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py from typing import Dict +import cv2 import numpy as np -from skimage.metrics import peak_signal_noise_ratio, structural_similarity from modelscope.metainfo import Metrics from modelscope.utils.registry import default_group @@ -34,12 +36,138 @@ class ImageDenoiseMetric(Metric): def evaluate(self): psnr_list, ssim_list = [], [] for (pred, label) in zip(self.preds, self.labels): - psnr_list.append( - peak_signal_noise_ratio(label[0], pred[0], data_range=255)) - ssim_list.append( - structural_similarity( - label[0], pred[0], multichannel=True, data_range=255)) + psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0)) + ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0)) return { MetricKeys.PSNR: np.mean(psnr_list), MetricKeys.SSIM: np.mean(ssim_list) } + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'" + ) + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def calculate_psnr(img, img2, crop_border, input_order='HWC', **kwargs): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' + ) + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + mse = np.mean((img - img2)**2) + if mse == 0: + return float('inf') + return 10. * np.log10(255. * 255. / mse) + + +def calculate_ssim(img, img2, crop_border, input_order='HWC', **kwargs): + """Calculate SSIM (structural similarity). + ``Paper: Image quality assessment: From error visibility to structural similarity`` + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + For three-channel images, SSIM is calculated for each channel and then + averaged. + Args: + img (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + Returns: + float: SSIM result. + """ + + assert img.shape == img2.shape, ( + f'Image shapes are different: {img.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError( + f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' + ) + img = reorder_image(img, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + + if crop_border != 0: + img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) + + ssims = [] + for i in range(img.shape[2]): + ssims.append(_ssim(img[..., i], img2[..., i])) + return np.array(ssims).mean() + + +def _ssim(img, img2): + """Calculate SSIM (structural similarity) for one channel images. + It is called by func:`calculate_ssim`. + Args: + img (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + Returns: + float: SSIM result. + """ + + c1 = (0.01 * 255)**2 + c2 = (0.03 * 255)**2 + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img, -1, window)[5:-5, + 5:-5] # valid mode for window size 11 + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2) + tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ssim_map = tmp1 / tmp2 + return ssim_map.mean() diff --git a/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py b/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py index 5b4e8ce1..c4de0729 100644 --- a/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py +++ b/modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------ +# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/models/archs/NAFNet_arch.py +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + import numpy as np import torch import torch.nn as nn diff --git a/modelscope/models/cv/image_denoise/nafnet/arch_util.py b/modelscope/models/cv/image_denoise/nafnet/arch_util.py index df394dd5..2d406141 100644 --- a/modelscope/models/cv/image_denoise/nafnet/arch_util.py +++ b/modelscope/models/cv/image_denoise/nafnet/arch_util.py @@ -1,3 +1,8 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + import torch import torch.nn as nn diff --git a/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py index c484b37b..a6fbf22f 100644 --- a/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py +++ b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py @@ -1,3 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from copy import deepcopy from typing import Any, Dict, Union diff --git a/modelscope/msdatasets/image_denoise_data/data_utils.py b/modelscope/msdatasets/image_denoise_data/data_utils.py deleted file mode 100644 index dd735830..00000000 --- a/modelscope/msdatasets/image_denoise_data/data_utils.py +++ /dev/null @@ -1,152 +0,0 @@ -# ------------------------------------------------------------------------ -# Modified from BasicSR (https://github.com/xinntao/BasicSR) -# Copyright 2018-2020 BasicSR Authors -# ------------------------------------------------------------------------ -import os -from os import path as osp - -import cv2 -import numpy as np -import torch - -from .transforms import mod_crop - - -def img2tensor(imgs, bgr2rgb=True, float32=True): - """Numpy array to tensor. - Args: - imgs (list[ndarray] | ndarray): Input images. - bgr2rgb (bool): Whether to change bgr to rgb. - float32 (bool): Whether to change to float32. - Returns: - list[tensor] | tensor: Tensor images. If returned results only have - one element, just return tensor. - """ - - def _totensor(img, bgr2rgb, float32): - if img.shape[2] == 3 and bgr2rgb: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = torch.from_numpy(img.transpose(2, 0, 1)) - if float32: - img = img.float() - return img - - if isinstance(imgs, list): - return [_totensor(img, bgr2rgb, float32) for img in imgs] - else: - return _totensor(imgs, bgr2rgb, float32) - - -def scandir(dir_path, keyword=None, recursive=False, full_path=False): - """Scan a directory to find the interested files. - Args: - dir_path (str): Path of the directory. - keyword (str | tuple(str), optional): File keyword that we are - interested in. Default: None. - recursive (bool, optional): If set to True, recursively scan the - directory. Default: False. - full_path (bool, optional): If set to True, include the dir_path. - Default: False. - Returns: - A generator for all the interested files with relative pathes. - """ - - if (keyword is not None) and not isinstance(keyword, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - root = dir_path - - def _scandir(dir_path, keyword, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - if full_path: - return_path = entry.path - else: - return_path = osp.relpath(entry.path, root) - - if keyword is None: - yield return_path - elif keyword in return_path: - yield return_path - else: - if recursive: - yield from _scandir( - entry.path, keyword=keyword, recursive=recursive) - else: - continue - - return _scandir(dir_path, keyword=keyword, recursive=recursive) - - -def padding(img_lq, img_gt, gt_size): - h, w, _ = img_lq.shape - - h_pad = max(0, gt_size - h) - w_pad = max(0, gt_size - w) - - if h_pad == 0 and w_pad == 0: - return img_lq, img_gt - - img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) - img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) - return img_lq, img_gt - - -def read_img_seq(path, require_mod_crop=False, scale=1): - """Read a sequence of images from a given folder path. - Args: - path (list[str] | str): List of image paths or image folder path. - require_mod_crop (bool): Require mod crop for each image. - Default: False. - scale (int): Scale factor for mod_crop. Default: 1. - Returns: - Tensor: size (t, c, h, w), RGB, [0, 1]. - """ - if isinstance(path, list): - img_paths = path - else: - img_paths = sorted(list(scandir(path, full_path=True))) - imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] - if require_mod_crop: - imgs = [mod_crop(img, scale) for img in imgs] - imgs = img2tensor(imgs, bgr2rgb=True, float32=True) - imgs = torch.stack(imgs, dim=0) - return imgs - - -def paired_paths_from_folder(folders, keys, filename_tmpl): - """Generate paired paths from folders. - Args: - folders (list[str]): A list of folder path. The order of list should - be [input_folder, gt_folder]. - keys (list[str]): A list of keys identifying folders. The order should - be in consistent with folders, e.g., ['lq', 'gt']. - filename_tmpl (str): Template for each filename. Note that the - template excludes the file extension. Usually the filename_tmpl is - for files in the input folder. - Returns: - list[str]: Returned path list. - """ - assert len(folders) == 2, ( - 'The len of folders should be 2 with [input_folder, gt_folder]. ' - f'But got {len(folders)}') - assert len(keys) == 2, ( - 'The len of keys should be 2 with [input_key, gt_key]. ' - f'But got {len(keys)}') - input_folder, gt_folder = folders - input_key, gt_key = keys - - input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True)) - gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True)) - assert len(input_paths) == len(gt_paths), ( - f'{input_key} and {gt_key} datasets have different number of images: ' - f'{len(input_paths)}, {len(gt_paths)}.') - paths = [] - for idx in range(len(gt_paths)): - gt_path = os.path.join(gt_folder, gt_paths[idx]) - input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY')) - - paths.append( - dict([(f'{input_key}_path', input_path), - (f'{gt_key}_path', gt_path)])) - return paths diff --git a/modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py b/modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py deleted file mode 100644 index 96b777e6..00000000 --- a/modelscope/msdatasets/image_denoise_data/image_denoise_dataset.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -from typing import Callable, List, Optional, Tuple, Union - -import cv2 -import numpy as np -from torch.utils import data - -from .data_utils import img2tensor, padding, paired_paths_from_folder -from .transforms import augment, paired_random_crop - - -def default_loader(path): - return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 - - -class PairedImageDataset(data.Dataset): - """Paired image dataset for image restoration. - """ - - def __init__(self, opt, root, is_train): - super(PairedImageDataset, self).__init__() - self.opt = opt - self.is_train = is_train - self.gt_folder, self.lq_folder = os.path.join( - root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq) - - if opt.filename_tmpl is not None: - self.filename_tmpl = opt.filename_tmpl - else: - self.filename_tmpl = '{}' - self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], - ['lq', 'gt'], self.filename_tmpl) - - def __getitem__(self, index): - scale = self.opt.scale - - # Load gt and lq images. Dimension order: HWC; channel order: BGR; - # image range: [0, 1], float32. - gt_path = self.paths[index]['gt_path'] - img_gt = default_loader(gt_path) - lq_path = self.paths[index]['lq_path'] - img_lq = default_loader(lq_path) - - # augmentation for training - # if self.is_train: - gt_size = self.opt.gt_size - # padding - img_gt, img_lq = padding(img_gt, img_lq, gt_size) - - # random crop - img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale) - - # flip, rotation - img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, - self.opt.use_rot) - - # BGR to RGB, HWC to CHW, numpy to tensor - img_gt, img_lq = img2tensor([img_gt, img_lq], - bgr2rgb=True, - float32=True) - - return { - 'input': img_lq, - 'target': img_gt, - 'input_path': lq_path, - 'target_path': gt_path - } - - def __len__(self): - return len(self.paths) - - def to_torch_dataset( - self, - columns: Union[str, List[str]] = None, - preprocessors: Union[Callable, List[Callable]] = None, - **format_kwargs, - ): - return self diff --git a/modelscope/msdatasets/image_denoise_data/__init__.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py similarity index 73% rename from modelscope/msdatasets/image_denoise_data/__init__.py rename to modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py index ba1d2df8..5376cd7c 100644 --- a/modelscope/msdatasets/image_denoise_data/__init__.py +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/__init__.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .image_denoise_dataset import PairedImageDataset + from .sidd_image_denoising_dataset import SiddImageDenoisingDataset else: _import_structure = { - 'image_denoise_dataset': ['PairedImageDataset'], + 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], } import sys diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py new file mode 100644 index 00000000..33fce4c8 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/data_utils.py @@ -0,0 +1,46 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import cv2 +import torch + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def padding(img_lq, img_gt, gt_size): + h, w, _ = img_lq.shape + + h_pad = max(0, gt_size - h) + w_pad = max(0, gt_size - w) + + if h_pad == 0 and w_pad == 0: + return img_lq, img_gt + + img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) + return img_lq, img_gt diff --git a/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py new file mode 100644 index 00000000..3f0cdae0 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/sidd_image_denoising/sidd_image_denoising_dataset.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.utils.constant import Tasks +from .data_utils import img2tensor, padding +from .transforms import augment, paired_random_crop + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0 + + +@TASK_DATASETS.register_module( + Tasks.image_denoising, module_name=Models.nafnet) +class SiddImageDenoisingDataset(TorchTaskDataset): + """Paired image dataset for image restoration. + """ + + def __init__(self, dataset, opt, is_train): + self.dataset = dataset + self.opt = opt + self.is_train = is_train + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load gt and lq images. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32. + item_dict = self.dataset[index] + gt_path = item_dict['Clean Image:FILE'] + img_gt = default_loader(gt_path) + lq_path = item_dict['Noisy Image:FILE'] + img_lq = default_loader(lq_path) + + # augmentation for training + if self.is_train: + gt_size = self.opt.gt_size + # padding + img_gt, img_lq = padding(img_gt, img_lq, gt_size) + + # random crop + img_gt, img_lq = paired_random_crop( + img_gt, img_lq, gt_size, scale=1) + + # flip, rotation + img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip, + self.opt.use_rot) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], + bgr2rgb=True, + float32=True) + + return {'input': img_lq, 'target': img_gt} diff --git a/modelscope/msdatasets/image_denoise_data/transforms.py b/modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py similarity index 100% rename from modelscope/msdatasets/image_denoise_data/transforms.py rename to modelscope/msdatasets/task_datasets/sidd_image_denoising/transforms.py diff --git a/modelscope/pipelines/cv/image_denoise_pipeline.py b/modelscope/pipelines/cv/image_denoise_pipeline.py index a11abf36..34ac1e81 100644 --- a/modelscope/pipelines/cv/image_denoise_pipeline.py +++ b/modelscope/pipelines/cv/image_denoise_pipeline.py @@ -105,4 +105,4 @@ class ImageDenoisePipeline(Pipeline): def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute( 1, 2, 0).numpy().astype('uint8') - return {OutputKeys.OUTPUT_IMG: output_img} + return {OutputKeys.OUTPUT_IMG: output_img[:, :, ::-1]} diff --git a/tests/pipelines/test_image_denoise.py b/tests/pipelines/test_image_denoise.py index bf8cfd0f..d95dd343 100644 --- a/tests/pipelines/test_image_denoise.py +++ b/tests/pipelines/test_image_denoise.py @@ -2,8 +2,6 @@ import unittest -from PIL import Image - from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.outputs import OutputKeys @@ -20,16 +18,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.image_denoising self.model_id = 'damo/cv_nafnet_image-denoise_sidd' - demo_image_path = 'data/test/images/noisy-demo-1.png' + demo_image_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/noisy-demo-0.png' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) pipeline = ImageDenoisePipeline(cache_path) + pipeline.group_key = self.task denoise_img = pipeline( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -37,9 +35,8 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck): model = Model.from_pretrained(self.model_id) pipeline_ins = pipeline(task=Tasks.image_denoising, model=model) denoise_img = pipeline_ins( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -47,18 +44,16 @@ class ImageDenoiseTest(unittest.TestCase, DemoCompatibilityCheck): pipeline_ins = pipeline( task=Tasks.image_denoising, model=self.model_id) denoise_img = pipeline_ins( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): pipeline_ins = pipeline(task=Tasks.image_denoising) denoise_img = pipeline_ins( - input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] - denoise_img = Image.fromarray(denoise_img) - w, h = denoise_img.size + input=self.demo_image_path)[OutputKeys.OUTPUT_IMG] # BGR + h, w = denoise_img.shape[:2] print('pipeline: the shape of output_img is {}x{}'.format(h, w)) @unittest.skip('demo compatibility test is only enabled on a needed-basis') diff --git a/tests/trainers/test_image_denoise_trainer.py b/tests/trainers/test_image_denoise_trainer.py index 261ee4ed..0bcb8930 100644 --- a/tests/trainers/test_image_denoise_trainer.py +++ b/tests/trainers/test_image_denoise_trainer.py @@ -6,10 +6,12 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.cv.image_denoise import NAFNetForImageDenoise -from modelscope.msdatasets.image_denoise_data import PairedImageDataset +from modelscope.msdatasets import MsDataset +from modelscope.msdatasets.task_datasets.sidd_image_denoising import \ + SiddImageDenoisingDataset from modelscope.trainers import build_trainer from modelscope.utils.config import Config -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import DownloadMode, ModelFile from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import test_level @@ -28,10 +30,20 @@ class ImageDenoiseTrainerTest(unittest.TestCase): self.cache_path = snapshot_download(self.model_id) self.config = Config.from_file( os.path.join(self.cache_path, ModelFile.CONFIGURATION)) - self.dataset_train = PairedImageDataset( - self.config.dataset, self.cache_path, is_train=True) - self.dataset_val = PairedImageDataset( - self.config.dataset, self.cache_path, is_train=False) + dataset_train = MsDataset.load( + 'SIDD', + namespace='huizheng', + split='validation', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds + dataset_val = MsDataset.load( + 'SIDD', + namespace='huizheng', + split='test', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds + self.dataset_train = SiddImageDenoisingDataset( + dataset_train, self.config.dataset, is_train=True) + self.dataset_val = SiddImageDenoisingDataset( + dataset_val, self.config.dataset, is_train=False) def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True)