|
|
@@ -12,14 +12,16 @@ import itertools |
|
|
import pickle |
|
|
import pickle |
|
|
import re |
|
|
import re |
|
|
from collections import OrderedDict |
|
|
from collections import OrderedDict |
|
|
from typing import Any, Dict, List, Sequence |
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence |
|
|
|
|
|
|
|
|
|
|
|
from ..core import _imperative_rt |
|
|
from ..core._imperative_rt import ComputingGraph, SerializationMetadata |
|
|
from ..core._imperative_rt import ComputingGraph, SerializationMetadata |
|
|
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape |
|
|
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape |
|
|
from ..core.tensor import megbrain_graph as G |
|
|
from ..core.tensor import megbrain_graph as G |
|
|
from ..logger import get_logger |
|
|
from ..logger import get_logger |
|
|
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq |
|
|
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq |
|
|
from .network_node import ( |
|
|
from .network_node import ( |
|
|
|
|
|
ConstOpBase, |
|
|
Host2DeviceCopy, |
|
|
Host2DeviceCopy, |
|
|
ImmutableTensor, |
|
|
ImmutableTensor, |
|
|
NetworkNode, |
|
|
NetworkNode, |
|
|
@@ -37,8 +39,10 @@ class Network: |
|
|
self._orig_inputs = [] |
|
|
self._orig_inputs = [] |
|
|
self.output_vars = [] # output var of graph |
|
|
self.output_vars = [] # output var of graph |
|
|
self._orig_outputs = [] |
|
|
self._orig_outputs = [] |
|
|
self.all_oprs_map = OrderedDict() |
|
|
|
|
|
self.all_vars_map = OrderedDict() |
|
|
|
|
|
|
|
|
self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode |
|
|
|
|
|
self.all_vars_map = ( |
|
|
|
|
|
OrderedDict() |
|
|
|
|
|
) # _imperative_rt.graph.OperatorNode.id: OpNode |
|
|
self.graph = ComputingGraph() |
|
|
self.graph = ComputingGraph() |
|
|
self._metadata = None |
|
|
self._metadata = None |
|
|
|
|
|
|
|
|
@@ -101,7 +105,7 @@ class Network: |
|
|
self.all_oprs_map = {} |
|
|
self.all_oprs_map = {} |
|
|
self.all_vars_map = {} |
|
|
self.all_vars_map = {} |
|
|
for opr in self.all_oprs: |
|
|
for opr in self.all_oprs: |
|
|
if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): |
|
|
|
|
|
|
|
|
if isinstance(opr, (ConstOpBase, Host2DeviceCopy)): |
|
|
opr.compile(self.graph) |
|
|
opr.compile(self.graph) |
|
|
else: |
|
|
else: |
|
|
opr.compile() |
|
|
opr.compile() |
|
|
@@ -295,6 +299,9 @@ class Network: |
|
|
def add_dep_oprs(self, *vars): |
|
|
def add_dep_oprs(self, *vars): |
|
|
if len(vars) == 0: |
|
|
if len(vars) == 0: |
|
|
vars = self.output_vars |
|
|
vars = self.output_vars |
|
|
|
|
|
|
|
|
|
|
|
assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode" |
|
|
|
|
|
|
|
|
q = list(vars) |
|
|
q = list(vars) |
|
|
while len(q) > 0: |
|
|
while len(q) > 0: |
|
|
cur = q.pop(0) |
|
|
cur = q.pop(0) |
|
|
@@ -368,11 +375,14 @@ class Network: |
|
|
for var in self.all_vars: |
|
|
for var in self.all_vars: |
|
|
if var in repl_dict: |
|
|
if var in repl_dict: |
|
|
repl_var = repl_dict[var] |
|
|
repl_var = repl_dict[var] |
|
|
owner = repl_var.owner |
|
|
|
|
|
idx = owner.outputs.index(repl_var) |
|
|
|
|
|
owner.outputs[idx] = var |
|
|
|
|
|
var.__dict__.update(repl_var.__dict__) |
|
|
|
|
|
var.var = repl_var.var |
|
|
|
|
|
|
|
|
if repl_var is var: |
|
|
|
|
|
continue |
|
|
|
|
|
for opnode in var.users: |
|
|
|
|
|
assert var in opnode.inputs |
|
|
|
|
|
opnode.inputs = [repl_var if var is i else i for i in opnode.inputs] |
|
|
|
|
|
if opnode not in repl_var.users: |
|
|
|
|
|
repl_var.users.append(opnode) |
|
|
|
|
|
var.users.clear() |
|
|
self._compile() |
|
|
self._compile() |
|
|
|
|
|
|
|
|
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): |
|
|
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): |
|
|
@@ -473,14 +483,20 @@ class Network: |
|
|
def all_oprs_dict(self): |
|
|
def all_oprs_dict(self): |
|
|
return self.opr_filter.as_dict() |
|
|
return self.opr_filter.as_dict() |
|
|
|
|
|
|
|
|
# used for loading and building graph |
|
|
|
|
|
def _add_opr(self, opr): |
|
|
|
|
|
|
|
|
def _add_opr(self, opr) -> Optional[OpNode]: |
|
|
|
|
|
""" |
|
|
|
|
|
Used for loading and building graph. |
|
|
|
|
|
""" |
|
|
|
|
|
assert isinstance(opr, _imperative_rt.graph.OperatorNode) |
|
|
|
|
|
|
|
|
# TODO: use megbrain C++ RTTI to replace type string |
|
|
# TODO: use megbrain C++ RTTI to replace type string |
|
|
if opr.id not in self.all_oprs_map: |
|
|
if opr.id not in self.all_oprs_map: |
|
|
opnode = str_to_mge_class(get_opr_type(opr)).load(opr) |
|
|
opnode = str_to_mge_class(get_opr_type(opr)).load(opr) |
|
|
self.all_oprs_map[opr.id] = opnode |
|
|
self.all_oprs_map[opr.id] = opnode |
|
|
for var in opr.inputs: |
|
|
for var in opr.inputs: |
|
|
opnode.add_inp_var(self._get_var(var)) |
|
|
|
|
|
|
|
|
varnode = self._get_var(var) |
|
|
|
|
|
opnode.add_inp_var(varnode) |
|
|
|
|
|
varnode.users.append(opnode) |
|
|
for var in opr.outputs: |
|
|
for var in opr.outputs: |
|
|
opnode.add_out_var(self._get_var(var)) |
|
|
opnode.add_out_var(self._get_var(var)) |
|
|
return opnode |
|
|
return opnode |
|
|
@@ -503,7 +519,10 @@ class Network: |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def _get_var(self, x): |
|
|
def _get_var(self, x): |
|
|
# auto convert to VarNode of Network |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`. |
|
|
|
|
|
""" |
|
|
|
|
|
assert isinstance(x, _imperative_rt.graph.VarNode) |
|
|
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: |
|
|
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: |
|
|
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) |
|
|
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) |
|
|
return self.all_vars_map[x.id] |
|
|
return self.all_vars_map[x.id] |
|
|
|