| @@ -78,7 +78,6 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||||
| from .serialization import load, save | from .serialization import load, save | ||||
| from .tensor import Parameter, Tensor, tensor | from .tensor import Parameter, Tensor, tensor | ||||
| from .version import __version__ | from .version import __version__ | ||||
| from .utils import comp_graph_tools as cgtools | |||||
| _set_fork_exec_path_for_timed_func( | _set_fork_exec_path_for_timed_func( | ||||
| sys.executable, | sys.executable, | ||||
| @@ -15,6 +15,19 @@ from ..core._imperative_rt import OperatorNode, VarNode | |||||
| from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
| from ..core.tensor.raw_tensor import as_raw_tensor | from ..core.tensor.raw_tensor import as_raw_tensor | ||||
| __all__ = [ | |||||
| "get_dep_vars", | |||||
| "get_owner_opr_inputs", | |||||
| "get_owner_opr_type", | |||||
| "get_opr_type", | |||||
| "graph_traversal", | |||||
| "get_oprs_seq", | |||||
| "replace_vars", | |||||
| "replace_oprs", | |||||
| "set_priority_to_id", | |||||
| "load_and_inference", | |||||
| ] | |||||
| def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: | def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: | ||||
| """ | """ | ||||
| @@ -166,7 +179,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo | |||||
| # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor | # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor | ||||
| # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph | # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph | ||||
| def prune_reshape_oprs(outputs, oprs_seq, var2oprs): | def prune_reshape_oprs(outputs, oprs_seq, var2oprs): | ||||
| def iterative_pruning(cur_opr, post_opr, marked_opr_ids): | |||||
| def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited): | |||||
| useless = True | useless = True | ||||
| for oup in cur_opr.outputs: | for oup in cur_opr.outputs: | ||||
| if "workspace" not in oup.name: | if "workspace" not in oup.name: | ||||
| @@ -177,15 +190,20 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo | |||||
| if useless: | if useless: | ||||
| marked_opr_ids.append(cur_opr.id) | marked_opr_ids.append(cur_opr.id) | ||||
| for inp in cur_opr.inputs: | |||||
| iterative_pruning(inp.owner, cur_opr, marked_opr_ids) | |||||
| for opr in set([var.owner for var in cur_opr.inputs]): | |||||
| if (opr.id, cur_opr.id) not in visited: | |||||
| visited.add((opr.id, cur_opr.id)) | |||||
| iterative_pruning(opr, cur_opr, marked_opr_ids, visited) | |||||
| reshape_vars = get_dep_vars(outputs, "Reshape") | reshape_vars = get_dep_vars(outputs, "Reshape") | ||||
| reshape_oprs = [var.owner for var in reshape_vars] | reshape_oprs = [var.owner for var in reshape_vars] | ||||
| marked_opr_ids = [] | marked_opr_ids = [] | ||||
| visited = set() | |||||
| for reshape_opr in reshape_oprs: | for reshape_opr in reshape_oprs: | ||||
| iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids) | |||||
| iterative_pruning( | |||||
| reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited | |||||
| ) | |||||
| # filter out all marked oprs | # filter out all marked oprs | ||||
| return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | ||||
| @@ -13,9 +13,10 @@ import pytest | |||||
| import megengine | import megengine | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | import megengine.module as M | ||||
| from megengine import cgtools | |||||
| import megengine.utils.comp_graph_tools as cgtools | |||||
| from megengine.core.tensor import megbrain_graph as mgb_graph | from megengine.core.tensor import megbrain_graph as mgb_graph | ||||
| from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
| from megengine.core.tensor.utils import astensor1d | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| @@ -98,3 +99,38 @@ def test_load_refcnt(): | |||||
| graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) | graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) | ||||
| del graph | del graph | ||||
| varnode.owner | varnode.owner | ||||
| def test_get_opr_seq(): | |||||
| class Net(M.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.data = megengine.tensor( | |||||
| np.random.random((1, 1, 4, 4)), dtype=np.float32 | |||||
| ) | |||||
| def forward(self, input): | |||||
| A = input.shape[0] | |||||
| shape = astensor1d((A, A), self.data, dtype="int32", device=input.device) | |||||
| x = F.reshape(self.data, shape) | |||||
| o = input + x | |||||
| return o | |||||
| net = Net() | |||||
| input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32) | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def func(inp, *, net=None): | |||||
| return net(inp) | |||||
| func(input, net=net) | |||||
| file = io.BytesIO() | |||||
| func.dump(file, optimize_for_inference=False) | |||||
| file.seek(0) | |||||
| *_, outputs = mgb_graph.load_graph(file) | |||||
| seq_1 = cgtools.get_oprs_seq(outputs, True) | |||||
| assert len(seq_1) == 5 | |||||
| seq_2 = cgtools.get_oprs_seq(outputs, False) | |||||
| assert len(seq_2) == 6 | |||||
| @@ -14,7 +14,8 @@ import pytest | |||||
| import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import cgtools, tensor | |||||
| import megengine.utils.comp_graph_tools as cgtools | |||||
| from megengine import tensor | |||||
| from megengine.core._trace_option import set_symbolic_shape | from megengine.core._trace_option import set_symbolic_shape | ||||
| from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||