Browse Source

!12031 Add shape check to Dense

From: @wanyiming
Reviewed-by: @zh_qh,@kingxian
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
5e29ea8184
1 changed files with 7 additions and 0 deletions
  1. +7
    -0
      mindspore/nn/layer/basic.py

+ 7
- 0
mindspore/nn/layer/basic.py View File

@@ -202,6 +202,12 @@ class Flatten(Cell):
def construct(self, x):
return F.reshape(x, (F.shape(x)[0], -1))

@constexpr
def check_dense_input_shape(x):
if len(x) < 2:
raise ValueError('For Dense, the dimension of input should not be less than 2, while the input dimension is '
+ f'{len(x)}.')

class Dense(Cell):
r"""
The dense connected layer.
@@ -291,6 +297,7 @@ class Dense(Cell):

def construct(self, x):
x_shape = self.shape_op(x)
check_dense_input_shape(x_shape)
if len(x_shape) != 2:
x = self.reshape(x, (-1, x_shape[-1]))
x = self.matmul(x, self.weight)


Loading…
Cancel
Save