Browse Source

!2834 Enable Split in the bprop of Concat

Merge pull request !2834 from gziyan/add_uniform_split_for_concat
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
fc24096baf
1 changed files with 22 additions and 4 deletions
  1. +22
    -4
      mindspore/ops/_grad/grad_array_ops.py

+ 22
- 4
mindspore/ops/_grad/grad_array_ops.py View File

@@ -220,19 +220,37 @@ def get_bprop_transpose(self):
return bprop


@constexpr
def _concat_grad_uniform(input_shapes, input_nums):
"""Helper function for bprop of Concat"""
is_uniform = True
for i in range(1, input_nums):
if input_shapes[i-1] != input_shapes[i]:
is_uniform = False
break
return is_uniform

@bprop_getters.register(P.Concat)
def get_bprop_concat(self):
"""Generate bprop for Concat"""
axis = self.axis
is_ascend = context.get_context('device_target') == "Ascend"

def bprop(x, out, dout):
dx = ()
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
for i in range(F.tuple_len(x)):
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
dx = dx + (slice_out,)
input_nums = F.tuple_len(x)
input_shapes = ()
for i in range(input_nums):
input_shapes = input_shapes + (shape_op(x[i]),)
is_uniform = _concat_grad_uniform(input_shapes, input_nums)
if is_uniform and is_ascend:
dx = P.Split(axis, input_nums)(dout)
else:
for i in range(input_nums):
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
dx = dx + (slice_out,)
return (dx,)

return bprop




Loading…
Cancel
Save