You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AddConst.py 1.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. from .._base import DNNL_LIB
  4. from ..cpu_links import matrix_elementwise_add_by_const as cpu_matrix_elementwise_add_by_const
  5. from ..gpu_links import matrix_elementwise_add_by_const
  6. class AddByConstOp(Op):
  7. def __init__(self, node_A, const_val, ctx=None):
  8. super().__init__(AddByConstOp, [node_A], ctx)
  9. self.const_attr = const_val
  10. self.desc = self.name + '(%s, %s)' % (node_A.name, str(const_val))
  11. def compute(self, input_vals, output_val, stream_handle=None):
  12. if self.on_cpu:
  13. if DNNL_LIB['DnnlMatrixElementwiseAddByConst']:
  14. cpu_matrix_elementwise_add_by_const(
  15. input_vals[0], self.const_attr, output_val)
  16. else:
  17. output_val[:] = input_vals[0].asnumpy() + self.const_attr
  18. else:
  19. matrix_elementwise_add_by_const(
  20. input_vals[0], self.const_attr, output_val, stream_handle)
  21. def gradient(self, output_grad):
  22. return [output_grad]
  23. def infer_shape(self, input_shapes):
  24. assert len(input_shapes) == 1
  25. return input_shapes[0]
  26. def addbyconst_op(node, const_val, ctx=None):
  27. """Make a new instance of AddByConstOp and call the instance.
  28. Parameters:
  29. ----
  30. node : Node
  31. The Node to be added.
  32. const_val : scalar value
  33. The constant value to be added.
  34. Returns:
  35. ----
  36. A new Node instance created by Op.
  37. """
  38. return AddByConstOp(node, const_val, ctx=ctx)