| @@ -153,6 +153,22 @@ class Flatten(Cell): | |||
| return F.reshape(x, (F.shape(x)[0], -1)) | |||
| def matmul_bias_select(x_shape, in_channel, out_channel): | |||
| """matmul and bias_add selection for different input""" | |||
| x_dim = len(x_shape) | |||
| broad_weight_shape = x_shape[:-2] + (out_channel, in_channel) | |||
| broad_bias_shape = x_shape[:-1] + (out_channel,) | |||
| weight_broadcast_to = P.BroadcastTo(broad_weight_shape) | |||
| bias_broadcast_to = P.BroadcastTo(broad_bias_shape) | |||
| if x_dim == 2: | |||
| matmul = P.MatMul(False, True) | |||
| bias_add = P.BiasAdd() | |||
| else: | |||
| matmul = P.BatchMatMul(False, True) | |||
| bias_add = P.TensorAdd() | |||
| return matmul, bias_add, weight_broadcast_to, bias_broadcast_to | |||
| class Dense(Cell): | |||
| r""" | |||
| The dense connected layer. | |||
| @@ -206,6 +222,7 @@ class Dense(Cell): | |||
| self.in_channels = Validator.check_positive_int(in_channels) | |||
| self.out_channels = Validator.check_positive_int(out_channels) | |||
| self.has_bias = Validator.check_bool(has_bias) | |||
| self.shape_op = P.Shape() | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| @@ -219,28 +236,33 @@ class Dense(Cell): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| raise ValueError("Bias init shape error.") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| self.bias_add = P.BiasAdd() | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.activation = get_activation(activation) if isinstance(activation, str) else activation | |||
| if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | |||
| raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | |||
| self.activation_flag = self.activation is not None | |||
| def construct(self, x): | |||
| x = self.matmul(x, self.weight) | |||
| x_shape = self.shape_op(x) | |||
| x_dim = len(x_shape) | |||
| matmul, bias_add, weight_broadcast_to, bias_broadcast_to = matmul_bias_select(x_shape, self.in_channels, | |||
| self.out_channels) | |||
| weight = self.weight if x_dim == 2 else weight_broadcast_to(self.weight) | |||
| x = matmul(x, weight) | |||
| if self.has_bias: | |||
| x = self.bias_add(x, self.bias) | |||
| bias = self.bias if x_dim == 2 else bias_broadcast_to(self.bias) | |||
| x = bias_add(x, bias) | |||
| if self.activation_flag: | |||
| x = self.activation(x) | |||
| return x | |||
| def extend_repr(self): | |||
| s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) | |||
| if self.has_bias: | |||
| s += ', has_bias={}'.format(self.has_bias) | |||
| if self.activation_flag: | |||
| s += ', activation={}'.fomat(self.activation) | |||
| s += ', activation={}'.format(self.activation) | |||
| return s | |||
| @@ -21,10 +21,34 @@ from mindspore import Tensor | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| class DenseDevice(nn.Cell): | |||
| def __init__(self, | |||
| in_channels, | |||
| out_channels, | |||
| device_target='Ascend', | |||
| weight_init='normal', | |||
| bias_init='zeros'): | |||
| super(DenseDevice, self).__init__() | |||
| self.device_target = device_target | |||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||
| self.bias_add = P.BiasAdd() | |||
| self.matmul = P.MatMul(transpose_b=True) | |||
| self.matmul.add_prim_attr("primitive_target", self.device_target) | |||
| self.bias_add.add_prim_attr("primitive_target", self.device_target) | |||
| def construct(self, x): | |||
| x = self.matmul(x, self.weight) | |||
| x = self.bias_add(x, self.bias) | |||
| return x | |||
| class LeNet(nn.Cell): | |||
| def __init__(self): | |||
| super(LeNet, self).__init__() | |||
| @@ -35,15 +59,9 @@ class LeNet(nn.Cell): | |||
| self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | |||
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.reshape = P.Reshape() | |||
| self.fc1 = nn.Dense(400, 120) | |||
| self.fc1.matmul.add_prim_attr("primitive_target", "CPU") | |||
| self.fc1.bias_add.add_prim_attr("primitive_target", "CPU") | |||
| self.fc2 = nn.Dense(120, 84) | |||
| self.fc2.matmul.add_prim_attr("primitive_target", "CPU") | |||
| self.fc2.bias_add.add_prim_attr("primitive_target", "CPU") | |||
| self.fc3 = nn.Dense(84, 10) | |||
| self.fc3.matmul.add_prim_attr("primitive_target", "CPU") | |||
| self.fc3.bias_add.add_prim_attr("primitive_target", "CPU") | |||
| self.fc1 = DenseDevice(400, 120, device_target='CPU') | |||
| self.fc2 = DenseDevice(120, 84, device_target='CPU') | |||
| self.fc3 = DenseDevice(84, 10, device_target='CPU') | |||
| def construct(self, input_x): | |||
| output = self.conv1(input_x) | |||
| @@ -38,3 +38,11 @@ def test_net(): | |||
| output = net(Tensor(x)) | |||
| print(x) | |||
| print(output.asnumpy()) | |||
| def test_net_ND(): | |||
| x = np.random.randn(2, 332, 2048).astype(np.float32) | |||
| net = Net() | |||
| output = net(Tensor(x)) | |||
| print(x) | |||
| print(output.asnumpy()) | |||
| @@ -49,3 +49,10 @@ def test_net(): | |||
| net = Grad(Net()) | |||
| output = net(Tensor(x), Tensor(sens)) | |||
| print(output.asnumpy()) | |||
| def test_net_ND(): | |||
| x = np.random.randn(2, 32, 2048).astype(np.float32) | |||
| sens = np.random.randn(2, 32, 1001).astype(np.float32) | |||
| net = Grad(Net()) | |||
| output = net(Tensor(x), Tensor(sens)) | |||
| print(output.asnumpy()) | |||
| @@ -128,6 +128,47 @@ def test_dx(): | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dx_ND(): | |||
| x = np.array([[[0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4]], | |||
| [[0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4]] | |||
| ]).astype(np.float32) | |||
| dy = np.array([[[1, 1], | |||
| [1, 1], | |||
| [1, 1]], | |||
| [[1, 1], | |||
| [1, 1], | |||
| [1, 1]]]).astype(np.float32) | |||
| dx_expect = np.array([[[1.1, 1.8, 1.1, 1.1], | |||
| [1.1, 1.8, 1.1, 1.1], | |||
| [1.1, 1.8, 1.1, 1.1]], | |||
| [[1.1, 1.8, 1.1, 1.1], | |||
| [1.1, 1.8, 1.1, 1.1], | |||
| [1.1, 1.8, 1.1, 1.1]] | |||
| ]).astype(np.float32) | |||
| error = np.ones(shape=[2, 3, 4]) * 1.0e-6 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = GradData(DenseNet()) | |||
| dx = net(Tensor(x), Tensor(dy)) | |||
| diff = dx[0].asnumpy() - dx_expect | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = GradData(DenseNet()) | |||
| dx = net(Tensor(x), Tensor(dy)) | |||
| diff = dx[0].asnumpy() - dx_expect | |||
| assert np.all(diff < error) | |||
| assert np.all(-diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| @@ -165,6 +206,49 @@ def test_dw(): | |||
| assert np.all(-diff < db_error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_dw_ND(): | |||
| x = np.array([[[0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4]], | |||
| [[0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4], | |||
| [0.1, 0.2, 0.3, 0.4]]]).astype(np.float32) | |||
| dy = np.array([[[1, 1], | |||
| [1, 1], | |||
| [1, 1]], | |||
| [[1, 1], | |||
| [1, 1], | |||
| [1, 1]]]).astype(np.float32) | |||
| dw_expect = 2 * np.array([[0.3, 0.6, 0.9, 1.2], | |||
| [0.3, 0.6, 0.9, 1.2]]).astype(np.float32) | |||
| dw_error = np.ones(shape=[2, 4]) * 1.0e-6 | |||
| db_expect = 2 * np.array([3, 3]).astype(np.float32) | |||
| db_error = np.ones(shape=[2]) * 1.0e-6 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = GradWeight(DenseNet()) | |||
| dw, db = net(Tensor(x), Tensor(dy)) | |||
| diff = dw.asnumpy() - dw_expect | |||
| assert np.all(diff < dw_error) | |||
| assert np.all(-diff < dw_error) | |||
| diff = db.asnumpy() - db_expect | |||
| assert np.all(diff < db_error) | |||
| assert np.all(-diff < db_error) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| net = GradWeight(DenseNet()) | |||
| dw, db = net(Tensor(x), Tensor(dy)) | |||
| diff = dw.asnumpy() - dw_expect | |||
| assert np.all(diff < dw_error) | |||
| assert np.all(-diff < dw_error) | |||
| diff = db.asnumpy() - db_expect | |||
| assert np.all(diff < db_error) | |||
| assert np.all(-diff < db_error) | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||