| @@ -323,19 +323,30 @@ def get_bprop_concat(self): | |||
| axis = self.axis | |||
| def bprop(x, out, dout): | |||
| dx = () | |||
| out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x) | |||
| input_nums = F.tuple_len(x) | |||
| out_offset = G.ConcatOffset(len(x), axis)(x) | |||
| input_nums = len(x) | |||
| input_shapes = () | |||
| for i in range(input_nums): | |||
| input_shapes = input_shapes + (shape_op(x[i]),) | |||
| is_uniform = _concat_grad_uniform(input_shapes, input_nums) | |||
| if is_uniform: | |||
| dx = P.Split(axis, input_nums)(dout) | |||
| if isinstance(x, list): | |||
| dx = [] | |||
| if is_uniform: | |||
| dx_tuple = P.Split(axis, input_nums)(dout) | |||
| for _, i in enumerate(dx_tuple): | |||
| dx = dx + [i,] | |||
| else: | |||
| for i in range(input_nums): | |||
| slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) | |||
| dx = dx + [slice_out,] | |||
| else: | |||
| for i in range(input_nums): | |||
| slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) | |||
| dx = dx + (slice_out,) | |||
| dx = () | |||
| if is_uniform: | |||
| dx = P.Split(axis, input_nums)(dout) | |||
| else: | |||
| for i in range(input_nums): | |||
| slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) | |||
| dx = dx + (slice_out,) | |||
| return (dx,) | |||
| return bprop | |||
| @@ -6,4 +6,4 @@ | |||
| bprop.10:x* | |||
| bprop.10:out* | |||
| bprop.10:dout2 | |||
| bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -8,4 +8,4 @@ | |||
| bprop.2:x* | |||
| bprop.2:out* | |||
| bprop.2:dout2 | |||
| bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 | |||
| @@ -18,6 +18,7 @@ import numpy as np | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| import mindspore.nn as nn | |||
| import mindspore.ops as ops | |||
| import mindspore.context as context | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| @@ -65,6 +66,7 @@ def test_axis10_int32(): | |||
| def test_axis10_bool(): | |||
| axis10(np.bool) | |||
| class ConcatV32(nn.Cell): | |||
| def __init__(self, nptype): | |||
| super(ConcatV32, self).__init__() | |||
| @@ -106,6 +108,68 @@ def test_axis32_bool(): | |||
| axis32(np.bool) | |||
| class ConcatWithList(nn.Cell): | |||
| def __init__(self): | |||
| super(ConcatWithList, self).__init__() | |||
| self.concat = P.Concat(axis=2) | |||
| def construct(self, x, y): | |||
| input_list = [x, y] | |||
| return self.concat(input_list) | |||
| class ConcatWithTuple(nn.Cell): | |||
| def __init__(self): | |||
| super(ConcatWithTuple, self).__init__() | |||
| self.concat = P.Concat(axis=2) | |||
| def construct(self, x, y): | |||
| input_list = (x, y) | |||
| return self.concat(input_list) | |||
| class GradConcat(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradConcat, self).__init__() | |||
| self.grad = ops.GradOperation() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| gout = self.grad(self.network)(x, y) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_concat_list_grad(): | |||
| x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(np.float32)) | |||
| x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(np.float32)) | |||
| concat = ConcatWithList() | |||
| output = GradConcat(concat)(x1, x2) | |||
| expect = np.array([[[1.], | |||
| [1.]], | |||
| [[1.], | |||
| [1.]]]).astype(np.float32) | |||
| print(output) | |||
| assert (output.asnumpy() == expect).all() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_concat_tuple_grad(): | |||
| x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(np.float32)) | |||
| x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(np.float32)) | |||
| concat = ConcatWithTuple() | |||
| output = GradConcat(concat)(x1, x2) | |||
| expect = np.array([[[1.], | |||
| [1.]], | |||
| [[1.], | |||
| [1.]]]).astype(np.float32) | |||
| print(output) | |||
| assert (output.asnumpy() == expect).all() | |||
| class ConcatV43(nn.Cell): | |||
| def __init__(self, nptype): | |||
| super(ConcatV43, self).__init__() | |||