From 7fadc16d3c5c563b4a8a991f0a7685747cdc99c4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 31 Aug 2020 12:16:01 +0800 Subject: [PATCH] refactor(mge/functional): support tensor shape in interpolate and split GitOrigin-RevId: 6430b64f010ea5d0ecb1caa59b6da0a1547552ae --- .../python/megengine/core/tensor/utils.py | 4 +++- .../python/megengine/functional/elemwise.py | 17 ++++++++++++----- imperative/python/megengine/functional/nn.py | 7 ++----- .../python/megengine/functional/tensor.py | 10 ++++++---- .../test/unit/functional/test_elemwise.py | 8 ++++---- .../test/unit/functional/test_functional.py | 3 --- .../python/test/unit/functional/test_tensor.py | 2 -- 7 files changed, 27 insertions(+), 24 deletions(-) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 5981b2f5..b700c1cd 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -31,7 +31,9 @@ def dtype_promotion(raw_inputs): ] inputs = [i for i in raw_inputs if hasattr(i, "dtype")] assert len(scalar_inputs + inputs) > 0 - dtype = np.result_type(*inputs) + dtype = None + if len(inputs) > 0: + dtype = np.result_type(*inputs) dtype_all = np.result_type(*(inputs + scalar_inputs)) assert ( dtype != np.float64 and dtype != np.int64 diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index f3a43733..bc7c68f8 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -10,8 +10,9 @@ import functools from ..core.ops import builtin -from ..core.tensor import utils +from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import apply +from ..device import get_default_device from ..tensor import Tensor __all__ = [ @@ -76,11 +77,17 @@ __all__ = [ def _elwise(*args, mode): op = builtin.Elemwise(mode=mode) + tensor_args = list( + filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) + ) + if len(tensor_args) == 0: + dtype = utils.dtype_promotion(args) + first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) + args = utils.convert_inputs(first_arg, *args[1:]) + else: + args = utils.convert_inputs(*args) if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"): - args = tuple( - map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args) - ) - args = utils.convert_inputs(*args) + args = tuple(map(lambda x: x.astype("float32"), args)) (result,) = apply(op, *args) return result diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index f7163cd4..38b4b16d 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1126,11 +1126,8 @@ def interpolate( if mode == "LINEAR": inp = add_axis(inp, 3) - if not isinstance(inp.shape, inp.__class__): - if len(inp.shape) != 4: - raise ValueError( - "shape of input tensor must correspond to the operartion mode" - ) + if inp.ndim != 4: + raise ValueError("shape of input tensor must correspond to the operartion mode") if size is None: if scale_factor is None: diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index d1cd1110..ef1e3a76 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -317,7 +317,7 @@ def split(inp, nsplits_or_sections, axis=0): def swapaxis(inp, src, dst): if src == dst: return inp - shape = [i for i in range(len(inp.shape))] + shape = [i for i in range(inp.ndim)] shape[src] = dst shape[dst] = src return inp.transpose(shape) @@ -325,9 +325,11 @@ def split(inp, nsplits_or_sections, axis=0): inp = swapaxis(inp, 0, axis) if isinstance(nsplits_or_sections, int): - incr_step = math.ceil(inp.shape[0] / nsplits_or_sections) - while incr_step < inp.shape[0]: - sections.append(incr_step) + incr_step = ceil(inp.shape[0] / nsplits_or_sections) + nsplits = nsplits_or_sections + while nsplits > 0: + nsplits -= 1 + sections.append(incr_step.astype("int32")) incr_step += nsplits_or_sections else: sections = nsplits_or_sections diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 75d6874d..683103fd 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -19,13 +19,13 @@ def test_abs(): np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), ) - # assertTensorClose(F.abs(-3.0), np.abs(np.float32(-3.0))) + assertTensorClose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0))) def test_multiply(): - # assertTensorClose( - # F.mul(-3.0, -4.0), np.multiply(np.float32(-3.0), np.float32(-4.0)) - # ) + assertTensorClose( + F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0)) + ) assertTensorClose( F.mul(tensor([3.0, 4.0]), 4.0).numpy(), diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 04d9e724..58a582d0 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -194,9 +194,6 @@ def test_matmul(): def test_interpolate(): - if use_tensor_shape(): # XXX: please fix me - return - def linear_interpolate(): inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 8fe8cb8d..72e1fb73 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -125,8 +125,6 @@ def test_stack(): def test_split(): - if use_tensor_shape(): # XXX: please fix me - return data = np.random.random((2, 3, 4, 5)).astype(np.float32) mge_out1 = F.split(tensor(data), 2, axis=3) mge_out2 = F.split(tensor(data), [3, 5], axis=3)