Browse Source

!211 add CropAndResize op

Merge pull request !211 from xutianchun/crop_and_resize
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bc13d6f7f8
4 changed files with 199 additions and 1 deletions
  1. +1
    -0
      mindspore/ops/_op_impl/aicpu/__init__.py
  2. +69
    -0
      mindspore/ops/_op_impl/aicpu/crop_and_resize.py
  3. +3
    -1
      mindspore/ops/operations/__init__.py
  4. +126
    -0
      mindspore/ops/operations/image_ops.py

+ 1
- 0
mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -29,3 +29,4 @@ from .rnnt_loss import _rnnt_loss_aicpu
from .random_categorical import _random_categorical_aicpu
from .reverse_sequence import _reverse_sequence_aicpu
from .pack import _pack_aicpu
from .crop_and_resize import _crop_and_resize_aicpu

+ 69
- 0
mindspore/ops/_op_impl/aicpu/crop_and_resize.py View File

@@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""CropAndResize op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
crop_and_resize_op_info = AiCPURegOp("CropAndResize") \
.fusion_type("OPAQUE") \
.input(0, "image", "required") \
.input(1, "boxes", "required") \
.input(2, "box_index", "required") \
.input(3, "crop_size", "required") \
.output(0, "y", "required") \
.attr("method", "str") \
.attr("extrapolation_value", "float") \
.dtype_format(DataType.I8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.I16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.U8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.U16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default) \
.dtype_format(DataType.I8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.I16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.I32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.I64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.F64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.U8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.dtype_format(DataType.U16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
DataType.F32_NHWC) \
.get_op_info()


@op_info_register(crop_and_resize_op_info)
def _crop_and_resize_aicpu():
"""CropAndResize AiCPU register"""
return

+ 3
- 1
mindspore/ops/operations/__init__.py View File

@@ -19,6 +19,7 @@ Primitive operator classes.
A collection of operators to build nerual networks or computing functions.
"""

from .image_ops import (CropAndResize)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation,
@@ -287,7 +288,8 @@ __all__ = [
"BesselI1e",
"Atan",
"Atanh",
"BasicLSTMCell"
"BasicLSTMCell",
"CropAndResize"
]

__all__.extend(_quant_ops.__all__)


+ 126
- 0
mindspore/ops/operations/image_ops.py View File

@@ -0,0 +1,126 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""image_ops"""
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register


class CropAndResize(PrimitiveWithInfer):
"""
Extracts crops from the input image tensor and resizes them.

Note:
In case that the output shape depends on crop_size, the crop_size should be constant.

Args:
method (str): An optional string specifying the sampling method for resizing.
It can be either "bilinear" or "nearest" and default to "bilinear"
extrapolation_value (float): An optional float defaults to 0. Value used for extrapolation, when applicable.

Inputs:
- **x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth].
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
- **boxes** (Tensor) - A 2-D tensor of shape [num_boxes, 4].
The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image
and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to
the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is
mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled
crop is an up-down flipped version of the original image. The width dimension is treated similarly.
Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to
extrapolate the input image values. Types allowd: float32.
- **box_index** (Tensor) - A 1-D tensor of shape [num_boxes] with int32 values in [0, batch).
The value of box_ind[i] specifies the image that the i-th box refers to. Types allowd: int32.
- **crop_size** (Tensor) - Only constant value is allowd. Types allowed: int32.
A 1-D tensor of 2 elements, size = [crop_height, crop_width].
All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved.
Both crop_height and crop_width need to be positive.
Outputs:
A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32.

Examples:
>>> class CropAndResizeNet(nn.Cell):
>>> def __init__(self, crop_size):
>>> super(CropAndResizeNet, self).__init__()
>>> self.crop_and_resize = P.CropAndResize()
>>> self.crop_size = crop_size
>>> @ms_function
>>> def construct(self, x, boxes, box_index):
>>> return self.crop_and_resize(x, boxes, box_index, self.crop_size)
>>>
>>> BATCH_SIZE = 1
>>> NUM_BOXES = 5
>>> IMAGE_HEIGHT = 256
>>> IMAGE_WIDTH = 256
>>> CHANNELS = 3
>>> image = np.random.normal(size=[BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS]).astype(np.float32)
>>> boxes = np.random.uniform(shape=[NUM_BOXES, 4]).astype(np.float32)
>>> box_index = np.random.uniform(shape=[NUM_BOXES], low=0, high=BATCH_SIZE).astype(np.int32)
>>> crop_size = np.array([24, 24]).astype(np.int32)
>>> crop_and_resize = CropAndResizeNet(crop_size=Tensor(crop_size))
>>> output = crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_index))
>>> print(output.asnumpy())
"""

@prim_attr_register
def __init__(self, method="bilinear", extrapolation_value=0.0):
"""init CropAndResize"""
self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y'])
validator.check_value_type("method", method, [str], self.name)
validator.check_string("method", method, ["bilinear", "nearest"], self.name)
self.method = method
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
self.extrapolation_value = extrapolation_value

def __infer__(self, x, boxes, box_index, crop_size):
# get shape
x_shape = list(x['shape'])
boxes_shape = list(boxes['shape'])
box_index_shape = list(box_index['shape'])
crop_size_shape = list(crop_size['shape'])
# get value
if crop_size['value'] is None:
raise ValueError(f"For {self.name}, crop_size must be const.")
crop_size_value = crop_size['value'].asnumpy()
# get dtype
x_dtype = x['dtype']
boxes_dtype = boxes['dtype']
box_index_dtype = box_index['dtype']
crop_size_dtype = crop_size['dtype']
# check dytpe
validator.check_tensor_type_same({"x": x_dtype},
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name)
validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name)
validator.check_tensor_type_same({"crop_size": crop_size_dtype}, [mstype.int32], self.name)
# check input shape rank
validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name)
validator.check("boxes rank", len(boxes_shape), "expected", 2, Rel.EQ, self.name)
validator.check("box_index rank", len(box_index_shape), "expected", 1, Rel.EQ, self.name)
validator.check("crop_size rank", len(crop_size_shape), "expected", 1, Rel.EQ, self.name)

validator.check("boxes dim_0", boxes_shape[0], "box_index dim_0", box_index_shape[0], Rel.EQ, self.name)
validator.check("boxes dim_1", boxes_shape[1], "expected", 4, Rel.EQ, self.name)

num_boxes = boxes_shape[0]
crop_height = crop_size_value[0]
crop_width = crop_size_value[1]
depth = x_shape[3]
return {'shape': (num_boxes, crop_height, crop_width, depth),
'dtype': mstype.float32,
'value': None}

Loading…
Cancel
Save