diff --git a/modelscope/metrics/image_denoise_metric.py b/modelscope/metrics/image_denoise_metric.py index c6df8df1..1692f299 100644 --- a/modelscope/metrics/image_denoise_metric.py +++ b/modelscope/metrics/image_denoise_metric.py @@ -1,14 +1,16 @@ -# The code is modified based on BasicSR metrics: -# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/metrics/psnr_ssim.py +# ------------------------------------------------------------------------ +# Copyright (c) Alibaba, Inc. and its affiliates. +# ------------------------------------------------------------------------ +# modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py +# ------------------------------------------------------------------------ from typing import Dict import cv2 import numpy as np +import torch from modelscope.metainfo import Metrics from modelscope.utils.registry import default_group -from modelscope.utils.tensor_utils import (torch_nested_detach, - torch_nested_numpify) from .base import Metric from .builder import METRICS, MetricKeys @@ -22,16 +24,15 @@ class ImageDenoiseMetric(Metric): label_name = 'target' def __init__(self): + super(ImageDenoiseMetric, self).__init__() self.preds = [] self.labels = [] def add(self, outputs: Dict, inputs: Dict): ground_truths = outputs[ImageDenoiseMetric.label_name] eval_results = outputs[ImageDenoiseMetric.pred_name] - self.preds.append( - torch_nested_numpify(torch_nested_detach(eval_results))) - self.labels.append( - torch_nested_numpify(torch_nested_detach(ground_truths))) + self.preds.append(eval_results) + self.labels.append(ground_truths) def evaluate(self): psnr_list, ssim_list = [], [] @@ -69,80 +70,117 @@ def reorder_image(img, input_order='HWC'): return img -def calculate_psnr(img, img2, crop_border, input_order='HWC', **kwargs): +def calculate_psnr(img1, img2, crop_border, input_order='HWC'): """Calculate PSNR (Peak Signal-to-Noise Ratio). - Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio Args: - img (ndarray): Images with range [0, 255]. - img2 (ndarray): Images with range [0, 255]. - crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. - input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. Returns: - float: PSNR result. + float: psnr result. """ - assert img.shape == img2.shape, ( - f'Image shapes are different: {img.shape}, {img2.shape}.') + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') if input_order not in ['HWC', 'CHW']: raise ValueError( - f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' - ) - img = reorder_image(img, input_order=input_order) + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) + + img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) if crop_border != 0: - img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] - img = img.astype(np.float64) - img2 = img2.astype(np.float64) + def _psnr(img1, img2): + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + max_value = 1. if img1.max() <= 1 else 255. + return 20. * np.log10(max_value / np.sqrt(mse)) - mse = np.mean((img - img2)**2) - if mse == 0: - return float('inf') - return 10. * np.log10(255. * 255. / mse) + return _psnr(img1, img2) -def calculate_ssim(img, img2, crop_border, input_order='HWC', **kwargs): +def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True): """Calculate SSIM (structural similarity). - ``Paper: Image quality assessment: From error visibility to structural similarity`` + Ref: + Image quality assessment: From error visibility to structural similarity The results are the same as that of the official released MATLAB code in https://ece.uwaterloo.ca/~z70wang/research/ssim/. For three-channel images, SSIM is calculated for each channel and then averaged. Args: - img (ndarray): Images with range [0, 255]. + img1 (ndarray): Images with range [0, 255]. img2 (ndarray): Images with range [0, 255]. - crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. Returns: - float: SSIM result. + float: ssim result. """ - assert img.shape == img2.shape, ( - f'Image shapes are different: {img.shape}, {img2.shape}.') + assert img1.shape == img2.shape, ( + f'Image shapes are differnet: {img1.shape}, {img2.shape}.') if input_order not in ['HWC', 'CHW']: raise ValueError( - f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"' - ) - img = reorder_image(img, input_order=input_order) + f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + + if type(img1) == torch.Tensor: + if len(img1.shape) == 4: + img1 = img1.squeeze(0) + img1 = img1.detach().cpu().numpy().transpose(1, 2, 0) + if type(img2) == torch.Tensor: + if len(img2.shape) == 4: + img2 = img2.squeeze(0) + img2 = img2.detach().cpu().numpy().transpose(1, 2, 0) + + img1 = reorder_image(img1, input_order=input_order) img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + if crop_border != 0: - img = img[crop_border:-crop_border, crop_border:-crop_border, ...] + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] - img = img.astype(np.float64) - img2 = img2.astype(np.float64) + def _cal_ssim(img1, img2): + ssims = [] + + max_value = 1 if img1.max() <= 1 else 255 + with torch.no_grad(): + final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim( + img1, img2, max_value) + ssims.append(final_ssim) - ssims = [] - for i in range(img.shape[2]): - ssims.append(_ssim(img[..., i], img2[..., i])) - return np.array(ssims).mean() + return np.array(ssims).mean() + return _cal_ssim(img1, img2) -def _ssim(img, img2): + +def _ssim(img, img2, max_value): """Calculate SSIM (structural similarity) for one channel images. It is called by func:`calculate_ssim`. Args: @@ -152,8 +190,11 @@ def _ssim(img, img2): float: SSIM result. """ - c1 = (0.01 * 255)**2 - c2 = (0.03 * 255)**2 + c1 = (0.01 * max_value)**2 + c2 = (0.03 * max_value)**2 + + img = img.astype(np.float64) + img2 = img2.astype(np.float64) kernel = cv2.getGaussianKernel(11, 1.5) window = np.outer(kernel, kernel.transpose()) @@ -171,3 +212,61 @@ def _ssim(img, img2): tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) ssim_map = tmp1 / tmp2 return ssim_map.mean() + + +def _3d_gaussian_calculator(img, conv3d): + out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) + return out + + +def _generate_3d_gaussian_kernel(): + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + kernel_3 = cv2.getGaussianKernel(11, 1.5) + kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) + conv3d = torch.nn.Conv3d( + 1, + 1, (11, 11, 11), + stride=1, + padding=(5, 5, 5), + bias=False, + padding_mode='replicate') + conv3d.weight.requires_grad = False + conv3d.weight[0, 0, :, :, :] = kernel + return conv3d + + +def _ssim_3d(img1, img2, max_value): + assert len(img1.shape) == 3 and len(img2.shape) == 3 + """Calculate SSIM (structural similarity) for one channel images. + It is called by func:`calculate_ssim`. + Args: + img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. + Returns: + float: ssim result. + """ + C1 = (0.01 * max_value)**2 + C2 = (0.03 * max_value)**2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + kernel = _generate_3d_gaussian_kernel().cuda() + + img1 = torch.tensor(img1).float().cuda() + img2 = torch.tensor(img2).float().cuda() + + mu1 = _3d_gaussian_calculator(img1, kernel) + mu2 = _3d_gaussian_calculator(img2, kernel) + + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq + sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq + sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2 + + tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2) + tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ssim_map = tmp1 / tmp2 + return float(ssim_map.mean()) diff --git a/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py index a6fbf22f..4e8fc0ed 100644 --- a/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py +++ b/modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py @@ -3,7 +3,6 @@ import os from copy import deepcopy from typing import Any, Dict, Union -import numpy as np import torch.cuda from torch.nn.parallel import DataParallel, DistributedDataParallel @@ -78,13 +77,8 @@ class NAFNetForImageDenoise(TorchModel): def _evaluate_postprocess(self, input: Tensor, target: Tensor) -> Dict[str, list]: preds = self.model(input) - preds = list(torch.split(preds, 1, 0)) - targets = list(torch.split(target, 1, 0)) - - preds = [(pred.data * 255.).squeeze(0).permute( - 1, 2, 0).cpu().numpy().astype(np.uint8) for pred in preds] - targets = [(target.data * 255.).squeeze(0).permute( - 1, 2, 0).cpu().numpy().astype(np.uint8) for target in targets] + preds = list(torch.split(preds.clamp(0, 1), 1, 0)) + targets = list(torch.split(target.clamp(0, 1), 1, 0)) return {'pred': preds, 'target': targets} diff --git a/modelscope/msdatasets/task_datasets/__init__.py b/modelscope/msdatasets/task_datasets/__init__.py index 35c060f0..7c31969a 100644 --- a/modelscope/msdatasets/task_datasets/__init__.py +++ b/modelscope/msdatasets/task_datasets/__init__.py @@ -26,6 +26,7 @@ else: 'video_summarization_dataset': ['VideoSummarizationDataset'], 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], 'image_inpainting': ['ImageInpaintingDataset'], + 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], } import sys