From 2608c17bee1c5a6cc2c28b2a7213866c8bbdb3e6 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Tue, 8 Dec 2020 11:18:04 +0800 Subject: [PATCH] Fix bug of tril/triu and enumerate under pynative --- mindspore/common/tensor.py | 4 +++- mindspore/nn/layer/basic.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 9f8de3f49d..5115ea1ff5 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -176,6 +176,8 @@ class Tensor(Tensor_): return out def __getitem__(self, index): + if isinstance(index, int) and index >= self.shape[0]: + raise IndexError("index {} is out of bounds for axis 0 with size {}".format(index, self.shape[0])) out = tensor_operator_registry.get('__getitem__')(self, index) return out @@ -319,7 +321,7 @@ class Tensor(Tensor_): Args: shape (Tensor): The input tensor. The shape of input tensor must obey - the broadcasting rule. + the broadcasting rule. Returns: Tensor, has the same dimension as input tensor. diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index f660ce67ab..51b21fcc8f 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -725,7 +725,7 @@ class Tril(Cell): def construct(self, x, k=0): assist = tril(x.shape, self.dtype(x), k) - result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32)) + result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32)) return self.cast(result, self.dtype(x)) @@ -767,7 +767,7 @@ class Triu(Cell): def construct(self, x, k=0): assist = triu(x.shape, self.dtype(x), k) - result = self.mul(self.cast(x, mstype.int32), self.cast(assist, mstype.int32)) + result = self.mul(self.cast(x, mstype.float32), self.cast(assist, mstype.float32)) return self.cast(result, self.dtype(x))