You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Conv2dReduceSum.py 1.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import conv2d_reduce_sum
  5. class Conv2d_ReduceSumOp(Op):
  6. def __init__(self, node_A, ctx=None):
  7. super().__init__(Conv2d_ReduceSumOp, [node_A], ctx)
  8. def compute(self, input_vals, output_val, stream_handle=None):
  9. if self.on_cpu:
  10. output_val[:] = np.sum(input_vals[0].asnumpy(), axis=(0, 2, 3))
  11. else:
  12. conv2d_reduce_sum(input_vals[0], output_val, stream_handle)
  13. def gradient(self, output_grad):
  14. from .Conv2dBroadcast import conv2d_broadcastto_op
  15. return [conv2d_broadcastto_op(output_grad, self.inputs[0], ctx=self.raw_ctx)]
  16. def infer_shape(self, input_shapes):
  17. """summation reduction axis = 0
  18. e.g. (3,4,5)->(4,5)
  19. for vector, simpler to do (3,)->(1,)
  20. """
  21. assert len(input_shapes) == 1
  22. channels = input_shapes[0][1]
  23. return (channels,)
  24. def conv2d_reducesum_op(node, ctx=None):
  25. """Creates a node that represents np.sum(node_A, axis=0).
  26. Only support common-case axis=0 reduction for simplicity of gradient.
  27. Parameters:
  28. ----
  29. node : Node
  30. The Node needed to be summed.
  31. Returns:
  32. ----
  33. A new Node instance created by Op.
  34. """
  35. return Conv2d_ReduceSumOp(node, ctx=ctx)