Browse Source

fix(imperative): fix buildin reduce keepdim

GitOrigin-RevId: 38d90ab38a
tags/v1.10.0
Megvii Engine Team 3 years ago
parent
commit
8563f51404
3 changed files with 19 additions and 2 deletions
  1. +17
    -0
      imperative/python/test/unit/functional/test_math.py
  2. +1
    -1
      imperative/src/impl/ops/reduce.cpp
  3. +1
    -1
      src/core/include/megbrain/ir/ops.td

+ 17
- 0
imperative/python/test/unit/functional/test_math.py View File

@@ -7,6 +7,8 @@ from utils import opr_test

import megengine.functional as F
from megengine import jit, tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin


def common_test_reduce(opr, ref_opr):
@@ -182,6 +184,21 @@ def test_sum_neg_axis():
F.sum(tensor(data), axis=(-1, 1))


def test_builtin_reduce():
shape = (2, 3, 3, 2)
data = np.random.random(shape).astype(np.float32)
for axis in (-1, -2, 0, 1):
for keepdims in (True, False):
op = builtin.Reduce(mode="sum", axis=axis, keepdim=keepdims)
get = apply(op, tensor(data))[0]
def_op = builtin.Reduce(mode="sum", axis=axis)
def_get = apply(def_op, tensor(data))[0]
ref = np.sum(data, axis=axis, keepdims=keepdims)
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
if keepdims == True:
np.testing.assert_allclose(def_get.numpy(), ref, rtol=1e-6)


def test_non_finite():
shape = (32, 3, 32, 32)
data = []


+ 1
- 1
imperative/src/impl/ops/reduce.cpp View File

@@ -222,7 +222,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
for (size_t i = 0; i < size; ++i) {
dests[i].comp_node = inputs[i].comp_node;
dests[i].layout = inputs[i].layout;
if (not keepdim && dests[i].layout.ndim > 1) {
if (!keepdim && dests[i].layout.ndim > 1) {
dests[i].layout.remove_axis_inplace(axis);
} else {
dests[i].layout.shape[axis] = 1;


+ 1
- 1
src/core/include/megbrain/ir/ops.td View File

@@ -16,7 +16,7 @@ def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {

def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>{
let extraArguments = (ins
MgbBoolAttr:$keepdim
MgbDefaultValuedAttr<MgbBoolAttr, "true">:$keepdim
);
}



Loading…
Cancel
Save