Browse Source

!8661 Add nn.Tril() and nn.Triu()

From: @liangzhibo
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
fa2185b8ef
3 changed files with 291 additions and 1 deletions
  1. +75
    -1
      mindspore/nn/layer/basic.py
  2. +108
    -0
      tests/ut/python/nn/test_tril.py
  3. +108
    -0
      tests/ut/python/nn/test_triu.py

+ 75
- 1
mindspore/nn/layer/basic.py View File

@@ -35,7 +35,7 @@ from .activation import get_activation




__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag']
'Tril', 'Triu', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag']




class Dropout(Cell): class Dropout(Cell):
@@ -601,6 +601,80 @@ class Unfold(Cell):
return result return result




@constexpr
def tril(x_shape, x_dtype, k):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril")
Validator.check_is_int(k, "k value", "tril")
mask = np.tril(np.ones(x_shape), k)
return Tensor(mask, x_dtype)


class Tril(Cell):
"""
Returns a tensor with elements above the kth diagonal zeroed.

Inputs:
- **x** (Tensor) - The input tensor.
- **k** (Int) - The index of diagonal. Default: 0

Outputs:
Tensor, has the same type as input `x`.

Examples:
>>> x = Tensor(np.array([[1, 2], [3, 4]]))
>>> tril = nn.Tril()
>>> result = tril(x)
>>> print(result)
[[1 0]
[3 4]]
"""
def __init__(self):
super(Tril, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()

def construct(self, x, k=0):
assist = tril(x.shape, self.dtype(x), k)
return self.mul(x, assist)


@constexpr
def triu(x_shape, x_dtype, k):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu")
Validator.check_is_int(k, "k value", "triu")
mask = np.triu(np.ones(x_shape), k)
return Tensor(mask, x_dtype)


class Triu(Cell):
"""
Returns a tensor with elements below the kth diagonal zeroed.

Inputs:
- **x** (Tensor) - The input tensor.
- **k** (Int) - The index of diagonal. Default: 0

Outputs:
Tensor, has the same type as input `x`.

Examples:
>>> x = Tensor(np.array([[1, 2], [3, 4]]))
>>> tril = nn.Tril()
>>> result = tril(x)
>>> print(result)
[[1 2]
[0 4]]
"""
def __init__(self):
super(Triu, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()

def construct(self, x, k=0):
assist = triu(x.shape, self.dtype(x), k)
return self.mul(x, assist)


@constexpr @constexpr
def _get_matrix_diag_assist(x_shape, x_dtype): def _get_matrix_diag_assist(x_shape, x_dtype):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist") Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")


+ 108
- 0
tests/ut/python/nn/test_tril.py View File

@@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test nn.Tril()
"""
import numpy as np

import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context

context.set_context(mode=context.GRAPH_MODE)


def test_tril():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

def construct(self):
tril = nn.Tril()
return tril(self.value, 0)

net = Net()
out = net()
assert np.sum(out.asnumpy()) == 34


def test_tril_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

def construct(self):
tril = nn.Tril()
return tril(self.value, 1)

net = Net()
out = net()
assert np.sum(out.asnumpy()) == 42


def test_tril_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

def construct(self):
tril = nn.Tril()
return tril(self.value, -1)

net = Net()
out = net()
assert np.sum(out.asnumpy()) == 19


def test_tril_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
tril = nn.Tril()
return tril(x, 0)

net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))


def test_tril_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
tril = nn.Tril()
return tril(x, 1)

net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))


def test_tril_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
tril = nn.Tril()
return tril(x, -1)

net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

+ 108
- 0
tests/ut/python/nn/test_triu.py View File

@@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test nn.Triu()
"""
import numpy as np

import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context

context.set_context(mode=context.GRAPH_MODE)


def test_triu():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

def construct(self):
triu = nn.Triu()
return triu(self.value, 0)

net = Net()
out = net()
assert np.sum(out.asnumpy()) == 26


def test_triu_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

def construct(self):
triu = nn.Triu()
return triu(self.value, 1)

net = Net()
out = net()
assert np.sum(out.asnumpy()) == 11


def test_triu_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

def construct(self):
triu = nn.Triu()
return triu(self.value, -1)

net = Net()
out = net()
assert np.sum(out.asnumpy()) == 38


def test_triu_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
triu = nn.Triu()
return triu(x, 0)

net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))


def test_triu_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
triu = nn.Triu()
return triu(x, 1)

net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))


def test_triu_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()

def construct(self, x):
triu = nn.Triu()
return triu(x, -1)

net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

Loading…
Cancel
Save