From e5eb05adc4a5908debdc391ea34bde8e1a30d473 Mon Sep 17 00:00:00 2001 From: wanyiming Date: Wed, 3 Feb 2021 10:50:43 +0800 Subject: [PATCH] checkdense --- mindspore/nn/layer/basic.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 6b081a0722..053b822841 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -195,6 +195,12 @@ class Flatten(Cell): def construct(self, x): return F.reshape(x, (F.shape(x)[0], -1)) +@constexpr +def check_dense_input_shape(x): + if len(x) < 2: + raise ValueError('For Dense, the dimension of input should not be less than 2, while the input dimension is ' + + f'{len(x)}.') + class Dense(Cell): r""" The dense connected layer. @@ -278,6 +284,7 @@ class Dense(Cell): def construct(self, x): x_shape = self.shape_op(x) + check_dense_input_shape(x_shape) if len(x_shape) != 2: x = self.reshape(x, (-1, x_shape[-1])) x = self.matmul(x, self.weight)