You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ReduceSumAxisZero.py 1.6 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from .._base import DNNL_LIB
  5. from ..cpu_links import reduce_sum_axis_zero as cpu_reduce_sum_axis_zero
  6. from ..gpu_links import reduce_sum_axis_zero
  7. class ReduceSumAxisZeroOp(Op):
  8. def __init__(self, node_A, ctx=None):
  9. super().__init__(ReduceSumAxisZeroOp, [node_A], ctx)
  10. def compute(self, input_vals, output_val, stream_handle=None):
  11. if self.on_cpu:
  12. if DNNL_LIB['cpu_ReduceSumAxisZero']:
  13. cpu_reduce_sum_axis_zero(input_vals[0], output_val)
  14. else:
  15. output_val[:] = np.sum(input_vals[0].asnumpy(), axis=0)
  16. else:
  17. reduce_sum_axis_zero(input_vals[0], output_val, stream_handle)
  18. def gradient(self, output_grad):
  19. from .Broadcast import broadcastto_op
  20. return [broadcastto_op(output_grad, self.inputs[0], ctx=self.raw_ctx)]
  21. def infer_shape(self, input_shapes):
  22. """summation reduction axis = 0
  23. e.g. (3,4,5)->(4,5)
  24. for vector, simpler to do (3,)->(1,)
  25. """
  26. assert len(input_shapes) == 1
  27. input_shape = input_shapes[0]
  28. if len(input_shape) == 1:
  29. return (1,)
  30. else:
  31. return input_shape[1:]
  32. def reducesumaxiszero_op(node, ctx=None):
  33. """Creates a node that represents np.sum(node_A, axis=0).
  34. Parameters:
  35. ----
  36. node : Node
  37. The Node needed to be summed.
  38. Returns:
  39. ----
  40. A new Node instance created by Op.
  41. """
  42. return ReduceSumAxisZeroOp(node, ctx=ctx)