from __future__ import absolute_import from .Node import Op import numpy as np from .._base import DNNL_LIB from ..cpu_links import concat as cpu_concat from ..cpu_links import concat_gradient as cpu_concat_gradient from ..gpu_links import concat from ..gpu_links import concat_gradient class ConcatOp(Op): def __init__(self, node_A, node_B, axis=0, ctx=None): super().__init__(ConcatOp, [node_A, node_B], ctx) self.axis = axis def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if DNNL_LIB['DnnlConcat']: cpu_concat(input_vals[0], input_vals[1], output_val, self.axis) else: output_val[:] = np.concatenate( (input_vals[0].asnumpy(), input_vals[1].asnumpy()), self.axis) else: concat(input_vals[0], input_vals[1], output_val, self.axis, stream_handle) def gradient(self, output_grad): return [concat_gradient_op(output_grad, self.inputs[0], self.axis, idx=0, ctx=self.raw_ctx), concat_gradient_op(output_grad, self.inputs[1], self.axis, idx=1, ctx=self.raw_ctx)] def infer_shape(self, input_shapes): assert len(input_shapes) == 2 assert len(input_shapes[0]) == len(input_shapes[1]) for i in range(self.axis): assert input_shapes[0][i] == input_shapes[1][i] for i in range(self.axis+1, len(input_shapes[0])): assert input_shapes[0][i] == input_shapes[1][i] out_shape = list(input_shapes[0]) out_shape[self.axis] = out_shape[self.axis] + \ input_shapes[1][self.axis] return tuple(out_shape) class Concat_gradientOP(Op): def __init__(self, grad_node, input_node, axis, idx, ctx=None): super().__init__(Concat_gradientOP, [grad_node, input_node], ctx) self.axis = axis self.idx = idx def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if 'cpu_Concat_Gradient' in DNNL_LIB: cpu_concat_gradient( input_vals[0], output_val, self.axis, self.idx) else: output_val[:] = concat_backward( input_vals[0].asnumpy(), self.idx, self.axis) else: concat_gradient(input_vals[0], output_val, self.axis, self.idx, stream_handle) def gradient(self, output_grad): raise NotImplementedError def infer_shape(self, input_shapes): assert len(input_shapes) == 2 return input_shapes[1] def concat_op(node_A, node_B, axis=0, ctx=None): """Concatenates given variables along an axis. Parameters: ---- node_A : Node The first node to be concated. node_B : Node The second node to be concated. axis : The axis along which two nodes are concated. Returns: ---- A new Node instance created by Op. """ return ConcatOp(node_A, node_B, axis, ctx=ctx) def concat_gradient_op(grad_node, input_node, axis, idx, ctx=None): """Gradient node of concat operation. Parameters: ---- grad_node : Node Previous gradient node. input_node : Node axis : Axis along which to be concatenated. idx : The index of concatenation. Returns: ---- A new Node instance created by Op. """ return Concat_gradientOP(grad_node, input_node, axis, idx, ctx=ctx) def concat_backward(grad, idx, axis=0): if axis == 0: gradient_x1 = grad[:idx] gradient_x2 = grad[idx:] elif axis == 1: gradient_x1 = grad[:, :idx] gradient_x2 = grad[:, idx:] elif axis == 2: gradient_x1 = grad[:, :, :idx] gradient_x2 = grad[:, :, idx:] else: gradient_x1 = grad[:, :, :, :idx] gradient_x2 = grad[:, :, :, idx:] return [gradient_x1, gradient_x2]