You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultiplyConst.py 1.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. from .._base import DNNL_LIB
  4. from ..cpu_links import matrix_elementwise_multiply_by_const as cpu_matrix_elementwise_multiply_by_const
  5. from ..gpu_links import matrix_elementwise_multiply_by_const
  6. class MulByConstOp(Op):
  7. def __init__(self, node_A, const_val, ctx=None):
  8. super().__init__(MulByConstOp, [node_A], ctx)
  9. self.const_attr = const_val
  10. self.desc = self.name + '(%s, %s)' % (node_A.name, str(const_val))
  11. def compute(self, input_vals, output_val, stream_handle=None):
  12. assert self.const_attr is not None
  13. if self.on_cpu:
  14. if DNNL_LIB['DnnlMatrixElementwiseMultiplyByConst']:
  15. cpu_matrix_elementwise_multiply_by_const(
  16. input_vals[0], self.const_attr, output_val)
  17. else:
  18. output_val[:] = input_vals[0].asnumpy() * self.const_attr
  19. else:
  20. matrix_elementwise_multiply_by_const(
  21. input_vals[0], self.const_attr, output_val, stream_handle)
  22. def gradient(self, output_grad):
  23. return [self.const_attr * output_grad]
  24. def infer_shape(self, input_shapes):
  25. assert len(input_shapes) == 1
  26. return input_shapes[0]
  27. def mul_byconst_op(node_A, const_val, ctx=None):
  28. """Make a new instance of MulByConstOp and call the instance.
  29. Parameters:
  30. ----
  31. node : Node
  32. The Node to be multiplied.
  33. const_val : scalar value
  34. The constant value to be mutiplied.
  35. Returns:
  36. ----
  37. A new Node instance created by Op.
  38. """
  39. return MulByConstOp(node_A, const_val, ctx=ctx)