| @@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis | |||
| from ..core.tensor.utils import astensor1d, convert_inputs, get_device | |||
| from ..device import get_default_device | |||
| from ..tensor import Tensor | |||
| from .elemwise import ceil, floor_div | |||
| from .elemwise import ceil | |||
| __all__ = [ | |||
| "arange", | |||
| @@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0): | |||
| Ntotal = inp.shape[axis] | |||
| try: | |||
| if isinstance(nsplits_or_sections, Sequence): | |||
| Nsections = len(nsplits_or_sections) + 1 | |||
| is_array = True | |||
| except TypeError: | |||
| else: | |||
| Nsections = int(nsplits_or_sections) | |||
| is_array = False | |||
| @@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0): | |||
| Ntotal, axis, Nsections | |||
| ) | |||
| ) | |||
| func = ( | |||
| floor_div | |||
| if isinstance(Nsections, (SymbolVar, Tensor)) | |||
| else lambda x, y: x // y | |||
| ) | |||
| div_points = [0] + [ | |||
| func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||
| ] | |||
| for i in range(2, Nsections + 1): | |||
| div_points[i] = div_points[i - 1] + div_points[i] | |||
| sub_tensors = [] | |||
| for i in range(Nsections): | |||
| l = div_points[i] | |||
| r = div_points[i + 1] | |||
| slices = tuple( | |||
| [slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1) | |||
| ) | |||
| sub_tensors.append(inp[slices]) | |||
| return sub_tensors | |||
| partitions = [] | |||
| for i in range(Nsections): | |||
| section_size = (Ntotal + Nsections - i - 1) // Nsections | |||
| partitions.append(section_size) | |||
| partitions = [ | |||
| part | |||
| if isinstance(part, (SymbolVar, Tensor)) | |||
| else Const(part, dtype="int32", device=inp.device)(inp)[0] | |||
| for part in partitions | |||
| ] | |||
| op = builtin.Split(axis=axis) | |||
| return apply(op, inp, *partitions) | |||
| def _get_idx(index, axis): | |||
| @@ -178,6 +178,21 @@ def test_regression_1762(): | |||
| gm.backward(loss) | |||
| def test_empty_grad_in_backward(): | |||
| x = mge.Parameter(F.full(100, 0.5)) | |||
| y = mge.Parameter(F.ones(100)) | |||
| gm = GradManager() | |||
| gm.attach([x, y]) | |||
| with gm: | |||
| z = F.where(x > 0.7, x, y) | |||
| loss = z.sum() | |||
| gm.backward(loss) | |||
| assert np.all(x.grad.numpy() == 0) | |||
| assert np.all(y.grad.numpy() == 1) | |||
| @pytest.mark.require_ngpu(2) | |||
| @pytest.mark.isolated_distributed | |||
| @pytest.mark.parametrize( | |||
| @@ -119,7 +119,7 @@ def test_stack(is_varnode): | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_split(is_varnode): | |||
| def test_split_basic(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| saved_symbolic_shape = set_symbolic_shape(False) | |||
| @@ -150,15 +150,48 @@ def test_split(is_varnode): | |||
| pass | |||
| try: | |||
| F.split(inp, [3, 3, 5], axis=3) | |||
| F.split(inp, [3, 2, 5], axis=3) | |||
| assert False | |||
| except ValueError as e: | |||
| assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | |||
| assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]" | |||
| if is_varnode: | |||
| set_symbolic_shape(saved_symbolic_shape) | |||
| @pytest.mark.parametrize("symbolic", [None, False, True]) | |||
| def test_split(symbolic): | |||
| inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32) | |||
| inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32) | |||
| def ref(inp, nsplits_or_sections, axis): | |||
| return np.split(inp, nsplits_or_sections, axis) | |||
| def func(inp, nsplits_or_sections, axis): | |||
| return F.split(inp, nsplits_or_sections, axis) | |||
| cases = [ | |||
| (inp1, 2, 3), | |||
| (inp1, [3], 3), | |||
| (inp1, [3, 3, 5], 3), | |||
| (inp2, 2, 3), | |||
| (inp2, [3], 3), | |||
| (inp2, [3, 3, 5], 3), | |||
| ] | |||
| for case in cases: | |||
| if symbolic is None: | |||
| fn = func | |||
| else: | |||
| fn = trace(symbolic=symbolic)(func) | |||
| for i in range(3 if symbolic is not None else 1): | |||
| ref_out = ref(*case) | |||
| out = fn(tensor(case[0]), case[1], case[2]) | |||
| assert len(ref_out) == len(out) | |||
| for idx in range(len(ref_out)): | |||
| np.testing.assert_equal(ref_out[idx], out[idx].numpy()) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_reshape(is_varnode): | |||
| if is_varnode: | |||
| @@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config) | |||
| } | |||
| for (size_t i = 0; i < m_opt.nr_part; ++ i) | |||
| add_output(ssprintf("o%zd", i))->dtype(inp->dtype()); | |||
| add_output(ssprintf("o%zd", i))->dtype(inp->dtype()) | |||
| .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
| m_output_spec.resize(m_opt.nr_part); | |||
| } | |||
| @@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest, | |||
| size_t size = 0; | |||
| for (size_t i = 0; i < m_opt.nr_part; ++ i) { | |||
| auto p = partition[i]; | |||
| mgb_assert(p, | |||
| "got zero partition size at part %zu, tot_size=%zu", | |||
| i, ishp.shape[axis]); | |||
| size += p; | |||
| auto &&cur = m_output_spec[i].shape; | |||
| @@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const { | |||
| auto rst = OperatorNodeBase::do_make_node_prop(); | |||
| rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); | |||
| outshape_by_symvar_reset_node_dep_type(rst); | |||
| rst->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
| return rst; | |||
| } | |||
| @@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) { | |||
| auto &&in = input(0)->dev_tensor(); | |||
| auto &&out = output(idx)->dev_tensor(); | |||
| auto &&spec = m_output_spec.at(idx); | |||
| if (out.layout().is_empty()) { | |||
| mgb_assert(spec.subspec.layout().is_empty()); | |||
| return; | |||
| } | |||
| owner_graph()->event().signal_inplace<cg::event::BeforeKernel>( | |||
| this, out.comp_node()); | |||
| if (spec.mem_fwd_success) { | |||