Browse Source

fix bug of nn.CentralCrop.

tags/v1.1.0
liuxiao93 5 years ago
parent
commit
cf7c9a65f9
2 changed files with 7 additions and 10 deletions
  1. +6
    -9
      mindspore/nn/layer/image.py
  2. +1
    -1
      mindspore/nn/layer/pooling.py

+ 6
- 9
mindspore/nn/layer/image.py View File

@@ -409,15 +409,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name):
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")


@constexpr @constexpr
def _get_bbox(rank, shape, size_h, size_w):
def _get_bbox(rank, shape, central_fraction):
"""get bbox start and size for slice""" """get bbox start and size for slice"""
if rank == 3: if rank == 3:
c, h, w = shape c, h, w = shape
else: else:
n, c, h, w = shape n, c, h, w = shape


bbox_h_start = int((float(h) - size_h) / 2)
bbox_w_start = int((float(w) - size_w) / 2)
bbox_h_start = int((float(h) - np.float32(h * central_fraction)) / 2)
bbox_w_start = int((float(w) - np.float32(w * central_fraction)) / 2)
bbox_h_size = h - bbox_h_start * 2 bbox_h_size = h - bbox_h_start * 2
bbox_w_size = w - bbox_w_start * 2 bbox_w_size = w - bbox_w_start * 2


@@ -454,22 +454,19 @@ class CentralCrop(Cell):
def __init__(self, central_fraction): def __init__(self, central_fraction):
super(CentralCrop, self).__init__() super(CentralCrop, self).__init__()
validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name)
self.central_fraction = validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT,
'central_fraction', self.cls_name)
validator.check_float_range(central_fraction, 0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', self.cls_name)
self.central_fraction = central_fraction
self.slice = P.Slice() self.slice = P.Slice()


def construct(self, image): def construct(self, image):
image_shape = F.shape(image) image_shape = F.shape(image)
rank = len(image_shape) rank = len(image_shape)
h, w = image_shape[-2], image_shape[-1]
if not rank in (3, 4): if not rank in (3, 4):
return _raise_dims_rank_error(image_shape, "image", self.cls_name) return _raise_dims_rank_error(image_shape, "image", self.cls_name)
if self.central_fraction == 1.0: if self.central_fraction == 1.0:
return image return image


size_h = self.central_fraction * h
size_w = self.central_fraction * w
bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w)
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction)
image = self.slice(image, bbox_begin, bbox_size) image = self.slice(image, bbox_begin, bbox_size)


return image return image

+ 1
- 1
mindspore/nn/layer/pooling.py View File

@@ -182,7 +182,7 @@ class MaxPool1d(_PoolNd):
Tensor of shape :math:`(N, C, L_{out}))`. Tensor of shape :math:`(N, C, L_{out}))`.


Examples: Examples:
>>> max_pool = nn.MaxPool1d(kernel_size=3, strides=1)
>>> max_pool = nn.MaxPool1d(kernel_size=3, stride=1)
>>> x = Tensor(np.random.randint(0, 10, [1, 2, 4]), mindspore.float32) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4]), mindspore.float32)
>>> output = max_pool(x) >>> output = max_pool(x)
>>> output.shape >>> output.shape


Loading…
Cancel
Save