from __future__ import absolute_import import numpy as np from .Node import Op from ..gpu_links import reduce_sum class ReduceSumOp(Op): def __init__(self, node_A, axes, keepdims=False, ctx=None): super().__init__(ReduceSumOp, [node_A], ctx) if axes is not None: if isinstance(axes, int): axes = [axes] self.axes = list(axes) assert all(map(lambda x: isinstance(x, int), self.axes)) if keepdims is not None: if keepdims is True or keepdims is False: self.keepdims = [keepdims] * len(self.axes) else: keepdims = list(keepdims) assert len(keepdims) == len(self.axes) assert all(map(lambda x: isinstance(x, bool), keepdims)) self.keepdims = keepdims def compute(self, input_vals, output_val, stream_handle=None): if self.on_cpu: if all(self.keepdims) or not any(self.keepdims): output_val[:] = np.sum(input_vals[0].asnumpy(), axis=tuple( self.axes), keepdims=self.keepdims[0]) else: temp = input_vals[0].asnumpy() for i in range(len(self.keepdims))[::-1]: temp = np.sum( temp, self.axes[i], keepdims=self.keepdims[i]) output_val[:] = temp else: reduce_sum(input_vals[0], output_val, self.axes, stream_handle) def gradient(self, output_grad): from .BroadcastShape import broadcast_shape_op self.grad_node = broadcast_shape_op( output_grad, None, None, ctx=self.raw_ctx) return [self.grad_node] def infer_shape(self, input_shapes): assert self.axes is not None and self.keepdims is not None assert len(input_shapes) == 1 input_shape = list(input_shapes[0]) if hasattr(self, 'grad_node'): self.grad_node.target_shape = tuple(input_shape) add_axes = [] for i in range(len(self.axes)): if not self.keepdims[i]: add_axes.append(self.axes[i]) self.grad_node.add_axes = add_axes for i in range(len(self.axes)): if self.axes[i] < 0: self.axes[i] += len(input_shape) assert 0 <= self.axes[i] < len(input_shape) input_shape[self.axes[i]] = 1 if self.keepdims[i] else 0 input_shape = [x for x in input_shape if x > 0] if input_shape == []: return (1,) else: return tuple(input_shape) def reduce_sum_op(node, axes, keepdims=False, ctx=None): """Creates a node that represents np.sum(node_A, axis, keepdims). Parameters: ---- node : Node The Node needed to be summed. axes : int or list The axis/axes needed to be summed. keepdims: bool or list Whether to keep the dimension(s). Returns: ---- A new Node instance created by Op. """ return ReduceSumOp(node, axes, keepdims, ctx=ctx)