You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Transpose.py 2.7 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from .._base import DNNL_LIB
  5. from ..cpu_links import transpose as cpu_transpose
  6. from ..gpu_links import matrix_transpose_simple
  7. from .. import ndarray
  8. class TransposeOp(Op):
  9. def __init__(self, node_A, perm=None, ctx=None):
  10. super().__init__(TransposeOp, [node_A], ctx)
  11. self.perm = perm
  12. def compute(self, input_vals, output_val, stream_handle=None):
  13. if self.on_cpu:
  14. if DNNL_LIB['cpu_Transpose']:
  15. cpu_transpose(input_vals[0], output_val, self.perm)
  16. else:
  17. output_val[:] = np.transpose(
  18. input_vals[0].asnumpy(), self.perm)
  19. else:
  20. # matrix_transpose(input_vals[0], output_val, self.perm, stream_handle)
  21. matrix_transpose_simple(
  22. input_vals[0], output_val, self.gpu_buffer, stream_handle)
  23. def gradient(self, output_grad):
  24. if self.perm:
  25. grad_perm = [0 for _ in self.perm]
  26. for i in range(len(self.perm)):
  27. grad_perm[self.perm[i]] = i
  28. else:
  29. grad_perm = None
  30. return [transpose_op(output_grad, grad_perm, ctx=self.raw_ctx)]
  31. def infer_shape(self, input_shapes):
  32. assert len(input_shapes) == 1
  33. # only support matrix transpose
  34. # assert len(input_shapes[0]) == 2
  35. ori_shape = list(input_shapes[0])
  36. if self.perm is None:
  37. self.perm = list(range(len(ori_shape))[::-1])
  38. res_shape = ori_shape[::-1]
  39. else:
  40. assert len(self.perm) == len(ori_shape) and set(
  41. self.perm) == set(range(len(self.perm)))
  42. res_shape = [ori_shape[self.perm[i]]
  43. for i in range(len(ori_shape))]
  44. # here we save the information for GPU computation
  45. if self.on_gpu:
  46. ndim = len(ori_shape)
  47. buffer = [0 for _ in range(3 * ndim)]
  48. in_stride = 1
  49. out_stride = 1
  50. for i in range(ndim - 1, -1, -1):
  51. buffer[i] = in_stride
  52. buffer[ndim + i] = out_stride
  53. buffer[2 * ndim + i] = self.perm[i]
  54. in_stride *= ori_shape[i]
  55. out_stride *= res_shape[i]
  56. self.gpu_buffer = ndarray.array(
  57. buffer, self.ctx, data_type=np.uintc)
  58. return res_shape
  59. def transpose_op(node_A, perm=None, ctx=None):
  60. """Make a new instance of transpose and call the instance.
  61. Parameters:
  62. ----
  63. node_A : Node
  64. Node to be transposed.
  65. Returns:
  66. ----
  67. A new Node instance created by Op.
  68. """
  69. return TransposeOp(node_A, perm, ctx=ctx)