Browse Source

feat(traced_module): add some functions of graph modification

GitOrigin-RevId: ac0603057a
tags/v1.6.0
Megvii Engine Team 4 years ago
parent
commit
b1c46ba46c
10 changed files with 729 additions and 73 deletions
  1. +42
    -27
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +11
    -2
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  3. +4
    -1
      imperative/python/megengine/experimental/traced_module/node.py
  4. +7
    -2
      imperative/python/megengine/experimental/traced_module/pytree.py
  5. +274
    -41
      imperative/python/megengine/experimental/traced_module/traced_module.py
  6. +90
    -0
      imperative/python/test/unit/traced_module/test_haoruitao.py
  7. +113
    -0
      imperative/python/test/unit/traced_module/test_modification.py
  8. +52
    -0
      imperative/python/test/unit/traced_module/test_serialization.py
  9. +42
    -0
      imperative/python/test/unit/traced_module/test_trace_module.py
  10. +94
    -0
      imperative/python/test/unit/traced_module/test_wujianan.py

+ 42
- 27
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -9,6 +9,7 @@

import builtins
import collections
import inspect
from typing import Callable, List

from ...core._imperative_rt import OpDef
@@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Tensor
from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef
from .pytree import TreeDef, tree_flatten


class Expr:
@@ -38,25 +39,28 @@ class Expr:
for val in vals:
node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)):
if node not in self.inputs:
self.inputs.append(node)
self.inputs.append(node)
node.users.append(self)
else:
assert node is None
assert type(val) in builtins.__dict__.values()
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))

def add_outputs(self, outputs):
def add_outputs(self, outputs, check_inplace=True):
self.outputs = []
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
if outputs is not None:
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)

for i in outputs:
assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
for i in outputs:
assert isinstance(i, RawTensor)
node = NodeMixin.get(i, None) if check_inplace else None
self.outputs.append(
node if node else NodeMixin.get_wrapped_type(i)(self)
)

for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)

def unflatten_args(self, inputs):
if self.arg_def is not None:
@@ -110,6 +114,7 @@ class GetAttr(Expr):
self.inputs = [
module,
]
module.users.append(self)
self.name = name
node_cls = type if type else Node
self.outputs = [
@@ -134,12 +139,20 @@ class GetAttr(Expr):

# expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr):
def __init__(self, module, method="__call__"):
assert isinstance(module, (TensorNode, ModuleNode))
self.inputs = [
module,
]
self.const_val = []
def __init__(self, node, method="__call__"):
if isinstance(node, type):
assert issubclass(node, Tensor)
cls = Parameter if issubclass(node, Parameter) else Tensor

self.inputs = []
self.const_val = [(0, cls)]
else:
assert isinstance(node, (TensorNode, ModuleNode))
node.users.append(self)
self.inputs = [
node,
]
self.const_val = []
self.method = method

@classmethod
@@ -160,10 +173,13 @@ class CallMethod(Expr):
def interpret(self, *inputs):
args, kwargs = self.unflatten_args(inputs)
obj = args[0]
args = args[1:]
meth = getattr(obj, self.method)
if inspect.ismethod(meth):
args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
if outputs is None:
return outputs
outputs, _ = tree_flatten(outputs, is_leaf=lambda x: isinstance(x, RawTensor))
return outputs

def __repr__(self):
@@ -171,7 +187,7 @@ class CallMethod(Expr):
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
self.args[0],
self.method,
", ".join([args, kwargs]),
)
@@ -209,9 +225,8 @@ class Apply(Expr):
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef)
for i in inputs:
assert isinstance(i, RawTensor)
apply_node.inputs.append(NodeMixin.get(i))
apply_node.add_inputs(inputs)
assert not apply_node.const_val

unset_module_tracing()
outputs = apply(opdef, *inputs)
@@ -283,7 +298,7 @@ class Constant(Expr):
return (self.value,)

def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], self.value)
return "{} = Constant({})".format(self.outputs[0], type(self.value))

def __getstate__(self):
state = self.__dict__.copy()


