From: @liangzhibo Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -35,7 +35,7 @@ from .activation import get_activation | |||||
| __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', | __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', | ||||
| 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] | |||||
| 'Tril', 'Triu', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] | |||||
| class Dropout(Cell): | class Dropout(Cell): | ||||
| @@ -601,6 +601,80 @@ class Unfold(Cell): | |||||
| return result | return result | ||||
| @constexpr | |||||
| def tril(x_shape, x_dtype, k): | |||||
| Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril") | |||||
| Validator.check_is_int(k, "k value", "tril") | |||||
| mask = np.tril(np.ones(x_shape), k) | |||||
| return Tensor(mask, x_dtype) | |||||
| class Tril(Cell): | |||||
| """ | |||||
| Returns a tensor with elements above the kth diagonal zeroed. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The input tensor. | |||||
| - **k** (Int) - The index of diagonal. Default: 0 | |||||
| Outputs: | |||||
| Tensor, has the same type as input `x`. | |||||
| Examples: | |||||
| >>> x = Tensor(np.array([[1, 2], [3, 4]])) | |||||
| >>> tril = nn.Tril() | |||||
| >>> result = tril(x) | |||||
| >>> print(result) | |||||
| [[1 0] | |||||
| [3 4]] | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Tril, self).__init__() | |||||
| self.dtype = P.DType() | |||||
| self.mul = P.Mul() | |||||
| def construct(self, x, k=0): | |||||
| assist = tril(x.shape, self.dtype(x), k) | |||||
| return self.mul(x, assist) | |||||
| @constexpr | |||||
| def triu(x_shape, x_dtype, k): | |||||
| Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu") | |||||
| Validator.check_is_int(k, "k value", "triu") | |||||
| mask = np.triu(np.ones(x_shape), k) | |||||
| return Tensor(mask, x_dtype) | |||||
| class Triu(Cell): | |||||
| """ | |||||
| Returns a tensor with elements below the kth diagonal zeroed. | |||||
| Inputs: | |||||
| - **x** (Tensor) - The input tensor. | |||||
| - **k** (Int) - The index of diagonal. Default: 0 | |||||
| Outputs: | |||||
| Tensor, has the same type as input `x`. | |||||
| Examples: | |||||
| >>> x = Tensor(np.array([[1, 2], [3, 4]])) | |||||
| >>> tril = nn.Tril() | |||||
| >>> result = tril(x) | |||||
| >>> print(result) | |||||
| [[1 2] | |||||
| [0 4]] | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Triu, self).__init__() | |||||
| self.dtype = P.DType() | |||||
| self.mul = P.Mul() | |||||
| def construct(self, x, k=0): | |||||
| assist = triu(x.shape, self.dtype(x), k) | |||||
| return self.mul(x, assist) | |||||
| @constexpr | @constexpr | ||||
| def _get_matrix_diag_assist(x_shape, x_dtype): | def _get_matrix_diag_assist(x_shape, x_dtype): | ||||
| Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") | Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") | ||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| test nn.Tril() | |||||
| """ | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def test_tril(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||||
| def construct(self): | |||||
| tril = nn.Tril() | |||||
| return tril(self.value, 0) | |||||
| net = Net() | |||||
| out = net() | |||||
| assert np.sum(out.asnumpy()) == 34 | |||||
| def test_tril_1(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||||
| def construct(self): | |||||
| tril = nn.Tril() | |||||
| return tril(self.value, 1) | |||||
| net = Net() | |||||
| out = net() | |||||
| assert np.sum(out.asnumpy()) == 42 | |||||
| def test_tril_2(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||||
| def construct(self): | |||||
| tril = nn.Tril() | |||||
| return tril(self.value, -1) | |||||
| net = Net() | |||||
| out = net() | |||||
| assert np.sum(out.asnumpy()) == 19 | |||||
| def test_tril_parameter(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| tril = nn.Tril() | |||||
| return tril(x, 0) | |||||
| net = Net() | |||||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||||
| def test_tril_parameter_1(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| tril = nn.Tril() | |||||
| return tril(x, 1) | |||||
| net = Net() | |||||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||||
| def test_tril_parameter_2(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| tril = nn.Tril() | |||||
| return tril(x, -1) | |||||
| net = Net() | |||||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| test nn.Triu() | |||||
| """ | |||||
| import numpy as np | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def test_triu(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||||
| def construct(self): | |||||
| triu = nn.Triu() | |||||
| return triu(self.value, 0) | |||||
| net = Net() | |||||
| out = net() | |||||
| assert np.sum(out.asnumpy()) == 26 | |||||
| def test_triu_1(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||||
| def construct(self): | |||||
| triu = nn.Triu() | |||||
| return triu(self.value, 1) | |||||
| net = Net() | |||||
| out = net() | |||||
| assert np.sum(out.asnumpy()) == 11 | |||||
| def test_triu_2(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||||
| def construct(self): | |||||
| triu = nn.Triu() | |||||
| return triu(self.value, -1) | |||||
| net = Net() | |||||
| out = net() | |||||
| assert np.sum(out.asnumpy()) == 38 | |||||
| def test_triu_parameter(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| triu = nn.Triu() | |||||
| return triu(x, 0) | |||||
| net = Net() | |||||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||||
| def test_triu_parameter_1(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| triu = nn.Triu() | |||||
| return triu(x, 1) | |||||
| net = Net() | |||||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||||
| def test_triu_parameter_2(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x): | |||||
| triu = nn.Triu() | |||||
| return triu(x, -1) | |||||
| net = Net() | |||||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||||