|
- from __future__ import absolute_import
- import numpy as np
- from .Node import Op
- from .._base import DNNL_LIB
- from ..cpu_links import reduce_sum_axis_zero as cpu_reduce_sum_axis_zero
- from ..gpu_links import reduce_sum_axis_zero
-
-
- class ReduceSumAxisZeroOp(Op):
- def __init__(self, node_A, ctx=None):
- super().__init__(ReduceSumAxisZeroOp, [node_A], ctx)
-
- def compute(self, input_vals, output_val, stream_handle=None):
- if self.on_cpu:
-
- if DNNL_LIB['cpu_ReduceSumAxisZero']:
- cpu_reduce_sum_axis_zero(input_vals[0], output_val)
- else:
- output_val[:] = np.sum(input_vals[0].asnumpy(), axis=0)
- else:
- reduce_sum_axis_zero(input_vals[0], output_val, stream_handle)
-
- def gradient(self, output_grad):
- from .Broadcast import broadcastto_op
- return [broadcastto_op(output_grad, self.inputs[0], ctx=self.raw_ctx)]
-
- def infer_shape(self, input_shapes):
- """summation reduction axis = 0
- e.g. (3,4,5)->(4,5)
- for vector, simpler to do (3,)->(1,)
- """
- assert len(input_shapes) == 1
- input_shape = input_shapes[0]
- if len(input_shape) == 1:
- return (1,)
- else:
- return input_shape[1:]
-
-
- def reducesumaxiszero_op(node, ctx=None):
- """Creates a node that represents np.sum(node_A, axis=0).
-
- Parameters:
- ----
- node : Node
- The Node needed to be summed.
-
- Returns:
- ----
- A new Node instance created by Op.
-
- """
- return ReduceSumAxisZeroOp(node, ctx=ctx)
|