| @@ -267,11 +267,9 @@ class PSNR(Cell): | |||||
| @constexpr | @constexpr | ||||
| def _check_input_3d_or_4d(input_shape, param_name, func_name): | |||||
| """check input 3d or 4d""" | |||||
| if len(input_shape) != 3 and len(input_shape) != 4: | |||||
| raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") | |||||
| return True | |||||
| def _raise_dims_rank_error(input_shape, param_name, func_name): | |||||
| """raise error if input is not 3d or 4d""" | |||||
| raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") | |||||
| @constexpr | @constexpr | ||||
| def _get_bbox(rank, shape, central_fraction): | def _get_bbox(rank, shape, central_fraction): | ||||
| @@ -281,6 +279,7 @@ def _get_bbox(rank, shape, central_fraction): | |||||
| else: | else: | ||||
| n, c, h, w = shape | n, c, h, w = shape | ||||
| central_fraction = central_fraction.asnumpy()[0] | |||||
| bbox_h_start = int((float(h) - float(h) * central_fraction) / 2) | bbox_h_start = int((float(h) - float(h) * central_fraction) / 2) | ||||
| bbox_w_start = int((float(w) - float(w) * central_fraction) / 2) | bbox_w_start = int((float(w) - float(w) * central_fraction) / 2) | ||||
| bbox_h_size = h - bbox_h_start * 2 | bbox_h_size = h - bbox_h_start * 2 | ||||
| @@ -319,16 +318,18 @@ class CentralCrop(Cell): | |||||
| validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) | validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) | ||||
| self.central_fraction = validator.check_number_range('central_fraction', central_fraction, | self.central_fraction = validator.check_number_range('central_fraction', central_fraction, | ||||
| 0.0, 1.0, Rel.INC_RIGHT, self.cls_name) | 0.0, 1.0, Rel.INC_RIGHT, self.cls_name) | ||||
| self.central_fraction_tensor = Tensor(np.array([central_fraction]).astype(np.float64)) | |||||
| self.slice = P.Slice() | self.slice = P.Slice() | ||||
| def construct(self, image): | def construct(self, image): | ||||
| image_shape = F.shape(image) | image_shape = F.shape(image) | ||||
| rank = len(image_shape) | rank = len(image_shape) | ||||
| _check_input_3d_or_4d(image_shape, "image", self.cls_name) | |||||
| if not rank in (3, 4): | |||||
| return _raise_dims_rank_error(image_shape, "image", self.cls_name) | |||||
| if self.central_fraction == 1.0: | if self.central_fraction == 1.0: | ||||
| return image | return image | ||||
| bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) | |||||
| bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction_tensor) | |||||
| image = self.slice(image, bbox_begin, bbox_size) | image = self.slice(image, bbox_begin, bbox_size) | ||||
| return image | return image | ||||