|
|
|
@@ -220,19 +220,37 @@ def get_bprop_transpose(self): |
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def _concat_grad_uniform(input_shapes, input_nums): |
|
|
|
"""Helper function for bprop of Concat""" |
|
|
|
is_uniform = True |
|
|
|
for i in range(1, input_nums): |
|
|
|
if input_shapes[i-1] != input_shapes[i]: |
|
|
|
is_uniform = False |
|
|
|
break |
|
|
|
return is_uniform |
|
|
|
|
|
|
|
@bprop_getters.register(P.Concat) |
|
|
|
def get_bprop_concat(self): |
|
|
|
"""Generate bprop for Concat""" |
|
|
|
axis = self.axis |
|
|
|
is_ascend = context.get_context('device_target') == "Ascend" |
|
|
|
|
|
|
|
def bprop(x, out, dout): |
|
|
|
dx = () |
|
|
|
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x) |
|
|
|
for i in range(F.tuple_len(x)): |
|
|
|
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) |
|
|
|
dx = dx + (slice_out,) |
|
|
|
input_nums = F.tuple_len(x) |
|
|
|
input_shapes = () |
|
|
|
for i in range(input_nums): |
|
|
|
input_shapes = input_shapes + (shape_op(x[i]),) |
|
|
|
is_uniform = _concat_grad_uniform(input_shapes, input_nums) |
|
|
|
if is_uniform and is_ascend: |
|
|
|
dx = P.Split(axis, input_nums)(dout) |
|
|
|
else: |
|
|
|
for i in range(input_nums): |
|
|
|
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) |
|
|
|
dx = dx + (slice_out,) |
|
|
|
return (dx,) |
|
|
|
|
|
|
|
return bprop |
|
|
|
|
|
|
|
|
|
|
|
|