|
|
|
@@ -195,15 +195,6 @@ class Flatten(Cell): |
|
|
|
def construct(self, x): |
|
|
|
return F.reshape(x, (F.shape(x)[0], -1)) |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): |
|
|
|
"""get broadcast_weight_bias shape""" |
|
|
|
broad_weight_shape = x_shape[:-2] + (out_channel, in_channel) |
|
|
|
broad_bias_shape = x_shape[:-1] + (out_channel,) |
|
|
|
return broad_weight_shape, broad_bias_shape |
|
|
|
|
|
|
|
|
|
|
|
class Dense(Cell): |
|
|
|
r""" |
|
|
|
The dense connected layer. |
|
|
|
@@ -249,7 +240,7 @@ class Dense(Cell): |
|
|
|
(2, 4) |
|
|
|
""" |
|
|
|
|
|
|
|
@cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) |
|
|
|
@cell_attr_register(attrs=['has_bias', 'activation']) |
|
|
|
def __init__(self, |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
@@ -261,8 +252,10 @@ class Dense(Cell): |
|
|
|
self.in_channels = Validator.check_positive_int(in_channels) |
|
|
|
self.out_channels = Validator.check_positive_int(out_channels) |
|
|
|
self.has_bias = Validator.check_bool(has_bias) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape_op = P.Shape() |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(weight_init, Tensor): |
|
|
|
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ |
|
|
|
weight_init.shape[1] != in_channels: |
|
|
|
@@ -276,10 +269,8 @@ class Dense(Cell): |
|
|
|
raise ValueError("Bias init shape error.") |
|
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") |
|
|
|
self.bias_add = P.BiasAdd() |
|
|
|
self.tensor_add = P.TensorAdd() |
|
|
|
|
|
|
|
self.matmul = P.MatMul(transpose_b=True) |
|
|
|
self.batch_matmul = P.BatchMatMul(transpose_b=True) |
|
|
|
self.activation = get_activation(activation) if isinstance(activation, str) else activation |
|
|
|
if activation is not None and not isinstance(self.activation, (Cell, Primitive)): |
|
|
|
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) |
|
|
|
@@ -287,27 +278,16 @@ class Dense(Cell): |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x_shape = self.shape_op(x) |
|
|
|
x_dim = len(x_shape) |
|
|
|
if x_dim == 2: |
|
|
|
matmul = self.matmul |
|
|
|
bias_add = self.bias_add if self.has_bias else None |
|
|
|
weight = self.weight |
|
|
|
bias = self.bias |
|
|
|
else: |
|
|
|
broad_weight_shape, broad_bias_shape = get_broadcast_weight_bias_shape(x_shape, self.out_channels, |
|
|
|
self.in_channels) |
|
|
|
weight_broadcast_to = P.BroadcastTo(broad_weight_shape) |
|
|
|
bias_broadcast_to = P.BroadcastTo(broad_bias_shape) |
|
|
|
matmul = self.batch_matmul |
|
|
|
bias_add = self.tensor_add if self.has_bias else None |
|
|
|
weight = weight_broadcast_to(self.weight) |
|
|
|
bias = bias_broadcast_to(self.bias) if self.has_bias else self.bias |
|
|
|
|
|
|
|
x = matmul(x, weight) |
|
|
|
if len(x_shape) != 2: |
|
|
|
x = self.reshape(x, (-1, x_shape[-1])) |
|
|
|
x = self.matmul(x, self.weight) |
|
|
|
if self.has_bias: |
|
|
|
x = bias_add(x, bias) |
|
|
|
x = self.bias_add(x, self.bias) |
|
|
|
if self.activation_flag: |
|
|
|
x = self.activation(x) |
|
|
|
if len(x_shape) != 2: |
|
|
|
out_shape = x_shape[:-1] + (-1,) |
|
|
|
x = self.reshape(x, out_shape) |
|
|
|
return x |
|
|
|
|
|
|
|
def extend_repr(self): |
|
|
|
|