| @@ -110,6 +110,8 @@ const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits | |||
| const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad"; | |||
| const char kNameScatterNdD[] = "ScatterNd"; | |||
| const char kNamePadD[] = "Pad"; | |||
| const char kNameMirrorPad[] = "MirrorPad"; | |||
| const char kNameMirrorPadGrad[] = "MirrorPadGrad"; | |||
| const char kNameGatherNd[] = "GatherNd"; | |||
| const char kNameArgmax[] = "Argmax"; | |||
| const char kNameArgmin[] = "Argmin"; | |||
| @@ -256,6 +258,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma | |||
| {string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)}, | |||
| {string(kNameScatterNdD), ADPT_DESC(ScatterNdD)}, | |||
| {string(kNamePadD), ADPT_DESC(PadD)}, | |||
| {string(kNameMirrorPad), ADPT_DESC(MirrorPad)}, | |||
| {string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)}, | |||
| {string(kNameGatherNd), ADPT_DESC(GatherNd)}, | |||
| {string(kNameArgmax), ADPT_DESC(ArgMaxD)}, | |||
| {string(kNameArgmin), ADPT_DESC(ArgMinD)}, | |||
| @@ -596,6 +596,16 @@ INPUT_MAP(PadD) = {{1, INPUT_DESC(x)}}; | |||
| ATTR_MAP(PadD) = {{"paddings", ATTR_DESC(paddings, AnyTraits<std::vector<std::vector<int64_t>>>())}}; | |||
| OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}}; | |||
| // MirrorPad | |||
| INPUT_MAP(MirrorPad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; | |||
| ATTR_MAP(MirrorPad) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}}; | |||
| OUTPUT_MAP(MirrorPad) = {{0, OUTPUT_DESC(y)}}; | |||
| // MirrorPadGrad | |||
| INPUT_MAP(MirrorPadGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; | |||
| ATTR_MAP(MirrorPadGrad) = {{"mode", ATTR_DESC(mode, AnyTraits<std::string>())}}; | |||
| OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}}; | |||
| // GatherNd | |||
| INPUT_MAP(GatherNd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||
| ATTR_MAP(GatherNd) = EMPTY_ATTR_MAP; | |||
| @@ -155,6 +155,10 @@ DECLARE_OP_USE_INPUT_ATTR(ScatterNdD) | |||
| DECLARE_OP_USE_OUTPUT(ScatterNdD) | |||
| DECLARE_OP_ADAPTER(PadD) | |||
| DECLARE_OP_USE_OUTPUT(PadD) | |||
| DECLARE_OP_ADAPTER(MirrorPad) | |||
| DECLARE_OP_USE_OUTPUT(MirrorPad) | |||
| DECLARE_OP_ADAPTER(MirrorPadGrad) | |||
| DECLARE_OP_USE_OUTPUT(MirrorPadGrad) | |||
| DECLARE_OP_ADAPTER(BoundingBoxEncode) | |||
| DECLARE_OP_USE_OUTPUT(BoundingBoxEncode) | |||
| DECLARE_OP_ADAPTER(BoundingBoxDecode) | |||
| @@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm | |||
| from .container import SequentialCell, CellList | |||
| from .conv import Conv2d, Conv2dTranspose | |||
| from .lstm import LSTM | |||
| from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradients | |||
| from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradients, Pad | |||
| from .embedding import Embedding | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| @@ -34,5 +34,5 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', | |||
| 'LSTM', | |||
| 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'ImageGradients', | |||
| 'Embedding', | |||
| 'AvgPool2d', 'MaxPool2d', | |||
| 'AvgPool2d', 'MaxPool2d', 'Pad', | |||
| ] | |||
| @@ -415,3 +415,72 @@ class ImageGradients(Cell): | |||
| dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) | |||
| dx = P.Concat(3)((dx, dx_last)) | |||
| return dy, dx | |||
| class Pad(Cell): | |||
| """ | |||
| Pads the input tensor according to the paddings and mode. | |||
| Args: | |||
| paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of | |||
| paddings are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be | |||
| extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how many sizes to | |||
| be extended behind of the `D` th dimension of the input tensor. | |||
| mode (string): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC". | |||
| Default: "CONSTANT". | |||
| Inputs: | |||
| - ** input_x** (Tensor) - The input tensor. | |||
| Outputs: | |||
| Tensor, the tensor after padding. | |||
| - If `mode` is "CONSTANT", it fill the edge with 0, regardless of the values of the `input_x`. | |||
| If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the | |||
| Outputs is [[0,0,0,0,0,0,0],[0,0,1,2,3,0,0],[0,0,4,5,6,0,0],[0,0,7,8,9,0,0],[0,0,0,0,0,0,0]]. | |||
| - If 'mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in, | |||
| symmetry. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the | |||
| Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]]. | |||
| - If 'mode' is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied | |||
| according to the symmetry axis, except that it includes the symmetry axis. If the `input_x` | |||
| is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is | |||
| [[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]]. | |||
| Examples: | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as P | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.pad = nn.Pad(paddings=((1,1),(2,2)), mode="CONSTANT") | |||
| >>> def construct(self, x): | |||
| >>> return self.pad(x) | |||
| >>> x = np.random.random(size=(2, 3)).astype(np.float32) | |||
| >>> pad = Net() | |||
| >>> ms_output = pad(Tensor(x)) | |||
| """ | |||
| def __init__(self, paddings, mode="CONSTANT"): | |||
| super(Pad, self).__init__() | |||
| self.mode = mode | |||
| self.paddings = paddings | |||
| validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) | |||
| if not isinstance(paddings, tuple): | |||
| raise TypeError('Paddings must be tuple type.') | |||
| for item in paddings: | |||
| if len(item) != 2: | |||
| raise ValueError('The shape of paddings must be (n, 2).') | |||
| if mode == "CONSTANT": | |||
| self.pad = P.Pad(self.paddings) | |||
| else: | |||
| self.paddings = Tensor(np.array(self.paddings)) | |||
| self.pad = P.MirrorPad(mode=mode) | |||
| def construct(self, x): | |||
| if self.mode == "CONSTANT": | |||
| x = self.pad(x) | |||
| else: | |||
| x = self.pad(x, self.paddings) | |||
| return x | |||
| @@ -470,6 +470,17 @@ def get_bprop_pad(self): | |||
| return bprop | |||
| @bprop_getters.register(P.MirrorPad) | |||
| def get_bprop_mirror_pad(self): | |||
| """Grad definition for `MirrorPad` operation.""" | |||
| mirror_pad_grad = G.MirrorPadGrad(self.mode) | |||
| def bprop(x, paddings, out, dout): | |||
| dx = mirror_pad_grad(dout, paddings, x) | |||
| return (dx, zeros_like(paddings)) | |||
| return bprop | |||
| @bprop_getters.register(P.ROIAlign) | |||
| def get_bprop_roi_align(self): | |||
| """Grad definition for `ROIAlign` operation.""" | |||
| @@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, | |||
| LogSoftmax, | |||
| MaxPool, | |||
| AvgPool, Conv2DBackpropInput, | |||
| MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, | |||
| MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, | |||
| ResizeBilinear, Sigmoid, | |||
| SigmoidCrossEntropyWithLogits, | |||
| SmoothL1Loss, Softmax, | |||
| @@ -180,6 +180,7 @@ __all__ = [ | |||
| 'ScatterNd', | |||
| 'ResizeNearestNeighbor', | |||
| 'Pad', | |||
| 'MirrorPad', | |||
| 'GatherNd', | |||
| 'ScatterNdUpdate', | |||
| 'Floor', | |||
| @@ -947,6 +947,24 @@ class TanhGrad(PrimitiveWithInfer): | |||
| return out | |||
| class MirrorPadGrad(PrimitiveWithInfer): | |||
| """Gradients of MirrorPad operation.""" | |||
| @prim_attr_register | |||
| def __init__(self, mode="REFLECT"): | |||
| """init MirrorPad""" | |||
| validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) | |||
| self.mode = mode | |||
| def __infer__(self, dout, paddings, x): | |||
| validator.check_subclass("dout", dout['dtype'], mstype.tensor) | |||
| validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) | |||
| validator.check_subclass("input_x", x['dtype'], mstype.tensor) | |||
| return {'shape': x['shape'], | |||
| 'dtype': dout['dtype'], | |||
| 'value': None} | |||
| class RefToEmbed(Primitive): | |||
| r""" | |||
| Make a key from Ref. | |||
| @@ -2092,6 +2092,7 @@ class Pad(PrimitiveWithInfer): | |||
| for item in paddings: | |||
| if len(item) != 2: | |||
| raise ValueError('The shape of paddings must be (n, 2).') | |||
| self.paddings = paddings | |||
| def infer_shape(self, x): | |||
| paddings = np.array(self.paddings) | |||
| @@ -2104,9 +2105,78 @@ class Pad(PrimitiveWithInfer): | |||
| return y_shape | |||
| def infer_dtype(self, x): | |||
| validator.check_subclass("input_x", x, mstype.tensor) | |||
| return x | |||
| class MirrorPad(PrimitiveWithInfer): | |||
| """ | |||
| Pads the input tensor according to the paddings and mode. | |||
| Args: | |||
| mode (string): Specifies padding mode. The optional values are "REFLECT", "SYMMETRIC". | |||
| Default: "REFLECT". | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor. | |||
| - **paddings** (Tensor) - The paddings tensor. The value of `paddings` is a matrix(list), | |||
| and its shape is (N, 2). N is the rank of input data. All elements of paddings | |||
| are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be | |||
| extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates | |||
| how many sizes to be extended behind of the `D` th dimension of the input tensor. | |||
| Outputs: | |||
| Tensor, the tensor after padding. | |||
| - If 'mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in, | |||
| symmetry. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the | |||
| Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]]. | |||
| - If 'mode' is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied | |||
| according to the symmetry axis, except that it includes the symmetry axis. If the `input_x` | |||
| is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is | |||
| [[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]]. | |||
| Examples: | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.ops import operations as P | |||
| >>> import mindspore.nn as nn | |||
| >>> import numpy as np | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.pad = P.MirrorPad(mode="REFLECT") | |||
| >>> def construct(self, x, paddings): | |||
| >>> return self.pad(x, paddings) | |||
| >>> x = np.random.random(size=(2, 3)).astype(np.float32) | |||
| >>> paddings = Tensor([[1,1],[2,2]]) | |||
| >>> pad = Net() | |||
| >>> ms_output = pad(Tensor(x), paddings) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, mode='REFLECT'): | |||
| """Init Pad""" | |||
| validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) | |||
| self.mode = mode | |||
| def __infer__(self, input_x, paddings): | |||
| validator.check_subclass("input_x", input_x['dtype'], mstype.tensor) | |||
| validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) | |||
| x_shape = list(input_x['shape']) | |||
| paddings_value = paddings['value'].asnumpy() | |||
| paddings_size = paddings_value.size | |||
| validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ) | |||
| if not np.all(paddings_size >= 0): | |||
| raise ValueError('All elements of paddings must be >= 0.') | |||
| y_shape = () | |||
| for i in range(0, int(paddings_size / 2)): | |||
| y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),) | |||
| return {'shape': y_shape, | |||
| 'dtype': input_x['dtype'], | |||
| 'value': None} | |||
| class ROIAlign(PrimitiveWithInfer): | |||
| """ | |||
| Computes Region of Interest (RoI) Align operator. | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test nn pad """ | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| from mindspore.ops.composite import GradOperation | |||
| from mindspore.common.api import ms_function | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| class Net(nn.Cell): | |||
| def __init__(self, raw_paddings, mode): | |||
| super(Net, self).__init__() | |||
| self.pad = nn.Pad(raw_paddings, mode=mode) | |||
| @ms_function | |||
| def construct(self, x): | |||
| return self.pad(x) | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) | |||
| self.network = network | |||
| @ms_function | |||
| def construct(self, x, grads): | |||
| return self.grad(self.network)(x, grads) | |||
| def test_pad_train(): | |||
| mode = 'CONSTANT' | |||
| x = np.random.random(size=(2, 3)).astype(np.float32) | |||
| raw_paddings = ((1, 1), (2, 2)) | |||
| grads = np.random.random(size=(4, 7)).astype(np.float32) | |||
| grad = Grad(Net(raw_paddings, mode)) | |||
| output = grad(Tensor(x), Tensor(grads)) | |||
| print("=================output====================") | |||
| print(output) | |||
| def test_pad_infer(): | |||
| mode = 'CONSTANT' | |||
| x = np.random.random(size=(2, 3)).astype(np.float32) | |||
| raw_paddings = ((1, 1), (2, 2)) | |||
| net = Net(raw_paddings, mode) | |||
| output = net(Tensor(x)) | |||
| print("=================output====================") | |||
| print(output) | |||