diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 5de8bf80e9..5e747a800d 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -195,15 +195,6 @@ class Flatten(Cell): def construct(self, x): return F.reshape(x, (F.shape(x)[0], -1)) - -@constexpr -def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): - """get broadcast_weight_bias shape""" - broad_weight_shape = x_shape[:-2] + (out_channel, in_channel) - broad_bias_shape = x_shape[:-1] + (out_channel,) - return broad_weight_shape, broad_bias_shape - - class Dense(Cell): r""" The dense connected layer. @@ -249,7 +240,7 @@ class Dense(Cell): (2, 4) """ - @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) + @cell_attr_register(attrs=['has_bias', 'activation']) def __init__(self, in_channels, out_channels, @@ -261,8 +252,10 @@ class Dense(Cell): self.in_channels = Validator.check_positive_int(in_channels) self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) + self.reshape = P.Reshape() self.shape_op = P.Shape() + if isinstance(weight_init, Tensor): if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \ weight_init.shape[1] != in_channels: @@ -276,10 +269,8 @@ class Dense(Cell): raise ValueError("Bias init shape error.") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") self.bias_add = P.BiasAdd() - self.tensor_add = P.TensorAdd() self.matmul = P.MatMul(transpose_b=True) - self.batch_matmul = P.BatchMatMul(transpose_b=True) self.activation = get_activation(activation) if isinstance(activation, str) else activation if activation is not None and not isinstance(self.activation, (Cell, Primitive)): raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) @@ -287,27 +278,16 @@ class Dense(Cell): def construct(self, x): x_shape = self.shape_op(x) - x_dim = len(x_shape) - if x_dim == 2: - matmul = self.matmul - bias_add = self.bias_add if self.has_bias else None - weight = self.weight - bias = self.bias - else: - broad_weight_shape, broad_bias_shape = get_broadcast_weight_bias_shape(x_shape, self.out_channels, - self.in_channels) - weight_broadcast_to = P.BroadcastTo(broad_weight_shape) - bias_broadcast_to = P.BroadcastTo(broad_bias_shape) - matmul = self.batch_matmul - bias_add = self.tensor_add if self.has_bias else None - weight = weight_broadcast_to(self.weight) - bias = bias_broadcast_to(self.bias) if self.has_bias else self.bias - - x = matmul(x, weight) + if len(x_shape) != 2: + x = self.reshape(x, (-1, x_shape[-1])) + x = self.matmul(x, self.weight) if self.has_bias: - x = bias_add(x, bias) + x = self.bias_add(x, self.bias) if self.activation_flag: x = self.activation(x) + if len(x_shape) != 2: + out_shape = x_shape[:-1] + (-1,) + x = self.reshape(x, out_shape) return x def extend_repr(self):