You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

image.py 20 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
5 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """image"""
  16. import numbers
  17. import numpy as np
  18. import mindspore.common.dtype as mstype
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. from mindspore.ops.primitive import constexpr
  23. from mindspore._checkparam import Rel, Validator as validator
  24. from .conv import Conv2d
  25. from .container import CellList
  26. from .pooling import AvgPool2d
  27. from .activation import ReLU
  28. from ..cell import Cell
  29. __all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop']
  30. class ImageGradients(Cell):
  31. r"""
  32. Returns two tensors, the first is along the height dimension and the second is along the width dimension.
  33. Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`,
  34. respectively.
  35. .. math::
  36. dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr
  37. 0, &if\ i==h-1\end{cases}
  38. dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr
  39. 0, &if\ i==w-1\end{cases}
  40. Inputs:
  41. - **images** (Tensor) - The input image data, with format 'NCHW'.
  42. Outputs:
  43. - **dy** (Tensor) - vertical image gradients, the same type and shape as input.
  44. - **dx** (Tensor) - horizontal image gradients, the same type and shape as input.
  45. Examples:
  46. >>> net = nn.ImageGradients()
  47. >>> image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32)
  48. >>> net(image)
  49. [[[[2,2]
  50. [0,0]]]]
  51. [[[[1,0]
  52. [1,0]]]]
  53. """
  54. def __init__(self):
  55. super(ImageGradients, self).__init__()
  56. def construct(self, images):
  57. check = _check_input_4d(F.shape(images), "images", self.cls_name)
  58. images = F.depend(images, check)
  59. batch_size, depth, height, width = P.Shape()(images)
  60. if height == 1:
  61. dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
  62. else:
  63. dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
  64. dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0)
  65. dy = P.Concat(2)((dy, dy_last))
  66. if width == 1:
  67. dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
  68. else:
  69. dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
  70. dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0)
  71. dx = P.Concat(3)((dx, dx_last))
  72. return dy, dx
  73. def _convert_img_dtype_to_float32(img, max_val):
  74. """convert img dtype to float32"""
  75. # Ususally max_val is 1.0 or 255, we will do the scaling if max_val > 1.
  76. # We will scale img pixel value if max_val > 1. and just cast otherwise.
  77. ret = F.cast(img, mstype.float32)
  78. max_val = F.scalar_cast(max_val, mstype.float32)
  79. if max_val > 1.:
  80. scale = 1. / max_val
  81. ret = ret * scale
  82. return ret
  83. @constexpr
  84. def _get_dtype_max(dtype):
  85. """get max of the dtype"""
  86. np_type = mstype.dtype_to_nptype(dtype)
  87. if issubclass(np_type, numbers.Integral):
  88. dtype_max = np.float64(np.iinfo(np_type).max)
  89. else:
  90. dtype_max = 1.0
  91. return dtype_max
  92. @constexpr
  93. def _check_input_4d(input_shape, param_name, func_name):
  94. if len(input_shape) != 4:
  95. raise ValueError(f"{func_name} {param_name} should be 4d, but got shape {input_shape}")
  96. return True
  97. @constexpr
  98. def _check_input_filter_size(input_shape, param_name, filter_size, func_name):
  99. _check_input_4d(input_shape, param_name, func_name)
  100. validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name)
  101. validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name)
  102. @constexpr
  103. def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
  104. validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
  105. def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0):
  106. return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  107. weight_init=weight, padding=padding, pad_mode="valid")
  108. def _create_window(size, sigma):
  109. x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
  110. x_data = np.expand_dims(x_data, axis=-1).astype(np.float32)
  111. x_data = np.expand_dims(x_data, axis=-1) ** 2
  112. y_data = np.expand_dims(y_data, axis=-1).astype(np.float32)
  113. y_data = np.expand_dims(y_data, axis=-1) ** 2
  114. sigma = 2 * sigma ** 2
  115. g = np.exp(-(x_data + y_data) / sigma)
  116. return np.transpose(g / np.sum(g), (2, 3, 0, 1))
  117. def _split_img(x):
  118. _, c, _, _ = F.shape(x)
  119. img_split = P.Split(1, c)
  120. output = img_split(x)
  121. return output, c
  122. def _compute_per_channel_loss(c1, c2, img1, img2, conv):
  123. """computes ssim index between img1 and img2 per single channel"""
  124. dot_img = img1 * img2
  125. mu1 = conv(img1)
  126. mu2 = conv(img2)
  127. mu1_sq = mu1 * mu1
  128. mu2_sq = mu2 * mu2
  129. mu1_mu2 = mu1 * mu2
  130. sigma1_tmp = conv(img1 * img1)
  131. sigma1_sq = sigma1_tmp - mu1_sq
  132. sigma2_tmp = conv(img2 * img2)
  133. sigma2_sq = sigma2_tmp - mu2_sq
  134. sigma12_tmp = conv(dot_img)
  135. sigma12 = sigma12_tmp - mu1_mu2
  136. a = (2 * mu1_mu2 + c1)
  137. b = (mu1_sq + mu2_sq + c1)
  138. v1 = 2 * sigma12 + c2
  139. v2 = sigma1_sq + sigma2_sq + c2
  140. ssim = (a * v1) / (b * v2)
  141. cs = v1 / v2
  142. return ssim, cs
  143. def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean):
  144. """computes ssim index between img1 and img2 per color channel"""
  145. split_img1, c = _split_img(img1)
  146. split_img2, _ = _split_img(img2)
  147. multi_ssim = ()
  148. multi_cs = ()
  149. for i in range(c):
  150. ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv)
  151. multi_ssim += (ssim_per_channel,)
  152. multi_cs += (cs_per_channel,)
  153. multi_ssim = concat(multi_ssim)
  154. multi_cs = concat(multi_cs)
  155. ssim = mean(multi_ssim, (2, 3))
  156. cs = mean(multi_cs, (2, 3))
  157. return ssim, cs
  158. class SSIM(Cell):
  159. r"""
  160. Returns SSIM index between img1 and img2.
  161. Its implementation is based on Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). `Image quality
  162. assessment: from error visibility to structural similarity <https://ieeexplore.ieee.org/document/1284395>`_.
  163. IEEE transactions on image processing.
  164. .. math::
  165. l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
  166. c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
  167. s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
  168. SSIM(x,y)&=l*c*s\\&=\frac{(2\mu_x\mu_y+C_1)(2\sigma_{xy}+C_2}{(\mu_x^2+\mu_y^2+C_1)(\sigma_x^2+\sigma_y^2+C_2)}.
  169. Args:
  170. max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
  171. Default: 1.0.
  172. filter_size (int): The size of the Gaussian filter. Default: 11. The value must be greater than or equal to 1.
  173. filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5. The value must be greater than 0.
  174. k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
  175. k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
  176. Inputs:
  177. - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
  178. - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
  179. Outputs:
  180. Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1.
  181. Examples:
  182. >>> net = nn.SSIM()
  183. >>> img1 = Tensor(np.random.random((1,3,16,16)), mindspore.float32)
  184. >>> img2 = Tensor(np.random.random((1,3,16,16)), mindspore.float32)
  185. >>> ssim = net(img1, img2)
  186. [0.12174469]
  187. """
  188. def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
  189. super(SSIM, self).__init__()
  190. validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
  191. validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
  192. self.max_val = max_val
  193. self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
  194. self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
  195. self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
  196. self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
  197. window = _create_window(filter_size, filter_sigma)
  198. self.conv = _conv2d(1, 1, filter_size, Tensor(window))
  199. self.conv.weight.requires_grad = False
  200. self.reduce_mean = P.ReduceMean()
  201. self.concat = P.Concat(axis=1)
  202. def construct(self, img1, img2):
  203. _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name)
  204. _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name)
  205. P.SameTypeShape()(img1, img2)
  206. dtype_max_val = _get_dtype_max(F.dtype(img1))
  207. max_val = F.scalar_cast(self.max_val, F.dtype(img1))
  208. max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
  209. img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
  210. img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
  211. c1 = (self.k1 * max_val) ** 2
  212. c2 = (self.k2 * max_val) ** 2
  213. ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean)
  214. loss = self.reduce_mean(ssim_ave_channel, -1)
  215. return loss
  216. def _downsample(img1, img2, op):
  217. a = op(img1)
  218. b = op(img2)
  219. return a, b
  220. class MSSSIM(Cell):
  221. r"""
  222. Returns MS-SSIM index between img1 and img2.
  223. Its implementation is based on Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. `Multiscale structural similarity
  224. for image quality assessment <https://ieeexplore.ieee.org/document/1292216>`_.
  225. Signals, Systems and Computers, 2004.
  226. .. math::
  227. l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\
  228. c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\
  229. s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\
  230. MSSSIM(x,y)&=l^\alpha_M*{\prod_{1\leq j\leq M} (c^\beta_j*s^\gamma_j)}.
  231. Args:
  232. max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
  233. Default: 1.0.
  234. power_factors (Union[tuple, list]): Iterable of weights for each scal e.
  235. Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al.
  236. filter_size (int): The size of the Gaussian filter. Default: 11.
  237. filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5.
  238. k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01.
  239. k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03.
  240. Inputs:
  241. - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
  242. - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
  243. Outputs:
  244. Tensor, the value is in range [0, 1]. It is a 1-D tensor with shape N, where N is the batch num of img1.
  245. Examples:
  246. >>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033))
  247. >>> img1 = Tensor(np.random.random((1,3,128,128)))
  248. >>> img2 = Tensor(np.random.random((1,3,128,128)))
  249. >>> msssim = net(img1, img2)
  250. """
  251. def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11,
  252. filter_sigma=1.5, k1=0.01, k2=0.03):
  253. super(MSSSIM, self).__init__()
  254. validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
  255. validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
  256. self.max_val = max_val
  257. validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name)
  258. self.filter_size = validator.check_int(filter_size, 1, Rel.GE, 'filter_size', self.cls_name)
  259. self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name)
  260. self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name)
  261. self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name)
  262. window = _create_window(filter_size, filter_sigma)
  263. self.level = len(power_factors)
  264. self.conv = []
  265. for i in range(self.level):
  266. self.conv.append(_conv2d(1, 1, filter_size, Tensor(window)))
  267. self.conv[i].weight.requires_grad = False
  268. self.multi_convs_list = CellList(self.conv)
  269. self.weight_tensor = Tensor(power_factors, mstype.float32)
  270. self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
  271. self.relu = ReLU()
  272. self.reduce_mean = P.ReduceMean()
  273. self.prod = P.ReduceProd()
  274. self.pow = P.Pow()
  275. self.pack = P.Pack(axis=-1)
  276. self.concat = P.Concat(axis=1)
  277. def construct(self, img1, img2):
  278. _check_input_4d(F.shape(img1), "img1", self.cls_name)
  279. _check_input_4d(F.shape(img2), "img2", self.cls_name)
  280. _check_input_dtype(F.dtype(img1), 'img1', mstype.number_type, self.cls_name)
  281. P.SameTypeShape()(img1, img2)
  282. dtype_max_val = _get_dtype_max(F.dtype(img1))
  283. max_val = F.scalar_cast(self.max_val, F.dtype(img1))
  284. max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
  285. img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
  286. img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
  287. c1 = (self.k1 * max_val) ** 2
  288. c2 = (self.k2 * max_val) ** 2
  289. sim = ()
  290. mcs = ()
  291. for i in range(self.level):
  292. sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2,
  293. self.multi_convs_list[i], self.concat, self.reduce_mean)
  294. mcs += (self.relu(cs),)
  295. img1, img2 = _downsample(img1, img2, self.avg_pool)
  296. mcs = mcs[0:-1:1]
  297. mcs_and_ssim = self.pack(mcs + (self.relu(sim),))
  298. mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor)
  299. ms_ssim = self.prod(mcs_and_ssim, -1)
  300. loss = self.reduce_mean(ms_ssim, -1)
  301. return loss
  302. class PSNR(Cell):
  303. r"""
  304. Returns Peak Signal-to-Noise Ratio of two image batches.
  305. It produces a PSNR value for each image in batch.
  306. Assume inputs are :math:`I` and :math:`K`, both with shape :math:`h*w`.
  307. :math:`MAX` represents the dynamic range of pixel values.
  308. .. math::
  309. MSE&=\frac{1}{hw}\sum\limits_{i=0}^{h-1}\sum\limits_{j=0}^{w-1}[I(i,j)-K(i,j)]^2\\
  310. PSNR&=10*log_{10}(\frac{MAX^2}{MSE})
  311. Args:
  312. max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images).
  313. The value must be greater than 0. Default: 1.0.
  314. Inputs:
  315. - **img1** (Tensor) - The first image batch with format 'NCHW'. It must be the same shape and dtype as img2.
  316. - **img2** (Tensor) - The second image batch with format 'NCHW'. It must be the same shape and dtype as img1.
  317. Outputs:
  318. Tensor, with dtype mindspore.float32. It is a 1-D tensor with shape N, where N is the batch num of img1.
  319. Examples:
  320. >>> net = nn.PSNR()
  321. >>> img1 = Tensor(np.random.random((1,3,16,16)))
  322. >>> img2 = Tensor(np.random.random((1,3,16,16)))
  323. >>> psnr = net(img1, img2)
  324. [7.8297315]
  325. """
  326. def __init__(self, max_val=1.0):
  327. super(PSNR, self).__init__()
  328. validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
  329. validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
  330. self.max_val = max_val
  331. def construct(self, img1, img2):
  332. _check_input_4d(F.shape(img1), "img1", self.cls_name)
  333. _check_input_4d(F.shape(img2), "img2", self.cls_name)
  334. P.SameTypeShape()(img1, img2)
  335. dtype_max_val = _get_dtype_max(F.dtype(img1))
  336. max_val = F.scalar_cast(self.max_val, F.dtype(img1))
  337. max_val = _convert_img_dtype_to_float32(max_val, dtype_max_val)
  338. img1 = _convert_img_dtype_to_float32(img1, dtype_max_val)
  339. img2 = _convert_img_dtype_to_float32(img2, dtype_max_val)
  340. mse = P.ReduceMean()(F.square(img1 - img2), (-3, -2, -1))
  341. psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0)
  342. return psnr
  343. @constexpr
  344. def _raise_dims_rank_error(input_shape, param_name, func_name):
  345. """raise error if input is not 3d or 4d"""
  346. raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")
  347. @constexpr
  348. def _get_bbox(rank, shape, size_h, size_w):
  349. """get bbox start and size for slice"""
  350. if rank == 3:
  351. c, h, w = shape
  352. else:
  353. n, c, h, w = shape
  354. bbox_h_start = int((float(h) - size_h) / 2)
  355. bbox_w_start = int((float(w) - size_w) / 2)
  356. bbox_h_size = h - bbox_h_start * 2
  357. bbox_w_size = w - bbox_w_start * 2
  358. if rank == 3:
  359. bbox_begin = (0, bbox_h_start, bbox_w_start)
  360. bbox_size = (c, bbox_h_size, bbox_w_size)
  361. else:
  362. bbox_begin = (0, 0, bbox_h_start, bbox_w_start)
  363. bbox_size = (n, c, bbox_h_size, bbox_w_size)
  364. return bbox_begin, bbox_size
  365. class CentralCrop(Cell):
  366. """
  367. Crop the centeral region of the images with the central_fraction.
  368. Args:
  369. central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0].
  370. Inputs:
  371. - **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W].
  372. Outputs:
  373. Tensor, 3-D or 4-D float tensor, according to the input.
  374. Examples:
  375. >>> net = nn.CentralCrop(central_fraction=0.5)
  376. >>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
  377. >>> output = net(image)
  378. """
  379. def __init__(self, central_fraction):
  380. super(CentralCrop, self).__init__()
  381. validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
  382. self.central_fraction = validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT,
  383. 'central_fraction', self.cls_name)
  384. self.slice = P.Slice()
  385. def construct(self, image):
  386. image_shape = F.shape(image)
  387. rank = len(image_shape)
  388. h, w = image_shape[-2], image_shape[-1]
  389. if not rank in (3, 4):
  390. return _raise_dims_rank_error(image_shape, "image", self.cls_name)
  391. if self.central_fraction == 1.0:
  392. return image
  393. size_h = self.central_fraction * h
  394. size_w = self.central_fraction * w
  395. bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w)
  396. image = self.slice(image, bbox_begin, bbox_size)
  397. return image