|
|
|
@@ -202,6 +202,12 @@ class Flatten(Cell): |
|
|
|
def construct(self, x): |
|
|
|
return F.reshape(x, (F.shape(x)[0], -1)) |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def check_dense_input_shape(x): |
|
|
|
if len(x) < 2: |
|
|
|
raise ValueError('For Dense, the dimension of input should not be less than 2, while the input dimension is ' |
|
|
|
+ f'{len(x)}.') |
|
|
|
|
|
|
|
class Dense(Cell): |
|
|
|
r""" |
|
|
|
The dense connected layer. |
|
|
|
@@ -291,6 +297,7 @@ class Dense(Cell): |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x_shape = self.shape_op(x) |
|
|
|
check_dense_input_shape(x_shape) |
|
|
|
if len(x_shape) != 2: |
|
|
|
x = self.reshape(x, (-1, x_shape[-1])) |
|
|
|
x = self.matmul(x, self.weight) |
|
|
|
|