From ab309eb5fc6174eb5e529af685b40b99dfac403c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 20 Aug 2021 19:14:49 +0800 Subject: [PATCH] feat(mgb/opr): let Split support empty IO GitOrigin-RevId: aad6dc06bfe9b95889b924e0a26f3ea33c52319a --- .../python/megengine/functional/tensor.py | 40 ++++++++----------- .../test/unit/autodiff/test_grad_manger.py | 15 +++++++ .../test/unit/functional/test_tensor.py | 39 ++++++++++++++++-- src/opr/impl/tensor_manip.cpp | 12 +++--- 4 files changed, 74 insertions(+), 32 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 0890219d..2677c6b0 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis from ..core.tensor.utils import astensor1d, convert_inputs, get_device from ..device import get_default_device from ..tensor import Tensor -from .elemwise import ceil, floor_div +from .elemwise import ceil __all__ = [ "arange", @@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0): Ntotal = inp.shape[axis] - try: + if isinstance(nsplits_or_sections, Sequence): Nsections = len(nsplits_or_sections) + 1 is_array = True - except TypeError: + else: Nsections = int(nsplits_or_sections) is_array = False @@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0): Ntotal, axis, Nsections ) ) - - func = ( - floor_div - if isinstance(Nsections, (SymbolVar, Tensor)) - else lambda x, y: x // y - ) - div_points = [0] + [ - func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) - ] - for i in range(2, Nsections + 1): - div_points[i] = div_points[i - 1] + div_points[i] - - sub_tensors = [] - for i in range(Nsections): - l = div_points[i] - r = div_points[i + 1] - slices = tuple( - [slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1) - ) - sub_tensors.append(inp[slices]) - return sub_tensors + partitions = [] + for i in range(Nsections): + section_size = (Ntotal + Nsections - i - 1) // Nsections + partitions.append(section_size) + + partitions = [ + part + if isinstance(part, (SymbolVar, Tensor)) + else Const(part, dtype="int32", device=inp.device)(inp)[0] + for part in partitions + ] + op = builtin.Split(axis=axis) + return apply(op, inp, *partitions) def _get_idx(index, axis): diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index daf58823..a288ba26 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -178,6 +178,21 @@ def test_regression_1762(): gm.backward(loss) +def test_empty_grad_in_backward(): + x = mge.Parameter(F.full(100, 0.5)) + y = mge.Parameter(F.ones(100)) + + gm = GradManager() + gm.attach([x, y]) + + with gm: + z = F.where(x > 0.7, x, y) + loss = z.sum() + gm.backward(loss) + assert np.all(x.grad.numpy() == 0) + assert np.all(y.grad.numpy() == 1) + + @pytest.mark.require_ngpu(2) @pytest.mark.isolated_distributed @pytest.mark.parametrize( diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index dbc3fa83..e6f56819 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -119,7 +119,7 @@ def test_stack(is_varnode): @pytest.mark.parametrize("is_varnode", [True, False]) -def test_split(is_varnode): +def test_split_basic(is_varnode): if is_varnode: network = Network() saved_symbolic_shape = set_symbolic_shape(False) @@ -150,15 +150,48 @@ def test_split(is_varnode): pass try: - F.split(inp, [3, 3, 5], axis=3) + F.split(inp, [3, 2, 5], axis=3) assert False except ValueError as e: - assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" + assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]" if is_varnode: set_symbolic_shape(saved_symbolic_shape) +@pytest.mark.parametrize("symbolic", [None, False, True]) +def test_split(symbolic): + inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32) + inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32) + + def ref(inp, nsplits_or_sections, axis): + return np.split(inp, nsplits_or_sections, axis) + + def func(inp, nsplits_or_sections, axis): + return F.split(inp, nsplits_or_sections, axis) + + cases = [ + (inp1, 2, 3), + (inp1, [3], 3), + (inp1, [3, 3, 5], 3), + (inp2, 2, 3), + (inp2, [3], 3), + (inp2, [3, 3, 5], 3), + ] + + for case in cases: + if symbolic is None: + fn = func + else: + fn = trace(symbolic=symbolic)(func) + for i in range(3 if symbolic is not None else 1): + ref_out = ref(*case) + out = fn(tensor(case[0]), case[1], case[2]) + assert len(ref_out) == len(out) + for idx in range(len(ref_out)): + np.testing.assert_equal(ref_out[idx], out[idx].numpy()) + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_reshape(is_varnode): if is_varnode: diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index 196f7200..11d56587 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config) } for (size_t i = 0; i < m_opt.nr_part; ++ i) - add_output(ssprintf("o%zd", i))->dtype(inp->dtype()); + add_output(ssprintf("o%zd", i))->dtype(inp->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); m_output_spec.resize(m_opt.nr_part); } @@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest, size_t size = 0; for (size_t i = 0; i < m_opt.nr_part; ++ i) { auto p = partition[i]; - mgb_assert(p, - "got zero partition size at part %zu, tot_size=%zu", - i, ishp.shape[axis]); - size += p; auto &&cur = m_output_spec[i].shape; @@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const { auto rst = OperatorNodeBase::do_make_node_prop(); rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); outshape_by_symvar_reset_node_dep_type(rst); + rst->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); return rst; } @@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) { auto &&in = input(0)->dev_tensor(); auto &&out = output(idx)->dev_tensor(); auto &&spec = m_output_spec.at(idx); + if (out.layout().is_empty()) { + mgb_assert(spec.subspec.layout().is_empty()); + return; + } owner_graph()->event().signal_inplace( this, out.comp_node()); if (spec.mem_fwd_success) {