| @@ -153,6 +153,22 @@ class Flatten(Cell): | |||||
| return F.reshape(x, (F.shape(x)[0], -1)) | return F.reshape(x, (F.shape(x)[0], -1)) | ||||
| def matmul_bias_select(x_shape, in_channel, out_channel): | |||||
| """matmul and bias_add selection for different input""" | |||||
| x_dim = len(x_shape) | |||||
| broad_weight_shape = x_shape[:-2] + (out_channel, in_channel) | |||||
| broad_bias_shape = x_shape[:-1] + (out_channel,) | |||||
| weight_broadcast_to = P.BroadcastTo(broad_weight_shape) | |||||
| bias_broadcast_to = P.BroadcastTo(broad_bias_shape) | |||||
| if x_dim == 2: | |||||
| matmul = P.MatMul(False, True) | |||||
| bias_add = P.BiasAdd() | |||||
| else: | |||||
| matmul = P.BatchMatMul(False, True) | |||||
| bias_add = P.TensorAdd() | |||||
| return matmul, bias_add, weight_broadcast_to, bias_broadcast_to | |||||
| class Dense(Cell): | class Dense(Cell): | ||||
| r""" | r""" | ||||
| The dense connected layer. | The dense connected layer. | ||||
| @@ -206,6 +222,7 @@ class Dense(Cell): | |||||
| self.in_channels = Validator.check_positive_int(in_channels) | self.in_channels = Validator.check_positive_int(in_channels) | ||||
| self.out_channels = Validator.check_positive_int(out_channels) | self.out_channels = Validator.check_positive_int(out_channels) | ||||
| self.has_bias = Validator.check_bool(has_bias) | self.has_bias = Validator.check_bool(has_bias) | ||||
| self.shape_op = P.Shape() | |||||
| if isinstance(weight_init, Tensor): | if isinstance(weight_init, Tensor): | ||||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | ||||
| @@ -219,28 +236,33 @@ class Dense(Cell): | |||||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | ||||
| raise ValueError("Bias init shape error.") | raise ValueError("Bias init shape error.") | ||||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | ||||
| self.bias_add = P.BiasAdd() | |||||
| self.matmul = P.MatMul(transpose_b=True) | |||||
| self.activation = get_activation(activation) if isinstance(activation, str) else activation | self.activation = get_activation(activation) if isinstance(activation, str) else activation | ||||
| if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | if activation is not None and not isinstance(self.activation, (Cell, Primitive)): | ||||
| raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation)) | ||||
| self.activation_flag = self.activation is not None | self.activation_flag = self.activation is not None | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.matmul(x, self.weight) | |||||
| x_shape = self.shape_op(x) | |||||
| x_dim = len(x_shape) | |||||
| matmul, bias_add, weight_broadcast_to, bias_broadcast_to = matmul_bias_select(x_shape, self.in_channels, | |||||
| self.out_channels) | |||||
| weight = self.weight if x_dim == 2 else weight_broadcast_to(self.weight) | |||||
| x = matmul(x, weight) | |||||
| if self.has_bias: | if self.has_bias: | ||||
| x = self.bias_add(x, self.bias) | |||||
| bias = self.bias if x_dim == 2 else bias_broadcast_to(self.bias) | |||||
| x = bias_add(x, bias) | |||||
| if self.activation_flag: | if self.activation_flag: | ||||
| x = self.activation(x) | x = self.activation(x) | ||||
| return x | return x | ||||
| def extend_repr(self): | def extend_repr(self): | ||||
| s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) | s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) | ||||
| if self.has_bias: | if self.has_bias: | ||||
| s += ', has_bias={}'.format(self.has_bias) | s += ', has_bias={}'.format(self.has_bias) | ||||
| if self.activation_flag: | if self.activation_flag: | ||||
| s += ', activation={}'.fomat(self.activation) | |||||
| s += ', activation={}'.format(self.activation) | |||||
| return s | return s | ||||
| @@ -21,10 +21,34 @@ from mindspore import Tensor | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | from mindspore.nn import TrainOneStepCell, WithLossCell | ||||
| from mindspore.nn.optim import Momentum | from mindspore.nn.optim import Momentum | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.common.parameter import Parameter | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| class DenseDevice(nn.Cell): | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| device_target='Ascend', | |||||
| weight_init='normal', | |||||
| bias_init='zeros'): | |||||
| super(DenseDevice, self).__init__() | |||||
| self.device_target = device_target | |||||
| self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight") | |||||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias") | |||||
| self.bias_add = P.BiasAdd() | |||||
| self.matmul = P.MatMul(transpose_b=True) | |||||
| self.matmul.add_prim_attr("primitive_target", self.device_target) | |||||
| self.bias_add.add_prim_attr("primitive_target", self.device_target) | |||||
| def construct(self, x): | |||||
| x = self.matmul(x, self.weight) | |||||
| x = self.bias_add(x, self.bias) | |||||
| return x | |||||
| class LeNet(nn.Cell): | class LeNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(LeNet, self).__init__() | super(LeNet, self).__init__() | ||||
| @@ -35,15 +59,9 @@ class LeNet(nn.Cell): | |||||
| self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid') | ||||
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.fc1 = nn.Dense(400, 120) | |||||
| self.fc1.matmul.add_prim_attr("primitive_target", "CPU") | |||||
| self.fc1.bias_add.add_prim_attr("primitive_target", "CPU") | |||||
| self.fc2 = nn.Dense(120, 84) | |||||
| self.fc2.matmul.add_prim_attr("primitive_target", "CPU") | |||||
| self.fc2.bias_add.add_prim_attr("primitive_target", "CPU") | |||||
| self.fc3 = nn.Dense(84, 10) | |||||
| self.fc3.matmul.add_prim_attr("primitive_target", "CPU") | |||||
| self.fc3.bias_add.add_prim_attr("primitive_target", "CPU") | |||||
| self.fc1 = DenseDevice(400, 120, device_target='CPU') | |||||
| self.fc2 = DenseDevice(120, 84, device_target='CPU') | |||||
| self.fc3 = DenseDevice(84, 10, device_target='CPU') | |||||
| def construct(self, input_x): | def construct(self, input_x): | ||||
| output = self.conv1(input_x) | output = self.conv1(input_x) | ||||
| @@ -38,3 +38,11 @@ def test_net(): | |||||
| output = net(Tensor(x)) | output = net(Tensor(x)) | ||||
| print(x) | print(x) | ||||
| print(output.asnumpy()) | print(output.asnumpy()) | ||||
| def test_net_ND(): | |||||
| x = np.random.randn(2, 332, 2048).astype(np.float32) | |||||
| net = Net() | |||||
| output = net(Tensor(x)) | |||||
| print(x) | |||||
| print(output.asnumpy()) | |||||
| @@ -49,3 +49,10 @@ def test_net(): | |||||
| net = Grad(Net()) | net = Grad(Net()) | ||||
| output = net(Tensor(x), Tensor(sens)) | output = net(Tensor(x), Tensor(sens)) | ||||
| print(output.asnumpy()) | print(output.asnumpy()) | ||||
| def test_net_ND(): | |||||
| x = np.random.randn(2, 32, 2048).astype(np.float32) | |||||
| sens = np.random.randn(2, 32, 1001).astype(np.float32) | |||||
| net = Grad(Net()) | |||||
| output = net(Tensor(x), Tensor(sens)) | |||||
| print(output.asnumpy()) | |||||
| @@ -128,6 +128,47 @@ def test_dx(): | |||||
| assert np.all(-diff < error) | assert np.all(-diff < error) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dx_ND(): | |||||
| x = np.array([[[0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4]], | |||||
| [[0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4]] | |||||
| ]).astype(np.float32) | |||||
| dy = np.array([[[1, 1], | |||||
| [1, 1], | |||||
| [1, 1]], | |||||
| [[1, 1], | |||||
| [1, 1], | |||||
| [1, 1]]]).astype(np.float32) | |||||
| dx_expect = np.array([[[1.1, 1.8, 1.1, 1.1], | |||||
| [1.1, 1.8, 1.1, 1.1], | |||||
| [1.1, 1.8, 1.1, 1.1]], | |||||
| [[1.1, 1.8, 1.1, 1.1], | |||||
| [1.1, 1.8, 1.1, 1.1], | |||||
| [1.1, 1.8, 1.1, 1.1]] | |||||
| ]).astype(np.float32) | |||||
| error = np.ones(shape=[2, 3, 4]) * 1.0e-6 | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| net = GradData(DenseNet()) | |||||
| dx = net(Tensor(x), Tensor(dy)) | |||||
| diff = dx[0].asnumpy() - dx_expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| net = GradData(DenseNet()) | |||||
| dx = net(Tensor(x), Tensor(dy)) | |||||
| diff = dx[0].asnumpy() - dx_expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -165,6 +206,49 @@ def test_dw(): | |||||
| assert np.all(-diff < db_error) | assert np.all(-diff < db_error) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_dw_ND(): | |||||
| x = np.array([[[0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4]], | |||||
| [[0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4], | |||||
| [0.1, 0.2, 0.3, 0.4]]]).astype(np.float32) | |||||
| dy = np.array([[[1, 1], | |||||
| [1, 1], | |||||
| [1, 1]], | |||||
| [[1, 1], | |||||
| [1, 1], | |||||
| [1, 1]]]).astype(np.float32) | |||||
| dw_expect = 2 * np.array([[0.3, 0.6, 0.9, 1.2], | |||||
| [0.3, 0.6, 0.9, 1.2]]).astype(np.float32) | |||||
| dw_error = np.ones(shape=[2, 4]) * 1.0e-6 | |||||
| db_expect = 2 * np.array([3, 3]).astype(np.float32) | |||||
| db_error = np.ones(shape=[2]) * 1.0e-6 | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| net = GradWeight(DenseNet()) | |||||
| dw, db = net(Tensor(x), Tensor(dy)) | |||||
| diff = dw.asnumpy() - dw_expect | |||||
| assert np.all(diff < dw_error) | |||||
| assert np.all(-diff < dw_error) | |||||
| diff = db.asnumpy() - db_expect | |||||
| assert np.all(diff < db_error) | |||||
| assert np.all(-diff < db_error) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| net = GradWeight(DenseNet()) | |||||
| dw, db = net(Tensor(x), Tensor(dy)) | |||||
| diff = dw.asnumpy() - dw_expect | |||||
| assert np.all(diff < dw_error) | |||||
| assert np.all(-diff < dw_error) | |||||
| diff = db.asnumpy() - db_expect | |||||
| assert np.all(diff < db_error) | |||||
| assert np.all(-diff < db_error) | |||||
| class Grad(nn.Cell): | class Grad(nn.Cell): | ||||
| def __init__(self, network): | def __init__(self, network): | ||||
| super(Grad, self).__init__() | super(Grad, self).__init__() | ||||