You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ReduceMean.py 3.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from __future__ import absolute_import
  2. import numpy as np
  3. from .Node import Op
  4. from ..gpu_links import reduce_mean
  5. class ReduceMeanOp(Op):
  6. def __init__(self, node_A, axes, keepdims=False, ctx=None):
  7. super().__init__(ReduceMeanOp, [node_A], ctx)
  8. if axes is not None:
  9. if isinstance(axes, int):
  10. axes = [axes]
  11. self.axes = list(axes)
  12. assert all(map(lambda x: isinstance(x, int), self.axes))
  13. if keepdims is not None:
  14. if keepdims is True or keepdims is False:
  15. self.keepdims = [keepdims] * len(self.axes)
  16. else:
  17. keepdims = list(keepdims)
  18. assert len(keepdims) == len(self.axes)
  19. assert all(map(lambda x: isinstance(x, bool), keepdims))
  20. self.keepdims = keepdims
  21. def compute(self, input_vals, output_val, stream_handle=None):
  22. assert self.axes is not None and self.keepdims is not None
  23. if self.on_cpu:
  24. if all(self.keepdims) or not any(self.keepdims):
  25. output_val[:] = np.mean(input_vals[0].asnumpy(), axis=tuple(
  26. self.axes), keepdims=self.keepdims[0])
  27. else:
  28. temp = input_vals[0].asnumpy()
  29. for i in range(len(self.keepdims))[::-1]:
  30. temp = np.mean(
  31. temp, self.axes[i], keepdims=self.keepdims[i])
  32. output_val[:] = temp
  33. else:
  34. reduce_mean(input_vals[0], output_val, self.axes, stream_handle)
  35. def gradient(self, output_grad):
  36. from .MultiplyConst import mul_byconst_op
  37. from .BroadcastShape import broadcast_shape_op
  38. # Here we don't know how to calculate gradient since we don't have shape information
  39. # The const is determined in infer_shape phase.
  40. self.grad_node = mul_byconst_op(broadcast_shape_op(
  41. output_grad, None, None, ctx=self.raw_ctx), None, ctx=self.raw_ctx)
  42. return [self.grad_node]
  43. def infer_shape(self, input_shapes):
  44. assert self.axes is not None and self.keepdims is not None
  45. assert len(input_shapes) == 1
  46. input_shape = list(input_shapes[0])
  47. mean_multiplier = 1
  48. for i in range(len(self.axes)):
  49. if self.axes[i] < 0:
  50. self.axes[i] += len(input_shape)
  51. assert 0 <= self.axes[i] < len(input_shape)
  52. mean_multiplier *= input_shape[self.axes[i]]
  53. input_shape[self.axes[i]] = 1 if self.keepdims[i] else 0
  54. if hasattr(self, 'grad_node'):
  55. self.grad_node.const_attr = 1.0 / mean_multiplier
  56. self.grad_node.inputs[0].target_shape = tuple(input_shapes[0])
  57. add_axes = []
  58. for i in range(len(self.axes)):
  59. if not self.keepdims[i]:
  60. add_axes.append(self.axes[i])
  61. self.grad_node.inputs[0].add_axes = add_axes
  62. input_shape = [x for x in input_shape if x > 0]
  63. if input_shape == []:
  64. return (1,)
  65. else:
  66. return tuple(input_shape)
  67. def reduce_mean_op(node, axes, keepdims=False, ctx=None):
  68. """Creates a node that represents np.mean(node_A, axis, keepdims).
  69. Parameters:
  70. ----
  71. node : Node
  72. The Node needed to be averaged.
  73. axes : int or list
  74. The axis/axes needed to be averaged.
  75. keepdims: bool or list
  76. Whether to keep the dimension(s).
  77. Returns:
  78. ----
  79. A new Node instance created by Op.
  80. """
  81. return ReduceMeanOp(node, axes, keepdims, ctx=ctx)