You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiplyElewise.py 3.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. from .._base import DNNL_LIB
  4. from ..cpu_links import matrix_elementwise_multiply as\
  5. cpu_matrix_elementwise_multiply
  6. from ..cpu_links import matrix_elementwise_multiply_by_const as\
  7. cpu_matrix_elementwise_multiply_by_const
  8. from ..gpu_links import matrix_elementwise_multiply,\
  9. matrix_elementwise_multiply_by_const
  10. class MulOp(Op):
  11. def __init__(self, node_A, node_B, ctx=None):
  12. super().__init__(MulOp, [node_A, node_B], ctx)
  13. def compute(self, input_vals, output_val, stream_handle=None):
  14. if self.on_cpu:
  15. if DNNL_LIB['DnnlMatrixElementwiseMultiply'] and input_vals[0].shape == input_vals[1].shape:
  16. cpu_matrix_elementwise_multiply(
  17. input_vals[0], input_vals[1], output_val)
  18. elif DNNL_LIB['DnnlMatrixElementwiseMultiplyByConst'] and (input_vals[0].shape == (1,) or input_vals[1].shape == (1,)):
  19. if input_vals[1].shape == (1,):
  20. const_val = input_vals[1].asnumpy()[0]
  21. cpu_matrix_elementwise_multiply_by_const(
  22. input_vals[0], const_val, output_val)
  23. elif input_vals[0].shape == (1,):
  24. const_val = input_vals[0].asnumpy()[0]
  25. cpu_matrix_elementwise_multiply_by_const(
  26. input_vals[1], const_val, output_val)
  27. else:
  28. output_val[:] = input_vals[0].asnumpy() * \
  29. input_vals[1].asnumpy()
  30. else:
  31. if input_vals[0].shape == input_vals[1].shape:
  32. matrix_elementwise_multiply(
  33. input_vals[0], input_vals[1], output_val, stream_handle)
  34. else:
  35. if input_vals[1].shape == (1,):
  36. const_val = input_vals[1].asnumpy()[0]
  37. matrix_elementwise_multiply_by_const(
  38. input_vals[0], const_val, output_val, stream_handle)
  39. elif input_vals[0].shape == (1,):
  40. const_val = input_vals[0].asnumpy()[0]
  41. matrix_elementwise_multiply_by_const(
  42. input_vals[1], const_val, output_val, stream_handle)
  43. def gradient(self, output_grad):
  44. return [mul_op(self.inputs[1], output_grad, ctx=self.raw_ctx),
  45. mul_op(self.inputs[0], output_grad, ctx=self.raw_ctx)]
  46. def infer_shape(self, input_shapes):
  47. """Need to handle input_vals[0].shape != input_vals[1].shape"""
  48. assert len(input_shapes) == 2
  49. if input_shapes[0] == input_shapes[1]:
  50. output = input_shapes[0]
  51. else:
  52. if input_shapes[0] == (1,):
  53. output = input_shapes[1]
  54. elif input_shapes[1] == (1,):
  55. output = input_shapes[0]
  56. else:
  57. assert False, "can't do elementwise multiply between variables of different sizes."
  58. return output
  59. def mul_op(node_A, node_B, ctx=None):
  60. """Make a new instance of matrixs elementwise multiplication and call the instance.
  61. Parameters:
  62. ----
  63. node_a : Node
  64. The Node to be multiplied.
  65. node_b : Node
  66. Another Node to be multiplied.
  67. Returns:
  68. ----
  69. A new Node instance created by Op.
  70. """
  71. return MulOp(node_A, node_B, ctx=ctx)