Browse Source

feat(mgb/opr): let Split support empty IO

GitOrigin-RevId: aad6dc06bf
tags/v1.6.0
Megvii Engine Team 4 years ago
parent
commit
ab309eb5fc
4 changed files with 74 additions and 32 deletions
  1. +16
    -24
      imperative/python/megengine/functional/tensor.py
  2. +15
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py
  3. +36
    -3
      imperative/python/test/unit/functional/test_tensor.py
  4. +7
    -5
      src/opr/impl/tensor_manip.cpp

+ 16
- 24
imperative/python/megengine/functional/tensor.py View File

@@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import astensor1d, convert_inputs, get_device from ..core.tensor.utils import astensor1d, convert_inputs, get_device
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import ceil, floor_div
from .elemwise import ceil


__all__ = [ __all__ = [
"arange", "arange",
@@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0):


Ntotal = inp.shape[axis] Ntotal = inp.shape[axis]


try:
if isinstance(nsplits_or_sections, Sequence):
Nsections = len(nsplits_or_sections) + 1 Nsections = len(nsplits_or_sections) + 1
is_array = True is_array = True
except TypeError:
else:
Nsections = int(nsplits_or_sections) Nsections = int(nsplits_or_sections)
is_array = False is_array = False


@@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal, axis, Nsections Ntotal, axis, Nsections
) )
) )

func = (
floor_div
if isinstance(Nsections, (SymbolVar, Tensor))
else lambda x, y: x // y
)
div_points = [0] + [
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
]
for i in range(2, Nsections + 1):
div_points[i] = div_points[i - 1] + div_points[i]

sub_tensors = []
for i in range(Nsections):
l = div_points[i]
r = div_points[i + 1]
slices = tuple(
[slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1)
)
sub_tensors.append(inp[slices])
return sub_tensors
partitions = []
for i in range(Nsections):
section_size = (Ntotal + Nsections - i - 1) // Nsections
partitions.append(section_size)

partitions = [
part
if isinstance(part, (SymbolVar, Tensor))
else Const(part, dtype="int32", device=inp.device)(inp)[0]
for part in partitions
]
op = builtin.Split(axis=axis)
return apply(op, inp, *partitions)




def _get_idx(index, axis): def _get_idx(index, axis):


+ 15
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -178,6 +178,21 @@ def test_regression_1762():
gm.backward(loss) gm.backward(loss)




def test_empty_grad_in_backward():
x = mge.Parameter(F.full(100, 0.5))
y = mge.Parameter(F.ones(100))

gm = GradManager()
gm.attach([x, y])

with gm:
z = F.where(x > 0.7, x, y)
loss = z.sum()
gm.backward(loss)
assert np.all(x.grad.numpy() == 0)
assert np.all(y.grad.numpy() == 1)


@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
@pytest.mark.parametrize( @pytest.mark.parametrize(


+ 36
- 3
imperative/python/test/unit/functional/test_tensor.py View File

@@ -119,7 +119,7 @@ def test_stack(is_varnode):




@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_split(is_varnode):
def test_split_basic(is_varnode):
if is_varnode: if is_varnode:
network = Network() network = Network()
saved_symbolic_shape = set_symbolic_shape(False) saved_symbolic_shape = set_symbolic_shape(False)
@@ -150,15 +150,48 @@ def test_split(is_varnode):
pass pass


try: try:
F.split(inp, [3, 3, 5], axis=3)
F.split(inp, [3, 2, 5], axis=3)
assert False assert False
except ValueError as e: except ValueError as e:
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"


if is_varnode: if is_varnode:
set_symbolic_shape(saved_symbolic_shape) set_symbolic_shape(saved_symbolic_shape)




@pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic):
inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)

def ref(inp, nsplits_or_sections, axis):
return np.split(inp, nsplits_or_sections, axis)

def func(inp, nsplits_or_sections, axis):
return F.split(inp, nsplits_or_sections, axis)

cases = [
(inp1, 2, 3),
(inp1, [3], 3),
(inp1, [3, 3, 5], 3),
(inp2, 2, 3),
(inp2, [3], 3),
(inp2, [3, 3, 5], 3),
]

for case in cases:
if symbolic is None:
fn = func
else:
fn = trace(symbolic=symbolic)(func)
for i in range(3 if symbolic is not None else 1):
ref_out = ref(*case)
out = fn(tensor(case[0]), case[1], case[2])
assert len(ref_out) == len(out)
for idx in range(len(ref_out)):
np.testing.assert_equal(ref_out[idx], out[idx].numpy())


@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode): def test_reshape(is_varnode):
if is_varnode: if is_varnode:


+ 7
- 5
src/opr/impl/tensor_manip.cpp View File

@@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config)
} }


for (size_t i = 0; i < m_opt.nr_part; ++ i) for (size_t i = 0; i < m_opt.nr_part; ++ i)
add_output(ssprintf("o%zd", i))->dtype(inp->dtype());
add_output(ssprintf("o%zd", i))->dtype(inp->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);


m_output_spec.resize(m_opt.nr_part); m_output_spec.resize(m_opt.nr_part);
} }
@@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest,
size_t size = 0; size_t size = 0;
for (size_t i = 0; i < m_opt.nr_part; ++ i) { for (size_t i = 0; i < m_opt.nr_part; ++ i) {
auto p = partition[i]; auto p = partition[i];
mgb_assert(p,
"got zero partition size at part %zu, tot_size=%zu",
i, ishp.shape[axis]);

size += p; size += p;


auto &&cur = m_output_spec[i].shape; auto &&cur = m_output_spec[i].shape;
@@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const {
auto rst = OperatorNodeBase::do_make_node_prop(); auto rst = OperatorNodeBase::do_make_node_prop();
rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
outshape_by_symvar_reset_node_dep_type(rst); outshape_by_symvar_reset_node_dep_type(rst);
rst->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return rst; return rst;
} }


@@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) {
auto &&in = input(0)->dev_tensor(); auto &&in = input(0)->dev_tensor();
auto &&out = output(idx)->dev_tensor(); auto &&out = output(idx)->dev_tensor();
auto &&spec = m_output_spec.at(idx); auto &&spec = m_output_spec.at(idx);
if (out.layout().is_empty()) {
mgb_assert(spec.subspec.layout().is_empty());
return;
}
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>( owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
this, out.comp_node()); this, out.comp_node());
if (spec.mem_fwd_success) { if (spec.mem_fwd_success) {


Loading…
Cancel
Save