From: @liangzhibo Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -35,7 +35,7 @@ from .activation import get_activation | |||
| __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', | |||
| 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] | |||
| 'Tril', 'Triu', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag'] | |||
| class Dropout(Cell): | |||
| @@ -601,6 +601,80 @@ class Unfold(Cell): | |||
| return result | |||
| @constexpr | |||
| def tril(x_shape, x_dtype, k): | |||
| Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril") | |||
| Validator.check_is_int(k, "k value", "tril") | |||
| mask = np.tril(np.ones(x_shape), k) | |||
| return Tensor(mask, x_dtype) | |||
| class Tril(Cell): | |||
| """ | |||
| Returns a tensor with elements above the kth diagonal zeroed. | |||
| Inputs: | |||
| - **x** (Tensor) - The input tensor. | |||
| - **k** (Int) - The index of diagonal. Default: 0 | |||
| Outputs: | |||
| Tensor, has the same type as input `x`. | |||
| Examples: | |||
| >>> x = Tensor(np.array([[1, 2], [3, 4]])) | |||
| >>> tril = nn.Tril() | |||
| >>> result = tril(x) | |||
| >>> print(result) | |||
| [[1 0] | |||
| [3 4]] | |||
| """ | |||
| def __init__(self): | |||
| super(Tril, self).__init__() | |||
| self.dtype = P.DType() | |||
| self.mul = P.Mul() | |||
| def construct(self, x, k=0): | |||
| assist = tril(x.shape, self.dtype(x), k) | |||
| return self.mul(x, assist) | |||
| @constexpr | |||
| def triu(x_shape, x_dtype, k): | |||
| Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu") | |||
| Validator.check_is_int(k, "k value", "triu") | |||
| mask = np.triu(np.ones(x_shape), k) | |||
| return Tensor(mask, x_dtype) | |||
| class Triu(Cell): | |||
| """ | |||
| Returns a tensor with elements below the kth diagonal zeroed. | |||
| Inputs: | |||
| - **x** (Tensor) - The input tensor. | |||
| - **k** (Int) - The index of diagonal. Default: 0 | |||
| Outputs: | |||
| Tensor, has the same type as input `x`. | |||
| Examples: | |||
| >>> x = Tensor(np.array([[1, 2], [3, 4]])) | |||
| >>> tril = nn.Tril() | |||
| >>> result = tril(x) | |||
| >>> print(result) | |||
| [[1 2] | |||
| [0 4]] | |||
| """ | |||
| def __init__(self): | |||
| super(Triu, self).__init__() | |||
| self.dtype = P.DType() | |||
| self.mul = P.Mul() | |||
| def construct(self, x, k=0): | |||
| assist = triu(x.shape, self.dtype(x), k) | |||
| return self.mul(x, assist) | |||
| @constexpr | |||
| def _get_matrix_diag_assist(x_shape, x_dtype): | |||
| Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") | |||
| @@ -0,0 +1,108 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| test nn.Tril() | |||
| """ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_tril(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||
| def construct(self): | |||
| tril = nn.Tril() | |||
| return tril(self.value, 0) | |||
| net = Net() | |||
| out = net() | |||
| assert np.sum(out.asnumpy()) == 34 | |||
| def test_tril_1(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||
| def construct(self): | |||
| tril = nn.Tril() | |||
| return tril(self.value, 1) | |||
| net = Net() | |||
| out = net() | |||
| assert np.sum(out.asnumpy()) == 42 | |||
| def test_tril_2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||
| def construct(self): | |||
| tril = nn.Tril() | |||
| return tril(self.value, -1) | |||
| net = Net() | |||
| out = net() | |||
| assert np.sum(out.asnumpy()) == 19 | |||
| def test_tril_parameter(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| tril = nn.Tril() | |||
| return tril(x, 0) | |||
| net = Net() | |||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||
| def test_tril_parameter_1(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| tril = nn.Tril() | |||
| return tril(x, 1) | |||
| net = Net() | |||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||
| def test_tril_parameter_2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| tril = nn.Tril() | |||
| return tril(x, -1) | |||
| net = Net() | |||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||
| @@ -0,0 +1,108 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| test nn.Triu() | |||
| """ | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_triu(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||
| def construct(self): | |||
| triu = nn.Triu() | |||
| return triu(self.value, 0) | |||
| net = Net() | |||
| out = net() | |||
| assert np.sum(out.asnumpy()) == 26 | |||
| def test_triu_1(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||
| def construct(self): | |||
| triu = nn.Triu() | |||
| return triu(self.value, 1) | |||
| net = Net() | |||
| out = net() | |||
| assert np.sum(out.asnumpy()) == 11 | |||
| def test_triu_2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | |||
| def construct(self): | |||
| triu = nn.Triu() | |||
| return triu(self.value, -1) | |||
| net = Net() | |||
| out = net() | |||
| assert np.sum(out.asnumpy()) == 38 | |||
| def test_triu_parameter(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| triu = nn.Triu() | |||
| return triu(x, 0) | |||
| net = Net() | |||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||
| def test_triu_parameter_1(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| triu = nn.Triu() | |||
| return triu(x, 1) | |||
| net = Net() | |||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||
| def test_triu_parameter_2(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x): | |||
| triu = nn.Triu() | |||
| return triu(x, -1) | |||
| net = Net() | |||
| net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) | |||