GitOrigin-RevId: ecc47edab8
tags/v1.5.0
| @@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape | |||
| from ..core._wrap import Device | |||
| from ..core.ops import builtin | |||
| from ..core.tensor.array_method import ArrayMethodMixin | |||
| from ..core.tensor.megbrain_graph import OutputNode | |||
| from .comp_graph_tools import replace_vars | |||
| from .module_stats import ( | |||
| preprocess_receptive_field, | |||
| @@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| return id(self) | |||
| def numpy(self): | |||
| o = OutputNode(self.var) | |||
| self.graph.compile(o.outputs).execute() | |||
| return o.get_value().numpy() | |||
| return super().numpy() | |||
| def _reset(self, other): | |||
| if not isinstance(other, VarNode): | |||
| @@ -141,15 +138,13 @@ class OpNode(NetworkNode): | |||
| @property | |||
| def id(self): | |||
| if self._opr is not None: | |||
| return self._opr.id | |||
| return id(self) | |||
| @property | |||
| def priority(self): | |||
| if self._opr is not None: | |||
| return self._opr.priority | |||
| return 0 | |||
| return (self._opr.priority, self._opr.id) | |||
| return (0, 0) | |||
| @classmethod | |||
| def load(cls, opr): | |||
| @@ -5,6 +5,7 @@ import numpy as np | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine import tensor | |||
| from megengine.core.tensor.megbrain_graph import OutputNode | |||
| from megengine.jit import trace | |||
| from megengine.utils.network_node import VarNode | |||
| @@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode | |||
| def _default_compare_fn(x, y): | |||
| if isinstance(x, np.ndarray): | |||
| np.testing.assert_allclose(x, y, rtol=1e-6) | |||
| else: | |||
| elif isinstance(x, tensor): | |||
| np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
| else: | |||
| np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6) | |||
| def make_tensor(x, network=None, device=None): | |||
| @@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None): | |||
| return tensor(x, device=device) | |||
| def get_var_value(x): | |||
| try: | |||
| o = OutputNode(x.var) | |||
| o.graph.compile(o.outputs).execute() | |||
| return o.get_value().numpy() | |||
| except RuntimeError: | |||
| raise ValueError("value invalid!") | |||
| def opr_test( | |||
| cases, | |||
| func, | |||
| @@ -10,7 +10,7 @@ import copy | |||
| import numpy as np | |||
| import pytest | |||
| from utils import make_tensor | |||
| from utils import get_var_value, make_tensor | |||
| from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | |||
| from megengine.tensor import Parameter, Tensor | |||
| @@ -55,7 +55,12 @@ def test_matmul(is_varnode): | |||
| A = make_tensor(np.random.rand(5, 7).astype("float32"), network) | |||
| B = make_tensor(np.random.rand(7, 10).astype("float32"), network) | |||
| C = A @ B | |||
| np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) | |||
| if is_varnode: | |||
| np.testing.assert_almost_equal( | |||
| get_var_value(C), get_var_value(A) @ get_var_value(B), decimal=6 | |||
| ) | |||
| else: | |||
| np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||
| @@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode): | |||
| x = make_tensor([1, 2, 3], network) | |||
| x[:] = [1, 1, 1] | |||
| np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) | |||
| np.testing.assert_almost_equal( | |||
| get_var_value(x) if is_varnode else x.numpy(), [1, 1, 1], decimal=6 | |||
| ) | |||
| x[[0, 2]] = [3, 2] | |||
| np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6) | |||
| np.testing.assert_almost_equal( | |||
| get_var_value(x) if is_varnode else x.numpy(), [3, 1, 2], decimal=6 | |||
| ) | |||
| x[1:3] = [4, 5] | |||
| np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6) | |||
| np.testing.assert_almost_equal( | |||
| get_var_value(x) if is_varnode else x.numpy(), [3, 4, 5], decimal=6 | |||
| ) | |||
| def test_computing_with_numpy_array(): | |||
| @@ -11,7 +11,7 @@ import platform | |||
| import numpy as np | |||
| import pytest | |||
| from utils import make_tensor, opr_test | |||
| from utils import get_var_value, make_tensor, opr_test | |||
| import megengine.functional as F | |||
| from megengine import tensor | |||
| @@ -75,8 +75,12 @@ def test_condtake(is_varnode): | |||
| xx = make_tensor(x, network) | |||
| yy = make_tensor(y, network) | |||
| val, idx = F.cond_take(yy, xx) | |||
| np.testing.assert_equal(val.numpy(), x[y]) | |||
| np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | |||
| if is_varnode: | |||
| np.testing.assert_equal(get_var_value(val), x[y]) | |||
| np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0]) | |||
| else: | |||
| np.testing.assert_equal(val.numpy(), x[y]) | |||
| np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | |||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||