|
|
|
@@ -23,7 +23,7 @@ from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore._checkparam import Rel |
|
|
|
from ..cell import Cell |
|
|
|
|
|
|
|
__all__ = ['ImageGradients', 'SSIM', 'PSNR'] |
|
|
|
__all__ = ['ImageGradients', 'SSIM', 'PSNR', 'CentralCrop'] |
|
|
|
|
|
|
|
class ImageGradients(Cell): |
|
|
|
r""" |
|
|
|
@@ -259,3 +259,71 @@ class PSNR(Cell): |
|
|
|
psnr = 10 * P.Log()(F.square(max_val) / mse) / F.scalar_log(10.0) |
|
|
|
|
|
|
|
return psnr |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _check_input_3d_or_4d(input_shape, param_name, func_name): |
|
|
|
"""check input 3d or 4d""" |
|
|
|
if len(input_shape) != 3 and len(input_shape) != 4: |
|
|
|
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") |
|
|
|
return True |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _get_bbox(rank, shape, central_fraction): |
|
|
|
"""get bbox start and size for slice""" |
|
|
|
if rank == 3: |
|
|
|
c, h, w = shape |
|
|
|
else: |
|
|
|
n, c, h, w = shape |
|
|
|
|
|
|
|
bbox_h_start = int((float(h) - float(h) * central_fraction) / 2) |
|
|
|
bbox_w_start = int((float(w) - float(w) * central_fraction) / 2) |
|
|
|
bbox_h_size = h - bbox_h_start * 2 |
|
|
|
bbox_w_size = w - bbox_w_start * 2 |
|
|
|
|
|
|
|
if rank == 3: |
|
|
|
bbox_begin = (0, bbox_h_start, bbox_w_start) |
|
|
|
bbox_size = (c, bbox_h_size, bbox_w_size) |
|
|
|
else: |
|
|
|
bbox_begin = (0, 0, bbox_h_start, bbox_w_start) |
|
|
|
bbox_size = (n, c, bbox_h_size, bbox_w_size) |
|
|
|
|
|
|
|
return bbox_begin, bbox_size |
|
|
|
|
|
|
|
class CentralCrop(Cell): |
|
|
|
""" |
|
|
|
Crop the centeral region of the images with the central_fraction. |
|
|
|
|
|
|
|
Args: |
|
|
|
central_fraction (float): Fraction of size to crop. It must be float and in range (0.0, 1.0]. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **image** (Tensor) - A 3-D tensor of shape [C, H, W], or a 4-D tensor of shape [N, C, H, W]. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, 3-D or 4-D float tensor, according to the input. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net = nn.CentralCrop(central_fraction=0.5) |
|
|
|
>>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32) |
|
|
|
>>> output = net(image) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, central_fraction): |
|
|
|
super(CentralCrop, self).__init__() |
|
|
|
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) |
|
|
|
self.central_fraction = validator.check_number_range('central_fraction', central_fraction, |
|
|
|
0.0, 1.0, Rel.INC_RIGHT, self.cls_name) |
|
|
|
self.slice = P.Slice() |
|
|
|
|
|
|
|
def construct(self, image): |
|
|
|
image_shape = F.shape(image) |
|
|
|
rank = len(image_shape) |
|
|
|
_check_input_3d_or_4d(image_shape, "image", self.cls_name) |
|
|
|
if self.central_fraction == 1.0: |
|
|
|
return image |
|
|
|
|
|
|
|
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) |
|
|
|
image = self.slice(image, bbox_begin, bbox_size) |
|
|
|
|
|
|
|
return image |