+ 11
- 2
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [
"min",
"max",
"mean",
"__getitem__",
"__setitem__",
]


@@ -176,7 +178,8 @@ class Patcher:
self.patch_module(module)
for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)

self.patch_method(Tensor, "detach", self.wrap_fn)
self.patch_method(Tensor, "__new__", self.wrap_fn)
for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)
@@ -203,7 +206,13 @@ class Patcher:
import inspect

if id(module.__dict__) not in self.visited_frames_ids:
for k, v in module.__dict__.items():
keys = (
getattr(module, "__all__")
if hasattr(module, "__all__")
else module.__dict__.keys()
)
for k in keys:
v = getattr(module, k)
if inspect.isfunction(v) and not k.startswith("_"):
self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__))


+ 4
- 1
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Any, Dict, Tuple, Type
from typing import Any, Dict, List, Tuple, Type

import numpy

@@ -31,6 +31,7 @@ class Node:

def __init__(self, expr: "Expr", name: str = None):
self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id
Node.__total_id += 1
self._name = name
@@ -59,11 +60,13 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module]
attr_type_map = None # type: Dict[str, Type[Any]]
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
argdef_outdef_map = None # type: Dict[Treedef, Treedef]

def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}
self.argdef_outdef_map = {}

def __repr__(self):
if self._name is None:


+ 7
- 2
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -10,6 +10,8 @@
import collections
from typing import Callable, NamedTuple

import numpy as np

SUPPORTED_TYPE = {}

NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
@@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data):


register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type(
slice,
@@ -52,7 +54,10 @@ def tree_flatten(
assert is_leaf(values)
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
node.const_val = values
if isinstance(values, np.ndarray):
node.const_val = str(values)
else:
node.const_val = values
return [values,], node

rst = []


+ 274
- 41
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -10,8 +10,13 @@ import collections
import copy
import functools
from inspect import getmembers, isclass, ismethod
from typing import Dict, List, Type
from typing import Callable, Dict, Iterable, List, Sequence, Type

import numpy as np
from numpy.lib.arraysetops import isin

from ... import functional as F
from ... import get_logger
from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import (
@@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import (
set_module_tracing,
unset_module_tracing,
)
from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...tensor import Tensor
@@ -32,6 +38,8 @@ from .module_tracer import (
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten

logger = get_logger(__name__)


def _leaf_type(node):
if isinstance(node, RawTensor):
@@ -42,6 +50,11 @@ def _leaf_type(node):
return type(node)


def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
return isinstance(node, RawTensor)


def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
@@ -80,7 +93,13 @@ class InternalGraph:

@property
def exprs(self):
return _expr_list(self)
return ExprFilter(_expr_iter(self))

def get_call_function(self, func: Callable = None):
return self.exprs.call_function(func)

def get_call_method(self, method: str = None):
return self.exprs.call_method(method)

def add_input(self, i):
self._inputs.append(i)
@@ -88,16 +107,131 @@ class InternalGraph:
def add_output(self, o):
self._outputs.append(o)

def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
nodes = (nodes,)
ret = list()
queue = list(nodes)
while queue:
node = queue.pop()
expr = node.expr
if expr not in ret:
ret.append(expr)
for i in expr.inputs:
if i not in queue:
queue.append(i)
return ret

def insert_call_function(self, func: Callable, nodes: Sequence[Node]):
if not isinstance(nodes, Sequence):
nodes = [nodes]
assert isinstance(func, Callable)
for i in nodes:
assert isinstance(
i, TensorNode
), "CallFunction only accept TensorNode as inputs"

expr = CallFunction(func)
expr.inputs = nodes

for i in nodes:
i.users.append(expr)

idx = max(self._exprs.index(i.expr) for i in nodes) + 1
self._exprs.insert(idx, expr)

fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes)
fake_out_val = func(*fake_inp_val)

def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node

out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes

return out_nodes

def insert_call_method(self, target, method, args):
if not isinstance(args, Sequence):
args = [args]
assert isinstance(target, (TensorNode, ModuleNode))
assert isinstance(method, str)
for i in args:
assert isinstance(i, TensorNode)

expr = CallMethod(method)
expr.inputs = [target, *args]

if isinstance(target, TensorNode):
fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype)
fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args)
fake_out_val = getattr(fake_target_val, method)(fake_inp_val)

def create_node(val: Tensor):
node = TensorNode(expr)
node.shape = val.shape
node.dtype = val.dtype
return node

out_nodes = list(create_node(i) for i in fake_out_val)
expr.outputs = out_nodes
else:
raise NotImplementedError()

return out_nodes

def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
# check graph inputs and outputs
assert node not in self.inputs, "Cannot replace inputs"
for i, n in enumerate(self.outputs):
if n is node:
self.outputs[i] = repl_node
# update users of node and repl_node
# update inputs of expr in node.users
dep_exprs = self.get_dep_exprs(repl_node)
i = 0
while i < len(node.users):
n = node.users[i]
if n in dep_exprs:
logger.info("Find a loop: ignore this replacement once")
logger.info("node: %s" % node.__repr__())
logger.info("repl_node: %s" % repl_node.__repr__())
i += 1
continue
repl_node.users.append(n)
node.users.pop(i)
idx = n.inputs.index(node)
n.inputs[idx] = repl_node

def compile(self):
"""
Delete unused expr.
"""
dep_exprs = self.get_dep_exprs(self.outputs)
i = 0
while i < len(self._exprs):
expr = self._exprs[i]
if expr in dep_exprs:
i += 1
continue
for n in expr.inputs:
n.users.remove(expr)
self._exprs.remove(expr)

def interpret(self, *inputs):
# TODO: support kwargs ?
# TODO: skip expressions which are independent and have no side effect
node2value = {}
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
for n, v in zip(expr.outputs, values):
node2value[n] = v
if values is not None:
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)

