GitOrigin-RevId: 614302552c
tags/v1.5.0
| @@ -519,8 +519,7 @@ def _unwrap(x): | |||
| return type(x)(map(_unwrap, x)) | |||
| if isinstance(x, VarNode): | |||
| return x._node | |||
| else: | |||
| return x | |||
| return x | |||
| def apply_normal_varnode(op: OpDef, *args: VarNode): | |||
| @@ -12,14 +12,16 @@ import itertools | |||
| import pickle | |||
| import re | |||
| from collections import OrderedDict | |||
| from typing import Any, Dict, List, Sequence | |||
| from typing import Any, Dict, List, Optional, Sequence | |||
| from ..core import _imperative_rt | |||
| from ..core._imperative_rt import ComputingGraph, SerializationMetadata | |||
| from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | |||
| from ..core.tensor import megbrain_graph as G | |||
| from ..logger import get_logger | |||
| from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
| from .network_node import ( | |||
| ConstOpBase, | |||
| Host2DeviceCopy, | |||
| ImmutableTensor, | |||
| NetworkNode, | |||
| @@ -37,8 +39,10 @@ class Network: | |||
| self._orig_inputs = [] | |||
| self.output_vars = [] # output var of graph | |||
| self._orig_outputs = [] | |||
| self.all_oprs_map = OrderedDict() | |||
| self.all_vars_map = OrderedDict() | |||
| self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode | |||
| self.all_vars_map = ( | |||
| OrderedDict() | |||
| ) # _imperative_rt.graph.OperatorNode.id: OpNode | |||
| self.graph = ComputingGraph() | |||
| self._metadata = None | |||
| @@ -101,7 +105,7 @@ class Network: | |||
| self.all_oprs_map = {} | |||
| self.all_vars_map = {} | |||
| for opr in self.all_oprs: | |||
| if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): | |||
| if isinstance(opr, (ConstOpBase, Host2DeviceCopy)): | |||
| opr.compile(self.graph) | |||
| else: | |||
| opr.compile() | |||
| @@ -295,6 +299,9 @@ class Network: | |||
| def add_dep_oprs(self, *vars): | |||
| if len(vars) == 0: | |||
| vars = self.output_vars | |||
| assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode" | |||
| q = list(vars) | |||
| while len(q) > 0: | |||
| cur = q.pop(0) | |||
| @@ -368,11 +375,14 @@ class Network: | |||
| for var in self.all_vars: | |||
| if var in repl_dict: | |||
| repl_var = repl_dict[var] | |||
| owner = repl_var.owner | |||
| idx = owner.outputs.index(repl_var) | |||
| owner.outputs[idx] = var | |||
| var.__dict__.update(repl_var.__dict__) | |||
| var.var = repl_var.var | |||
| if repl_var is var: | |||
| continue | |||
| for opnode in var.users: | |||
| assert var in opnode.inputs | |||
| opnode.inputs = [repl_var if var is i else i for i in opnode.inputs] | |||
| if opnode not in repl_var.users: | |||
| repl_var.users.append(opnode) | |||
| var.users.clear() | |||
| self._compile() | |||
| def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
| @@ -473,14 +483,20 @@ class Network: | |||
| def all_oprs_dict(self): | |||
| return self.opr_filter.as_dict() | |||
| # used for loading and building graph | |||
| def _add_opr(self, opr): | |||
| def _add_opr(self, opr) -> Optional[OpNode]: | |||
| """ | |||
| Used for loading and building graph. | |||
| """ | |||
| assert isinstance(opr, _imperative_rt.graph.OperatorNode) | |||
| # TODO: use megbrain C++ RTTI to replace type string | |||
| if opr.id not in self.all_oprs_map: | |||
| opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | |||
| self.all_oprs_map[opr.id] = opnode | |||
| for var in opr.inputs: | |||
| opnode.add_inp_var(self._get_var(var)) | |||
| varnode = self._get_var(var) | |||
| opnode.add_inp_var(varnode) | |||
| varnode.users.append(opnode) | |||
| for var in opr.outputs: | |||
| opnode.add_out_var(self._get_var(var)) | |||
| return opnode | |||
| @@ -503,7 +519,10 @@ class Network: | |||
| return None | |||
| def _get_var(self, x): | |||
| # auto convert to VarNode of Network | |||
| """ | |||
| Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`. | |||
| """ | |||
| assert isinstance(x, _imperative_rt.graph.VarNode) | |||
| if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | |||
| self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | |||
| return self.all_vars_map[x.id] | |||
| @@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||
| class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| def __init__(self, var=None, *, owner_opr=None, name=None): | |||
| SymbolVar.__init__(self, var) | |||
| self.users = [] # List[OpNode] | |||
| self.owner = owner_opr | |||
| self.name = name | |||
| self.id = id(self) | |||
| @@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode): | |||
| def compile(self, graph): | |||
| if ( | |||
| self._opr is None | |||
| or self._opr.graph != graph | |||
| or self._opr.outputs[0].comp_node != self.device | |||
| or self._opr.outputs[0].shape != self.shape | |||
| or self._opr.outputs[0].dtype != self.dtype | |||
| @@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode): | |||
| assert self.outputs[0].owner is self | |||
| class ImmutableTensor(OpNode): | |||
| type = "ImmutableTensor" | |||
| class ConstOpBase(OpNode): | |||
| type = "ConstOpBase" | |||
| def __init__(self, data=None, name=None, device=None, graph=None): | |||
| assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated" | |||
| super().__init__() | |||
| self.name = name | |||
| self.outputs = [] | |||
| @@ -254,7 +257,7 @@ class ImmutableTensor(OpNode): | |||
| return self._opr.outputs[0].dtype if self._opr else None | |||
| def numpy(self): | |||
| return self._opr.outputs[0].value if self._opr else None | |||
| return self.outputs[0].numpy() | |||
| def set_value(self, data, device=None): | |||
| assert self.graph is not None | |||
| @@ -266,7 +269,7 @@ class ImmutableTensor(OpNode): | |||
| data = data.astype(np.float32) | |||
| elif data.dtype == np.int64: | |||
| data = data.astype(np.int32) | |||
| varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||
| varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = varnode | |||
| @@ -291,6 +294,16 @@ class ImmutableTensor(OpNode): | |||
| self.outputs[0].var.name = self.name | |||
| class ImmutableTensor(ConstOpBase): | |||
| type = "ImmutableTensor" | |||
| rt_fun = rt.make_const | |||
| class SharedDeviceTensor(ConstOpBase): | |||
| type = "SharedDeviceTensor" | |||
| rt_fun = rt.make_shared | |||
| class ReadOnlyOpNode(OpNode): | |||
| @classmethod | |||
| def load(cls, opr): | |||
| @@ -130,6 +130,52 @@ def test_replace_opr(): | |||
| np.testing.assert_equal(out["o"], [0, 0]) | |||
| def test_splice_network(): | |||
| x = F.ones((2,)) | |||
| y = F.ones((2,)) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fun1(a, b): | |||
| return (a + b) * 2 | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fun2(a): | |||
| return a * 2 - 1 | |||
| model = io.BytesIO() | |||
| fun1(x, y) | |||
| fun2(x) | |||
| fun1.dump( | |||
| model, | |||
| arg_names=["net1_i0", "net1_i1"], | |||
| output_names=["net1_o0"], | |||
| optimize_for_inference=False, | |||
| ) | |||
| model.seek(0) | |||
| net1 = Net.load(model) | |||
| model.seek(0) | |||
| fun2.dump( | |||
| model, | |||
| arg_names=["net2_i0"], | |||
| output_names=["net2_o0"], | |||
| optimize_for_inference=False, | |||
| ) | |||
| model.seek(0) | |||
| net2 = Net.load(model) | |||
| net1.add_output(*net2.output_vars) | |||
| var = net1.var_filter.name("net1_i0").as_unique() | |||
| repl_var = net2.var_filter.name("net2_o0").as_unique() | |||
| net1.replace_vars({var: repl_var}) | |||
| assert "net1_i0" not in [var.name for var in net1.all_vars] | |||
| assert "net2_i0" in [var.name for var in net1.all_vars] | |||
| model.seek(0) | |||
| net1.dump(model, keep_var_name=2, optimize_for_inference=False) | |||
| model.seek(0) | |||
| net = Net.load(model) | |||
| assert "net1_i0" not in [var.name for var in net.all_vars] | |||
| assert "net2_i0" in [var.name for var in net.all_vars] | |||
| def test_modify_params(): | |||
| a = Tensor([1, 2]) | |||