GitOrigin-RevId: d22e639cd3
tags/v1.5.0
| @@ -140,7 +140,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
| else: | else: | ||||
| if ndim != 0 and ndim != 1: | if ndim != 0 and ndim != 1: | ||||
| raise ValueError("ndim != 1 or 0, get : %d" % ndim) | raise ValueError("ndim != 1 or 0, get : %d" % ndim) | ||||
| if not isinstance(x, Tensor): | |||||
| if not isinstance(x, (Tensor, SymbolVar)): | |||||
| (x,) = Const(x, dtype=dtype, device=device)(*reference) | (x,) = Const(x, dtype=dtype, device=device)(*reference) | ||||
| return x | return x | ||||
| @@ -334,7 +334,7 @@ def split(inp, nsplits_or_sections, axis=0): | |||||
| x = tensor(np.random.random((10, 20)), dtype=np.float32) | x = tensor(np.random.random((10, 20)), dtype=np.float32) | ||||
| y = F.split(x, 3) | y = F.split(x, 3) | ||||
| z = F.split(x, [6, 17], axis=1) | z = F.split(x, [6, 17], axis=1) | ||||
| print([i.numpy().shape for i in y]) | print([i.numpy().shape for i in y]) | ||||
| print([i.numpy().shape for i in z]) | print([i.numpy().shape for i in z]) | ||||
| @@ -686,9 +686,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||||
| [1. 4.] [0 3] | [1. 4.] [0 3] | ||||
| """ | """ | ||||
| if not isinstance(x, Tensor): | |||||
| if not isinstance(x, (Tensor, SymbolVar)): | |||||
| raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
| if not isinstance(mask, Tensor): | |||||
| if not isinstance(mask, (Tensor, SymbolVar)): | |||||
| raise TypeError("mask must be a tensor") | raise TypeError("mask must be a tensor") | ||||
| if mask.dtype != np.bool_: | if mask.dtype != np.bool_: | ||||
| raise ValueError("mask must be bool") | raise ValueError("mask must be bool") | ||||
| @@ -17,6 +17,7 @@ import numpy as np | |||||
| from ..core._imperative_rt import ComputingGraph | from ..core._imperative_rt import ComputingGraph | ||||
| from ..core._imperative_rt.core2 import SymbolVar | from ..core._imperative_rt.core2 import SymbolVar | ||||
| from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | |||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | ||||
| @@ -182,8 +183,13 @@ class Network: | |||||
| """ | """ | ||||
| def _set_var_name(var): | |||||
| graph_var = G.VarNode(var.var) | |||||
| graph_var.name = var.name | |||||
| return graph_var | |||||
| self._compile() | self._compile() | ||||
| out = [G.VarNode(var.var) for var in self.output_vars] | |||||
| out = list(map(_set_var_name, self.output_vars)) | |||||
| if kwargs.pop("arg_names", False): | if kwargs.pop("arg_names", False): | ||||
| logger.warning( | logger.warning( | ||||
| @@ -231,15 +237,20 @@ class Network: | |||||
| if not all([var.owner for var in vars]): | if not all([var.owner for var in vars]): | ||||
| self.add_dep_oprs(*vars) | self.add_dep_oprs(*vars) | ||||
| for var in vars: | for var in vars: | ||||
| if var not in self.output_vars: | |||||
| # use method 'is' instead of 'in' to avoid | |||||
| # compare VarNode use elemwise equal | |||||
| if not any(var is _ for _ in self.output_vars): | |||||
| self.output_vars.append(var) | self.output_vars.append(var) | ||||
| def remove_output(self, *vars: VarNode): | def remove_output(self, *vars: VarNode): | ||||
| """Removes vars from the network output node list. | """Removes vars from the network output node list. | ||||
| """ | """ | ||||
| for var in vars: | for var in vars: | ||||
| if var in self.output_vars: | |||||
| self.output_vars.remove(var) | |||||
| # use list pop instead of remove to avoid | |||||
| # compare VarNode use elemwise equal | |||||
| for idx, out_var in enumerate(self.output_vars): | |||||
| if var is out_var: | |||||
| self.output_vars.pop(idx) | |||||
| def add_dep_oprs(self, *vars): | def add_dep_oprs(self, *vars): | ||||
| if len(vars) == 0: | if len(vars) == 0: | ||||
| @@ -434,6 +445,15 @@ class Network: | |||||
| opnode.add_out_var(self._get_var(var)) | opnode.add_out_var(self._get_var(var)) | ||||
| return opnode | return opnode | ||||
| else: | else: | ||||
| # overwrite the opnode 'new' output VarNode with | |||||
| # original one when output number larger than 1, | |||||
| # or will cause dependence issue in _compiler step. | |||||
| if len(opr.outputs) > 1: | |||||
| opnode = self.all_oprs_map[opr.id] | |||||
| for idx, output in enumerate(opnode.outputs): | |||||
| if output.var.id in self.all_vars_map: | |||||
| opnode.outputs[idx] = self.all_vars_map[output.var.id] | |||||
| return None | return None | ||||
| def _get_opr(self, x): | def _get_opr(self, x): | ||||
| @@ -449,6 +469,15 @@ class Network: | |||||
| return self.all_vars_map[x.id] | return self.all_vars_map[x.id] | ||||
| def set_symbolic_shape(option: bool): | |||||
| """ | |||||
| Set the VarNode use symbolic shape or not, return the last status. | |||||
| Please set to True and must recover after dump if want to change the input batch size. | |||||
| :param option: True for enable symbolic shape. | |||||
| """ | |||||
| return _set_symbolic_shape(option) | |||||
| def as_varnode(obj): | def as_varnode(obj): | ||||
| """convert a :class:`.VarNode` compatible object to :class:`.VarNode`. | """convert a :class:`.VarNode` compatible object to :class:`.VarNode`. | ||||
| @@ -14,7 +14,8 @@ from typing import Callable, Sequence | |||||
| import numpy as np | import numpy as np | ||||
| from ..core import _imperative_rt as rt | from ..core import _imperative_rt as rt | ||||
| from ..core._imperative_rt.core2 import SymbolVar | |||||
| from ..core._imperative_rt.core2 import SymbolVar, apply | |||||
| from ..core._trace_option import use_symbolic_shape | |||||
| from ..core._wrap import Device | from ..core._wrap import Device | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.tensor.array_method import ArrayMethodMixin | from ..core.tensor.array_method import ArrayMethodMixin | ||||
| @@ -53,15 +54,41 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
| obj.owner = owner_opr | obj.owner = owner_opr | ||||
| return obj | return obj | ||||
| def _get_var_shape(self, axis=None): | |||||
| opdef = ( | |||||
| builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis) | |||||
| ) | |||||
| return apply(opdef, self)[0] | |||||
| @property | |||||
| def partial_shape(self): | |||||
| """Return the tuple type inferred shape of VarNode | |||||
| """ | |||||
| return tuple(self._get_var_shape().numpy()) | |||||
| def shapeof(self, axis): | |||||
| """Return the symbolic shape of axis | |||||
| """ | |||||
| return self._get_var_shape(axis=axis) if self.var else None | |||||
| @property | |||||
| def _tuple_shape(self): | |||||
| return self.partial_shape | |||||
| @property | @property | ||||
| def shape(self): | def shape(self): | ||||
| """Return the symbolic shape if using set_symbolic_shape(True) | |||||
| else inferred shape | |||||
| """ | |||||
| rst = None | rst = None | ||||
| if self.var: | if self.var: | ||||
| try: | try: | ||||
| rst = self.var.shape | rst = self.var.shape | ||||
| except: | except: | ||||
| rst = None | rst = None | ||||
| return rst | |||||
| if not use_symbolic_shape(): | |||||
| return rst | |||||
| return self._get_var_shape() if self.var else None | |||||
| @property | @property | ||||
| def dtype(self): | def dtype(self): | ||||
| @@ -78,10 +105,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
| def __hash__(self): | def __hash__(self): | ||||
| return id(self) | return id(self) | ||||
| @property | |||||
| def _tuple_shape(self): | |||||
| return self.var.shape | |||||
| def numpy(self): | def numpy(self): | ||||
| o = OutputNode(self.var) | o = OutputNode(self.var) | ||||
| self.graph.compile(o.outputs).execute() | self.graph.compile(o.outputs).execute() | ||||
| @@ -19,7 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape | |||||
| from megengine.core.tensor import megbrain_graph as G | from megengine.core.tensor import megbrain_graph as G | ||||
| from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
| from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
| from megengine.utils.network import Network | |||||
| from megengine.utils.network import Network, set_symbolic_shape | |||||
| from megengine.utils.network_node import VarNode | from megengine.utils.network_node import VarNode | ||||
| @@ -62,6 +62,22 @@ def test_concat(is_varnode): | |||||
| opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | ||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | |||||
| def test_condtake(is_varnode): | |||||
| if is_varnode: | |||||
| network = Network() | |||||
| else: | |||||
| network = None | |||||
| x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32") | |||||
| y = np.array([[True, False, True], [False, True, True]]) | |||||
| xx = make_tensor(x, network) | |||||
| yy = make_tensor(y, network) | |||||
| val, idx = F.cond_take(yy, xx) | |||||
| np.testing.assert_equal(val.numpy(), x[y]) | |||||
| np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
| def test_concat_device(is_varnode): | def test_concat_device(is_varnode): | ||||
| if is_varnode: | if is_varnode: | ||||
| @@ -102,6 +118,7 @@ def test_stack(is_varnode): | |||||
| def test_split(is_varnode): | def test_split(is_varnode): | ||||
| if is_varnode: | if is_varnode: | ||||
| network = Network() | network = Network() | ||||
| saved_symbolic_shape = set_symbolic_shape(False) | |||||
| else: | else: | ||||
| network = None | network = None | ||||
| @@ -134,6 +151,9 @@ def test_split(is_varnode): | |||||
| except ValueError as e: | except ValueError as e: | ||||
| assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | ||||
| if is_varnode: | |||||
| set_symbolic_shape(saved_symbolic_shape) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
| def test_reshape(is_varnode): | def test_reshape(is_varnode): | ||||
| @@ -161,6 +181,7 @@ def test_reshape(is_varnode): | |||||
| def test_reshape_shape_inference(is_varnode): | def test_reshape_shape_inference(is_varnode): | ||||
| if is_varnode: | if is_varnode: | ||||
| network = Network() | network = Network() | ||||
| saved_symbolic_shape = set_symbolic_shape(False) | |||||
| else: | else: | ||||
| network = None | network = None | ||||
| @@ -192,12 +213,15 @@ def test_reshape_shape_inference(is_varnode): | |||||
| {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | ||||
| ] | ] | ||||
| opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) | opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) | ||||
| if is_varnode: | |||||
| set_symbolic_shape(saved_symbolic_shape) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
| def test_squeeze(is_varnode): | def test_squeeze(is_varnode): | ||||
| if is_varnode: | if is_varnode: | ||||
| network = Network() | network = Network() | ||||
| saved_symbolic_shape = set_symbolic_shape(False) | |||||
| else: | else: | ||||
| network = None | network = None | ||||
| @@ -209,6 +233,9 @@ def test_squeeze(is_varnode): | |||||
| yy = F.squeeze(xx, axis) | yy = F.squeeze(xx, axis) | ||||
| np.testing.assert_equal(y, yy.numpy()) | np.testing.assert_equal(y, yy.numpy()) | ||||
| if is_varnode: | |||||
| set_symbolic_shape(saved_symbolic_shape) | |||||
| @pytest.mark.parametrize("is_varnode", [True, False]) | @pytest.mark.parametrize("is_varnode", [True, False]) | ||||
| def test_expand_dims(is_varnode): | def test_expand_dims(is_varnode): | ||||
| @@ -358,7 +385,7 @@ def test_flatten(is_varnode): | |||||
| data1 = np.random.random(data1_shape).astype(np.float32) | data1 = np.random.random(data1_shape).astype(np.float32) | ||||
| def compare_fn(x, y): | def compare_fn(x, y): | ||||
| assert x.shape[0] == y | |||||
| assert x._tuple_shape[0] == y | |||||
| output0 = (2 * 3 * 4 * 5,) | output0 = (2 * 3 * 4 * 5,) | ||||
| output1 = (4 * 5 * 6 * 7,) | output1 = (4 * 5 * 6 * 7,) | ||||
| @@ -420,7 +447,7 @@ def test_broadcast(is_varnode): | |||||
| data3 = np.random.random(input3_shape).astype(np.float32) | data3 = np.random.random(input3_shape).astype(np.float32) | ||||
| def compare_fn(x, y): | def compare_fn(x, y): | ||||
| assert x.shape[0] == y | |||||
| assert x._tuple_shape[0] == y | |||||
| cases = [ | cases = [ | ||||
| {"input": [data1, output1_shape], "output": output1_shape}, | {"input": [data1, output1_shape], "output": output1_shape}, | ||||
| @@ -10,7 +10,7 @@ from megengine.jit.tracing import trace | |||||
| from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
| from megengine.utils.comp_graph_tools import GraphInference | from megengine.utils.comp_graph_tools import GraphInference | ||||
| from megengine.utils.network import Network as Net | from megengine.utils.network import Network as Net | ||||
| from megengine.utils.network import as_oprnode | |||||
| from megengine.utils.network import as_oprnode, set_symbolic_shape | |||||
| from megengine.utils.network_node import Host2DeviceCopy, VarNode | from megengine.utils.network_node import Host2DeviceCopy, VarNode | ||||
| @@ -181,19 +181,22 @@ def test_add_input(): | |||||
| np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy()) | np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy()) | ||||
| def test_add_output(): | |||||
| def test_add_remove_output(): | |||||
| a = Tensor([1.0, 2.0]) | a = Tensor([1.0, 2.0]) | ||||
| b = Tensor([3.0, 4.0]) | b = Tensor([3.0, 4.0]) | ||||
| @trace(symbolic=True, capture_as_const=True) | @trace(symbolic=True, capture_as_const=True) | ||||
| def fwd(a, b): | def fwd(a, b): | ||||
| return (a + b) * 2 | |||||
| return (a + b) * 2, (a - b) | |||||
| fwd(a, b) | fwd(a, b) | ||||
| orig_model = io.BytesIO() | orig_model = io.BytesIO() | ||||
| fwd.dump( | fwd.dump( | ||||
| orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||||
| orig_model, | |||||
| arg_names=["a", "b"], | |||||
| output_names=["o1", "o2"], | |||||
| optimize_for_inference=False, | |||||
| ) | ) | ||||
| orig_model.seek(0) | orig_model.seek(0) | ||||
| @@ -201,11 +204,13 @@ def test_add_output(): | |||||
| var_a = net.var_filter.name("a").as_unique() | var_a = net.var_filter.name("a").as_unique() | ||||
| var_b = net.var_filter.name("b").as_unique() | var_b = net.var_filter.name("b").as_unique() | ||||
| y = F.add(var_a, var_b) | |||||
| y = F.sigmoid(y) | |||||
| y1 = (var_a + var_b) * 3 | |||||
| y2 = F.sigmoid(var_a + var_b) | |||||
| y.name = "o1" | |||||
| net.add_output(y) | |||||
| net.remove_output(*net.output_vars) | |||||
| y1.name = "new_o1" | |||||
| y2.name = "new_o2" | |||||
| net.add_output(y1, y2) | |||||
| modified_model = io.BytesIO() | modified_model = io.BytesIO() | ||||
| net.dump(modified_model) | net.dump(modified_model) | ||||
| @@ -214,8 +219,8 @@ def test_add_output(): | |||||
| g = GraphInference(modified_model) | g = GraphInference(modified_model) | ||||
| out = g.run(a.numpy(), b.numpy()) | out = g.run(a.numpy(), b.numpy()) | ||||
| np.testing.assert_equal(out["o"], ((a + b) * 2).numpy()) | |||||
| np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy()) | |||||
| np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy()) | |||||
| np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy()) | |||||
| def test_query(): | def test_query(): | ||||
| @@ -343,3 +348,68 @@ def test_modify_opr_name(): | |||||
| net1 = Net.load(modified_model) | net1 = Net.load(modified_model) | ||||
| assert net1.data_providers_filter.as_unique().name == "net1.net.a" | assert net1.data_providers_filter.as_unique().name == "net1.net.a" | ||||
| def test_dump_cond_take(): | |||||
| a = Tensor([1.0, 2.0]) | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def fwd(a): | |||||
| return F.cond_take(a > 1, a) | |||||
| fwd(a) | |||||
| orig_model = io.BytesIO() | |||||
| fwd.dump( | |||||
| orig_model, | |||||
| arg_names=["a"], | |||||
| output_names=["o1", "o2"], | |||||
| optimize_for_inference=False, | |||||
| ) | |||||
| orig_model.seek(0) | |||||
| net = Net.load(orig_model) | |||||
| var_a = net.input_vars[0] | |||||
| val, idx = F.cond_take(var_a > 1, var_a) | |||||
| net.remove_output(*net.output_vars) | |||||
| val.name = "value" | |||||
| idx.name = "index" | |||||
| net.add_output(val, idx) | |||||
| modified_model = io.BytesIO() | |||||
| net.dump(modified_model) | |||||
| modified_model.seek(0) | |||||
| g = GraphInference(modified_model) | |||||
| out = g.run(a.numpy()) | |||||
| data = a.numpy() | |||||
| mask = a.numpy() > 1 | |||||
| np.testing.assert_equal(out["index"], np.where(mask.reshape(-1))[0]) | |||||
| np.testing.assert_equal(out["value"], data[mask]) | |||||
| def test_set_symbolic_shape(): | |||||
| a = Tensor([1.0, 2.0]) | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def fwd(a): | |||||
| return F.relu(a * 2) | |||||
| fwd(a) | |||||
| orig_model = io.BytesIO() | |||||
| fwd.dump( | |||||
| orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False, | |||||
| ) | |||||
| orig_model.seek(0) | |||||
| net = Net.load(orig_model) | |||||
| var_a = net.input_vars[0] | |||||
| saved_symbolic_shape = set_symbolic_shape(True) | |||||
| assert isinstance(var_a.shape, VarNode) | |||||
| set_symbolic_shape(False) | |||||
| assert var_a.shape == var_a.partial_shape | |||||
| set_symbolic_shape(saved_symbolic_shape) | |||||