def __repr__(self):
@@ -109,7 +243,8 @@ class InternalGraph:


def _get_meth_name(obj, func):
for cls in type(obj).mro():
tp = obj if isinstance(obj, type) else type(obj)
for cls in tp.mro():
for k, v in cls.__dict__.items():
if v == func:
return k
@@ -131,15 +266,31 @@ def _wrapped_function(orig_func):
meth_name = _get_meth_name(args[0], wrapped_fn)
if meth_name:
self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]):
# only trace Tensor.__new__() when there are tensors in args
set_module_tracing()
return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1])
inputs[1] = copy.copy(inputs[1])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
NodeMixin.wrap_safe(inputs[1], node)
args, kwargs = tree_def.unflatten(inputs)
call_node = CallMethod.make(self, meth_name)
else:
call_node = CallMethod.make(NodeMixin.get(self), meth_name)
call_node.add_inputs(inputs[1:])
else:
call_node = CallFunction.make(orig_func)

call_node.add_inputs(inputs)
call_node.add_inputs(inputs)

call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
call_node.add_outputs(outputs)
if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*args, **kwargs)
@@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin):
mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self))

callnode.add_inputs(inputs)
callnode.add_inputs(inputs[1:])

callnode.arg_def = tree_def

if self._is_builtin:
unset_module_tracing()
outputs = self._mod(*args, **kwargs)
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
set_module_tracing()
if self._is_builtin:
self._body = None
@@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin):
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
def wrap(x):
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
return wrapped
return x

args = [self]
for i in inputs[1:]:
@@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
outputs = type(self._mod).forward(*args, **kwargs)
rst = type(self._mod).forward(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))

NodeMixin.wrap_safe(self, orig_self)
for arg, node in zip(inputs[1:], origin_inp_node):
if node:
NodeMixin.wrap_safe(arg, node)
active_module_tracer().pop_scope()

# rebind output to outer graph
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
return outputs
self_node.argdef_outdef_map[callnode.arg_def] = out_def
return rst

