GitOrigin-RevId: a5dc3b997c
tags/v1.5.0
| @@ -7,12 +7,14 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import pytest | |||
| import megengine.functional as F | |||
| import megengine.functional.elemwise as elemwise | |||
| from megengine import tensor | |||
| from megengine.core.tensor import dtype | |||
| from megengine.functional.elemwise import Elemwise, _elwise | |||
| from megengine.jit import trace | |||
| def test_abs(): | |||
| @@ -180,3 +182,80 @@ def test_int32_input(): | |||
| inp = (x,) * nargs | |||
| y = op(*inp) | |||
| y.numpy() | |||
| @pytest.mark.parametrize("is_trace", [True, False]) | |||
| def test_empty_tensor(is_trace): | |||
| binary_func = [] | |||
| unary_func = [] | |||
| for op_name in elemwise.__all__: | |||
| op = getattr(elemwise, op_name) | |||
| nargs = op.__code__.co_argcount | |||
| if op_name == "clip": | |||
| unary_func.append(["clip", lambda x, f=op: f(x, lower=0, upper=1)]) | |||
| elif op_name.endswith("_shift"): | |||
| unary_func.append( | |||
| [op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="int32"), 1)] | |||
| ) | |||
| elif op_name.startswith("logical_"): # logical_xxx op only accept boolean type | |||
| if nargs == 1: | |||
| unary_func.append( | |||
| [op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="bool"))] | |||
| ) | |||
| else: | |||
| assert nargs == 2 | |||
| binary_func.append( | |||
| [ | |||
| op_name, | |||
| lambda x, y, f=op: f( | |||
| tensor(x.numpy(), dtype="bool"), | |||
| tensor(y.numpy(), dtype="bool"), | |||
| ), | |||
| ] | |||
| ) | |||
| elif nargs == 1: | |||
| unary_func.append([op_name, op]) | |||
| elif nargs == 2: | |||
| binary_func.append([op_name, op]) | |||
| else: | |||
| print(nargs) | |||
| raise NotImplementedError | |||
| def run_test(func, args, ref_shape, is_trace, sym=False): | |||
| args = [tensor(t, dtype="float32") for t in args] | |||
| if is_trace: | |||
| func = trace(symbolic=sym)(func) | |||
| for _ in range(3): | |||
| out = func(*args) | |||
| assert out.numpy().shape == ref_shape | |||
| else: | |||
| out = func(*args) | |||
| assert out.numpy().shape == ref_shape | |||
| print(out.numpy().shape) | |||
| inps = [ | |||
| np.array([]).astype("float32"), | |||
| np.random.randn(2, 0, 3).astype("float32"), | |||
| 123, | |||
| ] | |||
| for op_name, op in unary_func: | |||
| if is_trace: | |||
| for sym in [True, False]: | |||
| run_test(op, [inps[0],], inps[0].shape, True, sym) | |||
| run_test(op, [inps[1],], inps[1].shape, True, sym) | |||
| else: | |||
| run_test(op, [inps[0],], inps[0].shape, False) | |||
| run_test(op, [inps[1],], inps[1].shape, False) | |||
| for op_name, op in binary_func: | |||
| if is_trace: | |||
| for sym in [True, False]: | |||
| run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, True, sym) | |||
| run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, True, sym) | |||
| run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, True, sym) | |||
| run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, True, sym) | |||
| else: | |||
| run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, False) | |||
| run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) | |||
| run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) | |||
| run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) | |||
| @@ -19,6 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.tensor import megbrain_graph as G | |||
| from megengine.core.tensor.utils import astensor1d | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.jit import trace | |||
| from megengine.utils.network import Network, set_symbolic_shape | |||
| from megengine.utils.network_node import VarNode | |||
| @@ -177,6 +178,48 @@ def test_reshape(is_varnode): | |||
| np.testing.assert_equal(yy.numpy(), y) | |||
| @pytest.mark.parametrize("is_trace", [True, False]) | |||
| def test_reshape_on_empty_tensor(is_trace): | |||
| input1_shape = (100, 0, 1) | |||
| output1_shape = (100, 0, 10) | |||
| data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||
| input2_shape = (10, 0) | |||
| output2_shape = (0,) | |||
| data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||
| input3_shape = (10, 0, 10) | |||
| output3_shape = (0, 1, 2, 3) | |||
| data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||
| def comp(out, target_shp): | |||
| assert out._tuple_shape == target_shp | |||
| def func(x, shp): | |||
| return F.reshape(x, shp) | |||
| cases = [ | |||
| [data1, output1_shape], | |||
| [data2, output2_shape], | |||
| [data3, output3_shape], | |||
| ] | |||
| def test(func, inp, comp, target_shp): | |||
| out = func(inp, target_shp) | |||
| comp(out, target_shp) | |||
| if is_trace: | |||
| for symbolic in [False, True]: | |||
| for inp, target_shp in cases: | |||
| func_traced = trace(symbolic=symbolic)(func) | |||
| test(func_traced, inp, comp, target_shp) | |||
| test(func_traced, inp, comp, target_shp) | |||
| test(func_traced, inp, comp, target_shp) | |||
| else: | |||
| for inp, target_shp in cases: | |||
| test(func, inp, comp, target_shp) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_reshape_shape_inference(is_varnode): | |||
| if is_varnode: | |||
| @@ -480,6 +523,48 @@ def test_broadcast(is_varnode): | |||
| F.broadcast_to(x, (1, 3)) | |||
| @pytest.mark.parametrize("is_trace", [True, False]) | |||
| def test_broadcast_on_empty_tensor(is_trace): | |||
| input1_shape = (100, 0, 1) | |||
| output1_shape = (100, 0, 10) | |||
| data1 = tensor(np.random.random(input1_shape).astype(np.float32)) | |||
| input2_shape = (10, 0) | |||
| output2_shape = (10, 10, 0) | |||
| data2 = tensor(np.random.random(input2_shape).astype(np.float32)) | |||
| input3_shape = (0, 0, 1, 10) | |||
| output3_shape = (10, 0, 0, 10, 10) | |||
| data3 = tensor(np.random.random(input3_shape).astype(np.float32)) | |||
| def comp(out, target_shp): | |||
| assert out._tuple_shape == target_shp | |||
| def func(x, shp): | |||
| return F.broadcast_to(x, shp) | |||
| cases = [ | |||
| [data1, output1_shape], | |||
| [data2, output2_shape], | |||
| [data3, output3_shape], | |||
| ] | |||
| def test(func, inp, comp, target_shp): | |||
| out = func(inp, target_shp) | |||
| comp(out, target_shp) | |||
| if is_trace: | |||
| for symbolic in [False, True]: | |||
| for inp, target_shp in cases: | |||
| func_traced = trace(symbolic=symbolic)(func) | |||
| test(func_traced, inp, comp, target_shp) | |||
| test(func_traced, inp, comp, target_shp) | |||
| test(func_traced, inp, comp, target_shp) | |||
| else: | |||
| for inp, target_shp in cases: | |||
| test(func, inp, comp, target_shp) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| def test_utils_astensor1d(is_varnode): | |||
| if is_varnode: | |||
| @@ -259,6 +259,10 @@ void Elemwise::perform( | |||
| mgb_assert(t.comp_node() == out_cn); | |||
| mgb_assert(t.dtype() == out_dt); | |||
| } | |||
| if (t.shape().is_empty()) { | |||
| mgb_assert(dest.empty()); | |||
| return; | |||
| } | |||
| inp_shapes[i] = t.shape(); | |||
| } | |||
| if (!opr) { | |||
| @@ -1064,4 +1064,37 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { | |||
| MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); | |||
| } | |||
| TEST(TestOprBasicArithElemwise, PerformEmptyIO) { | |||
| auto cn = CompNode::load("xpu0"); | |||
| HostTensorGenerator<> gen; | |||
| auto host_x1 = gen({2, 0, 3, 4}), | |||
| host_x2 = gen({1}); | |||
| auto dev_x1 = std::make_shared<DeviceTensorND>(cn), | |||
| dev_x2 = std::make_shared<DeviceTensorND>(cn); | |||
| dev_x1->copy_from(*host_x1); | |||
| dev_x2->copy_from(*host_x2); | |||
| auto dev_y = std::make_shared<DeviceTensorND>(cn, dev_x1->dtype()); | |||
| dev_y->resize(dev_x1->shape()); | |||
| auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(cn); | |||
| // test unary mode | |||
| for (auto mode: {Mode::NEGATE, Mode::EXP, Mode::LOG}) { | |||
| SmallVector<DeviceTensorND> inputs = {*dev_x1}; | |||
| ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); | |||
| ASSERT_TRUE(dev_y->empty()); | |||
| ASSERT_TRUE(dev_y->shape().is_empty()); | |||
| MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); | |||
| } | |||
| // test binary mode | |||
| for (auto mode: {Mode::ADD, Mode::MUL, Mode::LT}) { | |||
| SmallVector<DeviceTensorND> inputs = {*dev_x1, *dev_x2}; | |||
| ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr)); | |||
| ASSERT_TRUE(dev_y->empty()); | |||
| ASSERT_TRUE(dev_y->shape().is_empty()); | |||
| MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape()); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||