| @@ -152,22 +152,12 @@ class Flatten(Cell): | |||
| def construct(self, x): | |||
| return F.reshape(x, (F.shape(x)[0], -1)) | |||
| def matmul_bias_select(x_shape, in_channel, out_channel): | |||
| """matmul and bias_add selection for different input""" | |||
| x_dim = len(x_shape) | |||
| @constexpr | |||
| def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): | |||
| """get broadcast_weight_bias shape""" | |||
| broad_weight_shape = x_shape[:-2] + (out_channel, in_channel) | |||
| broad_bias_shape = x_shape[:-1] + (out_channel,) | |||
| weight_broadcast_to = P.BroadcastTo(broad_weight_shape) | |||
| bias_broadcast_to = P.BroadcastTo(broad_bias_shape) | |||
| if x_dim == 2: | |||
| matmul = P.MatMul(False, True) | |||
| bias_add = P.BiasAdd() | |||
| else: | |||
| matmul = P.BatchMatMul(False, True) | |||
| bias_add = P.TensorAdd() | |||
| return matmul, bias_add, weight_broadcast_to, bias_broadcast_to | |||
| return broad_weight_shape, broad_bias_shape | |||
| class Dense(Cell): | |||
| r""" | |||
| @@ -236,7 +226,11 @@ class Dense(Cell): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("Bias init shape error.") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| self.bias_add = P.BiasAdd() | |||
| self.tensor_add = P.TensorAdd() | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.batch_matmul = P.BatchMatMul(transpose_b=True) | |||
| self.activation = get_activation(activation) if isinstance(activation, str) else activation | |||
| if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | |||
| raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | |||
| @@ -245,12 +239,23 @@ class Dense(Cell): | |||
| def construct(self, x): | |||
| x_shape = self.shape_op(x) | |||
| x_dim = len(x_shape) | |||
| matmul, bias_add, weight_broadcast_to, bias_broadcast_to = matmul_bias_select(x_shape, self.in_channels, | |||
| self.out_channels) | |||
| weight = self.weight if x_dim == 2 else weight_broadcast_to(self.weight) | |||
| if x_dim == 2: | |||
| matmul = self.matmul | |||
| bias_add = self.bias_add if self.has_bias else None | |||
| weight = self.weight | |||
| bias = self.bias | |||
| else: | |||
| broad_weight_shape, broad_bias_shape = get_broadcast_weight_bias_shape(x_shape, self.out_channels, | |||
| self.in_channels) | |||
| weight_broadcast_to = P.BroadcastTo(broad_weight_shape) | |||
| bias_broadcast_to = P.BroadcastTo(broad_bias_shape) | |||
| matmul = self.batch_matmul | |||
| bias_add = self.tensor_add if self.has_bias else None | |||
| weight = weight_broadcast_to(self.weight) | |||
| bias = bias_broadcast_to(self.bias) if self.has_bias else self.bias | |||
| x = matmul(x, weight) | |||
| if self.has_bias: | |||
| bias = self.bias if x_dim == 2 else bias_broadcast_to(self.bias) | |||
| x = bias_add(x, bias) | |||
| if self.activation_flag: | |||
| x = self.activation(x) | |||
| @@ -21,34 +21,10 @@ from mindspore import Tensor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class DenseDevice(nn.Cell): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| device_target='Ascend', | |||
| weight_init='normal', | |||
| bias_init='zeros'): | |||
| super(DenseDevice, self).__init__() | |||
| self.device_target = device_target | |||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| self.bias_add = P.BiasAdd() | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.matmul.add_prim_attr("primitive_target", self.device_target) | |||
| self.bias_add.add_prim_attr("primitive_target", self.device_target) | |||
| def construct(self, x): | |||
| x = self.matmul(x, self.weight) | |||
| x = self.bias_add(x, self.bias) | |||
| return x | |||
| class LeNet(nn.Cell): | |||
| def __init__(self): | |||
| super(LeNet, self).__init__() | |||
| @@ -59,9 +35,15 @@ class LeNet(nn.Cell): | |||
| self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | |||
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.reshape = P.Reshape() | |||
| self.fc1 = DenseDevice(400, 120, device_target='CPU') | |||
| self.fc2 = DenseDevice(120, 84, device_target='CPU') | |||
| self.fc3 = DenseDevice(84, 10, device_target='CPU') | |||
| self.fc1 = nn.Dense(400, 120) | |||
| self.fc1.matmul.add_prim_attr("primitive_target", "CPU") | |||
| self.fc1.bias_add.add_prim_attr("primitive_target", "CPU") | |||
| self.fc2 = nn.Dense(120, 84) | |||
| self.fc2.matmul.add_prim_attr("primitive_target", "CPU") | |||
| self.fc2.bias_add.add_prim_attr("primitive_target", "CPU") | |||
| self.fc3 = nn.Dense(84, 10) | |||
| self.fc3.matmul.add_prim_attr("primitive_target", "CPU") | |||
| self.fc3.bias_add.add_prim_attr("primitive_target", "CPU") | |||
| def construct(self, input_x): | |||
| output = self.conv1(input_x) | |||