def __getattr__(self, name):
if name not in self._mod.__dict__:
@@ -268,20 +423,29 @@ class TracedModuleBuilder(NodeMixin):
return super().__getattribute__(name)
else:
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
assert not self._is_builtin
NodeMixin.wrap(
wrapped,
lambda: GetAttr.make(
if name in self._mod.__dict__:
if not NodeMixin.get(wrapped, None):
assert not self._is_builtin
NodeMixin.wrap(
wrapped,
lambda: GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
),
)
else:
node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
),
)
).expr
expr.outputs[0] = node
return wrapped


class _expr_list:
class _expr_iter:
def __init__(self, graph: InternalGraph):
self.graph = graph

@@ -295,6 +459,59 @@ class _expr_list:
yield expr


class ExprFilter:
def __init__(self, expr_iter: Iterable):
self._iter = expr_iter

def __iter__(self):
return iter(self._iter)

def call_function(self, func):
return ExprFilterCallFunction(self, func)

def call_method(self, method):
return ExprFilterCallMethod(self, method)

def as_list(self):
return list(self)

def as_dict(self):
raise NotImplementedError("need key")

def as_unique(self):
(expr,) = self
return expr

def as_count(self):
return sum(1 for _ in self)


class ExprFilterCallFunction(ExprFilter):
def __init__(self, expr_iter, func: Callable = None):
super().__init__(expr_iter)
self.func = func

def __iter__(self):
for i in self._iter:
if not isinstance(i, CallFunction):
continue
if self.func is None or i.func == self.func:
yield i


class ExprFilterCallMethod(ExprFilter):
def __init__(self, expr_iter, method: str = None):
super().__init__(expr_iter)
self.method = method

def __iter__(self):
for i in self._iter:
if not isinstance(i, CallMethod):
continue
if self.method is None or i.method == self.method:
yield i


class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
@@ -312,10 +529,12 @@ class TracedModule(Module):
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
assert treedef in self.m_node.argdef_graph_map
inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))]
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
if len(outputs) == 1:
return outputs[0]
out_def = self.m_node.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
return outputs

@property
@@ -339,9 +558,8 @@ class TracedModule(Module):
if graph is None:
assert not isinstance(module, TracedModule)
const = Constant(module)
modulenode = const.outputs[0]
modulenode.module_type = type(module)
call.inputs[0] = modulenode
const.outputs[0] = call.inputs[0]
const.outputs[0].expr = const
return [const, call]
exprs = []
for expr in graph._exprs:
@@ -350,30 +568,41 @@ class TracedModule(Module):
if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr

if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const.outputs = expr.outputs
const.outputs[0].expr = const
exprs.append(const)

elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
pre_expr = expr.inputs[0].expr
if isinstance(pre_expr, GetAttr):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
else:
# module has been replaced.
assert isinstance(pre_expr, Constant)
else:
exprs.append(expr)
else:
exprs.append(expr)

if call is not None:
for i in call.inputs:
i.users.remove(call)

return exprs

