diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index e82acc8bce..d9d1724286 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -202,6 +202,12 @@ class Flatten(Cell): def construct(self, x): return F.reshape(x, (F.shape(x)[0], -1)) +@constexpr +def check_dense_input_shape(x): + if len(x) < 2: + raise ValueError('For Dense, the dimension of input should not be less than 2, while the input dimension is ' + + f'{len(x)}.') + class Dense(Cell): r""" The dense connected layer. @@ -291,6 +297,7 @@ class Dense(Cell): def construct(self, x): x_shape = self.shape_op(x) + check_dense_input_shape(x_shape) if len(x_shape) != 2: x = self.reshape(x, (-1, x_shape[-1])) x = self.matmul(x, self.weight)