from __future__ import absolute_import import numpy as np from .Node import Op from .._base import DNNL_LIB from ..cpu_links import transpose as cpu_transpose from ..gpu_links import matrix_transpose_simple from .. import ndarray class TransposeOp(Op): def __init__(self, node_A, perm=None, ctx=None): super().__init__(TransposeOp, [node_A], ctx) self.perm = perm def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if DNNL_LIB['cpu_Transpose']: cpu_transpose(input_vals[0], output_val, self.perm) else: output_val[:] = np.transpose( input_vals[0].asnumpy(), self.perm) else: # matrix_transpose(input_vals[0], output_val, self.perm, stream_handle) matrix_transpose_simple( input_vals[0], output_val, self.gpu_buffer, stream_handle) def gradient(self, output_grad): if self.perm: grad_perm = [0 for _ in self.perm] for i in range(len(self.perm)): grad_perm[self.perm[i]] = i else: grad_perm = None return [transpose_op(output_grad, grad_perm, ctx=self.raw_ctx)] def infer_shape(self, input_shapes): assert len(input_shapes) == 1 # only support matrix transpose # assert len(input_shapes[0]) == 2 ori_shape = list(input_shapes[0]) if self.perm is None: self.perm = list(range(len(ori_shape))[::-1]) res_shape = ori_shape[::-1] else: assert len(self.perm) == len(ori_shape) and set( self.perm) == set(range(len(self.perm))) res_shape = [ori_shape[self.perm[i]] for i in range(len(ori_shape))] # here we save the information for GPU computation if self.on_gpu: ndim = len(ori_shape) buffer = [0 for _ in range(3 * ndim)] in_stride = 1 out_stride = 1 for i in range(ndim - 1, -1, -1): buffer[i] = in_stride buffer[ndim + i] = out_stride buffer[2 * ndim + i] = self.perm[i] in_stride *= ori_shape[i] out_stride *= res_shape[i] self.gpu_buffer = ndarray.array( buffer, self.ctx, data_type=np.uintc) return res_shape def transpose_op(node_A, perm=None, ctx=None): """Make a new instance of transpose and call the instance. Parameters: ---- node_A : Node Node to be transposed. Returns: ---- A new Node instance created by Op. """ return TransposeOp(node_A, perm, ctx=ctx)