|
- from __future__ import absolute_import
- import numpy as np
- import scipy.sparse
- from .Node import Op
- from .. import ndarray
- from .Transpose import transpose_op
- from ..gpu_links import CuSparse_Csrmv
- from ..gpu_links import CuSparse_Csrmm
-
-
- class CsrmvOp(Op):
- def __init__(self, node_A, node_B, trans=False, ctx=None):
- super().__init__(CsrmvOp, [node_A, node_B], ctx)
- self.csrmv_attr_trans = trans
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- assert isinstance(input_vals[0], scipy.sparse.spmatrix)
- if self.csrmv_attr_trans is False:
- output_val[:] = input_vals[0].dot(input_vals[1].asnumpy())
- else:
- output_val[:] = input_vals[0].T.dot(input_vals[1].asnumpy())
- else:
- assert isinstance(input_vals[0], ndarray.ND_Sparse_Array)
- CuSparse_Csrmv(
- input_vals[0], self.csrmv_attr_trans,
- input_vals[1], output_val, stream_handle)
-
- # ND_Sparse_Array gradient not implemented
- def gradient(self, output_grad):
- if self.csrmv_attr_trans is False:
- # if Y=AB, then dA=dY B^T, dB=A^T dY
- # lhs_grad = matmul_op(
- # output_grad, self.inputs[1], trans_A=False, trans_B=True)
- rhs_grad = csrmv_op(
- self.inputs[0], output_grad, trans=True, ctx=self.raw_ctx)
- else:
- # if Y=A^T B, then dA=(dY B^T)^T=B dY^T, dB=A dY
- # lhs_grad = matmul_op(
- # self.inputs[1], output_grad, trans_A=False, trans_B=True)
- rhs_grad = csrmv_op(
- self.inputs[0], output_grad, trans=False, ctx=self.raw_ctx)
- return [None, rhs_grad]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 2
- A = input_shapes[0]
- B = input_shapes[1]
- assert len(A) == 2 and len(B) == 1
- shape_A = A[0]
- shape_mid_1 = A[1]
- shape_mid_2 = B[0]
- if self.csrmv_attr_trans == True:
- shape_A = A[1]
- shape_mid_1 = A[0]
- assert shape_mid_1 == shape_mid_2
- return (shape_A, )
-
-
- class CsrmmOp(Op):
- def __init__(self, node_A, node_B, trans_A=False, trans_B=False, ctx=None):
- super().__init__(CsrmmOp, [node_A, node_B], ctx)
- self.csrmm_attr_trans_A = trans_A
- self.csrmm_attr_trans_B = trans_B
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- assert isinstance(input_vals[0], scipy.sparse.spmatrix)
- if ((self.csrmm_attr_trans_A is False) and
- (self.csrmm_attr_trans_B is False)):
- output_val[:] = input_vals[0].dot(input_vals[1].asnumpy())
- elif ((self.csrmm_attr_trans_A is True) and
- (self.csrmm_attr_trans_B is False)):
- output_val[:] = input_vals[0].T.dot(input_vals[1].asnumpy())
- elif ((self.csrmm_attr_trans_A is False) and
- (self.csrmm_attr_trans_B is True)):
- output_val[:] = input_vals[0].dot(
- np.transpose(input_vals[1].asnumpy()))
- elif ((self.csrmm_attr_trans_A is True) and
- (self.csrmm_attr_trans_B is True)):
- output_val[:] = input_vals[0].T.dot(
- np.transpose(input_vals[1].asnumpy()))
- else:
- assert isinstance(input_vals[0], ndarray.ND_Sparse_Array)
- CuSparse_Csrmm(
- input_vals[0], self.csrmm_attr_trans_A,
- input_vals[1], self.csrmm_attr_trans_B,
- output_val, stream_handle)
-
- # ND_Sparse_Array gradient not implemented
- def gradient(self, output_grad):
- if ((self.csrmm_attr_trans_A is False) and
- (self.csrmm_attr_trans_B is False)):
- # if Y=AB, then dA=dY B^T, dB=A^T dY
- # lhs_grad = matmul_op(
- # output_grad, self.inputs[1], trans_A=False, trans_B=True)
- # Notice: cuSparse not support left trans right not trans
- rhs_grad = csrmm_op(
- self.inputs[0], output_grad, trans_A=True, trans_B=False, ctx=self.raw_ctx)
- elif ((self.csrmm_attr_trans_A is True) and
- (self.csrmm_attr_trans_B is False)):
- # if Y=A^T B, then dA=(dY B^T)^T=B dY^T, dB=A dY
- # lhs_grad = matmul_op(
- # self.inputs[1], output_grad, trans_A=False, trans_B=True)
- rhs_grad = csrmm_op(
- self.inputs[0], output_grad, trans_A=False, trans_B=False, ctx=self.raw_ctx)
- elif ((self.csrmm_attr_trans_A is False) and
- (self.csrmm_attr_trans_B is True)):
- # if Y=A B^T, then dA=dY B, dB=(A^T dY)^T=dY^T A
- # lhs_grad = matmul_op(
- # output_grad, self.inputs[1], trans_A=False, trans_B=False)
- # rhs_grad = matmul_op(
- # output_grad, self.inputs[0], trans_A=True, trans_B=False)
- # Notice: cuSparse not support left trans right not trans
- rhs_grad = transpose_op(csrmm_op(
- self.inputs[0], output_grad, trans_A=True, trans_B=False, ctx=self.raw_ctx))
- elif ((self.csrmm_attr_trans_A is True) and
- (self.csrmm_attr_trans_B is True)):
- # if Y=A^T B^T, then dA=(dY B)^T=B^T dY^T, dB=(A dY)^T=dY^T A^T
- # lhs_grad = matmul_op(
- # self.inputs[1], output_grad, trans_A=True, trans_B=True)
- # rhs_grad = matmul_op(
- # output_grad, self.inputs[0], trans_A=True, trans_B=True)
- rhs_grad = transpose_op(csrmm_op(
- self.inputs[0], output_grad, trans_A=False, trans_B=False, ctx=self.raw_ctx))
- # return [lhs_grad, rhs_grad]
- return [None, rhs_grad]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 2
- A = input_shapes[0]
- B = input_shapes[1]
- assert len(A) == 2 and len(B) == 2
- shape_A = A[0]
- shape_B = B[1]
- shape_mid_1 = A[1]
- shape_mid_2 = B[0]
- if self.csrmm_attr_trans_A == True:
- shape_A = A[1]
- shape_mid_1 = A[0]
- if self.csrmm_attr_trans_B == True:
- shape_B = B[0]
- shape_mid_2 = B[1]
- assert shape_mid_1 == shape_mid_2
- return (shape_A, shape_B)
-
-
- def csrmv_op(node_A, node_B, trans=False, ctx=None):
- """Make a new instance of multiplication of a sparse matrix and a vector,
- and call the instance.
-
- Parameters:
- ----
- node_A : Node
- The left operand, a sparse matrix.
- node_B : Node
- The right operand, a vector.
- trans : Boolean
- Whether node_A to be transposed, default to be False.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return CsrmvOp(node_A, node_B, trans, ctx=ctx)
-
-
- def csrmm_op(node_A, node_B, trans_A=False, trans_B=False, ctx=None):
- """Make a new instance of Sparse Matrix Multiplication and call the instance.
-
- Parameters:
- ----
- node_A : Node
- The left operand, a sparse matrix.
- node_B : Node
- The right operand, a dense matrix.
- trans_A : Boolean
- Whether node_A to be transposed, default to be False.
- trans_B : Boolean
- Whether node_B to be transposed, default to be False.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return CsrmmOp(node_A, node_B, trans_A, trans_B, ctx=ctx)
|