From 3c7a3b6693ecc07e8e30bcfa53b471a479f4004f Mon Sep 17 00:00:00 2001 From: wanyiming Date: Tue, 3 Nov 2020 17:34:21 +0800 Subject: [PATCH] modify_dense --- mindspore/nn/layer/basic.py | 32 +++++-- .../st/host_device/test_host_device_lenet.py | 36 ++++++-- tests/st/ops/ascend/test_dense.py | 8 ++ tests/st/ops/ascend/test_dense_grad.py | 7 ++ tests/st/ops/gpu/test_dense_op.py | 84 +++++++++++++++++++ 5 files changed, 153 insertions(+), 14 deletions(-) diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index dc17bec715..173b6ea387 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -153,6 +153,22 @@ class Flatten(Cell): return F.reshape(x, (F.shape(x)[0], -1)) +def matmul_bias_select(x_shape, in_channel, out_channel): + """matmul and bias_add selection for different input""" + x_dim = len(x_shape) + broad_weight_shape = x_shape[:-2] + (out_channel, in_channel) + broad_bias_shape = x_shape[:-1] + (out_channel,) + weight_broadcast_to = P.BroadcastTo(broad_weight_shape) + bias_broadcast_to = P.BroadcastTo(broad_bias_shape) + if x_dim == 2: + matmul = P.MatMul(False, True) + bias_add = P.BiasAdd() + else: + matmul = P.BatchMatMul(False, True) + bias_add = P.TensorAdd() + return matmul, bias_add, weight_broadcast_to, bias_broadcast_to + + class Dense(Cell): r""" The dense connected layer. @@ -206,6 +222,7 @@ class Dense(Cell): self.in_channels = Validator.check_positive_int(in_channels) self.out_channels = Validator.check_positive_int(out_channels) self.has_bias = Validator.check_bool(has_bias) + self.shape_op = P.Shape() if isinstance(weight_init, Tensor): if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ @@ -219,28 +236,33 @@ class Dense(Cell): if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: raise ValueError("Bias init shape error.") self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") - self.bias_add = P.BiasAdd() - self.matmul = P.MatMul(transpose_b=True) self.activation = get_activation(activation) if isinstance(activation, str) else activation if activation is not None and not isinstance(self.activation, (Cell, Primitive)): raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) self.activation_flag = self.activation is not None def construct(self, x): - x = self.matmul(x, self.weight) + x_shape = self.shape_op(x) + x_dim = len(x_shape) + matmul, bias_add, weight_broadcast_to, bias_broadcast_to = matmul_bias_select(x_shape, self.in_channels, + self.out_channels) + weight = self.weight if x_dim == 2 else weight_broadcast_to(self.weight) + x = matmul(x, weight) if self.has_bias: - x = self.bias_add(x, self.bias) + bias = self.bias if x_dim == 2 else bias_broadcast_to(self.bias) + x = bias_add(x, bias) if self.activation_flag: x = self.activation(x) return x + def extend_repr(self): s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) if self.has_bias: s += ', has_bias={}'.format(self.has_bias) if self.activation_flag: - s += ', activation={}'.fomat(self.activation) + s += ', activation={}'.format(self.activation) return s diff --git a/tests/st/host_device/test_host_device_lenet.py b/tests/st/host_device/test_host_device_lenet.py index 80bf7b578a..243f3facaf 100644 --- a/tests/st/host_device/test_host_device_lenet.py +++ b/tests/st/host_device/test_host_device_lenet.py @@ -21,10 +21,34 @@ from mindspore import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Momentum from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +class DenseDevice(nn.Cell): + def __init__(self, + in_channels, + out_channels, + device_target='Ascend', + weight_init='normal', + bias_init='zeros'): + super(DenseDevice, self).__init__() + self.device_target = device_target + self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") + self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") + self.bias_add = P.BiasAdd() + self.matmul = P.MatMul(transpose_b=True) + self.matmul.add_prim_attr("primitive_target", self.device_target) + self.bias_add.add_prim_attr("primitive_target", self.device_target) + + def construct(self, x): + x = self.matmul(x, self.weight) + x = self.bias_add(x, self.bias) + return x + + class LeNet(nn.Cell): def __init__(self): super(LeNet, self).__init__() @@ -35,15 +59,9 @@ class LeNet(nn.Cell): self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.reshape = P.Reshape() - self.fc1 = nn.Dense(400, 120) - self.fc1.matmul.add_prim_attr("primitive_target", "CPU") - self.fc1.bias_add.add_prim_attr("primitive_target", "CPU") - self.fc2 = nn.Dense(120, 84) - self.fc2.matmul.add_prim_attr("primitive_target", "CPU") - self.fc2.bias_add.add_prim_attr("primitive_target", "CPU") - self.fc3 = nn.Dense(84, 10) - self.fc3.matmul.add_prim_attr("primitive_target", "CPU") - self.fc3.bias_add.add_prim_attr("primitive_target", "CPU") + self.fc1 = DenseDevice(400, 120, device_target='CPU') + self.fc2 = DenseDevice(120, 84, device_target='CPU') + self.fc3 = DenseDevice(84, 10, device_target='CPU') def construct(self, input_x): output = self.conv1(input_x) diff --git a/tests/st/ops/ascend/test_dense.py b/tests/st/ops/ascend/test_dense.py index c4916d53cd..65b1eb400f 100644 --- a/tests/st/ops/ascend/test_dense.py +++ b/tests/st/ops/ascend/test_dense.py @@ -38,3 +38,11 @@ def test_net(): output = net(Tensor(x)) print(x) print(output.asnumpy()) + + +def test_net_ND(): + x = np.random.randn(2, 332, 2048).astype(np.float32) + net = Net() + output = net(Tensor(x)) + print(x) + print(output.asnumpy()) diff --git a/tests/st/ops/ascend/test_dense_grad.py b/tests/st/ops/ascend/test_dense_grad.py index 6cd6516da1..1b20078b3f 100644 --- a/tests/st/ops/ascend/test_dense_grad.py +++ b/tests/st/ops/ascend/test_dense_grad.py @@ -49,3 +49,10 @@ def test_net(): net = Grad(Net()) output = net(Tensor(x), Tensor(sens)) print(output.asnumpy()) + +def test_net_ND(): + x = np.random.randn(2, 32, 2048).astype(np.float32) + sens = np.random.randn(2, 32, 1001).astype(np.float32) + net = Grad(Net()) + output = net(Tensor(x), Tensor(sens)) + print(output.asnumpy()) diff --git a/tests/st/ops/gpu/test_dense_op.py b/tests/st/ops/gpu/test_dense_op.py index b07baa658b..4c03bfc417 100644 --- a/tests/st/ops/gpu/test_dense_op.py +++ b/tests/st/ops/gpu/test_dense_op.py @@ -128,6 +128,47 @@ def test_dx(): assert np.all(-diff < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dx_ND(): + x = np.array([[[0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4]], + [[0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4]] + ]).astype(np.float32) + dy = np.array([[[1, 1], + [1, 1], + [1, 1]], + [[1, 1], + [1, 1], + [1, 1]]]).astype(np.float32) + dx_expect = np.array([[[1.1, 1.8, 1.1, 1.1], + [1.1, 1.8, 1.1, 1.1], + [1.1, 1.8, 1.1, 1.1]], + [[1.1, 1.8, 1.1, 1.1], + [1.1, 1.8, 1.1, 1.1], + [1.1, 1.8, 1.1, 1.1]] + ]).astype(np.float32) + error = np.ones(shape=[2, 3, 4]) * 1.0e-6 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = GradData(DenseNet()) + dx = net(Tensor(x), Tensor(dy)) + diff = dx[0].asnumpy() - dx_expect + assert np.all(diff < error) + assert np.all(-diff < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = GradData(DenseNet()) + dx = net(Tensor(x), Tensor(dy)) + diff = dx[0].asnumpy() - dx_expect + assert np.all(diff < error) + assert np.all(-diff < error) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -165,6 +206,49 @@ def test_dw(): assert np.all(-diff < db_error) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dw_ND(): + x = np.array([[[0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4]], + [[0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4]]]).astype(np.float32) + dy = np.array([[[1, 1], + [1, 1], + [1, 1]], + [[1, 1], + [1, 1], + [1, 1]]]).astype(np.float32) + dw_expect = 2 * np.array([[0.3, 0.6, 0.9, 1.2], + [0.3, 0.6, 0.9, 1.2]]).astype(np.float32) + dw_error = np.ones(shape=[2, 4]) * 1.0e-6 + db_expect = 2 * np.array([3, 3]).astype(np.float32) + db_error = np.ones(shape=[2]) * 1.0e-6 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + net = GradWeight(DenseNet()) + dw, db = net(Tensor(x), Tensor(dy)) + diff = dw.asnumpy() - dw_expect + assert np.all(diff < dw_error) + assert np.all(-diff < dw_error) + diff = db.asnumpy() - db_expect + assert np.all(diff < db_error) + assert np.all(-diff < db_error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = GradWeight(DenseNet()) + dw, db = net(Tensor(x), Tensor(dy)) + diff = dw.asnumpy() - dw_expect + assert np.all(diff < dw_error) + assert np.all(-diff < dw_error) + diff = db.asnumpy() - db_expect + assert np.all(diff < db_error) + assert np.all(-diff < db_error) + + class Grad(nn.Cell): def __init__(self, network): super(Grad, self).__init__()