|
|
|
@@ -409,15 +409,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name): |
|
|
|
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _get_bbox(rank, shape, size_h, size_w): |
|
|
|
def _get_bbox(rank, shape, central_fraction): |
|
|
|
"""get bbox start and size for slice""" |
|
|
|
if rank == 3: |
|
|
|
c, h, w = shape |
|
|
|
else: |
|
|
|
n, c, h, w = shape |
|
|
|
|
|
|
|
bbox_h_start = int((float(h) - size_h) / 2) |
|
|
|
bbox_w_start = int((float(w) - size_w) / 2) |
|
|
|
bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2) |
|
|
|
bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2) |
|
|
|
bbox_h_size = h - bbox_h_start * 2 |
|
|
|
bbox_w_size = w - bbox_w_start * 2 |
|
|
|
|
|
|
|
@@ -454,22 +454,19 @@ class CentralCrop(Cell): |
|
|
|
def __init__(self, central_fraction): |
|
|
|
super(CentralCrop, self).__init__() |
|
|
|
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) |
|
|
|
self.central_fraction = validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, |
|
|
|
'central_fraction', self.cls_name) |
|
|
|
validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', self.cls_name) |
|
|
|
self.central_fraction = central_fraction |
|
|
|
self.slice = P.Slice() |
|
|
|
|
|
|
|
def construct(self, image): |
|
|
|
image_shape = F.shape(image) |
|
|
|
rank = len(image_shape) |
|
|
|
h, w = image_shape[-2], image_shape[-1] |
|
|
|
if not rank in (3, 4): |
|
|
|
return _raise_dims_rank_error(image_shape, "image", self.cls_name) |
|
|
|
if self.central_fraction == 1.0: |
|
|
|
return image |
|
|
|
|
|
|
|
size_h = self.central_fraction * h |
|
|
|
size_w = self.central_fraction * w |
|
|
|
bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w) |
|
|
|
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) |
|
|
|
image = self.slice(image, bbox_begin, bbox_size) |
|
|
|
|
|
|
|
return image |