You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BatchMatrixMult.py 4.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import batch_matrix_multiply
  5. class BatchMatMulOp(Op):
  6. def __init__(self, node_A, node_B, trans_A=False, trans_B=False, ctx=None):
  7. super().__init__(BatchMatMulOp, [node_A, node_B], ctx)
  8. self.matmul_attr_trans_A = trans_A
  9. self.matmul_attr_trans_B = trans_B
  10. def compute(self, input_vals, output_val, stream_handle=None):
  11. if self.on_cpu:
  12. ndims = len(input_vals[0])
  13. perm = list(range(ndims-2)) + [ndims-1, ndims-2]
  14. if ((self.matmul_attr_trans_A is False) and
  15. (self.matmul_attr_trans_B is False)):
  16. output_val[:] = np.matmul(
  17. input_vals[0].asnumpy(), input_vals[1].asnumpy())
  18. elif ((self.matmul_attr_trans_A is True) and
  19. (self.matmul_attr_trans_B is False)):
  20. output_val[:] = np.matmul(
  21. np.transpose(input_vals[0].asnumpy(), perm), input_vals[1].asnumpy())
  22. elif ((self.matmul_attr_trans_A is False) and
  23. (self.matmul_attr_trans_B is True)):
  24. output_val[:] = np.matmul(
  25. input_vals[0].asnumpy(), np.transpose(input_vals[1].asnumpy(), perm))
  26. elif ((self.matmul_attr_trans_A is True) and
  27. (self.matmul_attr_trans_B is True)):
  28. output_val[:] = np.matmul(
  29. np.transpose(input_vals[0].asnumpy(), perm), np.transpose(input_vals[1].asnumpy(), perm))
  30. else:
  31. batch_matrix_multiply(
  32. input_vals[0], self.matmul_attr_trans_A,
  33. input_vals[1], self.matmul_attr_trans_B,
  34. output_val, stream_handle)
  35. def gradient(self, output_grad):
  36. if ((self.matmul_attr_trans_A is False) and
  37. (self.matmul_attr_trans_B is False)):
  38. # if Y=AB, then dA=dY B^T, dB=A^T dY
  39. lhs_grad = batch_matmul_op(
  40. output_grad, self.inputs[1], trans_A=False, trans_B=True, ctx=self.raw_ctx)
  41. rhs_grad = batch_matmul_op(
  42. self.inputs[0], output_grad, trans_A=True, trans_B=False, ctx=self.raw_ctx)
  43. elif ((self.matmul_attr_trans_A is True) and
  44. (self.matmul_attr_trans_B is False)):
  45. # if Y=A^T B, then dA=(dY B^T)^T=B dY^T, dB=A dY
  46. lhs_grad = batch_matmul_op(
  47. self.inputs[1], output_grad, trans_A=False, trans_B=True, ctx=self.raw_ctx)
  48. rhs_grad = batch_matmul_op(
  49. self.inputs[0], output_grad, trans_A=False, trans_B=False, ctx=self.raw_ctx)
  50. elif ((self.matmul_attr_trans_A is False) and
  51. (self.matmul_attr_trans_B is True)):
  52. # if Y=A B^T, then dA=dY B, dB=(A^T dY)^T=dY^T A
  53. lhs_grad = batch_matmul_op(
  54. output_grad, self.inputs[1], trans_A=False, trans_B=False, ctx=self.raw_ctx)
  55. rhs_grad = batch_matmul_op(
  56. output_grad, self.inputs[0], trans_A=True, trans_B=False, ctx=self.raw_ctx)
  57. elif ((self.matmul_attr_trans_A is True) and
  58. (self.matmul_attr_trans_B is True)):
  59. # if Y=A^T B^T, then dA=(dY B)^T=B^T dY^T, dB=(A dY)^T=dY^T A^T
  60. lhs_grad = batch_matmul_op(
  61. self.inputs[1], output_grad, trans_A=True, trans_B=True, ctx=self.raw_ctx)
  62. rhs_grad = batch_matmul_op(
  63. output_grad, self.inputs[0], trans_A=True, trans_B=True, ctx=self.raw_ctx)
  64. return [lhs_grad, rhs_grad]
  65. def infer_shape(self, input_shapes):
  66. assert len(input_shapes) == 2
  67. A = input_shapes[0]
  68. B = input_shapes[1]
  69. assert len(A) == len(B)
  70. assert len(A) >= 2
  71. for i in range(len(A)-2):
  72. assert A[i] == B[i]
  73. C = list(A)[:-2]
  74. shape_A = A[-2]
  75. shape_B = B[-1]
  76. k1 = A[-1]
  77. k2 = B[-2]
  78. if self.matmul_attr_trans_A == True:
  79. shape_A = A[-1]
  80. k1 = A[-2]
  81. if self.matmul_attr_trans_B == True:
  82. shape_B = B[-2]
  83. k2 = B[-1]
  84. assert k1 == k2
  85. C.extend([shape_A, shape_B])
  86. return tuple(C)
  87. def batch_matmul_op(node_A, node_B, trans_A=False, trans_B=False, ctx=None):
  88. """Make a new instance of Batch Matrix Multiplication and call the instance.
  89. Parameters:
  90. ----
  91. node_A : Node
  92. The left operand of the matrix multiplication.
  93. node_B : Node
  94. The right operand of the matrix multiplication.
  95. trans_A : Boolean
  96. Whether node_A to be transposed
  97. trans_B : Boolean
  98. Whether node_B to be transposed
  99. Returns:
  100. ----
  101. A new Node instance created by Op.
  102. """
  103. return BatchMatMulOp(node_A, node_B, trans_A, trans_B, ctx=ctx)