| @@ -682,3 +682,14 @@ def get_bprop_broadcast_to(self): | |||
| dx = reshape(reduced_grad, x_shape) | |||
| return (dx,) | |||
| return bprop | |||
| @bprop_getters.register(P.ReverseSequence) | |||
| def get_bprop_reverse_sequence(self): | |||
| """Generate bprop for ReverseSequence""" | |||
| reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_) | |||
| def bprop(x, seq_lengths, out, dout): | |||
| dx = reverse_sequence_grad(dout, seq_lengths) | |||
| return dx, zeros_like(seq_lengths) | |||
| return bprop | |||
| @@ -26,3 +26,6 @@ from .expand_dims import _expand_dims_aicpu | |||
| from .random_choice_with_mask import _random_choice_with_mask_aicpu | |||
| from .pack import _pack_aicpu | |||
| from .normal import _normal_aicpu | |||
| from .ctcloss import _ctcloss_aicpu | |||
| from .reverse_sequence import _reverse_sequence_aicpu | |||
| from .crop_and_resize import _crop_and_resize_aicpu | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CropAndResize op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| crop_and_resize_op_info = AiCPURegOp("CropAndResize") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "image", "required") \ | |||
| .input(1, "boxes", "required") \ | |||
| .input(2, "box_index", "required") \ | |||
| .input(3, "crop_size", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .attr("method", "str") \ | |||
| .attr("extrapolation_value", "float") \ | |||
| .dtype_format(DataType.I8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default) \ | |||
| .dtype_format(DataType.I8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.I16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.I32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.I64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.U8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.U16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, | |||
| DataType.F32_NHWC) \ | |||
| .get_op_info() | |||
| @op_info_register(crop_and_resize_op_info) | |||
| def _crop_and_resize_aicpu(): | |||
| """CropAndResize AiCPU register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CTCLoss op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| ctcloss_op_info = AiCPURegOp("CTCLoss") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "inputs", "required") \ | |||
| .input(1, "labels_indices", "required") \ | |||
| .input(2, "labels_values", "required") \ | |||
| .input(3, "sequence_length", "required") \ | |||
| .output(0, "loss", "required") \ | |||
| .output(1, "gradient", "required") \ | |||
| .attr("preprocess_collapse_repeated", "bool") \ | |||
| .attr("ctc_merge_repeated", "bool") \ | |||
| .attr("ignore_longer_outputs_than_inputs", "bool") \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, | |||
| DataType.F64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, | |||
| DataType.F32_NCHW, DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, | |||
| DataType.F64_NCHW, DataType.F64_NCHW) \ | |||
| .get_op_info() | |||
| @op_info_register(ctcloss_op_info) | |||
| def _ctcloss_aicpu(): | |||
| """CTCLoss AiCPU register""" | |||
| return | |||
| @@ -0,0 +1,78 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ReverseSequence op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| reverse_sequence_op_info = AiCPURegOp("ReverseSequence") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .input(0, "x", "required") \ | |||
| .input(1, "seq_lengths", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .attr("seq_dim", "int") \ | |||
| .attr("batch_dim", "int") \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \ | |||
| .dtype_format(DataType.I8_NCHW, DataType.I32_NCHW, DataType.I8_NCHW) \ | |||
| .dtype_format(DataType.I16_NCHW, DataType.I32_NCHW, DataType.I16_NCHW) \ | |||
| .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \ | |||
| .dtype_format(DataType.I64_NCHW, DataType.I32_NCHW, DataType.I64_NCHW) \ | |||
| .dtype_format(DataType.U8_NCHW, DataType.I32_NCHW, DataType.U8_NCHW) \ | |||
| .dtype_format(DataType.U16_NCHW, DataType.I32_NCHW, DataType.U16_NCHW) \ | |||
| .dtype_format(DataType.U32_NCHW, DataType.I32_NCHW, DataType.U32_NCHW) \ | |||
| .dtype_format(DataType.U64_NCHW, DataType.I32_NCHW, DataType.U64_NCHW) \ | |||
| .dtype_format(DataType.F16_NCHW, DataType.I32_NCHW, DataType.F16_NCHW) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F64_NCHW, DataType.I32_NCHW, DataType.F64_NCHW) \ | |||
| .dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.BOOL_NCHW, DataType.I64_NCHW, DataType.BOOL_NCHW) \ | |||
| .dtype_format(DataType.I8_NCHW, DataType.I64_NCHW, DataType.I8_NCHW) \ | |||
| .dtype_format(DataType.I16_NCHW, DataType.I64_NCHW, DataType.I16_NCHW) \ | |||
| .dtype_format(DataType.I32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW) \ | |||
| .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW, DataType.I64_NCHW) \ | |||
| .dtype_format(DataType.U8_NCHW, DataType.I64_NCHW, DataType.U8_NCHW) \ | |||
| .dtype_format(DataType.U16_NCHW, DataType.I64_NCHW, DataType.U16_NCHW) \ | |||
| .dtype_format(DataType.U32_NCHW, DataType.I64_NCHW, DataType.U32_NCHW) \ | |||
| .dtype_format(DataType.U64_NCHW, DataType.I64_NCHW, DataType.U64_NCHW) \ | |||
| .dtype_format(DataType.F16_NCHW, DataType.I64_NCHW, DataType.F16_NCHW) \ | |||
| .dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \ | |||
| .dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.F64_NCHW) \ | |||
| .get_op_info() | |||
| @op_info_register(reverse_sequence_op_info) | |||
| def _reverse_sequence_aicpu(): | |||
| """ReverseSequence AiCPU register""" | |||
| return | |||
| @@ -19,6 +19,7 @@ Primitive operator classes. | |||
| A collection of operators to build nerual networks or computing functions. | |||
| """ | |||
| from .image_ops import (CropAndResize) | |||
| from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Diag, DiagPart, DType, ExpandDims, Eye, | |||
| Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, | |||
| @@ -30,7 +31,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, | |||
| Squeeze, StridedSlice, Tile, TensorScatterUpdate, | |||
| Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, | |||
| UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate) | |||
| SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence) | |||
| from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, | |||
| _MirrorOperator, ReduceOp, _VirtualDataset, | |||
| _VirtualDiv, _GetTensorSlice, | |||
| @@ -79,6 +80,8 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, | |||
| from .thor_ops import * | |||
| __all__ = [ | |||
| 'ReverseSequence', | |||
| 'CropAndResize', | |||
| 'TensorAdd', | |||
| 'Argmax', | |||
| 'Argmin', | |||
| @@ -2841,3 +2841,52 @@ class InplaceUpdate(PrimitiveWithInfer): | |||
| Rel.EQ, self.name) | |||
| return x_shape | |||
| class ReverseSequence(PrimitiveWithInfer): | |||
| """ | |||
| Reverses variable length slices. | |||
| Args: | |||
| seq_dim (int): The dimension along which reversal is performed. Required. | |||
| batch_dim (int): The input is sliced along this dimmension. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input to reverse, support all number types including bool. | |||
| - **seq_lengths** (Tensor) - Must be 1-D vector with types: int32, int64. | |||
| Outputs: | |||
| Reversed tensor with the same shape and data type as input. | |||
| Examples: | |||
| >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32) | |||
| >>> seq_lengths = Tensor(np.array([1, 2, 3])) | |||
| >>> reverse_sequence = P.ReverseSequence(seq_dim=1) | |||
| >>> output = reverse_sequence(x, seq_lengths) | |||
| [[1 2 3] | |||
| [5 4 6] | |||
| [9 8 7]] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, seq_dim, batch_dim=0): | |||
| """init ReverseSequence""" | |||
| self.init_prim_io_names(inputs=['x', 'seq_lengths'], outputs=['y']) | |||
| validator.check_value_type("seq_dim", seq_dim, [int], self.name) | |||
| self.seq_dim_ = seq_dim | |||
| validator.check_value_type("batch_dim", batch_dim, [int], self.name) | |||
| self.batch_dim_ = batch_dim | |||
| def infer_shape(self, x, seq_lengths): | |||
| validator.check("seq_dim", self.seq_dim_, "x rank", len(x), Rel.LE, self.name) | |||
| validator.check("batch_dim", self.batch_dim_, "x rank", len(x), Rel.LE, self.name) | |||
| validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name) | |||
| validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name) | |||
| validator.check("seq_lengths vector size", seq_lengths[0], | |||
| "input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name) | |||
| return x | |||
| def infer_dtype(self, x, seq_lengths): | |||
| validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) | |||
| validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) | |||
| return x | |||
| @@ -0,0 +1,126 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """image_ops""" | |||
| from ..._checkparam import Validator as validator | |||
| from ..._checkparam import Rel | |||
| from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| class CropAndResize(PrimitiveWithInfer): | |||
| """ | |||
| Extracts crops from the input image tensor and resizes them. | |||
| Note: | |||
| In case that the output shape depends on crop_size, the crop_size should be constant. | |||
| Args: | |||
| method (str): An optional string specifying the sampling method for resizing. | |||
| It can be either "bilinear" or "nearest" and default to "bilinear" | |||
| extrapolation_value (float): An optional float defaults to 0. Value used for extrapolation, when applicable. | |||
| Inputs: | |||
| - **x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth]. | |||
| Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16. | |||
| - **boxes** (Tensor) - A 2-D tensor of shape [num_boxes, 4]. | |||
| The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image | |||
| and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to | |||
| the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is | |||
| mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled | |||
| crop is an up-down flipped version of the original image. The width dimension is treated similarly. | |||
| Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to | |||
| extrapolate the input image values. Types allowd: float32. | |||
| - **box_index** (Tensor) - A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). | |||
| The value of box_ind[i] specifies the image that the i-th box refers to. Types allowd: int32. | |||
| - **crop_size** (Tensor) - Only constant value is allowd. Types allowed: int32. | |||
| A 1-D tensor of 2 elements, size = [crop_height, crop_width]. | |||
| All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved. | |||
| Both crop_height and crop_width need to be positive. | |||
| Outputs: | |||
| A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32. | |||
| Examples: | |||
| >>> class CropAndResizeNet(nn.Cell): | |||
| >>> def __init__(self, crop_size): | |||
| >>> super(CropAndResizeNet, self).__init__() | |||
| >>> self.crop_and_resize = P.CropAndResize() | |||
| >>> self.crop_size = crop_size | |||
| >>> @ms_function | |||
| >>> def construct(self, x, boxes, box_index): | |||
| >>> return self.crop_and_resize(x, boxes, box_index, self.crop_size) | |||
| >>> | |||
| >>> BATCH_SIZE = 1 | |||
| >>> NUM_BOXES = 5 | |||
| >>> IMAGE_HEIGHT = 256 | |||
| >>> IMAGE_WIDTH = 256 | |||
| >>> CHANNELS = 3 | |||
| >>> image = np.random.normal(size=[BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS]).astype(np.float32) | |||
| >>> boxes = np.random.uniform(shape=[NUM_BOXES, 4]).astype(np.float32) | |||
| >>> box_index = np.random.uniform(shape=[NUM_BOXES], low=0, high=BATCH_SIZE).astype(np.int32) | |||
| >>> crop_size = np.array([24, 24]).astype(np.int32) | |||
| >>> crop_and_resize = CropAndResizeNet(crop_size=Tensor(crop_size)) | |||
| >>> output = crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_index)) | |||
| >>> print(output.asnumpy()) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, method="bilinear", extrapolation_value=0.0): | |||
| """init CropAndResize""" | |||
| self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y']) | |||
| validator.check_value_type("method", method, [str], self.name) | |||
| validator.check_string("method", method, ["bilinear", "nearest"], self.name) | |||
| self.method = method | |||
| validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name) | |||
| self.extrapolation_value = extrapolation_value | |||
| def __infer__(self, x, boxes, box_index, crop_size): | |||
| # get shape | |||
| x_shape = list(x['shape']) | |||
| boxes_shape = list(boxes['shape']) | |||
| box_index_shape = list(box_index['shape']) | |||
| crop_size_shape = list(crop_size['shape']) | |||
| # get value | |||
| if crop_size['value'] is None: | |||
| raise ValueError(f"For {self.name}, crop_size must be const.") | |||
| crop_size_value = crop_size['value'].asnumpy() | |||
| # get dtype | |||
| x_dtype = x['dtype'] | |||
| boxes_dtype = boxes['dtype'] | |||
| box_index_dtype = box_index['dtype'] | |||
| crop_size_dtype = crop_size['dtype'] | |||
| # check dytpe | |||
| validator.check_tensor_type_same({"x": x_dtype}, | |||
| [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, | |||
| mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) | |||
| validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name) | |||
| validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name) | |||
| validator.check_tensor_type_same({"crop_size": crop_size_dtype}, [mstype.int32], self.name) | |||
| # check input shape rank | |||
| validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) | |||
| validator.check("boxes rank", len(boxes_shape), "expected", 2, Rel.EQ, self.name) | |||
| validator.check("box_index rank", len(box_index_shape), "expected", 1, Rel.EQ, self.name) | |||
| validator.check("crop_size rank", len(crop_size_shape), "expected", 1, Rel.EQ, self.name) | |||
| validator.check("boxes dim_0", boxes_shape[0], "box_index dim_0", box_index_shape[0], Rel.EQ, self.name) | |||
| validator.check("boxes dim_1", boxes_shape[1], "expected", 4, Rel.EQ, self.name) | |||
| num_boxes = boxes_shape[0] | |||
| crop_height = crop_size_value[0] | |||
| crop_width = crop_size_value[1] | |||
| depth = x_shape[3] | |||
| return {'shape': (num_boxes, crop_height, crop_width, depth), | |||
| 'dtype': mstype.float32, | |||
| 'value': None} | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self, crop_size): | |||
| super(Net, self).__init__() | |||
| self.crop_and_resize = P.CropAndResize() | |||
| self.crop_size = crop_size | |||
| @ms_function | |||
| def construct(self, x, boxes, box_index): | |||
| return self.crop_and_resize(x, boxes, box_index, self.crop_size) | |||
| def test_net_float32(): | |||
| batch_size = 1 | |||
| num_boxes = 5 | |||
| image_height = 256 | |||
| image_width = 256 | |||
| channels = 3 | |||
| image = np.random.normal(size=[batch_size, image_height, image_width, channels]).astype(np.float32) | |||
| boxes = np.random.uniform(shape=[num_boxes, 4]).astype(np.float32) | |||
| box_index = np.random.uniform(shape=[num_boxes], low=0, high=batch_size).astype(np.int32) | |||
| crop_size = np.array([24, 24]).astype(np.int32) | |||
| net = Net(crop_size=Tensor(crop_size)) | |||
| output = net(Tensor(image), Tensor(boxes), Tensor(box_index)) | |||
| print(output.asnumpy()) | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.ctc_loss = P.CTCLoss() | |||
| @ms_function | |||
| def construct(self, inputs, labels_indices, labels_values, sequence_length): | |||
| return self.ctc_loss(inputs, labels_indices, labels_values, sequence_length) | |||
| def test_net_float32(): | |||
| x = np.rand.randn(2, 2, 3).astype(np.float32) | |||
| labels_indices = np.array([[0, 0], [1, 0]]).astype(np.int64) | |||
| labels_values = np.array([2, 2]).astype(np.int32) | |||
| sequence_length = np.array([2, 2]).astype(np.int32) | |||
| net = Net() | |||
| output = net(Tensor(x), Tensor(labels_indices), Tensor(labels_values), Tensor(sequence_length)) | |||
| print(output.asnumpy()) | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self, seq_dim, batch_dim): | |||
| super(Net, self).__init__() | |||
| self.reverse_sequence = P.ReverseSequence(seq_dim=seq_dim, batch_dim=batch_dim) | |||
| @ms_function | |||
| def construct(self, x, seq_lengths): | |||
| return self.reverse_sequence(x, seq_lengths) | |||
| def test_net_int8(): | |||
| x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int8) | |||
| seq_lengths = np.array([1, 2, 3]).astype(np.int32) | |||
| seq_dim = 0 | |||
| batch_dim = 1 | |||
| net = Net(seq_dim, batch_dim) | |||
| output = net(Tensor(x), Tensor(seq_lengths)) | |||
| expected = np.array([1, 5, 9], [4, 2, 6], [7, 8, 3]).astype(np.int8) | |||
| assert np.array_equal(output.asnumpy(), expected) | |||
| def test_net_int32(): | |||
| x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int32) | |||
| seq_lengths = np.array([1, 2, 3]).astype(np.int64) | |||
| seq_dim = 1 | |||
| batch_dim = 0 | |||
| net = Net(seq_dim, batch_dim) | |||
| output = net(Tensor(x), Tensor(seq_lengths)) | |||
| expected = np.array([1, 2, 3], [5, 4, 6], [9, 8, 7]).astype(np.int32) | |||
| assert np.array_equal(output.asnumpy(), expected) | |||
| @@ -1594,6 +1594,11 @@ test_case_array_ops = [ | |||
| Tensor(np.arange(16).reshape(2, 4, 2).astype(np.float32))], | |||
| 'skip': ['backward'], | |||
| }), | |||
| ('ReverseSequence', { | |||
| 'block': P.ReverseSequence(1, 0), | |||
| 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), | |||
| Tensor(np.array([1, 2, 3]).astype(np.int32))], | |||
| 'desc_bprop': [[3, 3]]}), | |||
| ] | |||
| test_case_other_ops = [ | |||