Browse Source

fix bug of nn.CentralCrop.

tags/v1.1.0
liuxiao93 5 years ago
parent
commit
cf7c9a65f9
2 changed files with 7 additions and 10 deletions
  1. +6
    -9
      mindspore/nn/layer/image.py
  2. +1
    -1
      mindspore/nn/layer/pooling.py

+ 6
- 9
mindspore/nn/layer/image.py View File

@@ -409,15 +409,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name):
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")

@constexpr
def _get_bbox(rank, shape, size_h, size_w):
def _get_bbox(rank, shape, central_fraction):
"""get bbox start and size for slice"""
if rank == 3:
c, h, w = shape
else:
n, c, h, w = shape

bbox_h_start = int((float(h) - size_h) / 2)
bbox_w_start = int((float(w) - size_w) / 2)
bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2)
bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2)
bbox_h_size = h - bbox_h_start * 2
bbox_w_size = w - bbox_w_start * 2

@@ -454,22 +454,19 @@ class CentralCrop(Cell):
def __init__(self, central_fraction):
super(CentralCrop, self).__init__()
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
self.central_fraction = validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT,
'central_fraction', self.cls_name)
validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', self.cls_name)
self.central_fraction = central_fraction
self.slice = P.Slice()

def construct(self, image):
image_shape = F.shape(image)
rank = len(image_shape)
h, w = image_shape[-2], image_shape[-1]
if not rank in (3, 4):
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
if self.central_fraction == 1.0:
return image

size_h = self.central_fraction * h
size_w = self.central_fraction * w
bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w)
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction)
image = self.slice(image, bbox_begin, bbox_size)

return image

+ 1
- 1
mindspore/nn/layer/pooling.py View File

@@ -182,7 +182,7 @@ class MaxPool1d(_PoolNd):
Tensor of shape :math:`(N, C, L_{out}))`.

Examples:
>>> max_pool = nn.MaxPool1d(kernel_size=3, strides=1)
>>> max_pool = nn.MaxPool1d(kernel_size=3, stride=1)
>>> x = Tensor(np.random.randint(0, 10, [1, 2, 4]), mindspore.float32)
>>> output = max_pool(x)
>>> output.shape


Loading…
Cancel
Save