| @@ -195,6 +195,12 @@ class Flatten(Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return F.reshape(x, (F.shape(x)[0], -1)) | return F.reshape(x, (F.shape(x)[0], -1)) | ||||
| @constexpr | |||||
| def check_dense_input_shape(x): | |||||
| if len(x) < 2: | |||||
| raise ValueError('For Dense, the dimension of input should not be less than 2, while the input dimension is ' | |||||
| + f'{len(x)}.') | |||||
| class Dense(Cell): | class Dense(Cell): | ||||
| r""" | r""" | ||||
| The dense connected layer. | The dense connected layer. | ||||
| @@ -278,6 +284,7 @@ class Dense(Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| x_shape = self.shape_op(x) | x_shape = self.shape_op(x) | ||||
| check_dense_input_shape(x_shape) | |||||
| if len(x_shape) != 2: | if len(x_shape) != 2: | ||||
| x = self.reshape(x, (-1, x_shape[-1])) | x = self.reshape(x, (-1, x_shape[-1])) | ||||
| x = self.matmul(x, self.weight) | x = self.matmul(x, self.weight) | ||||