You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ReduceSum.py 3.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import reduce_sum
  5. class ReduceSumOp(Op):
  6. def __init__(self, node_A, axes, keepdims=False, ctx=None):
  7. super().__init__(ReduceSumOp, [node_A], ctx)
  8. if axes is not None:
  9. if isinstance(axes, int):
  10. axes = [axes]
  11. self.axes = list(axes)
  12. assert all(map(lambda x: isinstance(x, int), self.axes))
  13. if keepdims is not None:
  14. if keepdims is True or keepdims is False:
  15. self.keepdims = [keepdims] * len(self.axes)
  16. else:
  17. keepdims = list(keepdims)
  18. assert len(keepdims) == len(self.axes)
  19. assert all(map(lambda x: isinstance(x, bool), keepdims))
  20. self.keepdims = keepdims
  21. def compute(self, input_vals, output_val, stream_handle=None):
  22. if self.on_cpu:
  23. if all(self.keepdims) or not any(self.keepdims):
  24. output_val[:] = np.sum(input_vals[0].asnumpy(), axis=tuple(
  25. self.axes), keepdims=self.keepdims[0])
  26. else:
  27. temp = input_vals[0].asnumpy()
  28. for i in range(len(self.keepdims))[::-1]:
  29. temp = np.sum(
  30. temp, self.axes[i], keepdims=self.keepdims[i])
  31. output_val[:] = temp
  32. else:
  33. reduce_sum(input_vals[0], output_val, self.axes, stream_handle)
  34. def gradient(self, output_grad):
  35. from .BroadcastShape import broadcast_shape_op
  36. self.grad_node = broadcast_shape_op(
  37. output_grad, None, None, ctx=self.raw_ctx)
  38. return [self.grad_node]
  39. def infer_shape(self, input_shapes):
  40. assert self.axes is not None and self.keepdims is not None
  41. assert len(input_shapes) == 1
  42. input_shape = list(input_shapes[0])
  43. if hasattr(self, 'grad_node'):
  44. self.grad_node.target_shape = tuple(input_shape)
  45. add_axes = []
  46. for i in range(len(self.axes)):
  47. if not self.keepdims[i]:
  48. add_axes.append(self.axes[i])
  49. self.grad_node.add_axes = add_axes
  50. for i in range(len(self.axes)):
  51. if self.axes[i] < 0:
  52. self.axes[i] += len(input_shape)
  53. assert 0 <= self.axes[i] < len(input_shape)
  54. input_shape[self.axes[i]] = 1 if self.keepdims[i] else 0
  55. input_shape = [x for x in input_shape if x > 0]
  56. if input_shape == []:
  57. return (1,)
  58. else:
  59. return tuple(input_shape)
  60. def reduce_sum_op(node, axes, keepdims=False, ctx=None):
  61. """Creates a node that represents np.sum(node_A, axis, keepdims).
  62. Parameters:
  63. ----
  64. node : Node
  65. The Node needed to be summed.
  66. axes : int or list
  67. The axis/axes needed to be summed.
  68. keepdims: bool or list
  69. Whether to keep the dimension(s).
  70. Returns:
  71. ----
  72. A new Node instance created by Op.
  73. """
  74. return ReduceSumOp(node, axes, keepdims, ctx=ctx)