| @@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis | |||||
| from ..core.tensor.utils import astensor1d, convert_inputs, get_device | from ..core.tensor.utils import astensor1d, convert_inputs, get_device | ||||
| from ..device import get_default_device | from ..device import get_default_device | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .elemwise import ceil, floor_div | |||||
| from .elemwise import ceil | |||||
| __all__ = [ | __all__ = [ | ||||
| "arange", | "arange", | ||||
| @@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
| Ntotal = inp.shape[axis] | Ntotal = inp.shape[axis] | ||||
| try: | |||||
| if isinstance(nsplits_or_sections, Sequence): | |||||
| Nsections = len(nsplits_or_sections) + 1 | Nsections = len(nsplits_or_sections) + 1 | ||||
| is_array = True | is_array = True | ||||
| except TypeError: | |||||
| else: | |||||
| Nsections = int(nsplits_or_sections) | Nsections = int(nsplits_or_sections) | ||||
| is_array = False | is_array = False | ||||
| @@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
| Ntotal, axis, Nsections | Ntotal, axis, Nsections | ||||
| ) | ) | ||||
| ) | ) | ||||
| func = ( | |||||
| floor_div | |||||
| if isinstance(Nsections, (SymbolVar, Tensor)) | |||||
| else lambda x, y: x // y | |||||
| ) | |||||
| div_points = [0] + [ | |||||
| func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||||
| ] | |||||
| for i in range(2, Nsections + 1): | |||||
| div_points[i] = div_points[i - 1] + div_points[i] | |||||
| sub_tensors = [] | |||||
| for i in range(Nsections): | |||||
| l = div_points[i] | |||||
| r = div_points[i + 1] | |||||
| slices = tuple( | |||||
| [slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1) | |||||
| ) | |||||
| sub_tensors.append(inp[slices]) | |||||
| return sub_tensors | |||||
| partitions = [] | |||||
| for i in range(Nsections): | |||||
| section_size = (Ntotal + Nsections - i - 1) // Nsections | |||||
| partitions.append(section_size) | |||||
| partitions = [ | |||||
| part | |||||
| if isinstance(part, (SymbolVar, Tensor)) | |||||
| else Const(part, dtype="int32", device=inp.device)(inp)[0] | |||||
| for part in partitions | |||||
| ] | |||||
| op = builtin.Split(axis=axis) | |||||
| return apply(op, inp, *partitions) | |||||
| def _get_idx(index, axis): | def _get_idx(index, axis): | ||||
| @@ -178,6 +178,21 @@ def test_regression_1762(): | |||||
| gm.backward(loss) | gm.backward(loss) | ||||
| def test_empty_grad_in_backward(): | |||||
| x = mge.Parameter(F.full(100, 0.5)) | |||||
| y = mge.Parameter(F.ones(100)) | |||||
| gm = GradManager() | |||||
| gm.attach([x, y]) | |||||
| with gm: | |||||
| z = F.where(x > 0.7, x, y) | |||||
| loss = z.sum() | |||||
| gm.backward(loss) | |||||
| assert np.all(x.grad.numpy() == 0) | |||||
| assert np.all(y.grad.numpy() == 1) | |||||
| @pytest.mark.require_ngpu(2) | @pytest.mark.require_ngpu(2) | ||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
| @@ -119,7 +119,7 @@ def test_stack(is_varnode): | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
| def test_split(is_varnode): | |||||
| def test_split_basic(is_varnode): | |||||
| if is_varnode: | if is_varnode: | ||||
| network = Network() | network = Network() | ||||
| saved_symbolic_shape = set_symbolic_shape(False) | saved_symbolic_shape = set_symbolic_shape(False) | ||||
| @@ -150,15 +150,48 @@ def test_split(is_varnode): | |||||
| pass | pass | ||||
| try: | try: | ||||
| F.split(inp, [3, 3, 5], axis=3) | |||||
| F.split(inp, [3, 2, 5], axis=3) | |||||
| assert False | assert False | ||||
| except ValueError as e: | except ValueError as e: | ||||
| assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | |||||
| assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]" | |||||
| if is_varnode: | if is_varnode: | ||||
| set_symbolic_shape(saved_symbolic_shape) | set_symbolic_shape(saved_symbolic_shape) | ||||
| @pytest.mark.parametrize("symbolic", [None, False, True]) | |||||
| def test_split(symbolic): | |||||
| inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32) | |||||
| inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32) | |||||
| def ref(inp, nsplits_or_sections, axis): | |||||
| return np.split(inp, nsplits_or_sections, axis) | |||||
| def func(inp, nsplits_or_sections, axis): | |||||
| return F.split(inp, nsplits_or_sections, axis) | |||||
| cases = [ | |||||
| (inp1, 2, 3), | |||||
| (inp1, [3], 3), | |||||
| (inp1, [3, 3, 5], 3), | |||||
| (inp2, 2, 3), | |||||
| (inp2, [3], 3), | |||||
| (inp2, [3, 3, 5], 3), | |||||
| ] | |||||
| for case in cases: | |||||
| if symbolic is None: | |||||
| fn = func | |||||
| else: | |||||
| fn = trace(symbolic=symbolic)(func) | |||||
| for i in range(3 if symbolic is not None else 1): | |||||
| ref_out = ref(*case) | |||||
| out = fn(tensor(case[0]), case[1], case[2]) | |||||
| assert len(ref_out) == len(out) | |||||
| for idx in range(len(ref_out)): | |||||
| np.testing.assert_equal(ref_out[idx], out[idx].numpy()) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
| def test_reshape(is_varnode): | def test_reshape(is_varnode): | ||||
| if is_varnode: | if is_varnode: | ||||
| @@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config) | |||||
| } | } | ||||
| for (size_t i = 0; i < m_opt.nr_part; ++ i) | for (size_t i = 0; i < m_opt.nr_part; ++ i) | ||||
| add_output(ssprintf("o%zd", i))->dtype(inp->dtype()); | |||||
| add_output(ssprintf("o%zd", i))->dtype(inp->dtype()) | |||||
| .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
| m_output_spec.resize(m_opt.nr_part); | m_output_spec.resize(m_opt.nr_part); | ||||
| } | } | ||||
| @@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest, | |||||
| size_t size = 0; | size_t size = 0; | ||||
| for (size_t i = 0; i < m_opt.nr_part; ++ i) { | for (size_t i = 0; i < m_opt.nr_part; ++ i) { | ||||
| auto p = partition[i]; | auto p = partition[i]; | ||||
| mgb_assert(p, | |||||
| "got zero partition size at part %zu, tot_size=%zu", | |||||
| i, ishp.shape[axis]); | |||||
| size += p; | size += p; | ||||
| auto &&cur = m_output_spec[i].shape; | auto &&cur = m_output_spec[i].shape; | ||||
| @@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const { | |||||
| auto rst = OperatorNodeBase::do_make_node_prop(); | auto rst = OperatorNodeBase::do_make_node_prop(); | ||||
| rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); | rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); | ||||
| outshape_by_symvar_reset_node_dep_type(rst); | outshape_by_symvar_reset_node_dep_type(rst); | ||||
| rst->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
| return rst; | return rst; | ||||
| } | } | ||||
| @@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) { | |||||
| auto &&in = input(0)->dev_tensor(); | auto &&in = input(0)->dev_tensor(); | ||||
| auto &&out = output(idx)->dev_tensor(); | auto &&out = output(idx)->dev_tensor(); | ||||
| auto &&spec = m_output_spec.at(idx); | auto &&spec = m_output_spec.at(idx); | ||||
| if (out.layout().is_empty()) { | |||||
| mgb_assert(spec.subspec.layout().is_empty()); | |||||
| return; | |||||
| } | |||||
| owner_graph()->event().signal_inplace<cg::event::BeforeKernel>( | owner_graph()->event().signal_inplace<cg::event::BeforeKernel>( | ||||
| this, out.comp_node()); | this, out.comp_node()); | ||||
| if (spec.mem_fwd_success) { | if (spec.mem_fwd_success) { | ||||