You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MatrixMult.py 5.6 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from .._base import DNNL_LIB
  5. from ..gpu_links import matrix_multiply
  6. from ..cpu_links import matrix_multiply as cpu_matrix_multiply
  7. class MatMulOp(Op):
  8. def __init__(self, node_A, node_B, trans_A=False, trans_B=False, ctx=None):
  9. super().__init__(MatMulOp, [node_A, node_B], ctx)
  10. self.matmul_attr_trans_A = trans_A
  11. self.matmul_attr_trans_B = trans_B
  12. def compute(self, input_vals, output_val, stream_handle=None):
  13. if self.on_cpu:
  14. if DNNL_LIB['DnnlMatrixMultiply']:
  15. cpu_matrix_multiply(
  16. input_vals[0], self.matmul_attr_trans_A,
  17. input_vals[1], self.matmul_attr_trans_B,
  18. output_val)
  19. else:
  20. input_vals = [n.asnumpy() for n in input_vals]
  21. if ((self.matmul_attr_trans_A is False) and
  22. (self.matmul_attr_trans_B is False)):
  23. output_val[:] = np.matmul(input_vals[0], input_vals[1])
  24. elif ((self.matmul_attr_trans_A is True) and
  25. (self.matmul_attr_trans_B is False)):
  26. output_val[:] = np.matmul(
  27. np.transpose(input_vals[0]), input_vals[1])
  28. elif ((self.matmul_attr_trans_A is False) and
  29. (self.matmul_attr_trans_B is True)):
  30. output_val[:] = np.matmul(
  31. input_vals[0], np.transpose(input_vals[1]))
  32. elif ((self.matmul_attr_trans_A is True) and
  33. (self.matmul_attr_trans_B is True)):
  34. output_val[:] = np.matmul(
  35. np.transpose(input_vals[0]), np.transpose(input_vals[1]))
  36. else:
  37. matrix_multiply(
  38. input_vals[0], self.matmul_attr_trans_A,
  39. input_vals[1], self.matmul_attr_trans_B,
  40. output_val, stream_handle)
  41. def gradient(self, output_grad):
  42. if ((self.matmul_attr_trans_A is False) and
  43. (self.matmul_attr_trans_B is False)):
  44. # if Y=AB, then dA=dY B^T, dB=A^T dY
  45. lhs_grad = matmul_op(
  46. output_grad, self.inputs[1], trans_A=False, trans_B=True, ctx=self.raw_ctx)
  47. rhs_grad = matmul_op(
  48. self.inputs[0], output_grad, trans_A=True, trans_B=False, ctx=self.raw_ctx)
  49. elif ((self.matmul_attr_trans_A is True) and
  50. (self.matmul_attr_trans_B is False)):
  51. # if Y=A^T B, then dA=(dY B^T)^T=B dY^T, dB=A dY
  52. lhs_grad = matmul_op(
  53. self.inputs[1], output_grad, trans_A=False, trans_B=True, ctx=self.raw_ctx)
  54. rhs_grad = matmul_op(
  55. self.inputs[0], output_grad, trans_A=False, trans_B=False, ctx=self.raw_ctx)
  56. elif ((self.matmul_attr_trans_A is False) and
  57. (self.matmul_attr_trans_B is True)):
  58. # if Y=A B^T, then dA=dY B, dB=(A^T dY)^T=dY^T A
  59. lhs_grad = matmul_op(
  60. output_grad, self.inputs[1], trans_A=False, trans_B=False, ctx=self.raw_ctx)
  61. rhs_grad = matmul_op(
  62. output_grad, self.inputs[0], trans_A=True, trans_B=False, ctx=self.raw_ctx)
  63. elif ((self.matmul_attr_trans_A is True) and
  64. (self.matmul_attr_trans_B is True)):
  65. # if Y=A^T B^T, then dA=(dY B)^T=B^T dY^T, dB=(A dY)^T=dY^T A^T
  66. lhs_grad = matmul_op(
  67. self.inputs[1], output_grad, trans_A=True, trans_B=True, ctx=self.raw_ctx)
  68. rhs_grad = matmul_op(
  69. output_grad, self.inputs[0], trans_A=True, trans_B=True, ctx=self.raw_ctx)
  70. return [lhs_grad, rhs_grad]
  71. def infer_shape(self, input_shapes):
  72. assert len(input_shapes) == 2
  73. A = input_shapes[0]
  74. B = input_shapes[1]
  75. shape_A = A[0]
  76. shape_B = B[1]
  77. if self.matmul_attr_trans_A == True:
  78. shape_A = A[1]
  79. if self.matmul_attr_trans_B == True:
  80. shape_B = B[0]
  81. return (shape_A, shape_B)
  82. def deduce_states(self, states, duplicates):
  83. def revert(x):
  84. return (x[1], x[0])
  85. def gcd(x, y):
  86. return y if x % y == 0 else gcd(y, x % y)
  87. if states[0] is None and states[1] is None:
  88. return None, min(duplicates)
  89. if states[0] is None:
  90. states[0] = (1, 1)
  91. if states[1] is None:
  92. states[1] = (1, 1)
  93. assert len(states[0]) == 2 and len(states[1]) == 2
  94. assert np.prod(states[0]) * \
  95. duplicates[0] == np.prod(states[1]) * duplicates[1]
  96. if self.matmul_attr_trans_A:
  97. states[0] = revert(states[0])
  98. if self.matmul_attr_trans_B:
  99. states[1] = revert(states[1])
  100. assert states[0][1] == states[1][0], 'Partition number of left matrix column shoule match that of right matrix row.'
  101. return (states[0][0], states[1][1]), gcd(max(duplicates), min(duplicates)) * states[0][1]
  102. def matmul_op(node_A, node_B, trans_A=False, trans_B=False, ctx=None):
  103. """Make a new instance of Matrix Multiplication and call the instance.
  104. Parameters:
  105. ----
  106. node_A : Node
  107. The left operand of the matrix multiplication.
  108. node_B : Node
  109. The right operand of the matrix multiplication.
  110. trans_A : Boolean
  111. Whether node_A to be transposed
  112. trans_B : Boolean
  113. Whether node_B to be transposed
  114. Returns:
  115. ----
  116. A new Node instance created by Op.
  117. """
  118. return MatMulOp(node_A, node_B, trans_A, trans_B, ctx=ctx)