diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index e0765c8d..a8f10326 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -7,6 +7,8 @@ from utils import opr_test import megengine.functional as F from megengine import jit, tensor +from megengine.core._imperative_rt.core2 import apply +from megengine.core.ops import builtin def common_test_reduce(opr, ref_opr): @@ -182,6 +184,21 @@ def test_sum_neg_axis(): F.sum(tensor(data), axis=(-1, 1)) +def test_builtin_reduce(): + shape = (2, 3, 3, 2) + data = np.random.random(shape).astype(np.float32) + for axis in (-1, -2, 0, 1): + for keepdims in (True, False): + op = builtin.Reduce(mode="sum", axis=axis, keepdim=keepdims) + get = apply(op, tensor(data))[0] + def_op = builtin.Reduce(mode="sum", axis=axis) + def_get = apply(def_op, tensor(data))[0] + ref = np.sum(data, axis=axis, keepdims=keepdims) + np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) + if keepdims == True: + np.testing.assert_allclose(def_get.numpy(), ref, rtol=1e-6) + + def test_non_finite(): shape = (32, 3, 32, 32) data = [] diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index dc4595f1..f50bef0f 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -222,7 +222,7 @@ std::tuple, bool> infer_output_attrs_fallible( for (size_t i = 0; i < size; ++i) { dests[i].comp_node = inputs[i].comp_node; dests[i].layout = inputs[i].layout; - if (not keepdim && dests[i].layout.ndim > 1) { + if (!keepdim && dests[i].layout.ndim > 1) { dests[i].layout.remove_axis_inplace(axis); } else { dests[i].layout.shape[axis] = 1; diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 3100df70..4d8dff62 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -16,7 +16,7 @@ def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{ let extraArguments = (ins - MgbBoolAttr:$keepdim + MgbDefaultValuedAttr:$keepdim ); }