You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CuSparse.py 7.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from __future__ import absolute_import
  2. import numpy as np
  3. import scipy.sparse
  4. from .Node import Op
  5. from .. import ndarray
  6. from .Transpose import transpose_op
  7. from ..gpu_links import CuSparse_Csrmv
  8. from ..gpu_links import CuSparse_Csrmm
  9. class CsrmvOp(Op):
  10. def __init__(self, node_A, node_B, trans=False, ctx=None):
  11. super().__init__(CsrmvOp, [node_A, node_B], ctx)
  12. self.csrmv_attr_trans = trans
  13. def compute(self, input_vals, output_val, stream_handle=None):
  14. if self.on_cpu:
  15. assert isinstance(input_vals[0], scipy.sparse.spmatrix)
  16. if self.csrmv_attr_trans is False:
  17. output_val[:] = input_vals[0].dot(input_vals[1].asnumpy())
  18. else:
  19. output_val[:] = input_vals[0].T.dot(input_vals[1].asnumpy())
  20. else:
  21. assert isinstance(input_vals[0], ndarray.ND_Sparse_Array)
  22. CuSparse_Csrmv(
  23. input_vals[0], self.csrmv_attr_trans,
  24. input_vals[1], output_val, stream_handle)
  25. # ND_Sparse_Array gradient not implemented
  26. def gradient(self, output_grad):
  27. if self.csrmv_attr_trans is False:
  28. # if Y=AB, then dA=dY B^T, dB=A^T dY
  29. # lhs_grad = matmul_op(
  30. # output_grad, self.inputs[1], trans_A=False, trans_B=True)
  31. rhs_grad = csrmv_op(
  32. self.inputs[0], output_grad, trans=True, ctx=self.raw_ctx)
  33. else:
  34. # if Y=A^T B, then dA=(dY B^T)^T=B dY^T, dB=A dY
  35. # lhs_grad = matmul_op(
  36. # self.inputs[1], output_grad, trans_A=False, trans_B=True)
  37. rhs_grad = csrmv_op(
  38. self.inputs[0], output_grad, trans=False, ctx=self.raw_ctx)
  39. return [None, rhs_grad]
  40. def infer_shape(self, input_shapes):
  41. assert len(input_shapes) == 2
  42. A = input_shapes[0]
  43. B = input_shapes[1]
  44. assert len(A) == 2 and len(B) == 1
  45. shape_A = A[0]
  46. shape_mid_1 = A[1]
  47. shape_mid_2 = B[0]
  48. if self.csrmv_attr_trans == True:
  49. shape_A = A[1]
  50. shape_mid_1 = A[0]
  51. assert shape_mid_1 == shape_mid_2
  52. return (shape_A, )
  53. class CsrmmOp(Op):
  54. def __init__(self, node_A, node_B, trans_A=False, trans_B=False, ctx=None):
  55. super().__init__(CsrmmOp, [node_A, node_B], ctx)
  56. self.csrmm_attr_trans_A = trans_A
  57. self.csrmm_attr_trans_B = trans_B
  58. def compute(self, input_vals, output_val, stream_handle=None):
  59. if self.on_cpu:
  60. assert isinstance(input_vals[0], scipy.sparse.spmatrix)
  61. if ((self.csrmm_attr_trans_A is False) and
  62. (self.csrmm_attr_trans_B is False)):
  63. output_val[:] = input_vals[0].dot(input_vals[1].asnumpy())
  64. elif ((self.csrmm_attr_trans_A is True) and
  65. (self.csrmm_attr_trans_B is False)):
  66. output_val[:] = input_vals[0].T.dot(input_vals[1].asnumpy())
  67. elif ((self.csrmm_attr_trans_A is False) and
  68. (self.csrmm_attr_trans_B is True)):
  69. output_val[:] = input_vals[0].dot(
  70. np.transpose(input_vals[1].asnumpy()))
  71. elif ((self.csrmm_attr_trans_A is True) and
  72. (self.csrmm_attr_trans_B is True)):
  73. output_val[:] = input_vals[0].T.dot(
  74. np.transpose(input_vals[1].asnumpy()))
  75. else:
  76. assert isinstance(input_vals[0], ndarray.ND_Sparse_Array)
  77. CuSparse_Csrmm(
  78. input_vals[0], self.csrmm_attr_trans_A,
  79. input_vals[1], self.csrmm_attr_trans_B,
  80. output_val, stream_handle)
  81. # ND_Sparse_Array gradient not implemented
  82. def gradient(self, output_grad):
  83. if ((self.csrmm_attr_trans_A is False) and
  84. (self.csrmm_attr_trans_B is False)):
  85. # if Y=AB, then dA=dY B^T, dB=A^T dY
  86. # lhs_grad = matmul_op(
  87. # output_grad, self.inputs[1], trans_A=False, trans_B=True)
  88. # Notice: cuSparse not support left trans right not trans
  89. rhs_grad = csrmm_op(
  90. self.inputs[0], output_grad, trans_A=True, trans_B=False, ctx=self.raw_ctx)
  91. elif ((self.csrmm_attr_trans_A is True) and
  92. (self.csrmm_attr_trans_B is False)):
  93. # if Y=A^T B, then dA=(dY B^T)^T=B dY^T, dB=A dY
  94. # lhs_grad = matmul_op(
  95. # self.inputs[1], output_grad, trans_A=False, trans_B=True)
  96. rhs_grad = csrmm_op(
  97. self.inputs[0], output_grad, trans_A=False, trans_B=False, ctx=self.raw_ctx)
  98. elif ((self.csrmm_attr_trans_A is False) and
  99. (self.csrmm_attr_trans_B is True)):
  100. # if Y=A B^T, then dA=dY B, dB=(A^T dY)^T=dY^T A
  101. # lhs_grad = matmul_op(
  102. # output_grad, self.inputs[1], trans_A=False, trans_B=False)
  103. # rhs_grad = matmul_op(
  104. # output_grad, self.inputs[0], trans_A=True, trans_B=False)
  105. # Notice: cuSparse not support left trans right not trans
  106. rhs_grad = transpose_op(csrmm_op(
  107. self.inputs[0], output_grad, trans_A=True, trans_B=False, ctx=self.raw_ctx))
  108. elif ((self.csrmm_attr_trans_A is True) and
  109. (self.csrmm_attr_trans_B is True)):
  110. # if Y=A^T B^T, then dA=(dY B)^T=B^T dY^T, dB=(A dY)^T=dY^T A^T
  111. # lhs_grad = matmul_op(
  112. # self.inputs[1], output_grad, trans_A=True, trans_B=True)
  113. # rhs_grad = matmul_op(
  114. # output_grad, self.inputs[0], trans_A=True, trans_B=True)
  115. rhs_grad = transpose_op(csrmm_op(
  116. self.inputs[0], output_grad, trans_A=False, trans_B=False, ctx=self.raw_ctx))
  117. # return [lhs_grad, rhs_grad]
  118. return [None, rhs_grad]
  119. def infer_shape(self, input_shapes):
  120. assert len(input_shapes) == 2
  121. A = input_shapes[0]
  122. B = input_shapes[1]
  123. assert len(A) == 2 and len(B) == 2
  124. shape_A = A[0]
  125. shape_B = B[1]
  126. shape_mid_1 = A[1]
  127. shape_mid_2 = B[0]
  128. if self.csrmm_attr_trans_A == True:
  129. shape_A = A[1]
  130. shape_mid_1 = A[0]
  131. if self.csrmm_attr_trans_B == True:
  132. shape_B = B[0]
  133. shape_mid_2 = B[1]
  134. assert shape_mid_1 == shape_mid_2
  135. return (shape_A, shape_B)
  136. def csrmv_op(node_A, node_B, trans=False, ctx=None):
  137. """Make a new instance of multiplication of a sparse matrix and a vector,
  138. and call the instance.
  139. Parameters:
  140. ----
  141. node_A : Node
  142. The left operand, a sparse matrix.
  143. node_B : Node
  144. The right operand, a vector.
  145. trans : Boolean
  146. Whether node_A to be transposed, default to be False.
  147. Returns:
  148. ----
  149. A new Node instance created by Op.
  150. """
  151. return CsrmvOp(node_A, node_B, trans, ctx=ctx)
  152. def csrmm_op(node_A, node_B, trans_A=False, trans_B=False, ctx=None):
  153. """Make a new instance of Sparse Matrix Multiplication and call the instance.
  154. Parameters:
  155. ----
  156. node_A : Node
  157. The left operand, a sparse matrix.
  158. node_B : Node
  159. The right operand, a dense matrix.
  160. trans_A : Boolean
  161. Whether node_A to be transposed, default to be False.
  162. trans_B : Boolean
  163. Whether node_B to be transposed, default to be False.
  164. Returns:
  165. ----
  166. A new Node instance created by Op.
  167. """
  168. return CsrmmOp(node_A, node_B, trans_A, trans_B, ctx=ctx)