You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 41 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. import inspect
  13. import weakref
  14. from inspect import getmembers, isclass, ismethod
  15. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
  16. from ... import functional as F
  17. from ... import get_logger
  18. from ... import module as M
  19. from ...core._imperative_rt.core2 import Tensor as RawTensor
  20. from ...core._imperative_rt.core2 import (
  21. is_tracing_module,
  22. set_module_tracing,
  23. unset_module_tracing,
  24. )
  25. from ...core._trace_option import set_symbolic_shape
  26. from ...core.tensor.array_method import ArrayMethodMixin
  27. from ...module import Module
  28. from ...tensor import Tensor
  29. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  30. from .module_tracer import (
  31. Patcher,
  32. active_module_tracer,
  33. module_tracer,
  34. set_active_module_tracer,
  35. )
  36. from .node import ModuleNode, Node, NodeMixin, TensorNode
  37. from .pytree import tree_flatten
  38. logger = get_logger(__name__)
  39. def _leaf_type(node):
  40. if isinstance(node, (RawTensor, TensorNode)):
  41. return (Tensor, TensorNode)
  42. elif isinstance(node, (NodeMixin, Module, ModuleNode)):
  43. return (Module, ModuleNode, NodeMixin)
  44. else:
  45. return type(node)
  46. def _is_leaf(node):
  47. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  48. type(node)
  49. )
  50. return isinstance(node, RawTensor)
  51. def _is_const_leaf(node):
  52. if isinstance(node, (RawTensor, NodeMixin, Module)):
  53. return False
  54. return True
  55. def wrap_tensors(tensors: Tensor, nodes: TensorNode):
  56. inp_tensors = copy.deepcopy(tensors)
  57. inp_tensors, inp_def_v = tree_flatten(
  58. inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  59. )
  60. inp_nodes, inp_def_n = tree_flatten(
  61. nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  62. )
  63. for v, n in zip(inp_tensors, inp_nodes):
  64. if isinstance(n, TensorNode) and isinstance(v, Tensor):
  65. NodeMixin.wrap_safe(v, n)
  66. return inp_def_v.unflatten(inp_tensors)
  67. class _InsertExprs:
  68. def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True):
  69. self.graph = graph
  70. self.global_scope = InternalGraph()
  71. self.expr = expr
  72. self.after = after
  73. def __enter__(self):
  74. self.use_sym_shape = set_symbolic_shape(True)
  75. set_module_tracing()
  76. assert active_module_tracer() is None
  77. set_active_module_tracer(module_tracer(_wrapped_function))
  78. active_module_tracer().patcher.__enter__()
  79. active_module_tracer().push_scope(self.global_scope)
  80. def __exit__(self, ty, va, tr):
  81. set_symbolic_shape(self.use_sym_shape)
  82. unset_module_tracing()
  83. active_module_tracer().patcher.__exit__(ty, va, tr)
  84. set_active_module_tracer(None)
  85. index = len(self.graph._exprs) if self.after else 0
  86. if self.expr is not None:
  87. index = self.graph._exprs.index(self.expr)
  88. if self.after:
  89. index += 1
  90. for expr in self.global_scope._exprs:
  91. self.graph._exprs.insert(index, expr)
  92. index += 1
  93. class InternalGraph:
  94. """
  95. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  96. Attributes:
  97. _exprs: List of Exprs in order of execution
  98. _inputs: Input Nodes of InternalGraph
  99. _outputs: Output Nodes of InternalGraph
  100. """
  101. _exprs = None # type: List[Expr]
  102. _inputs = None # type: List[Node]
  103. _outputs = None # type: List[Node]
  104. def __init__(self):
  105. self._exprs = []
  106. self._inputs = []
  107. self._outputs = []
  108. def insert(self, expr):
  109. self._exprs.append(expr)
  110. @property
  111. def inputs(self):
  112. return self._inputs
  113. @property
  114. def outputs(self):
  115. return self._outputs
  116. @property
  117. def expr_filter(self):
  118. return ExprFilter(_expr_iter(self))
  119. @property
  120. def node_filter(self):
  121. return NodeFilter(_node_iter(self))
  122. def get_function_by_type(self, func: Callable = None):
  123. return self.expr_filter.call_function(func)
  124. def get_method_by_type(self, method: str = None):
  125. return self.expr_filter.call_method(method)
  126. def get_expr_by_id(self, expr_id: List[int] = None):
  127. return self.expr_filter.expr_id(expr_id)
  128. def get_module_by_type(self, module_cls: Module):
  129. assert issubclass(module_cls, Module)
  130. return self.node_filter.type(module_cls, ModuleNode)
  131. def get_node_by_id(self, node_id: List[int] = None):
  132. return self.node_filter.node_id(node_id)
  133. def add_input(self, i):
  134. self._inputs.append(i)
  135. def add_output(self, o):
  136. self._outputs.append(o)
  137. def _replace_inputs_outputs(self, repl_dict):
  138. for node, repl_node in repl_dict.items():
  139. assert node in self._inputs or node in self._outputs
  140. for i in node.users:
  141. if i not in repl_node.users:
  142. repl_node.users.append(i)
  143. for idx, i in enumerate(self._inputs):
  144. if i in repl_dict:
  145. self._inputs[idx] = repl_dict[i]
  146. for idx, o in enumerate(self._outputs):
  147. if o in repl_dict:
  148. self._outputs[idx] = repl_dict[o]
  149. for expr in self._exprs:
  150. for idx, i in enumerate(expr.inputs):
  151. if i in repl_dict:
  152. expr.inputs[idx] = repl_dict[i]
  153. for idx, o in enumerate(expr.outputs):
  154. if o in repl_dict:
  155. expr.outputs[idx] = repl_dict[o]
  156. expr.outputs[idx].expr = expr
  157. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  158. if not isinstance(nodes, Sequence):
  159. nodes = (nodes,)
  160. ret = list()
  161. queue = list(nodes)
  162. visited_queue = list()
  163. while queue:
  164. node = queue.pop()
  165. visited_queue.append(node)
  166. expr = node.expr
  167. if expr not in ret:
  168. ret.append(expr)
  169. for i in expr.inputs:
  170. if i not in queue and i not in visited_queue:
  171. queue.append(i)
  172. return ret
  173. def reset_inputs(self, *args, **kwargs):
  174. forma_mnode = self.inputs[0]
  175. actual_mnodes = forma_mnode.actual_mnode
  176. call_nodes = []
  177. for n in actual_mnodes:
  178. for c_expr in n.users:
  179. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  180. call_nodes.append((c_expr, n))
  181. moudle = forma_mnode.owner
  182. assert moudle._is_top, "reset_inputs only support the top-level graph"
  183. inputs, tree_def = tree_flatten(
  184. ((moudle, *args), kwargs),
  185. leaf_type=_leaf_type,
  186. is_const_leaf=_is_const_leaf,
  187. )
  188. def create_node(val: Tensor):
  189. node = Input(type=TensorNode).outputs[0]
  190. node.shape = val.shape
  191. node.dtype = val.dtype
  192. return node
  193. formal_node_inputs = [
  194. forma_mnode,
  195. ]
  196. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  197. if call_nodes:
  198. org_argdef = call_nodes[0][0].arg_def
  199. for v in inputs[1:]:
  200. assert isinstance(v, RawTensor)
  201. formal_node_inputs.append(create_node(v))
  202. actual_nodes = []
  203. for e, n in call_nodes:
  204. e.arg_def = tree_def
  205. actual_node_inputs = [
  206. n,
  207. ]
  208. for v in inputs[1:]:
  209. actual_node_inputs.append(create_node(v))
  210. for org_n in e.inputs:
  211. org_n.users.pop(e)
  212. e.inputs[:] = actual_node_inputs
  213. e.const_val = []
  214. actual_nodes.append(actual_node_inputs[1:])
  215. self._inputs[:] = formal_node_inputs
  216. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  217. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  218. # return formal_node_inputs[1:], actual_nodes
  219. return formal_node_inputs[1:]
  220. def add_input_node(self, shape, dtype="float32"):
  221. forma_mnode = self.inputs[0]
  222. actual_mnodes = forma_mnode.actual_mnode
  223. moudle = forma_mnode.owner
  224. assert moudle._is_top, "add_input_node only support the top-level graph"
  225. call_nodes = []
  226. for n in actual_mnodes:
  227. for c_expr in n.users:
  228. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  229. call_nodes.append(c_expr)
  230. def create_node(is_input: bool = True):
  231. if is_input:
  232. node = Input(type=TensorNode).outputs[0]
  233. else:
  234. node = TensorNode(expr=None)
  235. node.shape = shape
  236. node.dtype = dtype
  237. return node
  238. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  239. if call_nodes:
  240. org_argdef = call_nodes[0].arg_def
  241. args, kwargs = org_argdef.unflatten(self._inputs)
  242. formal_inp_node = create_node(True)
  243. inputs, tree_def = tree_flatten(
  244. ((*args, formal_inp_node), kwargs),
  245. leaf_type=_leaf_type,
  246. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  247. )
  248. self._inputs[:] = inputs[:]
  249. actual_inp_nodes = []
  250. for e in call_nodes:
  251. args, kwargs = e.unflatten_args(e.inputs)
  252. args = args + (create_node(False),)
  253. inputs, tree_def = tree_flatten(
  254. (args, kwargs),
  255. leaf_type=_leaf_type,
  256. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  257. )
  258. e.inputs[:] = inputs[:]
  259. e.arg_def = tree_def
  260. actual_inp_nodes.append(args[-1])
  261. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  262. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  263. # return formal_inp_node, actual_inp_nodes
  264. return formal_inp_node
  265. def reset_outputs(self, outputs):
  266. outputs, out_def = tree_flatten(
  267. outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode),
  268. )
  269. forma_mnode = self.inputs[0]
  270. moudle = forma_mnode.owner
  271. assert moudle._is_top, "reset_outputs only support the top-level graph"
  272. actual_mnodes = forma_mnode.actual_mnode
  273. call_nodes = []
  274. for n in actual_mnodes:
  275. for c_expr in n.users:
  276. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  277. call_nodes.append((c_expr))
  278. def create_node(val: TensorNode, expr: Expr):
  279. node = TensorNode(expr)
  280. node.shape = val.shape
  281. node.dtype = val.dtype
  282. return node
  283. tree_def = list(moudle.argdef_graph_map.keys())[0]
  284. if call_nodes:
  285. tree_def = call_nodes[0].arg_def
  286. actual_nodes = []
  287. for e in call_nodes:
  288. actual_node_outputs = []
  289. for v in outputs:
  290. actual_node_outputs.append(create_node(v, e))
  291. e.outputs[:] = actual_node_outputs
  292. e.out_def = out_def
  293. actual_nodes.append(actual_node_outputs)
  294. self._outputs[:] = outputs
  295. moudle.argdef_outdef_map[tree_def] = out_def
  296. return actual_nodes
  297. def add_output_node(self, node: TensorNode):
  298. forma_mnode = self.inputs[0]
  299. moudle = forma_mnode.owner
  300. assert moudle._is_top, "add_output_node only support the top-level graph"
  301. actual_mnodes = forma_mnode.actual_mnode
  302. call_nodes = []
  303. for n in actual_mnodes:
  304. for c_expr in n.users:
  305. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  306. call_nodes.append((c_expr))
  307. def create_node(val: TensorNode, expr: Expr):
  308. node = TensorNode(expr)
  309. node.shape = val.shape
  310. node.dtype = val.dtype
  311. return node
  312. tree_def = list(moudle.argdef_graph_map.keys())[0]
  313. if call_nodes:
  314. tree_def = call_nodes[0].arg_def
  315. org_out_def = moudle.argdef_outdef_map[tree_def]
  316. org_outs = org_out_def.unflatten(self._outputs)
  317. outputs, out_def = tree_flatten(
  318. (org_outs, node),
  319. leaf_type=_leaf_type,
  320. is_leaf=lambda x: isinstance(x, TensorNode),
  321. )
  322. self._outputs[:] = outputs
  323. actual_out_nodes = []
  324. for e in call_nodes:
  325. actual_node = create_node(node, e)
  326. org_outs = org_out_def.unflatten(e.outputs)
  327. outputs, out_def = tree_flatten(
  328. (org_outs, actual_node),
  329. leaf_type=_leaf_type,
  330. is_leaf=lambda x: isinstance(x, TensorNode),
  331. )
  332. e.outputs[:] = outputs
  333. e.out_def = out_def
  334. actual_out_nodes.append(actual_node)
  335. moudle.argdef_outdef_map[tree_def] = out_def
  336. return actual_out_nodes
  337. def insert_function(self, func: Callable, *args, **kwargs):
  338. assert isinstance(func, Callable)
  339. inp_nodes, inp_def = tree_flatten(
  340. (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  341. )
  342. insert_idx = -1
  343. for i in inp_nodes:
  344. if isinstance(i, TensorNode) and i.expr in self._exprs:
  345. insert_idx = max(insert_idx, self._exprs.index(i.expr))
  346. fake_inp_val = list(
  347. F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i
  348. for i in inp_nodes
  349. )
  350. for v, n in zip(fake_inp_val, inp_nodes):
  351. if isinstance(n, TensorNode):
  352. NodeMixin.wrap_safe(v, n)
  353. fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val)
  354. insert_point = self.insert_exprs_before()
  355. if insert_idx != -1:
  356. insert_point = self.insert_exprs_after(self._exprs[insert_idx])
  357. with insert_point:
  358. rst = func(*fake_args, **fake_kwargs)
  359. if rst is None:
  360. return None
  361. outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
  362. node_outputs = []
  363. for out in outputs:
  364. assert isinstance(out, RawTensor)
  365. node_outputs.append(NodeMixin.get(out, None))
  366. node_outputs = out_def.unflatten(node_outputs)
  367. return node_outputs
  368. def insert_exprs_after(self, expr: Optional[Expr] = None):
  369. if expr is not None:
  370. assert expr.top_graph == self, "Expr to insert after is not in graph."
  371. return _InsertExprs(self, expr, after=True)
  372. def insert_exprs_before(self, expr: Optional[Expr] = None):
  373. if expr is not None:
  374. assert expr.top_graph == self, "Expr to insert before is not in graph."
  375. return _InsertExprs(self, expr, after=False)
  376. def replace_node(self, repl_dict: Dict[Node, Node]):
  377. while repl_dict:
  378. node, repl_node = repl_dict.popitem()
  379. # check graph inputs and outputs
  380. assert node not in self.inputs, "Cannot replace inputs"
  381. for i, n in enumerate(self.outputs):
  382. if n is node:
  383. self.outputs[i] = repl_node
  384. # update users of node and repl_node
  385. # update inputs of expr in node.users
  386. dep_exprs = self.get_dep_exprs(repl_node)
  387. i = 0
  388. while i < len(node.users):
  389. n = node.users[i]
  390. if n in dep_exprs:
  391. logger.info("Find a loop: ignore this replacement once")
  392. logger.info("node: %s" % node.__repr__())
  393. logger.info("repl_node: %s" % repl_node.__repr__())
  394. i += 1
  395. continue
  396. repl_node.users.append(n)
  397. node.users.pop(i)
  398. idx = n.inputs.index(node)
  399. n.inputs[idx] = repl_node
  400. def compile(self):
  401. """
  402. Delete unused expr.
  403. """
  404. dep_exprs = self.get_dep_exprs(self.outputs)
  405. i = 0
  406. while i < len(self._exprs):
  407. expr = self._exprs[i]
  408. if expr in dep_exprs or expr._disable_remove:
  409. i += 1
  410. continue
  411. for n in expr.inputs:
  412. n.users.remove(expr)
  413. self._exprs.remove(expr)
  414. def interpret(self, *inputs):
  415. node2value = {}
  416. for n, v in zip(self._inputs, inputs):
  417. node2value[n] = v
  418. for expr in self._exprs:
  419. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  420. if values is not None:
  421. for n, v in zip(expr.outputs, values):
  422. node2value[n] = v
  423. return list(node2value[i] for i in self._outputs)
  424. def __repr__(self):
  425. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  426. ", ".join(str(i) for i in self._inputs),
  427. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  428. ", ".join(str(i) for i in self._outputs),
  429. )
  430. def _get_meth_name(obj, func):
  431. tp = obj if isinstance(obj, type) else type(obj)
  432. for cls in tp.mro():
  433. for k, v in cls.__dict__.items():
  434. if v == func:
  435. return k
  436. return None
  437. def _wrapped_function(orig_func):
  438. @functools.wraps(orig_func)
  439. def wrapped_fn(*args, **kwargs):
  440. if is_tracing_module():
  441. unset_module_tracing()
  442. inputs, tree_def = tree_flatten(
  443. (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  444. )
  445. for i in inputs:
  446. if not NodeMixin.get(i, None):
  447. if isinstance(i, (RawTensor, NodeMixin)):
  448. NodeMixin.wrap_safe(i, Constant.make(i))
  449. meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
  450. if meth_name:
  451. self = inputs[0]
  452. if meth_name == "__new__":
  453. if all([not isinstance(i, RawTensor) for i in inputs]):
  454. # only trace Tensor.__new__() when there are tensors in args
  455. set_module_tracing()
  456. return orig_func(*args, **kwargs)
  457. if isinstance(args[1], RawTensor):
  458. node = NodeMixin.get(inputs[1])
  459. inputs[1] = copy.copy(inputs[1])
  460. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  461. NodeMixin.wrap_safe(inputs[1], node)
  462. args, kwargs = tree_def.unflatten(inputs)
  463. call_node = CallMethod.make(self, meth_name)
  464. else:
  465. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  466. call_node.add_inputs(inputs[1:])
  467. else:
  468. call_node = CallFunction.make(orig_func)
  469. call_node.add_inputs(inputs)
  470. call_node.arg_def = tree_def
  471. rst = orig_func(*args, **kwargs)
  472. if meth_name == "__setitem__":
  473. rst = self
  474. if rst is not None:
  475. outputs, out_def = tree_flatten(
  476. rst, leaf_type=_leaf_type, is_leaf=_is_leaf
  477. )
  478. call_node.out_def = out_def
  479. else:
  480. outputs = None
  481. call_node.add_outputs(outputs)
  482. set_module_tracing()
  483. return rst
  484. return orig_func(*args, **kwargs)
  485. return wrapped_fn
  486. class TracedModuleBuilder(NodeMixin):
  487. _mod = None # type: Module
  488. _body = None # type: InternalGraph
  489. _is_builtin = None # type: bool
  490. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  491. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  492. nodes = None
  493. __builder_attributes__ = [
  494. "_mod",
  495. "_body",
  496. "_NodeMixin__node",
  497. "_is_builtin",
  498. "build",
  499. "_argdef_graph_map",
  500. "_argdef_outdef_map",
  501. "nodes",
  502. ]
  503. def __init__(self, mod, is_top_module=False):
  504. super(TracedModuleBuilder, self).__init__()
  505. self._mod = mod
  506. self._body = None
  507. self._is_top = is_top_module
  508. self._is_builtin = module_tracer.is_builtin(mod)
  509. self._argdef_graph_map = {}
  510. self._argdef_outdef_map = {}
  511. self.nodes = set()
  512. def build(self):
  513. if self._is_builtin:
  514. for node in self.nodes:
  515. node.module_type = type(self._mod)
  516. # node._owner = weakref.ref(self._mod)
  517. return self._mod
  518. else:
  519. traced_module = TracedModule(
  520. self._is_top, self._argdef_graph_map, self._argdef_outdef_map
  521. )
  522. for _, g in self._argdef_graph_map.items():
  523. g.compile()
  524. # for node in self.nodes:
  525. # node._owner = weakref.ref(traced_module)
  526. for k, v in self.__dict__.items():
  527. if k not in TracedModuleBuilder.__builder_attributes__:
  528. if isinstance(v, TracedModuleBuilder):
  529. v = v.build()
  530. setattr(traced_module, k, v)
  531. return traced_module
  532. def _record_wrapped_nodes(self, node):
  533. self.nodes.add(node)
  534. def __call__(self, *args, **kwargs):
  535. assert isinstance(self._mod, Module)
  536. # prepare args and kwargs for inner graph
  537. def mark_constant(x):
  538. node = NodeMixin.get(x, None)
  539. if node is None: # capture as constant
  540. NodeMixin.wrap(x, lambda: Constant.make(x))
  541. inputs, tree_def = tree_flatten(
  542. ((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  543. )
  544. for i in inputs:
  545. mark_constant(i)
  546. callnode = CallMethod.make(NodeMixin.get(self))
  547. callnode.add_inputs(inputs[1:])
  548. callnode.arg_def = tree_def
  549. if self._is_builtin:
  550. unset_module_tracing()
  551. rst = self._mod(*args, **kwargs)
  552. outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
  553. set_module_tracing()
  554. if self._is_builtin:
  555. self._body = None
  556. else:
  557. self_node = None
  558. if tree_def in self._argdef_graph_map:
  559. self_node = self._argdef_graph_map[tree_def].inputs[0]
  560. self._body = InternalGraph()
  561. active_module_tracer().push_scope(self._body)
  562. # rebind self to new input node
  563. orig_self = NodeMixin.get(self)
  564. if self_node:
  565. NodeMixin.wrap_safe(self, self_node)
  566. active_module_tracer().current_scope().add_input(self_node)
  567. else:
  568. NodeMixin.wrap_safe(
  569. self,
  570. self_node
  571. if self_node
  572. else Input.make("self", NodeMixin.get_wrapped_type(self)),
  573. )
  574. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  575. # prepare args and kwargs for inner graph
  576. def wrap(x):
  577. if isinstance(x, (RawTensor, NodeMixin)):
  578. NodeMixin.wrap(
  579. x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
  580. )
  581. return x
  582. args = [self]
  583. for i in inputs[1:]:
  584. args.append(wrap(i))
  585. args, kwargs = tree_def.unflatten(args)
  586. active_module_tracer().patcher.auto_patch(
  587. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  588. )
  589. rst = type(self._mod).forward(*args, **kwargs)
  590. outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
  591. for i in (
  592. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  593. ):
  594. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  595. NodeMixin.get(self, None).actual_mnode.append(orig_self)
  596. NodeMixin.wrap_safe(self, orig_self)
  597. for arg, node in zip(inputs[1:], origin_inp_node):
  598. if node:
  599. NodeMixin.wrap_safe(arg, node)
  600. active_module_tracer().pop_scope()
  601. # rebind output to outer graph
  602. callnode.out_def = out_def
  603. callnode.add_outputs(outputs)
  604. self._argdef_graph_map[callnode.arg_def] = self._body
  605. self._argdef_outdef_map[callnode.arg_def] = out_def
  606. return rst
  607. def __getattr__(self, name):
  608. if name not in self._mod.__dict__:
  609. attr = getattr(type(self._mod), name).__get__(self, type(self))
  610. else:
  611. attr = getattr(self._mod, name)
  612. if isinstance(attr, Module):
  613. attr = TracedModuleBuilder(attr)
  614. setattr(self, name, attr)
  615. NodeMixin.wrap(
  616. attr,
  617. lambda: GetAttr.make(
  618. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  619. ),
  620. )
  621. return attr
  622. def __getattribute__(self, name):
  623. if name in TracedModuleBuilder.__builder_attributes__:
  624. return super().__getattribute__(name)
  625. else:
  626. wrapped = super().__getattribute__(name)
  627. if name in self._mod.__dict__:
  628. assert not self._is_builtin
  629. if isinstance(wrapped, (NodeMixin, RawTensor)):
  630. NodeMixin.wrap(
  631. wrapped,
  632. lambda: GetAttr.make(
  633. NodeMixin.get(self),
  634. name,
  635. type=NodeMixin.get_wrapped_type(wrapped),
  636. ),
  637. )
  638. """
  639. else:
  640. node = NodeMixin.get(wrapped)
  641. expr = node.expr
  642. assert isinstance(expr, GetAttr)
  643. if expr not in active_module_tracer().current_scope()._exprs:
  644. active_module_tracer().current_scope().insert(expr)
  645. """
  646. return wrapped
  647. class _expr_iter:
  648. def __init__(self, graph: InternalGraph):
  649. self.graph = graph
  650. def __iter__(self):
  651. for expr in self.graph._exprs:
  652. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  653. yield expr
  654. if expr.graph is not None:
  655. yield from expr.graph.expr_filter
  656. else:
  657. yield expr
  658. class _node_iter:
  659. def __init__(self, graph: InternalGraph) -> None:
  660. nodes = []
  661. node_ids = set()
  662. for expr in graph.expr_filter:
  663. for n in expr.inputs + expr.outputs:
  664. if n._id in node_ids:
  665. continue
  666. nodes.append(n)
  667. node_ids.add(n._id)
  668. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  669. def __iter__(self):
  670. for node in self.nodes:
  671. yield node
  672. class BaseFilter:
  673. def __init__(self, expr_iter: Iterable):
  674. self._iter = expr_iter
  675. def __iter__(self):
  676. return iter(self._iter)
  677. def as_list(self):
  678. return list(self)
  679. def as_dict(self):
  680. return collections.OrderedDict((i._id, i) for i in self)
  681. def as_unique(self):
  682. rst = self.as_list()
  683. assert len(rst) == 1, "{} elements found".format(len(rst))
  684. (expr,) = self
  685. return expr
  686. def as_count(self):
  687. return sum(1 for _ in self)
  688. class ExprFilter(BaseFilter):
  689. def call_function(self, func):
  690. return ExprFilterCallFunction(self, func)
  691. def call_method(self, method):
  692. return ExprFilterCallMethod(self, method)
  693. def expr_id(self, expr_id: List[int]):
  694. return ExprFilterExprId(self, expr_id)
  695. class NodeFilter(BaseFilter):
  696. def type(self, owner_type, node_type):
  697. return NodeFilterType(self, owner_type, node_type)
  698. def node_id(self, node_id: List[int]):
  699. return NodeFilterNodeId(self, node_id)
  700. class NodeFilterType(NodeFilter):
  701. def __init__(self, expr_iter, owner_type, node_type):
  702. super().__init__(expr_iter)
  703. self.owner_type = owner_type
  704. self.node_type = node_type
  705. def __iter__(self):
  706. for node in self._iter:
  707. if not isinstance(node, self.node_type):
  708. continue
  709. if not hasattr(node, "owner"):
  710. continue
  711. if isinstance(node.owner, self.owner_type):
  712. yield node
  713. class NodeFilterNodeId(NodeFilter):
  714. def __init__(self, expr_iter, node_id: List[int]):
  715. super().__init__(expr_iter)
  716. if not isinstance(node_id, Sequence):
  717. node_id = [node_id]
  718. self.node_id = node_id
  719. def __iter__(self):
  720. for node in self._iter:
  721. if node._id in self.node_id:
  722. yield node
  723. class ExprFilterCallFunction(ExprFilter):
  724. def __init__(self, expr_iter, func: Callable = None):
  725. super().__init__(expr_iter)
  726. self.func = func
  727. def __iter__(self):
  728. for expr in self._iter:
  729. if not isinstance(expr, CallFunction):
  730. continue
  731. if self.func is None or expr.func == self.func:
  732. yield expr
  733. class ExprFilterCallMethod(ExprFilter):
  734. def __init__(self, expr_iter, method: str = None):
  735. super().__init__(expr_iter)
  736. self.method = method
  737. def __iter__(self):
  738. for expr in self._iter:
  739. if not isinstance(expr, CallMethod):
  740. continue
  741. if self.method is None or expr.method == self.method:
  742. yield expr
  743. class ExprFilterExprId(ExprFilter):
  744. def __init__(self, expr_iter, expr_id: List[int]):
  745. super().__init__(expr_iter)
  746. if not isinstance(expr_id, Sequence):
  747. expr_id = [expr_id]
  748. self.expr_id = expr_id
  749. def __iter__(self):
  750. for expr in self._iter:
  751. if expr._id in self.expr_id:
  752. yield expr
  753. class TracedModule(Module):
  754. """
  755. `TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
  756. """
  757. # m_node = None # type: ModuleNode
  758. argdef_graph_map = None
  759. argdef_outdef_map = None
  760. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map):
  761. super(TracedModule, self).__init__()
  762. self.argdef_graph_map = argdef_graph_map
  763. self.argdef_outdef_map = argdef_outdef_map
  764. self._is_top = is_top
  765. def forward(self, *args, **kwargs):
  766. inputs, treedef = tree_flatten(
  767. ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
  768. )
  769. assert treedef in self.argdef_graph_map
  770. inputs = filter(
  771. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  772. ) # allow TracedModuleBuilder for retrace.
  773. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  774. out_def = self.argdef_outdef_map[treedef]
  775. outputs = out_def.unflatten(outputs)
  776. return outputs
  777. @property
  778. def graph(self) -> InternalGraph:
  779. if self._is_top:
  780. self._update_ref()
  781. assert len(self.argdef_graph_map) == 1
  782. return list(self.argdef_graph_map.values())[0]
  783. def _update_ref(self, actual_node_map: Union[Dict] = None):
  784. for inp_def, graph in self.argdef_graph_map.items():
  785. for n in graph._inputs + graph.outputs:
  786. n._top_graph = weakref.ref(graph)
  787. graph._inputs[0]._owner = weakref.ref(self)
  788. graph._inputs[0].actual_mnode = []
  789. if actual_node_map is not None and inp_def in actual_node_map.keys():
  790. graph._inputs[0].actual_mnode = actual_node_map[inp_def]
  791. node2obj = {}
  792. next_actual_node_map = collections.defaultdict(
  793. lambda: collections.defaultdict(list)
  794. )
  795. node2obj[graph._inputs[0]] = self
  796. for expr in graph._exprs:
  797. for n in expr.inputs + expr.outputs:
  798. n._top_graph = weakref.ref(graph)
  799. expr._top_graph = weakref.ref(graph)
  800. if isinstance(expr, GetAttr) and isinstance(
  801. expr.outputs[0], ModuleNode
  802. ):
  803. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  804. expr.outputs[0]._owner = weakref.ref(obj)
  805. node2obj[expr.outputs[0]] = obj
  806. if isinstance(expr, Constant) and isinstance(
  807. expr.outputs[0], ModuleNode
  808. ):
  809. obj = expr.value
  810. expr.outputs[0]._owner = weakref.ref(obj)
  811. node2obj[expr.outputs[0]] = obj
  812. if (
  813. isinstance(expr, CallMethod)
  814. and expr.method == "__call__"
  815. and isinstance(expr.inputs[0], ModuleNode)
  816. ):
  817. obj = node2obj[expr.inputs[0]]
  818. if expr.arg_def is not None:
  819. next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0])
  820. for obj in node2obj.values():
  821. if obj is self:
  822. continue
  823. mnode_map = None
  824. if obj in next_actual_node_map.keys():
  825. mnode_map = next_actual_node_map[obj]
  826. if isinstance(obj, TracedModule):
  827. obj._update_ref(mnode_map)
  828. def flatten(self):
  829. """
  830. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  831. :return: :class:`TracedModule`
  832. """
  833. new_module = copy.deepcopy(self)
  834. def _flatten_subgraph(graph, module, call=None):
  835. if graph is None:
  836. assert not isinstance(module, TracedModule)
  837. const = Constant(module)
  838. const.outputs[0] = call.inputs[0]
  839. const.outputs[0].expr = const
  840. return [const, call]
  841. if call is not None:
  842. graph = copy.deepcopy(graph)
  843. exprs = []
  844. node2obj = {}
  845. node2obj[graph._inputs[0]] = module
  846. if call:
  847. node2obj[call.inputs[0]] = module
  848. repl_dict = dict(zip(graph._inputs, call.inputs))
  849. for ind, out in enumerate(graph.outputs):
  850. if isinstance(out.expr, Input):
  851. assert out in repl_dict
  852. call_out = call.outputs[ind]
  853. for expr in call.outputs[ind].users:
  854. for index, inp in enumerate(expr.inputs):
  855. if inp is call_out:
  856. expr.inputs[index] = repl_dict[out]
  857. continue
  858. repl_dict[out] = call.outputs[ind]
  859. graph._replace_inputs_outputs(repl_dict)
  860. for expr in graph._exprs:
  861. if isinstance(expr, GetAttr):
  862. # replace GetAttr with Constant
  863. if isinstance(expr.outputs[0], TensorNode):
  864. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  865. const.outputs = expr.outputs
  866. const.outputs[0].expr = const
  867. exprs.append(const)
  868. elif isinstance(expr.outputs[0], ModuleNode):
  869. node2obj[expr.outputs[0]] = getattr(
  870. node2obj[expr.inputs[0]], expr.name
  871. )
  872. elif isinstance(expr, CallMethod):
  873. obj_node = expr.inputs[0]
  874. if isinstance(obj_node, ModuleNode):
  875. pre_expr = expr.inputs[0].expr
  876. if isinstance(pre_expr, GetAttr):
  877. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  878. expr_graph = (
  879. obj.argdef_graph_map[expr.arg_def]
  880. if hasattr(obj, "argdef_graph_map")
  881. else None
  882. )
  883. exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
  884. else:
  885. # module has been replaced.
  886. assert isinstance(pre_expr, Constant)
  887. exprs.append(expr)
  888. else:
  889. exprs.append(expr)
  890. else:
  891. exprs.append(expr)
  892. if call is not None:
  893. for i in call.inputs:
  894. i.users.remove(call)
  895. return exprs
  896. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  897. return new_module
  898. def __getstate__(self):
  899. d = self.__dict__
  900. for k in Module.__dict__:
  901. d.pop(k, None)
  902. return d
  903. def cpp_apply_module_trace(opdef, *args):
  904. return Apply.apply_module_trace_hook(opdef, *args)
  905. def register_as_builtin(mod_cls: Type[Module]) -> None:
  906. """
  907. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  908. param mod_cls: the Module class which will be threated as builtin module in tracing
  909. """
  910. module_tracer.register_as_builtin(mod_cls)
  911. def wrap(func: Union[Callable]):
  912. assert callable(func)
  913. if hasattr(func, "__code__"):
  914. assert not isinstance(func, str)
  915. fn_name = func.__code__.co_name
  916. currentframe = inspect.currentframe()
  917. assert currentframe is not None
  918. f = currentframe.f_back
  919. assert f is not None
  920. if f.f_code.co_name != "<module>":
  921. raise NotImplementedError("wrap must be called at the top level of a module")
  922. Patcher._builtin_functions.append((f.f_globals, fn_name))
  923. return func
  924. def _register_all_builtin_module():
  925. for sub_mod in [M, M.qat, M.quantized]:
  926. for m in getmembers(sub_mod):
  927. if (
  928. isclass(m[1])
  929. and issubclass(m[1], M.Module)
  930. and m[1] is not M.Sequential
  931. ):
  932. module_tracer.register_as_builtin(m[1])
  933. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  934. """
  935. Traces module ``mod`` and returns corresponding TracedModule.
  936. param mod: the module will be converted to TracedModule
  937. param input: the positional arguments passed to forward method of ``mod``
  938. param kwargs: the keyword arguments passed to forward method of ``mod``
  939. """
  940. assert active_module_tracer() is None
  941. try:
  942. use_sym_shape = set_symbolic_shape(True)
  943. set_module_tracing()
  944. set_active_module_tracer(module_tracer(_wrapped_function))
  945. with active_module_tracer().patcher:
  946. global_scope = InternalGraph()
  947. active_module_tracer().push_scope(global_scope)
  948. builder = TracedModuleBuilder(mod, True)
  949. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  950. inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
  951. for _, i in enumerate(inputs):
  952. assert isinstance(i, Tensor), "not support "
  953. if isinstance(i, RawTensor):
  954. NodeMixin.wrap_safe(
  955. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  956. )
  957. builder(*args, **kwargs)
  958. active_module_tracer().pop_scope()
  959. return builder.build()
  960. finally:
  961. set_symbolic_shape(use_sym_shape)
  962. set_active_module_tracer(None)
  963. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台