new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
@@ -422,22 +651,26 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
"""
assert active_module_tracer() is None
try:
use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
set_active_module_tracer(module_tracer(_wrapped_function))

with active_module_tracer().patcher:
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)

builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
for _, i in enumerate(inputs):
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
)
if isinstance(i, RawTensor):
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
)
builder(*args, **kwargs)
active_module_tracer().pop_scope()
return builder.build()
finally:
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)
unset_module_tracing()

+ 90
- 0
imperative/python/test/unit/traced_module/test_haoruitao.py View File

@@ -0,0 +1,90 @@
import io
import pickle

import numpy as np

import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace

set_symbolic_shape(True)


class Main(M.Module):
def forward(self, x):
return x


class PreProcess(M.Module):
def __init__(self):
super().__init__()
self.I = F.ones((1,))
self.M = F.zeros((1,))

def forward(self, data, idx, roi):
N, H, W, C = data.shape
xmax = roi[:, 1, 0]
xmin = roi[:, 0, 0]
ymax = roi[:, 1, 1]
ymin = roi[:, 0, 1]
scale = F.maximum((xmax - xmin) / W, (ymax - ymin) / H)
I = F.broadcast_to(self.I, (N,))
M = F.broadcast_to(self.M, (N, 3, 3))
M[:, 0, 0] = scale
M[:, 0, 2] = xmin
M[:, 1, 1] = scale
M[:, 1, 2] = ymin
M[:, 2, 2] = I
resized = (
F.warp_perspective(
data, M, (H, W), mat_idx=idx, border_mode="CONSTANT", format="NHWC"
)
.transpose(0, 3, 1, 2)
.astype(np.float32)
)
return resized


class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.pre_process = PreProcess()
self.traced_module = traced_module

def forward(self, data, idx, roi):
x = self.pre_process(data, idx, roi)
x = self.traced_module(x)
return x


def test_preprocess():
module = Main()
data = F.ones((1, 14, 8, 8), dtype=np.uint8)
traced_module = trace_module(module, data)
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
module = Net(traced_module)
module.eval()
idx = F.zeros((1,), dtype=np.int32)
roi = F.ones((1, 2, 2), dtype=np.float32)
y = module(data, idx, roi)
traced_module = trace_module(module, data, idx, roi)
np.testing.assert_array_equal(traced_module(data, idx, roi), y)
func = trace(traced_module, capture_as_const=True)
np.testing.assert_array_equal(func(data, idx, roi), y)
model = io.BytesIO()
func.dump(model, arg_names=("data", "idx", "roi"))
model.seek(0)
infer_cg = cgtools.GraphInference(model)
np.testing.assert_allclose(
list(
infer_cg.run(
inp_dict={"data": data.numpy(), "idx": idx.numpy(), "roi": roi.numpy()}
).values()
)[0],
y,
atol=1e-6,
)

+ 113
- 0
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -0,0 +1,113 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine.functional as F
import megengine.module as M
from megengine.experimental.traced_module import trace_module
from megengine.experimental.traced_module.expr import CallFunction, GetAttr


class MyBlock(M.Module):
def __init__(self, in_channels=3, channels=3):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x


class MyModule(M.Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock()
self.block1 = MyBlock()

def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x


def _init_cls(cls):
module = cls()
x = F.ones((1, 3, 3, 3))
y = module(x)
traced_module = trace_module(module, x)
return traced_module, x, y


def _init_block():
return _init_cls(MyBlock)


def _init_module():
return _init_cls(MyModule)


def test_search():
traced_module, *_ = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu


def test_insert():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_node = graph.get_call_function(F.relu).as_unique().outputs
neg_node = graph.insert_call_function(F.neg, relu_node)
graph.replace_node({relu_node[0]: neg_node[0]})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)


def test_delete():
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_expr = graph.get_call_function(F.relu).as_unique()
node = relu_expr.outputs
repl_node = relu_expr.inputs
graph.replace_node({node[0]: repl_node[0]})
graph.compile()
np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)


def test_flatten():
traced_module, x, expect = _init_module()
traced_module = traced_module.flatten()
traced_module.graph.compile()
assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
assert len(traced_module.graph._exprs) == 12


def test_extra_block():
class PostProcess(M.Module):
def forward(self, x):
return x * 2

class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.post_process = PostProcess()
self.traced_module = traced_module

def forward(self, x):
x = self.traced_module(x)
x = self.post_process(x)
return x

traced_module, x, expect = _init_block()
module = Net(traced_module)
np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
traced_module = trace_module(module, x)
np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)

+ 52
- 0
imperative/python/test/unit/traced_module/test_serialization.py View File

@@ -0,0 +1,52 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle

import numpy as np

import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module


class MyBlock(Module):
def __init__(self, in_channels, channels):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x


class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock(8, 4)
self.block1 = MyBlock(4, 2)

def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x


def test_dump_and_load():
module = MyModule()
x = Tensor(np.ones((1, 8, 14, 14)))
expect = module(x)
traced_module = trace_module(module, x)
np.testing.assert_array_equal(expect, traced_module(x))
obj = pickle.dumps(traced_module)
pickle.loads(obj)
np.testing.assert_array_equal(expect, traced_module(x))

+ 42
- 0
imperative/python/test/unit/traced_module/test_trace_module.py View File

@@ -0,0 +1,42 @@
import numpy as np

from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module as M


class MyModule1(M):
def forward(self, x):
y = Tensor(x)
y += 1
x = x + 2
return x, y


class MyModule2(M):
def forward(self, x):
y = Tensor([1, x, 1])
y += 1
x = x + 2
return x, y


def test_trace_module():

x = Tensor(1)
m1 = MyModule1()
tm1 = trace_module(m1, x)

m2 = MyModule2()
tm2 = trace_module(m2, x)
inp = Tensor(2)
gt = m1(inp)
output = tm1(inp)
for a, b in zip(output, gt):
np.testing.assert_equal(a.numpy(), b.numpy())

gt1 = m2(inp)
output1 = tm2(inp)

for a, b in zip(output1, gt1):
np.testing.assert_equal(a.numpy(), b.numpy())

+ 94
- 0
imperative/python/test/unit/traced_module/test_wujianan.py View File

@@ -0,0 +1,94 @@
import io
import pickle

import numpy as np

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace

set_symbolic_shape(True)


class Main(M.Module):
def forward(self, x):
return x["data"]


class PreProcess(M.Module):
def __init__(self):
super().__init__()
self.A = F.zeros((1,))
self.I = F.ones((1,))
self.bb_out = mge.tensor(
np.array([[[0, 0], [160, 0], [160, 48], [0, 48]]], dtype="float32")
)

def forward(self, data, quad):
"""
data: (1, 3, 48, 160)
quad: (1, 4, 2)
"""
N = quad.shape[0]
dst = F.repeat(self.bb_out, N, axis=0).reshape(-1, 4, 2)
I = F.broadcast_to(self.I, quad.shape)
A = F.broadcast_to(self.A, (N, 8, 8))
A[:, 0:4, 0:2] = quad
A[:, 4:8, 5:6] = I[:, :, 0:1]
A[:, 0:4, 6:8] = -quad * dst[:, :, 0:1]
A[:, 4:8, 3:5] = quad
A[:, 0:4, 2:3] = I[:, :, 0:1]
A[:, 4:8, 6:8] = -quad * dst[:, :, 1:2]
B = dst.transpose(0, 2, 1).reshape(-1, 8, 1)
M = F.concat([F.matmul(F.matinv(A), B)[:, :, 0], I[:, 0:1, 0]], axis=1).reshape(
-1, 3, 3
)
new_data = F.warp_perspective(data, M, (48, 160)) # (N, 3, 48, 160)
return {"data": new_data}


class Net(M.Module):
def __init__(self, traced_module):
super().__init__()
self.pre_process = PreProcess()
self.traced_module = traced_module

def forward(self, data, quad):
x = self.pre_process(data, quad)
x = self.traced_module(x)
return x


def test_preprocess():
batch_size = 2
module = Main()
data = mge.tensor(
np.random.randint(0, 256, size=(batch_size, 3, 48, 160)), dtype=np.float32
)
traced_module = trace_module(module, {"data": data})
obj = pickle.dumps(traced_module)
traced_module = pickle.loads(obj)
module = Net(traced_module)
module.eval()
quad = mge.tensor(np.random.normal(size=(batch_size, 4, 2)), dtype=np.float32)
expect = module(data, quad)
traced_module = trace_module(module, data, quad)
actual = traced_module(data, quad)
for i, j in zip(expect, actual):
np.testing.assert_array_equal(i, j)
func = trace(traced_module, capture_as_const=True)
actual = func(data, quad)
for i, j in zip(expect, actual):
np.testing.assert_array_equal(i, j)
model = io.BytesIO()
func.dump(model, arg_names=("data", "quad"))
model.seek(0)
infer_cg = cgtools.GraphInference(model)
actual = list(
infer_cg.run(inp_dict={"data": data.numpy(), "quad": quad.numpy()}).values()
)[0]
np.testing.assert_allclose(expect, actual)

Loading…
Cancel
Save