You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MatrixDot.py 1.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. import numpy as np
  4. from .MultiplyElewise import mul_op
  5. from .ReduceSum import reduce_sum_op
  6. from ..gpu_links import matrix_dot
  7. # TODO: This op may have bugs and is not complete!
  8. # Use other ops to replace it
  9. class MatrixDotOp(Op):
  10. def __init__(self, node_A, node_B, axes=0, ctx=None):
  11. super().__init__(MatrixDotOp, [node_A, node_B], ctx)
  12. self.axes = axes
  13. def compute(self, input_vals, output_val, stream_handle=None):
  14. if self.on_cpu:
  15. output_val[:] = np.tensordot(
  16. input_vals[0], input_vals[1], axes=self.axes)
  17. else:
  18. matrix_dot(input_vals[0], input_vals[1], output_val, stream_handle)
  19. def gradient(self, output_grad):
  20. return [matrix_dot_op(output_grad, self.inputs[1], axes=1, ctx=self.raw_ctx),
  21. reduce_sum_op(mul_op(self.inputs[0], output_grad, ctx=self.raw_ctx), axes=1, keepdims=True, ctx=self.raw_ctx)]
  22. def infer_shape(self, input_shapes):
  23. """Need to handle input_vals[0].shape != input_vals[1].shape"""
  24. return input_shapes[0]
  25. def matrix_dot_op(node_A, node_B, axes=0, ctx=None):
  26. """Make a new instance of matrixs elementwise multiplication and call the instance.
  27. Parameters:
  28. ----
  29. node_a : Node
  30. The Node to be multiplied.
  31. node_b : Node
  32. Another Node to be multiplied.
  33. Returns:
  34. ----
  35. A new Node instance created by Op.
  36. """
  37. return MatrixDotOp(node_A, node_B, ctx=ctx)