From cf7c9a65f9bee7a42576b495890f3eb958bac7e4 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Fri, 13 Nov 2020 10:56:43 +0800 Subject: [PATCH] fix bug of nn.CentralCrop. --- mindspore/nn/layer/image.py | 15 ++++++--------- mindspore/nn/layer/pooling.py | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index ba7d7d6f4a..a8cacfd960 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -409,15 +409,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name): raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") @constexpr -def _get_bbox(rank, shape, size_h, size_w): +def _get_bbox(rank, shape, central_fraction): """get bbox start and size for slice""" if rank == 3: c, h, w = shape else: n, c, h, w = shape - bbox_h_start = int((float(h) - size_h) / 2) - bbox_w_start = int((float(w) - size_w) / 2) + bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2) + bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2) bbox_h_size = h - bbox_h_start * 2 bbox_w_size = w - bbox_w_start * 2 @@ -454,22 +454,19 @@ class CentralCrop(Cell): def __init__(self, central_fraction): super(CentralCrop, self).__init__() validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) - self.central_fraction = validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, - 'central_fraction', self.cls_name) + validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', self.cls_name) + self.central_fraction = central_fraction self.slice = P.Slice() def construct(self, image): image_shape = F.shape(image) rank = len(image_shape) - h, w = image_shape[-2], image_shape[-1] if not rank in (3, 4): return _raise_dims_rank_error(image_shape, "image", self.cls_name) if self.central_fraction == 1.0: return image - size_h = self.central_fraction * h - size_w = self.central_fraction * w - bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w) + bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) image = self.slice(image, bbox_begin, bbox_size) return image diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 7cab152545..9beb5fc80b 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -182,7 +182,7 @@ class MaxPool1d(_PoolNd): Tensor of shape :math:`(N, C, L_{out}))`. Examples: - >>> max_pool = nn.MaxPool1d(kernel_size=3, strides=1) + >>> max_pool = nn.MaxPool1d(kernel_size=3, stride=1) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4]), mindspore.float32) >>> output = max_pool(x) >>> output.shape