from __future__ import absolute_import import numpy as np from .Node import Op from .._base import DNNL_LIB from ..gpu_links import matrix_multiply from ..cpu_links import matrix_multiply as cpu_matrix_multiply class MatMulOp(Op): def __init__(self, node_A, node_B, trans_A=False, trans_B=False, ctx=None): super().__init__(MatMulOp, [node_A, node_B], ctx) self.matmul_attr_trans_A = trans_A self.matmul_attr_trans_B = trans_B def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if DNNL_LIB['DnnlMatrixMultiply']: cpu_matrix_multiply( input_vals[0], self.matmul_attr_trans_A, input_vals[1], self.matmul_attr_trans_B, output_val) else: input_vals = [n.asnumpy() for n in input_vals] if ((self.matmul_attr_trans_A is False) and (self.matmul_attr_trans_B is False)): output_val[:] = np.matmul(input_vals[0], input_vals[1]) elif ((self.matmul_attr_trans_A is True) and (self.matmul_attr_trans_B is False)): output_val[:] = np.matmul( np.transpose(input_vals[0]), input_vals[1]) elif ((self.matmul_attr_trans_A is False) and (self.matmul_attr_trans_B is True)): output_val[:] = np.matmul( input_vals[0], np.transpose(input_vals[1])) elif ((self.matmul_attr_trans_A is True) and (self.matmul_attr_trans_B is True)): output_val[:] = np.matmul( np.transpose(input_vals[0]), np.transpose(input_vals[1])) else: matrix_multiply( input_vals[0], self.matmul_attr_trans_A, input_vals[1], self.matmul_attr_trans_B, output_val, stream_handle) def gradient(self, output_grad): if ((self.matmul_attr_trans_A is False) and (self.matmul_attr_trans_B is False)): # if Y=AB, then dA=dY B^T, dB=A^T dY lhs_grad = matmul_op( output_grad, self.inputs[1], trans_A=False, trans_B=True, ctx=self.raw_ctx) rhs_grad = matmul_op( self.inputs[0], output_grad, trans_A=True, trans_B=False, ctx=self.raw_ctx) elif ((self.matmul_attr_trans_A is True) and (self.matmul_attr_trans_B is False)): # if Y=A^T B, then dA=(dY B^T)^T=B dY^T, dB=A dY lhs_grad = matmul_op( self.inputs[1], output_grad, trans_A=False, trans_B=True, ctx=self.raw_ctx) rhs_grad = matmul_op( self.inputs[0], output_grad, trans_A=False, trans_B=False, ctx=self.raw_ctx) elif ((self.matmul_attr_trans_A is False) and (self.matmul_attr_trans_B is True)): # if Y=A B^T, then dA=dY B, dB=(A^T dY)^T=dY^T A lhs_grad = matmul_op( output_grad, self.inputs[1], trans_A=False, trans_B=False, ctx=self.raw_ctx) rhs_grad = matmul_op( output_grad, self.inputs[0], trans_A=True, trans_B=False, ctx=self.raw_ctx) elif ((self.matmul_attr_trans_A is True) and (self.matmul_attr_trans_B is True)): # if Y=A^T B^T, then dA=(dY B)^T=B^T dY^T, dB=(A dY)^T=dY^T A^T lhs_grad = matmul_op( self.inputs[1], output_grad, trans_A=True, trans_B=True, ctx=self.raw_ctx) rhs_grad = matmul_op( output_grad, self.inputs[0], trans_A=True, trans_B=True, ctx=self.raw_ctx) return [lhs_grad, rhs_grad] def infer_shape(self, input_shapes): assert len(input_shapes) == 2 A = input_shapes[0] B = input_shapes[1] shape_A = A[0] shape_B = B[1] if self.matmul_attr_trans_A == True: shape_A = A[1] if self.matmul_attr_trans_B == True: shape_B = B[0] return (shape_A, shape_B) def deduce_states(self, states, duplicates): def revert(x): return (x[1], x[0]) def gcd(x, y): return y if x % y == 0 else gcd(y, x % y) if states[0] is None and states[1] is None: return None, min(duplicates) if states[0] is None: states[0] = (1, 1) if states[1] is None: states[1] = (1, 1) assert len(states[0]) == 2 and len(states[1]) == 2 assert np.prod(states[0]) * \ duplicates[0] == np.prod(states[1]) * duplicates[1] if self.matmul_attr_trans_A: states[0] = revert(states[0]) if self.matmul_attr_trans_B: states[1] = revert(states[1]) assert states[0][1] == states[1][0], 'Partition number of left matrix column shoule match that of right matrix row.' return (states[0][0], states[1][1]), gcd(max(duplicates), min(duplicates)) * states[0][1] def matmul_op(node_A, node_B, trans_A=False, trans_B=False, ctx=None): """Make a new instance of Matrix Multiplication and call the instance. Parameters: ---- node_A : Node The left operand of the matrix multiplication. node_B : Node The right operand of the matrix multiplication. trans_A : Boolean Whether node_A to be transposed trans_B : Boolean Whether node_B to be transposed Returns: ---- A new Node instance created by Op. """ return MatMulOp(node_A, node_B, trans_A, trans_B, ctx=ctx)