Merge pull request !681 from zhaozhenlong/psnr-check-same-shapetags/v0.3.0-alpha
| @@ -95,6 +95,11 @@ def _gauss_kernel_helper(filter_size): | |||
| g = Tensor(g) | |||
| return filter_size, g | |||
| @constexpr | |||
| def _check_input_4d(input_shape, param_name, func_name): | |||
| if len(input_shape) != 4: | |||
| raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}") | |||
| return True | |||
| class SSIM(Cell): | |||
| r""" | |||
| @@ -146,6 +151,9 @@ class SSIM(Cell): | |||
| self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) | |||
| def construct(self, img1, img2): | |||
| _check_input_4d(F.shape(img1), "img1", "SSIM") | |||
| _check_input_4d(F.shape(img2), "img2", "SSIM") | |||
| P.SameTypeShape()(img1, img2) | |||
| max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) | |||
| img1 = _convert_img_dtype_to_float32(img1, self.max_val) | |||
| img2 = _convert_img_dtype_to_float32(img2, self.max_val) | |||
| @@ -236,6 +244,9 @@ class PSNR(Cell): | |||
| self.max_val = max_val | |||
| def construct(self, img1, img2): | |||
| _check_input_4d(F.shape(img1), "img1", "PSNR") | |||
| _check_input_4d(F.shape(img2), "img2", "PSNR") | |||
| P.SameTypeShape()(img1, img2) | |||
| max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) | |||
| img1 = _convert_img_dtype_to_float32(img1, self.max_val) | |||
| img2 = _convert_img_dtype_to_float32(img2, self.max_val) | |||
| @@ -18,10 +18,12 @@ test psnr | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.api import _executor | |||
| from mindspore import Tensor | |||
| class PSNRNet(nn.Cell): | |||
| def __init__(self, max_val=1.0): | |||
| super(PSNRNet, self).__init__() | |||
| @@ -59,3 +61,38 @@ def test_psnr_max_val_zero(): | |||
| max_val = 0 | |||
| with pytest.raises(ValueError): | |||
| net = PSNRNet(max_val) | |||
| def test_psnr_different_shape(): | |||
| shape_1 = (8, 3, 16, 16) | |||
| shape_2 = (8, 3, 8, 8) | |||
| img1 = Tensor(np.random.random(shape_1)) | |||
| img2 = Tensor(np.random.random(shape_2)) | |||
| net = PSNRNet() | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, img1, img2) | |||
| def test_psnr_different_dtype(): | |||
| dtype_1 = mstype.float32 | |||
| dtype_2 = mstype.float16 | |||
| img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1) | |||
| img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2) | |||
| net = PSNRNet() | |||
| with pytest.raises(TypeError): | |||
| _executor.compile(net, img1, img2) | |||
| def test_psnr_invalid_5d_input(): | |||
| shape_1 = (8, 3, 16, 16) | |||
| shape_2 = (8, 3, 8, 8) | |||
| invalid_shape = (8, 3, 16, 16, 1) | |||
| img1 = Tensor(np.random.random(shape_1)) | |||
| invalid_img1 = Tensor(np.random.random(invalid_shape)) | |||
| img2 = Tensor(np.random.random(shape_2)) | |||
| invalid_img2 = Tensor(np.random.random(invalid_shape)) | |||
| net = PSNRNet() | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, invalid_img1, img2) | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, img1, invalid_img2) | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, invalid_img1, invalid_img2) | |||
| @@ -18,6 +18,7 @@ test ssim | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.api import _executor | |||
| from mindspore import Tensor | |||
| @@ -93,3 +94,38 @@ def test_ssim_k1_k2_wrong_value(): | |||
| net = SSIMNet(k2=0.0) | |||
| with pytest.raises(ValueError): | |||
| net = SSIMNet(k2=-1.0) | |||
| def test_ssim_different_shape(): | |||
| shape_1 = (8, 3, 16, 16) | |||
| shape_2 = (8, 3, 8, 8) | |||
| img1 = Tensor(np.random.random(shape_1)) | |||
| img2 = Tensor(np.random.random(shape_2)) | |||
| net = SSIMNet() | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, img1, img2) | |||
| def test_ssim_different_dtype(): | |||
| dtype_1 = mstype.float32 | |||
| dtype_2 = mstype.float16 | |||
| img1 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_1) | |||
| img2 = Tensor(np.random.random((8, 3, 16, 16)), dtype=dtype_2) | |||
| net = SSIMNet() | |||
| with pytest.raises(TypeError): | |||
| _executor.compile(net, img1, img2) | |||
| def test_ssim_invalid_5d_input(): | |||
| shape_1 = (8, 3, 16, 16) | |||
| shape_2 = (8, 3, 8, 8) | |||
| invalid_shape = (8, 3, 16, 16, 1) | |||
| img1 = Tensor(np.random.random(shape_1)) | |||
| invalid_img1 = Tensor(np.random.random(invalid_shape)) | |||
| img2 = Tensor(np.random.random(shape_2)) | |||
| invalid_img2 = Tensor(np.random.random(invalid_shape)) | |||
| net = SSIMNet() | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, invalid_img1, img2) | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, img1, invalid_img2) | |||
| with pytest.raises(ValueError): | |||
| _executor.compile(net, invalid_img1, invalid_img2) | |||