GitOrigin-RevId: 47a26dd6dd
tags/v1.5.0
| @@ -893,6 +893,10 @@ class trace: | |||
| if isinstance(file, str): | |||
| permission = "wb" if append == False else "ab" | |||
| file = open(file, permission) | |||
| if keep_opr_priority: | |||
| graph._set_priority_to_id(dest_vars) | |||
| dump_content, dump_info = G.dump_graph( | |||
| dest_vars, | |||
| keep_var_name=keep_var_name, | |||
| @@ -6,6 +6,7 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import collections | |||
| import heapq | |||
| from collections import OrderedDict | |||
| from typing import Dict, List, Tuple, Union | |||
| @@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str: | |||
| return opr.type | |||
| class _OprStableOrderHeapq: | |||
| """heap implementation for operator comparison in stable order""" | |||
| _list = None | |||
| _extra_priority = None | |||
| _used_id_name_pairs = None | |||
| def __init__(self, extra_priority): | |||
| assert isinstance(extra_priority, collections.Callable) | |||
| self._list = [] | |||
| self._extra_priority = extra_priority | |||
| self._used_id_name_pairs = {} | |||
| def pop_min(self): | |||
| return heapq.heappop(self._list)[-1] | |||
| def add(self, opr): | |||
| # named as add to mimic set() interface | |||
| id_ = opr.id | |||
| name = opr.name | |||
| other = self._used_id_name_pairs.setdefault((id_, name), opr) | |||
| if other is not opr: | |||
| raise RuntimeError( | |||
| "duplicated (id, name) pair: opr0={} opr1={}".format(other, opr) | |||
| ) | |||
| item = self._extra_priority(opr) + (id_, name, opr) | |||
| heapq.heappush(self._list, item) | |||
| def __bool__(self): | |||
| return bool(self._list) | |||
| def graph_traversal(outputs: _VarNode): | |||
| """ | |||
| Helper function to traverse the computing graph and return enough useful information. | |||
| @@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode): | |||
| var2oprs = collections.defaultdict(list) | |||
| opr2receivers = collections.defaultdict(list) | |||
| queue = list(set(map(lambda x: x.owner, outputs))) | |||
| queue = [] | |||
| [queue.append(o) for o in [x.owner for x in outputs] if o not in queue] | |||
| visited = set(map(lambda x: x.id, queue)) | |||
| # iterate through whole comp_graph, fill in meta information | |||
| indegree2opr = collections.defaultdict(set) | |||
| indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,)) | |||
| opr2indegree = {} | |||
| idx = 0 | |||
| @@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode): | |||
| indegree += 1 | |||
| opr2receivers[pre_opr.id].append(cur_opr.id) | |||
| indegree2opr[indegree].add(cur_opr.id) | |||
| opr = cur_opr if indegree == 0 else cur_opr.id | |||
| indegree2opr[indegree].add(opr) | |||
| opr2indegree[cur_opr.id] = indegree | |||
| return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | |||
| @@ -162,8 +199,8 @@ def get_oprs_seq( | |||
| oprs_seq = [] | |||
| nr_remain = len(map_oprs) | |||
| while indegree2opr[0]: | |||
| opr_id = indegree2opr[0].pop() | |||
| opr = map_oprs[opr_id] | |||
| opr = indegree2opr[0].pop_min() | |||
| opr_id = opr.id | |||
| nr_remain -= 1 | |||
| if opr.type != "ImmutableTensor" or not prune_immtensor: | |||
| oprs_seq.append(opr) | |||
| @@ -173,7 +210,10 @@ def get_oprs_seq( | |||
| indegree2opr[indegree].remove(post_id) | |||
| indegree -= 1 | |||
| indegree2opr[indegree].add(post_id) | |||
| if indegree == 0: | |||
| indegree2opr[indegree].add(map_oprs[post_id]) | |||
| else: | |||
| indegree2opr[indegree].add(post_id) | |||
| opr2indegree[post_id] = indegree | |||
| assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | |||
| @@ -213,10 +253,34 @@ def get_oprs_seq( | |||
| # filter out all marked oprs | |||
| return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | |||
| # adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs. | |||
| def reorder_oprs_seq(oprs): | |||
| rst = [] | |||
| param_or_data_provider_oprs = [] | |||
| other_oprs = [] | |||
| for o in oprs: | |||
| if o.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||
| param_or_data_provider_oprs.append(o) | |||
| else: | |||
| other_oprs.append(o) | |||
| for o in other_oprs: | |||
| for inp in o.inputs: | |||
| if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||
| if inp.owner in param_or_data_provider_oprs: | |||
| rst.append(inp.owner) | |||
| param_or_data_provider_oprs.remove(inp.owner) | |||
| rst.append(o) | |||
| rst = rst + param_or_data_provider_oprs | |||
| assert len(rst) == len(oprs) | |||
| return rst | |||
| map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | |||
| outputs | |||
| ) | |||
| oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | |||
| oprs_seq = reorder_oprs_seq(oprs_seq) | |||
| if prune_reshape is True: | |||
| oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | |||
| return oprs_seq | |||
| @@ -241,6 +241,7 @@ class Network: | |||
| if optimize_for_inference: | |||
| metadata.optimize_options = optimize_options | |||
| G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out]) | |||
| dump_content, _ = G.dump_graph( | |||
| out, | |||
| keep_var_name=keep_var_name, | |||
| @@ -353,7 +354,7 @@ class Network: | |||
| ) | |||
| shp[0] = batchsize | |||
| i.shape = tuple(shp) | |||
| self._compile() | |||
| assert prev_batchsize is not None, "no data provider found" | |||
| assert not blacklist, "unused items in blacklist: {}".format(blacklist) | |||
| @@ -363,7 +364,6 @@ class Network: | |||
| :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||
| """ | |||
| if not all([var.owner for var in repl_dict.values()]): | |||
| print(repl_dict.values()) | |||
| self.add_dep_oprs(*list(repl_dict.values())) | |||
| for var in self.all_vars: | |||
| if var in repl_dict: | |||
| @@ -373,6 +373,7 @@ class Network: | |||
| owner.outputs[idx] = var | |||
| var.__dict__.update(repl_var.__dict__) | |||
| var.var = repl_var.var | |||
| self._compile() | |||
| def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
| """ | |||
| @@ -384,11 +385,11 @@ class Network: | |||
| assert len(opr.outputs) == len( | |||
| repl_dict[opr].outputs | |||
| ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | |||
| repl_dict[opr].outputs = opr.outputs | |||
| for ind, var in enumerate(opr.outputs): | |||
| var.owner = repl_dict[opr] | |||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
| var.var = repl_dict[opr].outputs[ind].var | |||
| self._compile() | |||
| def get_opr_by_type(self, oprcls, unique=True): | |||
| assert issubclass(oprcls, OpNode) | |||
| @@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
| def dtype(self): | |||
| return self.var.dtype if self.var else None | |||
| @property | |||
| def ndim(self): | |||
| return super().ndim | |||
| def __bool__(self): | |||
| return False | |||
| @@ -134,7 +138,18 @@ class OpNode(NetworkNode): | |||
| self.outputs = [] | |||
| self.params = {} | |||
| self._opr = None # mgb opnode | |||
| self.id = id(self) | |||
| @property | |||
| def id(self): | |||
| if self._opr is not None: | |||
| return self._opr.id | |||
| return id(self) | |||
| @property | |||
| def priority(self): | |||
| if self._opr is not None: | |||
| return self._opr.priority | |||
| return 0 | |||
| @classmethod | |||
| def load(cls, opr): | |||
| @@ -144,16 +159,21 @@ class OpNode(NetworkNode): | |||
| obj._opr = opr | |||
| return obj | |||
| def compile(self, graph=None): | |||
| op = self.opdef(**self.params) | |||
| args = [i.var for i in self.inputs] | |||
| outputs = rt.invoke_op(op, args) | |||
| assert len(outputs) == len(self.outputs) | |||
| self._opr = outputs[0].owner | |||
| for i in range(len(self.outputs)): | |||
| self.outputs[i].var = outputs[i] | |||
| self.outputs[i].var.name = self.outputs[i].name | |||
| assert self.outputs[i].owner is self | |||
| def compile(self): | |||
| if ( | |||
| self._opr is None | |||
| or len(self._opr.inputs) != len(self.inputs) | |||
| or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)]) | |||
| ): | |||
| op = self.opdef(**self.params) | |||
| args = [i.var for i in self.inputs] | |||
| outputs = rt.invoke_op(op, args) | |||
| assert len(outputs) == len(self.outputs) | |||
| self._opr = outputs[0].owner | |||
| for i in range(len(self.outputs)): | |||
| self.outputs[i].var = outputs[i] | |||
| self.outputs[i].var.name = self.outputs[i].name | |||
| assert self.outputs[i].owner is self | |||
| def add_inp_var(self, x): | |||
| self.inputs.append(x) | |||
| @@ -197,11 +217,17 @@ class Host2DeviceCopy(OpNode): | |||
| return self | |||
| def compile(self, graph): | |||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
| self._opr = outputs.owner | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = outputs | |||
| if ( | |||
| self._opr is None | |||
| or self._opr.outputs[0].comp_node != self.device | |||
| or self._opr.outputs[0].shape != self.shape | |||
| or self._opr.outputs[0].dtype != self.dtype | |||
| ): | |||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
| self._opr = outputs.owner | |||
| if len(self.outputs) == 0: | |||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
| self.outputs[0].var = outputs | |||
| assert self.outputs[0].owner is self | |||
| @@ -192,7 +192,14 @@ void init_graph_rt(py::module m) { | |||
| }) | |||
| .def("__repr__", [](cg::OperatorNodeBase* opr){ | |||
| return "Opr:" + opr->name(); | |||
| }); | |||
| }) | |||
| .def_property("priority", | |||
| [](cg::OperatorNodeBase* opr) { | |||
| return opr->node_prop().attribute().priority; | |||
| }, | |||
| [](cg::OperatorNodeBase* opr, int priority) { | |||
| opr->node_prop().attribute().priority = priority; | |||
| }); | |||
| py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | |||
| .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) | |||
| @@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph | |||
| from megengine.core.tensor.megbrain_graph import apply_normal_varnode | |||
| from megengine.core.tensor.utils import astensor1d | |||
| from megengine.jit import trace | |||
| from megengine.utils.network import Network | |||
| def make_dev_tensor(value, dtype=None, device=None): | |||
| @@ -143,6 +144,46 @@ def test_get_opr_seq(): | |||
| assert len(seq_2) == 6 | |||
| def test_topological_sort(): | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def func(x, y): | |||
| a = x + y | |||
| a1 = F.relu(a) | |||
| a2 = F.abs(a) | |||
| a3 = F.ceil(a) * 2 | |||
| a4 = F.floor(a) | |||
| r = a1 - a2 | |||
| r1 = a3 / a4 | |||
| return r, r1 | |||
| file = io.BytesIO() | |||
| func(megengine.tensor(1.0), megengine.tensor(2.0)) | |||
| func.dump( | |||
| file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True | |||
| ) | |||
| file.seek(0) | |||
| g = Network.load(file) | |||
| oprseq1 = g.all_oprs | |||
| gt = [ | |||
| "Host2DeviceCopy", | |||
| "Host2DeviceCopy", | |||
| "ADD", | |||
| "RELU", | |||
| "ABS", | |||
| "CEIL", | |||
| "ImmutableTensor", | |||
| "MUL", | |||
| "FLOOR", | |||
| "SUB", | |||
| "TRUE_DIV", | |||
| ] | |||
| for op, mode in zip(oprseq1, gt): | |||
| if op.type == "Elemwise": | |||
| assert op.params["mode"] == mode | |||
| else: | |||
| assert op.type == mode | |||
| def test_graph_function(): | |||
| class Net(M.Module): | |||
| def forward(self, a, b): | |||