You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

image.py 11 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """image"""
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import functional as F
  21. from mindspore.ops.primitive import constexpr
  22. from mindspore._checkparam import Validator as validator
  23. from mindspore._checkparam import Rel
  24. from ..cell import Cell
  25. class ImageGradients(Cell):
  26. r"""
  27. Returns two tensors, the first is along the height dimension and the second is along the width dimension.
  28. Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`,
  29. respectively.
  30. .. math::
  31. dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
  32. 0, &if\ i==h-1\end{cases}
  33. dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
  34. 0, &if\ i==w-1\end{cases}
  35. Inputs:
  36. - **images** (Tensor) - The input image data, with format 'NCHW'.
  37. Outputs:
  38. - **dy** (Tensor) - vertical image gradients, the same type and shape as input.
  39. - **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
  40. Examples:
  41. >>> net = nn.ImageGradients()
  42. >>> image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
  43. >>> net(image)
  44. [[[[2,2]
  45. [0,0]]]]
  46. [[[[1,0]
  47. [1,0]]]]
  48. """
  49. def __init__(self):
  50. super(ImageGradients, self).__init__()
  51. def construct(self, images):
  52. batch_size, depth, height, width = P.Shape()(images)
  53. dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
  54. dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
  55. dy = P.Concat(2)((dy, dy_last))
  56. dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
  57. dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
  58. dx = P.Concat(3)((dx, dx_last))
  59. return dy, dx
  60. def _convert_img_dtype_to_float32(img, max_val):
  61. """convert img dtype to float32"""
  62. # Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
  63. # We will scale img pixel value if max_val > 1. and just cast otherwise.
  64. ret = F.cast(img, mstype.float32)
  65. max_val = F.scalar_cast(max_val, mstype.float32)
  66. if max_val > 1.:
  67. scale = 1. / max_val
  68. ret = ret * scale
  69. return ret
  70. @constexpr
  71. def _gauss_kernel_helper(filter_size):
  72. """gauss kernel helper"""
  73. filter_size = F.scalar_cast(filter_size, mstype.int32)
  74. coords = ()
  75. for i in range(filter_size):
  76. i_cast = F.scalar_cast(i, mstype.float32)
  77. offset = F.scalar_cast(filter_size-1, mstype.float32)/2.0
  78. element = i_cast-offset
  79. coords = coords+(element,)
  80. g = np.square(coords).astype(np.float32)
  81. g = Tensor(g)
  82. return filter_size, g
  83. @constexpr
  84. def _check_input_4d(input_shape, param_name, func_name):
  85. if len(input_shape) != 4:
  86. raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}")
  87. return True
  88. class SSIM(Cell):
  89. r"""
  90. Returns SSIM index between img1 and img2.
  91. Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). `Image quality
  92. assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_.
  93. IEEE transactions on image processing.
  94. .. math::
  95. l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
  96. c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
  97. s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
  98. SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}.
  99. Args:
  100. max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
  101. Default: 1.0.
  102. filter_size (int): The size of the Gaussian filter. Default: 11.
  103. filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
  104. k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
  105. k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
  106. Inputs:
  107. - **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2.
  108. - **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1.
  109. Outputs:
  110. Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
  111. Examples:
  112. >>> net = nn.SSIM()
  113. >>> img1 = Tensor(np.random.random((1,3,16,16)))
  114. >>> img2 = Tensor(np.random.random((1,3,16,16)))
  115. >>> ssim = net(img1, img2)
  116. """
  117. def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
  118. super(SSIM, self).__init__()
  119. validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
  120. validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
  121. self.max_val = max_val
  122. self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
  123. self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
  124. validator.check_value_type('k1', k1, [float], self.cls_name)
  125. self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
  126. validator.check_value_type('k2', k2, [float], self.cls_name)
  127. self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
  128. self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
  129. def construct(self, img1, img2):
  130. _check_input_4d(F.shape(img1), "img1", "SSIM")
  131. _check_input_4d(F.shape(img2), "img2", "SSIM")
  132. P.SameTypeShape()(img1, img2)
  133. max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
  134. img1 = _convert_img_dtype_to_float32(img1, self.max_val)
  135. img2 = _convert_img_dtype_to_float32(img2, self.max_val)
  136. kernel = self._fspecial_gauss(self.filter_size, self.filter_sigma)
  137. kernel = P.Tile()(kernel, (1, P.Shape()(img1)[1], 1, 1))
  138. mean_ssim = self._calculate_mean_ssim(img1, img2, kernel, max_val, self.k1, self.k2)
  139. return mean_ssim
  140. def _calculate_mean_ssim(self, x, y, kernel, max_val, k1, k2):
  141. """calculate mean ssim"""
  142. c1 = (k1 * max_val) * (k1 * max_val)
  143. c2 = (k2 * max_val) * (k2 * max_val)
  144. # SSIM luminance formula
  145. # (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1)
  146. mean_x = self.mean(x, kernel)
  147. mean_y = self.mean(y, kernel)
  148. square_sum = F.square(mean_x)+F.square(mean_y)
  149. luminance = (2*mean_x*mean_y+c1)/(square_sum+c1)
  150. # SSIM contrast*structure formula (when c3 = c2/2)
  151. # (2 * conv_{xy} + c2) / (conv_{xx} + conv_{yy} + c2), equals to
  152. # (2 * (mean_{xy} - mean_{x}*mean_{y}) + c2) / (mean_{xx}-mean_{x}**2 + mean_{yy}-mean_{y}**2 + c2)
  153. mean_xy = self.mean(x*y, kernel)
  154. mean_square_add = self.mean(F.square(x)+F.square(y), kernel)
  155. cs = (2*(mean_xy-mean_x*mean_y)+c2)/(mean_square_add-square_sum+c2)
  156. # SSIM formula
  157. # luminance * cs
  158. ssim = luminance*cs
  159. mean_ssim = P.ReduceMean()(ssim, (-3, -2, -1))
  160. return mean_ssim
  161. def _fspecial_gauss(self, filter_size, filter_sigma):
  162. """get gauss kernel"""
  163. filter_size, g = _gauss_kernel_helper(filter_size)
  164. square_sigma_scale = -0.5/(filter_sigma * filter_sigma)
  165. g = g*square_sigma_scale
  166. g = F.reshape(g, (1, -1))+F.reshape(g, (-1, 1))
  167. g = F.reshape(g, (1, -1))
  168. g = P.Softmax()(g)
  169. ret = F.reshape(g, (1, 1, filter_size, filter_size))
  170. return ret
  171. class PSNR(Cell):
  172. r"""
  173. Returns Peak Signal-to-Noise Ratio of two image batches.
  174. It produces a PSNR value for each image in batch.
  175. Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`.
  176. :math:`MAX` represents the dynamic range of pixel values.
  177. .. math::
  178. MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\
  179. PSNR&=10*log_{10}(\frac{MAX^2}{MSE})
  180. Args:
  181. max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
  182. Default: 1.0.
  183. Inputs:
  184. - **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2.
  185. - **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1.
  186. Outputs:
  187. Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.
  188. Examples:
  189. >>> net = nn.PSNR()
  190. >>> img1 = Tensor(np.random.random((1,3,16,16)))
  191. >>> img2 = Tensor(np.random.random((1,3,16,16)))
  192. >>> psnr = net(img1, img2)
  193. """
  194. def __init__(self, max_val=1.0):
  195. super(PSNR, self).__init__()
  196. validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
  197. validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
  198. self.max_val = max_val
  199. def construct(self, img1, img2):
  200. _check_input_4d(F.shape(img1), "img1", "PSNR")
  201. _check_input_4d(F.shape(img2), "img2", "PSNR")
  202. P.SameTypeShape()(img1, img2)
  203. max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val)
  204. img1 = _convert_img_dtype_to_float32(img1, self.max_val)
  205. img2 = _convert_img_dtype_to_float32(img2, self.max_val)
  206. mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1))
  207. # 10*log_10(max_val^2/MSE)
  208. psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)
  209. return psnr