GitOrigin-RevId: 2071bb63a8
tags/v1.0.0
| @@ -15,6 +15,7 @@ from ..core.ops._internal import param_defs as P | |||||
| from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
| from ..core.tensor import utils | from ..core.tensor import utils | ||||
| from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
| from ..core.tensor.utils import astensor1d | |||||
| from ..distributed import WORLD, is_distributed | from ..distributed import WORLD, is_distributed | ||||
| from ..random import uniform | from ..random import uniform | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| @@ -868,7 +869,8 @@ def warp_perspective( | |||||
| imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val | imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val | ||||
| ) | ) | ||||
| inp, M = utils.convert_inputs(inp, M) | inp, M = utils.convert_inputs(inp, M) | ||||
| (result,) = apply(op, inp, M, Tensor(dsize)) | |||||
| dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device) | |||||
| (result,) = apply(op, inp, M, dsize) | |||||
| return result | return result | ||||
| @@ -13,6 +13,7 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
| import megengine.functional as F | |||||
| from megengine import cgtools, tensor | from megengine import cgtools, tensor | ||||
| from megengine.core._trace_option import set_tensor_shape | from megengine.core._trace_option import set_tensor_shape | ||||
| from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
| @@ -261,3 +262,36 @@ def test_trace_reshape(): | |||||
| f(x1) | f(x1) | ||||
| f(x2) | f(x2) | ||||
| f(x3) | f(x3) | ||||
| def test_trace_topk(): | |||||
| x = tensor([5, 2, 7, 1, 0, 3, 2]) | |||||
| @trace(symbolic=True) | |||||
| def f(x): | |||||
| y = F.topk(x, 3) | |||||
| np.testing.assert_equal(y[0].shape.numpy(), np.array([3,])) | |||||
| return y | |||||
| for i in range(3): | |||||
| f(x) | |||||
| def test_trace_warp_perspective(): | |||||
| inp_shape = (1, 1, 4, 4) | |||||
| x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||||
| M_shape = (1, 3, 3) | |||||
| M = tensor( | |||||
| np.array( | |||||
| [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 | |||||
| ).reshape(M_shape) | |||||
| ) | |||||
| @trace(symbolic=True) | |||||
| def f(x, M): | |||||
| out = F.warp_perspective(x, M, (2, 2)) | |||||
| np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) | |||||
| return out | |||||
| for i in range(1): | |||||
| f(x, M) | |||||