| @@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm | |||
| from .container import SequentialCell, CellList | |||
| from .conv import Conv2d, Conv2dTranspose | |||
| from .lstm import LSTM | |||
| from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot | |||
| from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradients | |||
| from .embedding import Embedding | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| @@ -31,7 +31,7 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', | |||
| 'SequentialCell', 'CellList', | |||
| 'Conv2d', 'Conv2dTranspose', | |||
| 'LSTM', | |||
| 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', | |||
| 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'ImageGradients', | |||
| 'Embedding', | |||
| 'AvgPool2d', 'MaxPool2d', | |||
| ] | |||
| @@ -370,3 +370,48 @@ class OneHot(Cell): | |||
| def construct(self, indices): | |||
| return self.onehot(indices, self.depth, self.on_value, self.off_value) | |||
| class ImageGradients(Cell): | |||
| r""" | |||
| Returns two tensors, the first is along the height dimension and the second is along the width dimension. | |||
| Assume an image shape is :math:`h*w`. The gradients along the height and the width are :math:`dy` and :math:`dx`, | |||
| respectively. | |||
| .. math:: | |||
| dy[i] = \begin{cases} image[i+1, :]-image[i, :], &if\ 0<=i<h-1 \cr | |||
| 0, &if\ i==h-1\end{cases} | |||
| dx[i] = \begin{cases} image[:, i+1]-image[:, i], &if\ 0<=i<w-1 \cr | |||
| 0, &if\ i==w-1\end{cases} | |||
| Inputs: | |||
| - **images** (Tensor) - The input image data, with format 'NCHW'. | |||
| Outputs: | |||
| - **dy** (Tensor) - vertical image gradients, the same type and shape as input. | |||
| - **dx** (Tensor) - horizontal image gradients, the same type and shape as input. | |||
| Examples: | |||
| >>> net = nn.ImageGradients() | |||
| >>> image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32) | |||
| >>> net(image) | |||
| [[[[2,2] | |||
| [0,0]]]] | |||
| [[[[1,0] | |||
| [1,0]]]] | |||
| """ | |||
| def __init__(self): | |||
| super(ImageGradients, self).__init__() | |||
| def construct(self, images): | |||
| batch_size, depth, height, width = P.Shape()(images) | |||
| dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] | |||
| dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) | |||
| dy = P.Concat(2)((dy, dy_last)) | |||
| dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] | |||
| dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) | |||
| dx = P.Concat(3)((dx, dx_last)) | |||
| return dy, dx | |||
| @@ -0,0 +1,62 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| context.set_context(device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.image_gradients = nn.ImageGradients() | |||
| @ms_function | |||
| def construct(self, x): | |||
| return self.image_gradients(x) | |||
| def test_image_gradients(): | |||
| image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32) | |||
| expected_dy = np.array([[[[2,2],[0,0]]]]).astype(np.int32) | |||
| expected_dx = np.array([[[[1,0],[1,0]]]]).astype(np.int32) | |||
| net = Net() | |||
| dy, dx = net(image) | |||
| assert np.any(dx.asnumpy()-expected_dx) == False | |||
| assert np.any(dy.asnumpy()-expected_dy) == False | |||
| def test_image_gradients_multi_channel_depth(): | |||
| # 4 x 2 x 2 x 2 | |||
| dtype = mstype.int32 | |||
| image = Tensor(np.array([[[[1,2],[3,4]], [[5,6],[7,8]]], | |||
| [[[3,5],[7,9]], [[11,13],[15,17]]], | |||
| [[[5,10],[15,20]], [[25,30],[35,40]]], | |||
| [[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype) | |||
| expected_dy = Tensor(np.array([[[[2,2],[0,0]], [[2,2],[0,0]]], | |||
| [[[4,4],[0,0]], [[4,4],[0,0]]], | |||
| [[[10,10],[0,0]], [[10,10],[0,0]]], | |||
| [[[20,20],[0,0]], [[20,20],[0,0]]]]), dtype=dtype) | |||
| expected_dx = Tensor(np.array([[[[1,0],[1,0]], [[1,0],[1,0]]], | |||
| [[[2,0],[2,0]], [[2,0],[2,0]]], | |||
| [[[5,0],[5,0]], [[5,0],[5,0]]], | |||
| [[[10,0],[10,0]], [[10,0],[10,0]]]]), dtype=dtype) | |||
| net = Net() | |||
| dy, dx = net(image) | |||
| assert np.any(dx.asnumpy()-expected_dx.asnumpy()) == False | |||
| assert np.any(dy.asnumpy()-expected_dy.asnumpy()) == False | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test loss """ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common.api import ms_function | |||
| context.set_context(device_target="Ascend") | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.image_gradients = nn.ImageGradients() | |||
| @ms_function | |||
| def construct(self, x): | |||
| return self.image_gradients(x) | |||
| def test_compile(): | |||
| # input shape 1 x 1 x 2 x 2 | |||
| image = Tensor(np.array([[[[1,2],[3,4]]]]), dtype=mstype.int32) | |||
| net = Net() | |||
| _executor.compile(net, image) | |||
| def test_compile_multi_channel(): | |||
| # input shape 4 x 2 x 2 x 2 | |||
| dtype = mstype.int32 | |||
| image = Tensor(np.array([[[[1,2],[3,4]], [[5,6],[7,8]]], | |||
| [[[3,5],[7,9]], [[11,13],[15,17]]], | |||
| [[[5,10],[15,20]], [[25,30],[35,40]]], | |||
| [[[10,20],[30,40]], [[50,60],[70,80]]]]), dtype=dtype) | |||
| net = Net() | |||
| _executor.compile(net, image) | |||