|
- from __future__ import absolute_import
- from .Node import Op
- import numpy as np
- from .MultiplyElewise import mul_op
- from .ReduceSum import reduce_sum_op
- from ..gpu_links import matrix_dot
-
- # TODO: This op may have bugs and is not complete!
- # Use other ops to replace it
-
-
- class MatrixDotOp(Op):
- def __init__(self, node_A, node_B, axes=0, ctx=None):
- super().__init__(MatrixDotOp, [node_A, node_B], ctx)
- self.axes = axes
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
- output_val[:] = np.tensordot(
- input_vals[0], input_vals[1], axes=self.axes)
- else:
- matrix_dot(input_vals[0], input_vals[1], output_val, stream_handle)
-
- def gradient(self, output_grad):
- return [matrix_dot_op(output_grad, self.inputs[1], axes=1, ctx=self.raw_ctx),
- reduce_sum_op(mul_op(self.inputs[0], output_grad, ctx=self.raw_ctx), axes=1, keepdims=True, ctx=self.raw_ctx)]
-
- def infer_shape(self, input_shapes):
- """Need to handle input_vals[0].shape != input_vals[1].shape"""
- return input_shapes[0]
-
-
- def matrix_dot_op(node_A, node_B, axes=0, ctx=None):
- """Make a new instance of matrixs elementwise multiplication and call the instance.
-
- Parameters:
- ----
- node_a : Node
- The Node to be multiplied.
- node_b : Node
- Another Node to be multiplied.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return MatrixDotOp(node_A, node_B, ctx=ctx)
|