GitOrigin-RevId: 47a26dd6dd
tags/v1.5.0
| @@ -893,6 +893,10 @@ class trace: | |||||
| if isinstance(file, str): | if isinstance(file, str): | ||||
| permission = "wb" if append == False else "ab" | permission = "wb" if append == False else "ab" | ||||
| file = open(file, permission) | file = open(file, permission) | ||||
| if keep_opr_priority: | |||||
| graph._set_priority_to_id(dest_vars) | |||||
| dump_content, dump_info = G.dump_graph( | dump_content, dump_info = G.dump_graph( | ||||
| dest_vars, | dest_vars, | ||||
| keep_var_name=keep_var_name, | keep_var_name=keep_var_name, | ||||
| @@ -6,6 +6,7 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import collections | import collections | ||||
| import heapq | |||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from typing import Dict, List, Tuple, Union | from typing import Dict, List, Tuple, Union | ||||
| @@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str: | |||||
| return opr.type | return opr.type | ||||
| class _OprStableOrderHeapq: | |||||
| """heap implementation for operator comparison in stable order""" | |||||
| _list = None | |||||
| _extra_priority = None | |||||
| _used_id_name_pairs = None | |||||
| def __init__(self, extra_priority): | |||||
| assert isinstance(extra_priority, collections.Callable) | |||||
| self._list = [] | |||||
| self._extra_priority = extra_priority | |||||
| self._used_id_name_pairs = {} | |||||
| def pop_min(self): | |||||
| return heapq.heappop(self._list)[-1] | |||||
| def add(self, opr): | |||||
| # named as add to mimic set() interface | |||||
| id_ = opr.id | |||||
| name = opr.name | |||||
| other = self._used_id_name_pairs.setdefault((id_, name), opr) | |||||
| if other is not opr: | |||||
| raise RuntimeError( | |||||
| "duplicated (id, name) pair: opr0={} opr1={}".format(other, opr) | |||||
| ) | |||||
| item = self._extra_priority(opr) + (id_, name, opr) | |||||
| heapq.heappush(self._list, item) | |||||
| def __bool__(self): | |||||
| return bool(self._list) | |||||
| def graph_traversal(outputs: _VarNode): | def graph_traversal(outputs: _VarNode): | ||||
| """ | """ | ||||
| Helper function to traverse the computing graph and return enough useful information. | Helper function to traverse the computing graph and return enough useful information. | ||||
| @@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode): | |||||
| var2oprs = collections.defaultdict(list) | var2oprs = collections.defaultdict(list) | ||||
| opr2receivers = collections.defaultdict(list) | opr2receivers = collections.defaultdict(list) | ||||
| queue = list(set(map(lambda x: x.owner, outputs))) | |||||
| queue = [] | |||||
| [queue.append(o) for o in [x.owner for x in outputs] if o not in queue] | |||||
| visited = set(map(lambda x: x.id, queue)) | visited = set(map(lambda x: x.id, queue)) | ||||
| # iterate through whole comp_graph, fill in meta information | # iterate through whole comp_graph, fill in meta information | ||||
| indegree2opr = collections.defaultdict(set) | indegree2opr = collections.defaultdict(set) | ||||
| indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,)) | |||||
| opr2indegree = {} | opr2indegree = {} | ||||
| idx = 0 | idx = 0 | ||||
| @@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode): | |||||
| indegree += 1 | indegree += 1 | ||||
| opr2receivers[pre_opr.id].append(cur_opr.id) | opr2receivers[pre_opr.id].append(cur_opr.id) | ||||
| indegree2opr[indegree].add(cur_opr.id) | |||||
| opr = cur_opr if indegree == 0 else cur_opr.id | |||||
| indegree2opr[indegree].add(opr) | |||||
| opr2indegree[cur_opr.id] = indegree | opr2indegree[cur_opr.id] = indegree | ||||
| return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | ||||
| @@ -162,8 +199,8 @@ def get_oprs_seq( | |||||
| oprs_seq = [] | oprs_seq = [] | ||||
| nr_remain = len(map_oprs) | nr_remain = len(map_oprs) | ||||
| while indegree2opr[0]: | while indegree2opr[0]: | ||||
| opr_id = indegree2opr[0].pop() | |||||
| opr = map_oprs[opr_id] | |||||
| opr = indegree2opr[0].pop_min() | |||||
| opr_id = opr.id | |||||
| nr_remain -= 1 | nr_remain -= 1 | ||||
| if opr.type != "ImmutableTensor" or not prune_immtensor: | if opr.type != "ImmutableTensor" or not prune_immtensor: | ||||
| oprs_seq.append(opr) | oprs_seq.append(opr) | ||||
| @@ -173,7 +210,10 @@ def get_oprs_seq( | |||||
| indegree2opr[indegree].remove(post_id) | indegree2opr[indegree].remove(post_id) | ||||
| indegree -= 1 | indegree -= 1 | ||||
| indegree2opr[indegree].add(post_id) | |||||
| if indegree == 0: | |||||
| indegree2opr[indegree].add(map_oprs[post_id]) | |||||
| else: | |||||
| indegree2opr[indegree].add(post_id) | |||||
| opr2indegree[post_id] = indegree | opr2indegree[post_id] = indegree | ||||
| assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | ||||
| @@ -213,10 +253,34 @@ def get_oprs_seq( | |||||
| # filter out all marked oprs | # filter out all marked oprs | ||||
| return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | ||||
| # adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs. | |||||
| def reorder_oprs_seq(oprs): | |||||
| rst = [] | |||||
| param_or_data_provider_oprs = [] | |||||
| other_oprs = [] | |||||
| for o in oprs: | |||||
| if o.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||||
| param_or_data_provider_oprs.append(o) | |||||
| else: | |||||
| other_oprs.append(o) | |||||
| for o in other_oprs: | |||||
| for inp in o.inputs: | |||||
| if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||||
| if inp.owner in param_or_data_provider_oprs: | |||||
| rst.append(inp.owner) | |||||
| param_or_data_provider_oprs.remove(inp.owner) | |||||
| rst.append(o) | |||||
| rst = rst + param_or_data_provider_oprs | |||||
| assert len(rst) == len(oprs) | |||||
| return rst | |||||
| map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | ||||
| outputs | outputs | ||||
| ) | ) | ||||
| oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | ||||
| oprs_seq = reorder_oprs_seq(oprs_seq) | |||||
| if prune_reshape is True: | if prune_reshape is True: | ||||
| oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | ||||
| return oprs_seq | return oprs_seq | ||||
| @@ -241,6 +241,7 @@ class Network: | |||||
| if optimize_for_inference: | if optimize_for_inference: | ||||
| metadata.optimize_options = optimize_options | metadata.optimize_options = optimize_options | ||||
| G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out]) | |||||
| dump_content, _ = G.dump_graph( | dump_content, _ = G.dump_graph( | ||||
| out, | out, | ||||
| keep_var_name=keep_var_name, | keep_var_name=keep_var_name, | ||||
| @@ -353,7 +354,7 @@ class Network: | |||||
| ) | ) | ||||
| shp[0] = batchsize | shp[0] = batchsize | ||||
| i.shape = tuple(shp) | i.shape = tuple(shp) | ||||
| self._compile() | |||||
| assert prev_batchsize is not None, "no data provider found" | assert prev_batchsize is not None, "no data provider found" | ||||
| assert not blacklist, "unused items in blacklist: {}".format(blacklist) | assert not blacklist, "unused items in blacklist: {}".format(blacklist) | ||||
| @@ -363,7 +364,6 @@ class Network: | |||||
| :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | ||||
| """ | """ | ||||
| if not all([var.owner for var in repl_dict.values()]): | if not all([var.owner for var in repl_dict.values()]): | ||||
| print(repl_dict.values()) | |||||
| self.add_dep_oprs(*list(repl_dict.values())) | self.add_dep_oprs(*list(repl_dict.values())) | ||||
| for var in self.all_vars: | for var in self.all_vars: | ||||
| if var in repl_dict: | if var in repl_dict: | ||||
| @@ -373,6 +373,7 @@ class Network: | |||||
| owner.outputs[idx] = var | owner.outputs[idx] = var | ||||
| var.__dict__.update(repl_var.__dict__) | var.__dict__.update(repl_var.__dict__) | ||||
| var.var = repl_var.var | var.var = repl_var.var | ||||
| self._compile() | |||||
| def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | ||||
| """ | """ | ||||
| @@ -384,11 +385,11 @@ class Network: | |||||
| assert len(opr.outputs) == len( | assert len(opr.outputs) == len( | ||||
| repl_dict[opr].outputs | repl_dict[opr].outputs | ||||
| ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | ||||
| repl_dict[opr].outputs = opr.outputs | |||||
| for ind, var in enumerate(opr.outputs): | for ind, var in enumerate(opr.outputs): | ||||
| var.owner = repl_dict[opr] | var.owner = repl_dict[opr] | ||||
| var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | ||||
| var.var = repl_dict[opr].outputs[ind].var | var.var = repl_dict[opr].outputs[ind].var | ||||
| self._compile() | |||||
| def get_opr_by_type(self, oprcls, unique=True): | def get_opr_by_type(self, oprcls, unique=True): | ||||
| assert issubclass(oprcls, OpNode) | assert issubclass(oprcls, OpNode) | ||||
| @@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
| def dtype(self): | def dtype(self): | ||||
| return self.var.dtype if self.var else None | return self.var.dtype if self.var else None | ||||
| @property | |||||
| def ndim(self): | |||||
| return super().ndim | |||||
| def __bool__(self): | def __bool__(self): | ||||
| return False | return False | ||||
| @@ -134,7 +138,18 @@ class OpNode(NetworkNode): | |||||
| self.outputs = [] | self.outputs = [] | ||||
| self.params = {} | self.params = {} | ||||
| self._opr = None # mgb opnode | self._opr = None # mgb opnode | ||||
| self.id = id(self) | |||||
| @property | |||||
| def id(self): | |||||
| if self._opr is not None: | |||||
| return self._opr.id | |||||
| return id(self) | |||||
| @property | |||||
| def priority(self): | |||||
| if self._opr is not None: | |||||
| return self._opr.priority | |||||
| return 0 | |||||
| @classmethod | @classmethod | ||||
| def load(cls, opr): | def load(cls, opr): | ||||
| @@ -144,16 +159,21 @@ class OpNode(NetworkNode): | |||||
| obj._opr = opr | obj._opr = opr | ||||
| return obj | return obj | ||||
| def compile(self, graph=None): | |||||
| op = self.opdef(**self.params) | |||||
| args = [i.var for i in self.inputs] | |||||
| outputs = rt.invoke_op(op, args) | |||||
| assert len(outputs) == len(self.outputs) | |||||
| self._opr = outputs[0].owner | |||||
| for i in range(len(self.outputs)): | |||||
| self.outputs[i].var = outputs[i] | |||||
| self.outputs[i].var.name = self.outputs[i].name | |||||
| assert self.outputs[i].owner is self | |||||
| def compile(self): | |||||
| if ( | |||||
| self._opr is None | |||||
| or len(self._opr.inputs) != len(self.inputs) | |||||
| or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)]) | |||||
| ): | |||||
| op = self.opdef(**self.params) | |||||
| args = [i.var for i in self.inputs] | |||||
| outputs = rt.invoke_op(op, args) | |||||
| assert len(outputs) == len(self.outputs) | |||||
| self._opr = outputs[0].owner | |||||
| for i in range(len(self.outputs)): | |||||
| self.outputs[i].var = outputs[i] | |||||
| self.outputs[i].var.name = self.outputs[i].name | |||||
| assert self.outputs[i].owner is self | |||||
| def add_inp_var(self, x): | def add_inp_var(self, x): | ||||
| self.inputs.append(x) | self.inputs.append(x) | ||||
| @@ -197,11 +217,17 @@ class Host2DeviceCopy(OpNode): | |||||
| return self | return self | ||||
| def compile(self, graph): | def compile(self, graph): | ||||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||||
| self._opr = outputs.owner | |||||
| if len(self.outputs) == 0: | |||||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
| self.outputs[0].var = outputs | |||||
| if ( | |||||
| self._opr is None | |||||
| or self._opr.outputs[0].comp_node != self.device | |||||
| or self._opr.outputs[0].shape != self.shape | |||||
| or self._opr.outputs[0].dtype != self.dtype | |||||
| ): | |||||
| outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||||
| self._opr = outputs.owner | |||||
| if len(self.outputs) == 0: | |||||
| self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
| self.outputs[0].var = outputs | |||||
| assert self.outputs[0].owner is self | assert self.outputs[0].owner is self | ||||
| @@ -192,7 +192,14 @@ void init_graph_rt(py::module m) { | |||||
| }) | }) | ||||
| .def("__repr__", [](cg::OperatorNodeBase* opr){ | .def("__repr__", [](cg::OperatorNodeBase* opr){ | ||||
| return "Opr:" + opr->name(); | return "Opr:" + opr->name(); | ||||
| }); | |||||
| }) | |||||
| .def_property("priority", | |||||
| [](cg::OperatorNodeBase* opr) { | |||||
| return opr->node_prop().attribute().priority; | |||||
| }, | |||||
| [](cg::OperatorNodeBase* opr, int priority) { | |||||
| opr->node_prop().attribute().priority = priority; | |||||
| }); | |||||
| py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | ||||
| .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) | .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) | ||||
| @@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph | |||||
| from megengine.core.tensor.megbrain_graph import apply_normal_varnode | from megengine.core.tensor.megbrain_graph import apply_normal_varnode | ||||
| from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.utils.network import Network | |||||
| def make_dev_tensor(value, dtype=None, device=None): | def make_dev_tensor(value, dtype=None, device=None): | ||||
| @@ -143,6 +144,46 @@ def test_get_opr_seq(): | |||||
| assert len(seq_2) == 6 | assert len(seq_2) == 6 | ||||
| def test_topological_sort(): | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def func(x, y): | |||||
| a = x + y | |||||
| a1 = F.relu(a) | |||||
| a2 = F.abs(a) | |||||
| a3 = F.ceil(a) * 2 | |||||
| a4 = F.floor(a) | |||||
| r = a1 - a2 | |||||
| r1 = a3 / a4 | |||||
| return r, r1 | |||||
| file = io.BytesIO() | |||||
| func(megengine.tensor(1.0), megengine.tensor(2.0)) | |||||
| func.dump( | |||||
| file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True | |||||
| ) | |||||
| file.seek(0) | |||||
| g = Network.load(file) | |||||
| oprseq1 = g.all_oprs | |||||
| gt = [ | |||||
| "Host2DeviceCopy", | |||||
| "Host2DeviceCopy", | |||||
| "ADD", | |||||
| "RELU", | |||||
| "ABS", | |||||
| "CEIL", | |||||
| "ImmutableTensor", | |||||
| "MUL", | |||||
| "FLOOR", | |||||
| "SUB", | |||||
| "TRUE_DIV", | |||||
| ] | |||||
| for op, mode in zip(oprseq1, gt): | |||||
| if op.type == "Elemwise": | |||||
| assert op.params["mode"] == mode | |||||
| else: | |||||
| assert op.type == mode | |||||
| def test_graph_function(): | def test_graph_function(): | ||||
| class Net(M.Module): | class Net(M.Module): | ||||
| def forward(self, a, b): | def forward(self, a, b): | ||||