|
|
|
@@ -24,7 +24,7 @@ from ..._checkparam import Rel |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ...common.tensor import Tensor |
|
|
|
from .._utils import _get_broadcast_shape |
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register |
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op |
|
|
|
|
|
|
|
|
|
|
|
def _infer_shape_reduce(x, axis, keep_dims, prim_name): |
|
|
|
@@ -225,6 +225,11 @@ class _Reduce(PrimitiveWithInfer): |
|
|
|
validator.check_value_type('keep_dims', keep_dims, [bool], self.name) |
|
|
|
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) |
|
|
|
|
|
|
|
def __call__(self, x, axis=()): |
|
|
|
args = [x, axis] |
|
|
|
output = _run_op(self, self.name, args) |
|
|
|
return output |
|
|
|
|
|
|
|
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): |
|
|
|
axis_v = axis['value'] |
|
|
|
input_shp = input_x['shape'] |
|
|
|
|