Browse Source

refactor(mge/functional): support tensor shape in interpolate and split

GitOrigin-RevId: 6430b64f01
tags/v1.0.0-rc1
Megvii Engine Team 5 years ago
parent
commit
7fadc16d3c
7 changed files with 27 additions and 24 deletions
  1. +3
    -1
      imperative/python/megengine/core/tensor/utils.py
  2. +12
    -5
      imperative/python/megengine/functional/elemwise.py
  3. +2
    -5
      imperative/python/megengine/functional/nn.py
  4. +6
    -4
      imperative/python/megengine/functional/tensor.py
  5. +4
    -4
      imperative/python/test/unit/functional/test_elemwise.py
  6. +0
    -3
      imperative/python/test/unit/functional/test_functional.py
  7. +0
    -2
      imperative/python/test/unit/functional/test_tensor.py

+ 3
- 1
imperative/python/megengine/core/tensor/utils.py View File

@@ -31,7 +31,9 @@ def dtype_promotion(raw_inputs):
]
inputs = [i for i in raw_inputs if hasattr(i, "dtype")]
assert len(scalar_inputs + inputs) > 0
dtype = np.result_type(*inputs)
dtype = None
if len(inputs) > 0:
dtype = np.result_type(*inputs)
dtype_all = np.result_type(*(inputs + scalar_inputs))
assert (
dtype != np.float64 and dtype != np.int64


+ 12
- 5
imperative/python/megengine/functional/elemwise.py View File

@@ -10,8 +10,9 @@
import functools

from ..core.ops import builtin
from ..core.tensor import utils
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply
from ..device import get_default_device
from ..tensor import Tensor

__all__ = [
@@ -76,11 +77,17 @@ __all__ = [

def _elwise(*args, mode):
op = builtin.Elemwise(mode=mode)
tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
)
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
args = tuple(
map(lambda x: x.astype("float32") if hasattr(x, "dtype") else x, args)
)
args = utils.convert_inputs(*args)
args = tuple(map(lambda x: x.astype("float32"), args))
(result,) = apply(op, *args)
return result



+ 2
- 5
imperative/python/megengine/functional/nn.py View File

@@ -1126,11 +1126,8 @@ def interpolate(
if mode == "LINEAR":
inp = add_axis(inp, 3)

if not isinstance(inp.shape, inp.__class__):
if len(inp.shape) != 4:
raise ValueError(
"shape of input tensor must correspond to the operartion mode"
)
if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode")

if size is None:
if scale_factor is None:


+ 6
- 4
imperative/python/megengine/functional/tensor.py View File

@@ -317,7 +317,7 @@ def split(inp, nsplits_or_sections, axis=0):
def swapaxis(inp, src, dst):
if src == dst:
return inp
shape = [i for i in range(len(inp.shape))]
shape = [i for i in range(inp.ndim)]
shape[src] = dst
shape[dst] = src
return inp.transpose(shape)
@@ -325,9 +325,11 @@ def split(inp, nsplits_or_sections, axis=0):
inp = swapaxis(inp, 0, axis)

if isinstance(nsplits_or_sections, int):
incr_step = math.ceil(inp.shape[0] / nsplits_or_sections)
while incr_step < inp.shape[0]:
sections.append(incr_step)
incr_step = ceil(inp.shape[0] / nsplits_or_sections)
nsplits = nsplits_or_sections
while nsplits > 0:
nsplits -= 1
sections.append(incr_step.astype("int32"))
incr_step += nsplits_or_sections
else:
sections = nsplits_or_sections


+ 4
- 4
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -19,13 +19,13 @@ def test_abs():
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)),
)

# assertTensorClose(F.abs(-3.0), np.abs(np.float32(-3.0)))
assertTensorClose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0)))


def test_multiply():
# assertTensorClose(
# F.mul(-3.0, -4.0), np.multiply(np.float32(-3.0), np.float32(-4.0))
# )
assertTensorClose(
F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0))
)

assertTensorClose(
F.mul(tensor([3.0, 4.0]), 4.0).numpy(),


+ 0
- 3
imperative/python/test/unit/functional/test_functional.py View File

@@ -194,9 +194,6 @@ def test_matmul():


def test_interpolate():
if use_tensor_shape(): # XXX: please fix me
return

def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))



+ 0
- 2
imperative/python/test/unit/functional/test_tensor.py View File

@@ -125,8 +125,6 @@ def test_stack():


def test_split():
if use_tensor_shape(): # XXX: please fix me
return
data = np.random.random((2, 3, 4, 5)).astype(np.float32)
mge_out1 = F.split(tensor(data), 2, axis=3)
mge_out2 = F.split(tensor(data), [3, 5], axis=3)


Loading…
Cancel
Save