Browse Source

Fix concate problem

tags/v1.3.0
l00591931 4 years ago
parent
commit
d25e3497f7
4 changed files with 85 additions and 10 deletions
  1. +19
    -8
      mindspore/ops/_grad/grad_array_ops.py
  2. +1
    -1
      mindspore/ops/bprop_mindir/Identity_bprop.mindir
  3. +1
    -1
      mindspore/ops/bprop_mindir/ReLU_bprop.mindir
  4. +64
    -0
      tests/st/ops/cpu/test_concat_op.py

+ 19
- 8
mindspore/ops/_grad/grad_array_ops.py View File

@@ -323,19 +323,30 @@ def get_bprop_concat(self):
axis = self.axis

def bprop(x, out, dout):
dx = ()
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
input_nums = F.tuple_len(x)
out_offset = G.ConcatOffset(len(x), axis)(x)
input_nums = len(x)
input_shapes = ()
for i in range(input_nums):
input_shapes = input_shapes + (shape_op(x[i]),)
is_uniform = _concat_grad_uniform(input_shapes, input_nums)
if is_uniform:
dx = P.Split(axis, input_nums)(dout)
if isinstance(x, list):
dx = []
if is_uniform:
dx_tuple = P.Split(axis, input_nums)(dout)
for _, i in enumerate(dx_tuple):
dx = dx + [i,]
else:
for i in range(input_nums):
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
dx = dx + [slice_out,]
else:
for i in range(input_nums):
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
dx = dx + (slice_out,)
dx = ()
if is_uniform:
dx = P.Split(axis, input_nums)(dout)
else:
for i in range(input_nums):
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
dx = dx + (slice_out,)
return (dx,)

return bprop


+ 1
- 1
mindspore/ops/bprop_mindir/Identity_bprop.mindir View File

@@ -6,4 +6,4 @@
bprop.10:x*
bprop.10:out*
bprop.10:dout2
bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

+ 1
- 1
mindspore/ops/bprop_mindir/ReLU_bprop.mindir View File

@@ -8,4 +8,4 @@
bprop.2:x*
bprop.2:out*
bprop.2:dout2
bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260
bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a2102a58399653345b09bd6f5b337c4b81c4f8900664c0abc09fb80f38f8e95be82366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b224c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb365c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca0593a639478ea8dfad17fdbe39f66855cc459eb58bcaf5eac44185e03b16374a6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260

+ 64
- 0
tests/st/ops/cpu/test_concat_op.py View File

@@ -18,6 +18,7 @@ import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.context as context

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@@ -65,6 +66,7 @@ def test_axis10_int32():
def test_axis10_bool():
axis10(np.bool)


class ConcatV32(nn.Cell):
def __init__(self, nptype):
super(ConcatV32, self).__init__()
@@ -106,6 +108,68 @@ def test_axis32_bool():
axis32(np.bool)


class ConcatWithList(nn.Cell):
def __init__(self):
super(ConcatWithList, self).__init__()
self.concat = P.Concat(axis=2)

def construct(self, x, y):
input_list = [x, y]
return self.concat(input_list)


class ConcatWithTuple(nn.Cell):
def __init__(self):
super(ConcatWithTuple, self).__init__()
self.concat = P.Concat(axis=2)

def construct(self, x, y):
input_list = (x, y)
return self.concat(input_list)


class GradConcat(nn.Cell):
def __init__(self, network):
super(GradConcat, self).__init__()
self.grad = ops.GradOperation()
self.network = network

def construct(self, x, y):
gout = self.grad(self.network)(x, y)
return gout

@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_concat_list_grad():
x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(np.float32))
x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(np.float32))
concat = ConcatWithList()
output = GradConcat(concat)(x1, x2)
expect = np.array([[[1.],
[1.]],
[[1.],
[1.]]]).astype(np.float32)
print(output)
assert (output.asnumpy() == expect).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_concat_tuple_grad():
x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(np.float32))
x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(np.float32))
concat = ConcatWithTuple()
output = GradConcat(concat)(x1, x2)
expect = np.array([[[1.],
[1.]],
[[1.],
[1.]]]).astype(np.float32)
print(output)
assert (output.asnumpy() == expect).all()


class ConcatV43(nn.Cell):
def __init__(self, nptype):
super(ConcatV43, self).__init__()


Loading…
Cancel
Save