You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 64 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import builtins
  10. import collections
  11. import copy
  12. import ctypes
  13. import fnmatch
  14. import functools
  15. import inspect
  16. import keyword
  17. import re
  18. import weakref
  19. from inspect import getcallargs, getmembers, isclass, ismethod
  20. from itertools import chain
  21. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
  22. from megengine import tensor
  23. from ... import functional as F
  24. from ... import get_logger
  25. from ... import module as M
  26. from ...core._imperative_rt.core2 import Tensor as RawTensor
  27. from ...core._imperative_rt.core2 import (
  28. is_tracing_module,
  29. set_module_tracing,
  30. unset_module_tracing,
  31. )
  32. from ...core._trace_option import set_symbolic_shape
  33. from ...core.tensor.array_method import ArrayMethodMixin
  34. from ...module import Module
  35. from ...module.qat import QATModule
  36. from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
  37. from ...quantization.observer import (
  38. ExponentialMovingAverageObserver,
  39. HistogramObserver,
  40. MinMaxObserver,
  41. Observer,
  42. PassiveObserver,
  43. SyncExponentialMovingAverageObserver,
  44. SyncMinMaxObserver,
  45. )
  46. from ...tensor import Tensor
  47. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  48. from .fake_quant import FakeQuantize as TM_FakeQuant
  49. from .module_tracer import (
  50. PatchedFn,
  51. Patcher,
  52. active_module_tracer,
  53. get_tensor_wrapable_method,
  54. module_tracer,
  55. set_active_module_tracer,
  56. )
  57. from .node import ModuleNode, Node, NodeMixin, TensorNode
  58. from .pytree import ArgsIndex, tree_flatten
  59. from .utils import replace_container_with_module_container
  60. logger = get_logger(__name__)
  61. def _is_builtin_name(name: str) -> bool:
  62. return (
  63. name in builtins.__dict__
  64. or name in keyword.kwlist
  65. or name in {"inf", "nan", "NoneType"}
  66. )
  67. def _is_leaf(node):
  68. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  69. type(node)
  70. )
  71. return isinstance(node, RawTensor)
  72. _enable_node_to_tensor = False
  73. def _convert_node_flag():
  74. return _enable_node_to_tensor
  75. def _set_convert_node_flag(flag: bool = False):
  76. global _enable_node_to_tensor
  77. pre_flag = _enable_node_to_tensor
  78. _enable_node_to_tensor = flag
  79. return pre_flag
  80. def _node_to_tensor(*args, **kwargs):
  81. tensors = []
  82. nodes, tree_def = tree_flatten((args, kwargs))
  83. for n in nodes:
  84. if isinstance(n, TensorNode):
  85. if n.top_graph is not None:
  86. active_module_tracer().current_scope()._add_input(n)
  87. value = n.value
  88. if value is None:
  89. flag = _set_convert_node_flag(False)
  90. unset_module_tracing()
  91. value = F.zeros(shape=n._shape, dtype=n._dtype)
  92. set_module_tracing()
  93. _set_convert_node_flag(flag)
  94. orig_n = NodeMixin.get(value, None)
  95. if orig_n is None or "setitem" not in orig_n._name:
  96. NodeMixin.wrap_safe(value, n)
  97. tensors.append(value)
  98. else:
  99. tensors.append(n)
  100. tensors = tree_def.unflatten(tensors)
  101. return tensors
  102. def _tensor_to_node(tensors):
  103. if tensors is None:
  104. return None
  105. nodes = []
  106. tensors, out_def = tree_flatten(tensors)
  107. for t in tensors:
  108. if isinstance(t, Tensor):
  109. n = NodeMixin.get(t, None)
  110. if isinstance(n, TensorNode):
  111. n.value = t
  112. nodes.append(n)
  113. else:
  114. nodes.append(t)
  115. else:
  116. nodes.append(t)
  117. nodes = out_def.unflatten(nodes)
  118. return nodes
  119. def _wrap_method_to_tensor_node():
  120. def _any_method(name):
  121. def _any(*args, **kwargs):
  122. args, kwargs = _node_to_tensor(*args, **kwargs)
  123. attr = getattr(args[0], name)
  124. outs = attr
  125. if callable(attr):
  126. outs = attr(*(args[1:]), **kwargs)
  127. if name == "__setitem__":
  128. _node_to_tensor(outs)
  129. return None
  130. outs = _tensor_to_node(outs)
  131. return outs
  132. return _any
  133. tensor_method_patch = []
  134. for method in get_tensor_wrapable_method():
  135. patch = PatchedFn(TensorNode, method)
  136. if type(getattr(Tensor, method)) == property:
  137. patch.set_func(property(_any_method(method)))
  138. else:
  139. patch.set_func(_any_method(method))
  140. tensor_method_patch.append(patch)
  141. return tensor_method_patch
  142. def _convert_node_and_tensor(orig_func):
  143. @functools.wraps(orig_func)
  144. def _convert(*args, **kwargs):
  145. if _convert_node_flag() and is_tracing_module():
  146. args, kwargs = _node_to_tensor(*args, **kwargs)
  147. rst = orig_func(*args, **kwargs, method_func=_convert)
  148. rst = _tensor_to_node(rst)
  149. return rst
  150. else:
  151. rst = orig_func(*args, **kwargs)
  152. return rst
  153. return _convert
  154. def _wrap_mnode_getattr(orig_getattr):
  155. @functools.wraps(orig_getattr)
  156. def wraped_fn(self, name):
  157. obj = self.owner
  158. if self.top_graph is not None:
  159. active_module_tracer().current_scope()._add_input(self)
  160. attr = getattr(obj, name)
  161. node = attr
  162. full_name = None
  163. if id(attr) in active_module_tracer().id2name:
  164. full_name = active_module_tracer().id2name[id(attr)]
  165. if not isinstance(attr, TracedModuleBuilder):
  166. if isinstance(attr, Module):
  167. attr = TracedModuleBuilder(attr)
  168. setattr(obj, name, attr)
  169. active_module_tracer().id2name[id(attr)] = full_name
  170. if isinstance(attr, (NodeMixin, RawTensor)):
  171. if full_name:
  172. scope_name = active_module_tracer().current_scope()._module_name
  173. if scope_name:
  174. full_name = full_name[len(scope_name) + 1 :]
  175. else:
  176. full_name = name
  177. else:
  178. full_name = name
  179. NodeMixin.wrap(
  180. attr,
  181. lambda: GetAttr.make(
  182. self,
  183. name,
  184. type=NodeMixin.get_wrapped_type(attr),
  185. orig_name=full_name,
  186. ),
  187. )
  188. if isinstance(attr, (NodeMixin, RawTensor)):
  189. node = NodeMixin.get(attr)
  190. if isinstance(node, ModuleNode):
  191. node._owner = weakref.ref(attr)
  192. return node
  193. return wraped_fn
  194. def _wrap_mnode_call(orig_call):
  195. @functools.wraps(orig_call)
  196. def wraped_fn(self, *args, **kwargs):
  197. obj = self.owner
  198. if self.top_graph is not None:
  199. active_module_tracer().current_scope()._add_input(self)
  200. rst = obj(*args, **kwargs)
  201. return rst
  202. return wraped_fn
  203. def _init_id2name(mod: Module, prefix: str = ""):
  204. id2name = {
  205. id(m): "%s.%s" % (prefix, key)
  206. for key, m in chain(
  207. mod.named_modules(), mod.named_parameters(), mod.named_buffers()
  208. )
  209. }
  210. return id2name
  211. class _InsertExprs:
  212. def __init__(self, graph, expr: Optional[Expr] = None):
  213. self.graph = graph
  214. self.global_scope = InternalGraph(
  215. graph._name, graph._prefix_name, graph._module_name
  216. )
  217. self.global_scope._used_names.update(graph._used_names)
  218. self.expr = expr
  219. self._tensor_method_patch = None
  220. def __enter__(self):
  221. self.use_sym_shape = set_symbolic_shape(True)
  222. set_module_tracing()
  223. _set_convert_node_flag(True)
  224. assert active_module_tracer() is None
  225. module = self.graph.inputs[0].owner
  226. _wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x))
  227. set_active_module_tracer(
  228. module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name))
  229. )
  230. active_module_tracer().patcher.__enter__()
  231. for cls, name, func in [
  232. [ModuleNode, "__getattr__", _wrap_mnode_getattr],
  233. [ModuleNode, "__call__", _wrap_mnode_call],
  234. [TracedModuleBuilder, "__call__", _convert_node_and_tensor],
  235. ]:
  236. active_module_tracer().patcher.patch_function(cls, name, func)
  237. self._tensor_method_patch = _wrap_method_to_tensor_node()
  238. active_module_tracer().push_scope(self.global_scope)
  239. def __exit__(self, ty, va, tr):
  240. if va is not None:
  241. return False
  242. set_symbolic_shape(self.use_sym_shape)
  243. unset_module_tracing()
  244. active_module_tracer().patcher.__exit__(ty, va, tr)
  245. _set_convert_node_flag(False)
  246. while self._tensor_method_patch:
  247. pf = self._tensor_method_patch.pop()
  248. pf.set_func(pf.origin_fn)
  249. module = self.graph.inputs[0].owner
  250. for mod, parent in module.modules(with_parent=True):
  251. name = mod._name
  252. if isinstance(mod, TracedModuleBuilder):
  253. mod = mod.build()
  254. if hasattr(mod, "graph"):
  255. for node in mod.graph.nodes():
  256. node.value = None
  257. setattr(parent, name, mod)
  258. set_active_module_tracer(None)
  259. for node in self.global_scope.nodes():
  260. node.value = None
  261. extra_inp_nodes = set(self.global_scope.inputs)
  262. max_inp_expr_idx = -1
  263. for node in extra_inp_nodes:
  264. assert (
  265. node.top_graph == self.graph
  266. ), "The input node ({}) is not in the graph ({})".format(node, self.graph)
  267. if isinstance(node, TensorNode) and node.expr in self.graph._exprs:
  268. max_inp_expr_idx = max(
  269. max_inp_expr_idx, self.graph._exprs.index(node.expr)
  270. )
  271. max_inp_expr_idx += 1
  272. insert_index = -1
  273. if self.expr is not None:
  274. insert_index = self.graph._exprs.index(self.expr)
  275. insert_index += 1
  276. if insert_index < max_inp_expr_idx:
  277. insert_index = max_inp_expr_idx
  278. anchor_index = insert_index - 1
  279. if anchor_index >= 0:
  280. logger.info(
  281. "The new expr will be inserted after ( {} )".format(
  282. self.graph._exprs[anchor_index]
  283. )
  284. )
  285. for expr in self.global_scope._exprs:
  286. self.graph._exprs.insert(insert_index, expr)
  287. insert_index += 1
  288. self.graph._used_names.update(self.global_scope._used_names)
  289. graph = self.graph
  290. while graph.top_graph is not None:
  291. graph = graph.top_graph
  292. graph.inputs[0].owner._update_ref()
  293. return True
  294. class InternalGraph:
  295. """
  296. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  297. Attributes:
  298. _exprs: List of Exprs in order of execution
  299. _inputs: Input Nodes of InternalGraph
  300. _outputs: Output Nodes of InternalGraph
  301. """
  302. _exprs = None # type: List[Expr]
  303. _inputs = None # type: List[Node]
  304. _outputs = None # type: List[Node]
  305. _top_graph = None
  306. def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
  307. self._exprs = []
  308. self._inputs = []
  309. self._outputs = []
  310. self._watch_point = []
  311. self._end_point = []
  312. self._used_names = {}
  313. self._rst = collections.defaultdict(list)
  314. self._name = name
  315. self._prefix_name = prefix_name
  316. self._module_name = module_name
  317. def _insert(self, expr):
  318. self._exprs.append(expr)
  319. def _create_unique_name(self, name: str) -> str:
  320. assert isinstance(name, str), "The name must be a str"
  321. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  322. if name[0].isdigit():
  323. name = "_{}".format(name)
  324. while name in self._used_names or _is_builtin_name(name):
  325. match = re.match(r"(.*)_(\d+)$", name)
  326. if match is None:
  327. name = name + "_1"
  328. else:
  329. base, num = match.group(1, 2)
  330. name = "{}_{}".format(base, int(num) + 1)
  331. self._used_names.setdefault(name)
  332. return name
  333. @property
  334. def inputs(self):
  335. return self._inputs
  336. @property
  337. def outputs(self):
  338. return self._outputs
  339. @property
  340. def top_graph(self):
  341. if self._top_graph:
  342. return self._top_graph()
  343. return None
  344. def exprs(self, recursive=True):
  345. return ExprFilter(_expr_iter(self, recursive))
  346. def nodes(self, recursive=True):
  347. return NodeFilter(_node_iter(self, recursive))
  348. def get_function_by_type(self, func: Callable = None, recursive=True):
  349. return self.exprs(recursive).call_function(func)
  350. def get_method_by_type(self, method: str = None, recursive=True):
  351. return self.exprs(recursive).call_method(method)
  352. def get_expr_by_id(self, expr_id: List[int] = None, recursive=True):
  353. return self.exprs(recursive).expr_id(expr_id)
  354. def get_module_by_type(self, module_cls: Module, recursive=True):
  355. assert issubclass(module_cls, Module)
  356. return self.nodes(recursive).type(module_cls, ModuleNode)
  357. def get_node_by_id(self, node_id: List[int] = None, recursive=True):
  358. return self.nodes(recursive).node_id(node_id)
  359. def get_node_by_name(
  360. self, name: str = None, ignorecase: bool = True, recursive=True
  361. ):
  362. return self.nodes(recursive).name(name, ignorecase)
  363. def _add_input(self, i):
  364. self._inputs.append(i)
  365. def _add_output(self, o):
  366. self._outputs.append(o)
  367. def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""):
  368. for node, repl_node in repl_dict.items():
  369. assert node in self._inputs or node in self._outputs
  370. for i in node.users:
  371. if i not in repl_node.users:
  372. repl_node.users.append(i)
  373. for idx, i in enumerate(self._inputs):
  374. if i in repl_dict:
  375. self._inputs[idx] = repl_dict[i]
  376. for idx, o in enumerate(self._outputs):
  377. if o in repl_dict:
  378. repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name)
  379. self._outputs[idx] = repl_dict[o]
  380. for expr in self._exprs:
  381. for idx, i in enumerate(expr.inputs):
  382. assert isinstance(
  383. i._name, str
  384. ), "The node ({}) name must be a str".format(i)
  385. if i in repl_dict:
  386. expr.inputs[idx] = repl_dict[i]
  387. elif isinstance(i, TensorNode) and prefix_name not in i._name:
  388. if i.top_graph != active_module_tracer().current_scope():
  389. i._name = (
  390. active_module_tracer()
  391. .current_scope()
  392. ._create_unique_name(prefix_name + i._name.lstrip("_"))
  393. )
  394. i._orig_name = "{}{}".format(module_name, i._orig_name)
  395. for idx, o in enumerate(expr.outputs):
  396. assert isinstance(
  397. o._name, str
  398. ), "The node ({}) name must be a str".format(i)
  399. if o in repl_dict:
  400. expr.outputs[idx] = repl_dict[o]
  401. expr.outputs[idx].expr = expr
  402. elif isinstance(o, TensorNode) and prefix_name not in i._name:
  403. if o.top_graph != active_module_tracer().current_scope():
  404. o._name = (
  405. active_module_tracer()
  406. .current_scope()
  407. ._create_unique_name(prefix_name + o._name.lstrip("_"))
  408. )
  409. o._orig_name = "{}{}".format(module_name, o._orig_name)
  410. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  411. if not isinstance(nodes, Sequence):
  412. nodes = (nodes,)
  413. ret = list()
  414. queue = list(nodes)
  415. visited_queue = list()
  416. while queue:
  417. node = queue.pop()
  418. visited_queue.append(node)
  419. expr = node.expr
  420. if expr not in ret:
  421. ret.append(expr)
  422. for i in expr.inputs:
  423. if i not in queue and i not in visited_queue:
  424. queue.append(i)
  425. return ret
  426. def reset_inputs(self, *args, **kwargs):
  427. forma_mnode = self.inputs[0]
  428. actual_mnodes = forma_mnode.actual_node
  429. call_nodes = []
  430. for n in actual_mnodes:
  431. for c_expr in n.users:
  432. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  433. call_nodes.append((c_expr, n))
  434. moudle = forma_mnode.owner
  435. assert moudle._is_top, "reset_inputs only support the top-level graph"
  436. inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
  437. def create_node(val: Tensor):
  438. node = Input(type=TensorNode).outputs[0]
  439. node.shape = val.shape
  440. node.dtype = val.dtype
  441. return node
  442. formal_node_inputs = [
  443. forma_mnode,
  444. ]
  445. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  446. if call_nodes:
  447. org_argdef = call_nodes[0][0].arg_def
  448. for v in inputs[1:]:
  449. assert isinstance(v, RawTensor)
  450. formal_node_inputs.append(create_node(v))
  451. actual_nodes = []
  452. for e, n in call_nodes:
  453. e.arg_def = tree_def
  454. actual_node_inputs = [
  455. n,
  456. ]
  457. for v in inputs[1:]:
  458. actual_node_inputs.append(create_node(v))
  459. for org_n in e.inputs:
  460. org_n.users.pop(e)
  461. e.inputs[:] = actual_node_inputs
  462. e.const_val = []
  463. actual_nodes.append(actual_node_inputs[1:])
  464. self._inputs[:] = formal_node_inputs
  465. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  466. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  467. # return formal_node_inputs[1:], actual_nodes
  468. return formal_node_inputs[1:]
  469. def add_input_node(self, shape, dtype="float32", name="args"):
  470. forma_mnode = self.inputs[0]
  471. actual_mnodes = forma_mnode.actual_node
  472. moudle = forma_mnode.owner
  473. assert moudle._is_top, "add_input_node only support the top-level graph"
  474. call_nodes = []
  475. for n in actual_mnodes:
  476. for c_expr in n.users:
  477. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  478. call_nodes.append(c_expr)
  479. def create_node(name=None, is_input: bool = True):
  480. if is_input:
  481. node = Input(type=TensorNode, name=name).outputs[0]
  482. else:
  483. node = TensorNode(expr=None, name=None)
  484. node.shape = shape
  485. node.dtype = dtype
  486. return node
  487. org_argdef = list(moudle.argdef_graph_map.keys())[0]
  488. if call_nodes:
  489. org_argdef = call_nodes[0].arg_def
  490. args, kwargs = org_argdef.unflatten(self._inputs)
  491. formal_inp_node = create_node(self._create_unique_name(name), True)
  492. inputs, tree_def = tree_flatten(
  493. ((*args, formal_inp_node), kwargs),
  494. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  495. )
  496. self._inputs[:] = inputs[:]
  497. actual_inp_nodes = []
  498. for e in call_nodes:
  499. args, kwargs = e.unflatten_args(e.inputs)
  500. args = args + (create_node(False),)
  501. inputs, tree_def = tree_flatten(
  502. (args, kwargs),
  503. is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
  504. )
  505. e.inputs[:] = inputs[:]
  506. e.arg_def = tree_def
  507. actual_inp_nodes.append(args[-1])
  508. moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef)
  509. moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef)
  510. # return formal_inp_node, actual_inp_nodes
  511. return formal_inp_node
  512. def reset_outputs(self, outputs):
  513. outputs, out_def = tree_flatten(
  514. outputs, is_leaf=lambda x: isinstance(x, TensorNode),
  515. )
  516. forma_mnode = self.inputs[0]
  517. moudle = forma_mnode.owner
  518. assert moudle._is_top, "reset_outputs only support the top-level graph"
  519. actual_mnodes = forma_mnode.actual_node
  520. call_nodes = []
  521. for n in actual_mnodes:
  522. for c_expr in n.users:
  523. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  524. call_nodes.append((c_expr))
  525. def create_node(val: TensorNode, expr: Expr):
  526. node = TensorNode(expr)
  527. node.shape = val.shape
  528. node.dtype = val.dtype
  529. return node
  530. tree_def = list(moudle.argdef_graph_map.keys())[0]
  531. if call_nodes:
  532. tree_def = call_nodes[0].arg_def
  533. actual_nodes = []
  534. for e in call_nodes:
  535. actual_node_outputs = []
  536. for v in outputs:
  537. actual_node_outputs.append(create_node(v, e))
  538. e.outputs[:] = actual_node_outputs
  539. e.out_def = out_def
  540. actual_nodes.append(actual_node_outputs)
  541. self._outputs[:] = outputs
  542. moudle.argdef_outdef_map[tree_def] = out_def
  543. return actual_nodes
  544. def add_output_node(self, node: TensorNode):
  545. forma_mnode = self.inputs[0]
  546. moudle = forma_mnode.owner
  547. assert moudle._is_top, "add_output_node only support the top-level graph"
  548. actual_mnodes = forma_mnode.actual_node
  549. call_nodes = []
  550. for n in actual_mnodes:
  551. for c_expr in n.users:
  552. if isinstance(c_expr, CallMethod) and c_expr.method == "__call__":
  553. call_nodes.append((c_expr))
  554. def create_node(val: TensorNode, expr: Expr):
  555. node = TensorNode(expr)
  556. node.shape = val.shape
  557. node.dtype = val.dtype
  558. return node
  559. tree_def = list(moudle.argdef_graph_map.keys())[0]
  560. if call_nodes:
  561. tree_def = call_nodes[0].arg_def
  562. org_out_def = moudle.argdef_outdef_map[tree_def]
  563. org_outs = org_out_def.unflatten(self._outputs)
  564. outputs, out_def = tree_flatten(
  565. (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
  566. )
  567. self._outputs[:] = outputs
  568. actual_out_nodes = []
  569. for e in call_nodes:
  570. actual_node = create_node(node, e)
  571. org_outs = org_out_def.unflatten(e.outputs)
  572. outputs, out_def = tree_flatten(
  573. (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
  574. )
  575. e.outputs[:] = outputs
  576. e.out_def = out_def
  577. actual_out_nodes.append(actual_node)
  578. moudle.argdef_outdef_map[tree_def] = out_def
  579. return actual_out_nodes
  580. def insert_exprs(self, expr: Optional[Expr] = None):
  581. if expr is not None:
  582. assert expr.top_graph == self, "Expr to insert after is not in graph."
  583. return _InsertExprs(self, expr)
  584. def replace_node(self, repl_dict: Dict[Node, Node]):
  585. while repl_dict:
  586. node, repl_node = repl_dict.popitem()
  587. # check graph inputs and outputs
  588. # assert node not in self.inputs, "Cannot replace inputs"
  589. for i, n in enumerate(self.outputs):
  590. if n is node:
  591. self.outputs[i] = repl_node
  592. # update users of node and repl_node
  593. # update inputs of expr in node.users
  594. graph = repl_node.top_graph
  595. assert graph is not None
  596. index = graph._exprs.index(repl_node.expr)
  597. dep_exprs = self.get_dep_exprs(repl_node)
  598. i = 0
  599. while i < len(node.users):
  600. n = node.users[i]
  601. if n in graph._exprs and index >= graph._exprs.index(n):
  602. i += 1
  603. continue
  604. if n in dep_exprs:
  605. logger.info("Find a loop: ignore this replacement once")
  606. logger.info("node: %s" % node.__repr__())
  607. logger.info("expr: %s" % n.__repr__())
  608. i += 1
  609. continue
  610. repl_node.users.append(n)
  611. node.users.pop(i)
  612. idx = n.inputs.index(node)
  613. n.inputs[idx] = repl_node
  614. def compile(self):
  615. """
  616. Delete unused expr.
  617. """
  618. dep_exprs = self.get_dep_exprs(self.outputs)
  619. i = 0
  620. while i < len(self._exprs):
  621. expr = self._exprs[i]
  622. if expr in dep_exprs or expr._disable_remove:
  623. i += 1
  624. continue
  625. for n in expr.inputs:
  626. n.users.remove(expr)
  627. self._exprs.remove(expr)
  628. def interpret(self, *inputs):
  629. node2value = {}
  630. end_nodes_set = set(self._end_point)
  631. endnode2value = {}
  632. def get_all_endnode_val(n, v):
  633. if n in end_nodes_set:
  634. endnode2value[n] = v
  635. end_nodes_set.remove(n)
  636. return not end_nodes_set
  637. return False
  638. for n, v in zip(self._inputs, inputs):
  639. node2value[n] = v
  640. if n in self._watch_point:
  641. self._rst[n].append(v)
  642. if n in self._end_point and get_all_endnode_val(n, v):
  643. return list(endnode2value[i] for i in self._end_point)
  644. for expr in self._exprs:
  645. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  646. if values is not None:
  647. for n, v in zip(expr.outputs, values):
  648. node2value[n] = v
  649. if n in self._watch_point:
  650. self._rst[n] = v
  651. if self._end_point and get_all_endnode_val(n, v):
  652. return list(endnode2value[i] for i in self._end_point)
  653. return list(node2value[i] for i in self._outputs)
  654. def eval(self, *inputs):
  655. assert len(inputs) == len(self._inputs) - 1
  656. inp = [self._inputs[0].owner] + list(inputs)
  657. return self.interpret(*inp)
  658. def __repr__(self):
  659. return self.__format__()
  660. def __format__(self, format_spec: str = "") -> str:
  661. saved_format_spec = Node.set_format_spec(format_spec)
  662. name = ""
  663. if self._name:
  664. name = "%s.Graph" % self._name
  665. res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format(
  666. name,
  667. ", ".join(str(i) for i in self._inputs),
  668. "\n\t".join("{}".format(str(i)) for i in self._exprs),
  669. ", ".join(str(i) for i in self._outputs),
  670. )
  671. Node.set_format_spec(saved_format_spec)
  672. return res
  673. def __getstate__(self):
  674. state = self.__dict__.copy()
  675. if "_top_graph" in state:
  676. state.pop("_top_graph")
  677. return state
  678. def _get_meth_name(obj, func):
  679. tp = obj if isinstance(obj, type) else type(obj)
  680. for cls in tp.mro():
  681. for k, v in cls.__dict__.items():
  682. if v == func:
  683. return k
  684. return None
  685. def _wrapped_function(orig_func):
  686. @functools.wraps(orig_func)
  687. def wrapped_fn(*args, **kwargs):
  688. method_func = wrapped_fn
  689. if "method_func" in kwargs:
  690. method_func = kwargs.pop("method_func")
  691. if is_tracing_module():
  692. unset_module_tracing()
  693. inputs, tree_def = tree_flatten((args, kwargs))
  694. for i in inputs:
  695. if not NodeMixin.get(i, None):
  696. if isinstance(i, (RawTensor, NodeMixin)):
  697. NodeMixin.wrap_safe(i, Constant.make(i))
  698. meth_name, arg_type = None, None
  699. if args:
  700. meth_name = _get_meth_name(args[0], method_func)
  701. arg_type = args[0] if isinstance(args[0], type) else type(args[0])
  702. if meth_name and arg_type and issubclass(arg_type, RawTensor):
  703. self = inputs[0]
  704. if meth_name == "__new__":
  705. if all([not isinstance(i, RawTensor) for i in inputs]):
  706. # only trace Tensor.__new__() when there are tensors in args
  707. set_module_tracing()
  708. return orig_func(*args, **kwargs)
  709. if isinstance(args[1], RawTensor):
  710. node = NodeMixin.get(inputs[1])
  711. inputs[1] = copy.copy(inputs[1])
  712. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  713. NodeMixin.wrap_safe(inputs[1], node)
  714. args, kwargs = tree_def.unflatten(inputs)
  715. call_node = CallMethod.make(self, meth_name)
  716. else:
  717. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  718. call_node.add_inputs(inputs[1:])
  719. else:
  720. call_node = CallFunction.make(orig_func)
  721. call_node.add_inputs(inputs)
  722. call_node.arg_def = tree_def
  723. rst = orig_func(*args, **kwargs)
  724. if meth_name == "__setitem__":
  725. rst = self
  726. if rst is not None:
  727. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  728. call_node.out_def = out_def
  729. else:
  730. outputs = None
  731. call_node.add_outputs(outputs)
  732. set_module_tracing()
  733. return rst
  734. return orig_func(*args, **kwargs)
  735. return wrapped_fn
  736. class TracedModuleBuilder(NodeMixin):
  737. _mod = None # type: Module
  738. _body = None # type: InternalGraph
  739. _is_builtin = None # type: bool
  740. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  741. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  742. nodes = None
  743. __builder_attributes__ = [
  744. "_mod",
  745. "_body",
  746. "_NodeMixin__node",
  747. "_is_builtin",
  748. "build",
  749. "_record_wrapped_nodes",
  750. "_argdef_graph_map",
  751. "_argdef_outdef_map",
  752. "nodes",
  753. "__class__",
  754. "__dict__",
  755. ]
  756. def __init__(self, mod, is_top_module=False):
  757. super(TracedModuleBuilder, self).__init__()
  758. assert isinstance(mod, Module)
  759. self._mod = mod
  760. self._body = None
  761. self._is_top = is_top_module
  762. self._is_builtin = (
  763. True
  764. if isinstance(mod, (Observer, _FakeQuantize))
  765. else module_tracer.is_builtin(mod)
  766. )
  767. if isinstance(self._mod, QATModule):
  768. unset_module_tracing()
  769. self._check_qat_module(self._mod)
  770. set_module_tracing()
  771. self._argdef_graph_map = {}
  772. self._argdef_outdef_map = {}
  773. self.nodes = set()
  774. # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
  775. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
  776. self.__class__ = type(
  777. "TracedModuleBuilder",
  778. (TracedModuleBuilder, mod.__class__),
  779. dict(TracedModuleBuilder.__dict__),
  780. )
  781. def _check_qat_module(self, qat_module):
  782. def isbuiltin(m):
  783. return m is None or module_tracer.is_builtin(m)
  784. if qat_module.with_act:
  785. act_observer = qat_module.act_observer
  786. act_fakequant = qat_module.act_fake_quant
  787. if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
  788. qparams = (
  789. act_observer.get_qparams()
  790. if hasattr(act_observer, "get_qparams")
  791. else act_fakequant.get_qparams()
  792. )
  793. dtype = (
  794. act_observer.dtype
  795. if hasattr(act_observer, "dtype")
  796. else act_fakequant.dtype
  797. )
  798. qat_module.act_observer = None
  799. qat_module.act_fake_quant = TM_FakeQuant(dtype)
  800. qat_module.act_fake_quant.set_qparams(qparams)
  801. if qat_module.with_weight:
  802. weight_observer = qat_module.weight_observer
  803. weight_fakequant = qat_module.weight_fake_quant
  804. if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
  805. qparams = (
  806. weight_observer.get_qparams()
  807. if hasattr(weight_observer, "get_qparams")
  808. else weight_fakequant.get_qparams()
  809. )
  810. dtype = (
  811. weight_observer.dtype
  812. if hasattr(weight_observer, "dtype")
  813. else weight_fakequant.dtype
  814. )
  815. qat_module.weight_observer = None
  816. qat_module.weight_fake_quant = TM_FakeQuant(dtype)
  817. qat_module.weight_fake_quant.set_qparams(qparams)
  818. def build(self):
  819. if self._is_builtin or isinstance(self._mod, TracedModule):
  820. if module_tracer.is_builtin(self._mod) or isinstance(
  821. self._mod, TracedModule
  822. ):
  823. mod_type = type(self._mod)
  824. else:
  825. assert isinstance(self._mod, (Observer, _FakeQuantize))
  826. mod_type = (
  827. Observer if isinstance(self._mod, Observer) else _FakeQuantize
  828. )
  829. for node in self.nodes:
  830. node.module_type = mod_type
  831. return self._mod
  832. else:
  833. is_qat = isinstance(self._mod, QATModule)
  834. traced_module = TracedModule(
  835. self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
  836. )
  837. for _, g in self._argdef_graph_map.items():
  838. g.compile()
  839. for k, v in self.__dict__.items():
  840. if k not in TracedModuleBuilder.__builder_attributes__:
  841. if isinstance(v, TracedModuleBuilder):
  842. v = v.build()
  843. setattr(traced_module, k, v)
  844. elif isinstance(v, RawTensor):
  845. setattr(traced_module, k, v)
  846. if isinstance(self._mod, QATModule):
  847. unset_module_tracing()
  848. traced_module.with_act = self._mod.with_act
  849. traced_module.with_weight = self._mod.with_weight
  850. if not hasattr(traced_module, "act_fake_quant"):
  851. traced_module.act_fakequant = None
  852. if not hasattr(traced_module, "act_observer"):
  853. traced_module.act_observer = None
  854. if not hasattr(traced_module, "weight_fake_quant"):
  855. traced_module.weight_fakequant = None
  856. if not hasattr(traced_module, "weight_observer"):
  857. traced_module.weight_observer = None
  858. set_module_tracing()
  859. return traced_module
  860. def _record_wrapped_nodes(self, node):
  861. self.nodes.add(node)
  862. def __call__(self, *args, **kwargs):
  863. assert isinstance(self._mod, Module)
  864. # prepare args and kwargs for inner graph
  865. if "method_func" in kwargs:
  866. kwargs.pop("method_func")
  867. def mark_constant(x):
  868. node = NodeMixin.get(x, None)
  869. if node is None: # capture as constant
  870. NodeMixin.wrap(x, lambda: Constant.make(x))
  871. inputs, tree_def = tree_flatten(((self, *args), kwargs))
  872. for i in inputs:
  873. mark_constant(i)
  874. callnode = CallMethod.make(NodeMixin.get(self))
  875. callnode.add_inputs(inputs[1:])
  876. callnode.arg_def = tree_def
  877. if (
  878. self._is_builtin
  879. or tree_def in self._argdef_graph_map
  880. or isinstance(self._mod, TracedModule)
  881. ):
  882. unset_module_tracing()
  883. rst = self._mod(*args, **kwargs)
  884. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  885. set_module_tracing()
  886. if self._is_builtin:
  887. self._body = None
  888. elif tree_def in self._argdef_graph_map:
  889. self._body = self._argdef_graph_map[tree_def]
  890. else:
  891. self._mod._is_top = False
  892. self._body = self._mod.graph
  893. else:
  894. self_node = None
  895. orig_self = NodeMixin.get(self)
  896. top_graph = active_module_tracer().current_scope()
  897. graph_prefix_name = top_graph._name
  898. if top_graph._prefix_name:
  899. graph_prefix_name = "{}_{}".format(
  900. top_graph._prefix_name, graph_prefix_name.lstrip("_")
  901. )
  902. module_name = orig_self._orig_name
  903. if top_graph._module_name:
  904. module_name = "{}.{}".format(top_graph._module_name, module_name)
  905. self._body = InternalGraph(
  906. orig_self._name, prefix_name=graph_prefix_name, module_name=module_name
  907. )
  908. active_module_tracer().push_scope(self._body)
  909. # rebind self to new input node
  910. if self_node:
  911. NodeMixin.wrap_safe(self, self_node)
  912. active_module_tracer().current_scope()._add_input(self_node)
  913. else:
  914. NodeMixin.wrap_safe(
  915. self,
  916. self_node
  917. if self_node
  918. else Input.make("self", NodeMixin.get_wrapped_type(self), ""),
  919. )
  920. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  921. # prepare args and kwargs for inner graph
  922. index_args, index_kwargs = tree_def.unflatten(
  923. [
  924. ArgsIndex(0),
  925. *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))),
  926. ]
  927. )
  928. key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs)
  929. idx2key = {}
  930. for k, v in key2idx.items():
  931. if isinstance(v, ArgsIndex):
  932. idx2key[v.index] = k
  933. else:
  934. flatten_argidx, _ = tree_flatten(v)
  935. for _i, v in enumerate(flatten_argidx):
  936. if isinstance(v, ArgsIndex):
  937. idx2key[v.index] = k + "_%d" % _i
  938. def wrap(x, name):
  939. if isinstance(x, (RawTensor, NodeMixin)):
  940. NodeMixin.wrap(
  941. x,
  942. lambda: Input.make(
  943. type=NodeMixin.get_wrapped_type(x), name=name
  944. ),
  945. )
  946. return x
  947. args = [self]
  948. for i, v in enumerate(inputs[1:]):
  949. args.append(wrap(v, idx2key[i + 1]))
  950. args, kwargs = tree_def.unflatten(args)
  951. active_module_tracer().patcher.auto_patch(
  952. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  953. )
  954. rst = type(self._mod).forward(*args, **kwargs)
  955. if _convert_node_flag():
  956. rst = _node_to_tensor(rst)[0][0]
  957. outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
  958. for i in (
  959. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  960. ):
  961. active_module_tracer().current_scope()._add_output(NodeMixin.get(i))
  962. NodeMixin.wrap_safe(self, orig_self)
  963. for arg, node in zip(inputs[1:], origin_inp_node):
  964. if node:
  965. NodeMixin.wrap_safe(arg, node)
  966. active_module_tracer().pop_scope()
  967. # rebind output to outer graph
  968. callnode.out_def = out_def
  969. callnode.add_outputs(outputs)
  970. self._argdef_graph_map[callnode.arg_def] = self._body
  971. self._argdef_outdef_map[callnode.arg_def] = out_def
  972. return rst
  973. def __setattr__(self, name, value):
  974. object.__setattr__(self, name, value)
  975. def __repr__(self):
  976. return repr(self._mod)
  977. def __getattr__(self, name):
  978. if name not in self._mod.__dict__:
  979. attr = getattr(type(self._mod), name).__get__(self, type(self))
  980. else:
  981. attr = getattr(self._mod, name)
  982. full_name = None
  983. if id(attr) in active_module_tracer().id2name:
  984. full_name = active_module_tracer().id2name[id(attr)]
  985. if isinstance(attr, (List, Dict)):
  986. unset_module_tracing()
  987. has_module, m_container = replace_container_with_module_container(attr)
  988. if m_container:
  989. attr = m_container
  990. if has_module and not m_container:
  991. raise ValueError(
  992. "Can not trace the module that uses the same container to store Module and Non-Module objects "
  993. )
  994. set_module_tracing()
  995. if isinstance(attr, Module):
  996. attr = TracedModuleBuilder(attr)
  997. if isinstance(attr, (Module, RawTensor)):
  998. setattr(self, name, attr)
  999. active_module_tracer().id2name[id(attr)] = full_name
  1000. if full_name:
  1001. scope_name = active_module_tracer().current_scope()._module_name
  1002. if scope_name:
  1003. full_name = full_name[len(scope_name) + 1 :]
  1004. else:
  1005. full_name = name
  1006. else:
  1007. full_name = name
  1008. NodeMixin.wrap(
  1009. attr,
  1010. lambda: GetAttr.make(
  1011. NodeMixin.get(self),
  1012. name,
  1013. type=NodeMixin.get_wrapped_type(attr),
  1014. orig_name=full_name,
  1015. ),
  1016. )
  1017. return attr
  1018. def __getattribute__(self, name):
  1019. if name in TracedModuleBuilder.__builder_attributes__:
  1020. return object.__getattribute__(self, name)
  1021. else:
  1022. wrapped = object.__getattribute__(self, name)
  1023. class_members = dict(inspect.getmembers(self.__class__))
  1024. if name in self._mod.__dict__:
  1025. mod_attr = getattr(self._mod, name)
  1026. if name in class_members:
  1027. if (
  1028. not isinstance(wrapped, TracedModuleBuilder)
  1029. and wrapped is not mod_attr
  1030. ):
  1031. wrapped = self.__getattr__(name)
  1032. if isinstance(wrapped, TracedModuleBuilder):
  1033. if not isinstance(mod_attr, (List, Dict)):
  1034. assert mod_attr is wrapped._mod
  1035. else:
  1036. assert mod_attr is wrapped
  1037. full_name = None
  1038. if id(mod_attr) in active_module_tracer().id2name:
  1039. full_name = active_module_tracer().id2name[id(mod_attr)]
  1040. scope_name = active_module_tracer().current_scope()._module_name
  1041. if full_name and scope_name:
  1042. full_name = full_name[len(scope_name) + 1 :]
  1043. else:
  1044. full_name = name
  1045. else:
  1046. full_name = name
  1047. # assert not self._is_builtin
  1048. if isinstance(wrapped, (NodeMixin, RawTensor)):
  1049. NodeMixin.wrap(
  1050. wrapped,
  1051. lambda: GetAttr.make(
  1052. NodeMixin.get(self),
  1053. name,
  1054. type=NodeMixin.get_wrapped_type(wrapped),
  1055. orig_name=full_name,
  1056. ),
  1057. )
  1058. return wrapped
  1059. class _expr_iter:
  1060. def __init__(self, graph: InternalGraph, recursive: bool = True):
  1061. self.graph = graph
  1062. self.recursive = recursive
  1063. def __iter__(self):
  1064. for expr in self.graph._exprs:
  1065. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  1066. yield expr
  1067. if self.recursive and expr.graph is not None:
  1068. yield from expr.graph.exprs(self.recursive)
  1069. else:
  1070. yield expr
  1071. class _node_iter:
  1072. def __init__(self, graph: InternalGraph, recursive: bool = True) -> None:
  1073. nodes = []
  1074. node_ids = set()
  1075. for expr in graph.exprs(recursive):
  1076. for n in expr.inputs + expr.outputs:
  1077. if n._id in node_ids:
  1078. continue
  1079. nodes.append(n)
  1080. node_ids.add(n._id)
  1081. self.nodes = list(sorted(nodes, key=lambda x: x._id))
  1082. def __iter__(self):
  1083. for node in self.nodes:
  1084. yield node
  1085. class BaseFilter:
  1086. def __init__(self, expr_iter: Iterable):
  1087. self._iter = expr_iter
  1088. def __iter__(self):
  1089. return iter(self._iter)
  1090. def as_list(self):
  1091. return list(self)
  1092. def as_dict(self):
  1093. return collections.OrderedDict((i._id, i) for i in self)
  1094. def as_unique(self):
  1095. rst = self.as_list()
  1096. assert len(rst) == 1, "{} elements found".format(len(rst))
  1097. (expr,) = self
  1098. return expr
  1099. def as_count(self):
  1100. return sum(1 for _ in self)
  1101. class ExprFilter(BaseFilter):
  1102. def call_function(self, func):
  1103. return ExprFilterCallFunction(self, func)
  1104. def call_method(self, method):
  1105. return ExprFilterCallMethod(self, method)
  1106. def expr_id(self, expr_id: List[int]):
  1107. return ExprFilterExprId(self, expr_id)
  1108. class NodeFilter(BaseFilter):
  1109. def type(self, owner_type, node_type):
  1110. return NodeFilterType(self, owner_type, node_type)
  1111. def node_id(self, node_id: List[int]):
  1112. return NodeFilterNodeId(self, node_id)
  1113. def name(self, name: str, ignorecase: bool = True):
  1114. return NodeFilterName(self, name, ignorecase)
  1115. class NodeFilterType(NodeFilter):
  1116. def __init__(self, expr_iter, owner_type, node_type):
  1117. super().__init__(expr_iter)
  1118. self.owner_type = owner_type
  1119. self.node_type = node_type
  1120. def __iter__(self):
  1121. for node in self._iter:
  1122. if not isinstance(node, self.node_type):
  1123. continue
  1124. if not hasattr(node, "owner"):
  1125. continue
  1126. if isinstance(node.owner, self.owner_type):
  1127. yield node
  1128. class NodeFilterNodeId(NodeFilter):
  1129. def __init__(self, expr_iter, node_id: List[int]):
  1130. super().__init__(expr_iter)
  1131. if not isinstance(node_id, Sequence):
  1132. node_id = [node_id]
  1133. self.node_id = node_id
  1134. def __iter__(self):
  1135. for node in self._iter:
  1136. if node._id in self.node_id:
  1137. yield node
  1138. class NodeFilterName(NodeFilter):
  1139. _re = None
  1140. def __init__(self, node_iter, pattern, ignorecase):
  1141. super().__init__(node_iter)
  1142. self.pattern = pattern
  1143. self._re = self.make_re(pattern, ignorecase)
  1144. @classmethod
  1145. def make_re(cls, pattern, ignorecase=True):
  1146. assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern)
  1147. assert isinstance(ignorecase, bool)
  1148. flags = 0
  1149. if ignorecase:
  1150. flags |= re.IGNORECASE
  1151. return re.compile(fnmatch.translate(pattern), flags=flags)
  1152. def __iter__(self):
  1153. for i in self._iter:
  1154. graph = i.top_graph
  1155. name = "{}_{}".format(graph._name, i._name.lstrip("_"))
  1156. if graph._prefix_name:
  1157. name = "{}_{}".format(graph._prefix_name, name.lstrip("_"))
  1158. if self.pattern == name or self._re.match(name):
  1159. yield i
  1160. class ExprFilterCallFunction(ExprFilter):
  1161. def __init__(self, expr_iter, func: Callable = None):
  1162. super().__init__(expr_iter)
  1163. self.func = func
  1164. def __iter__(self):
  1165. for expr in self._iter:
  1166. if not isinstance(expr, CallFunction):
  1167. continue
  1168. if self.func is None or expr.func == self.func:
  1169. yield expr
  1170. class ExprFilterCallMethod(ExprFilter):
  1171. def __init__(self, expr_iter, method: str = None):
  1172. super().__init__(expr_iter)
  1173. self.method = method
  1174. def __iter__(self):
  1175. for expr in self._iter:
  1176. if not isinstance(expr, CallMethod):
  1177. continue
  1178. if self.method is None or expr.method == self.method:
  1179. yield expr
  1180. class ExprFilterExprId(ExprFilter):
  1181. def __init__(self, expr_iter, expr_id: List[int]):
  1182. super().__init__(expr_iter)
  1183. if not isinstance(expr_id, Sequence):
  1184. expr_id = [expr_id]
  1185. self.expr_id = expr_id
  1186. def __iter__(self):
  1187. for expr in self._iter:
  1188. if expr._id in self.expr_id:
  1189. yield expr
  1190. class TracedModule(Module):
  1191. """
  1192. `TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
  1193. """
  1194. # m_node = None # type: ModuleNode
  1195. argdef_graph_map = None
  1196. argdef_outdef_map = None
  1197. def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
  1198. super(TracedModule, self).__init__()
  1199. self.argdef_graph_map = argdef_graph_map
  1200. self.argdef_outdef_map = argdef_outdef_map
  1201. self._is_top = is_top
  1202. self.watch_points = []
  1203. self.watch_node_value = {}
  1204. self.end_points = []
  1205. self.is_qat = is_qat
  1206. def forward(self, *args, **kwargs):
  1207. inputs, treedef = tree_flatten(((self, *args), kwargs))
  1208. assert treedef in self.argdef_graph_map
  1209. inputs = filter(
  1210. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  1211. ) # allow TracedModuleBuilder for retrace.
  1212. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  1213. if self.watch_points:
  1214. self.watch_node_value = {}
  1215. for n in self.watch_points:
  1216. self.watch_node_value[n] = n.top_graph._rst.pop(n)
  1217. if self.end_points:
  1218. return outputs
  1219. out_def = self.argdef_outdef_map[treedef]
  1220. outputs = out_def.unflatten(outputs)
  1221. return outputs
  1222. def set_watch_points(self, nodes):
  1223. if not isinstance(nodes, Sequence):
  1224. nodes = [nodes]
  1225. self.watch_points = nodes
  1226. for n in nodes:
  1227. n.top_graph._watch_point.append(n)
  1228. def clear_watch_points(self):
  1229. for n in self.watch_points:
  1230. n.top_graph._watch_point = []
  1231. self.watch_points = []
  1232. self.watch_node_value = {}
  1233. def set_end_points(self, nodes):
  1234. if not isinstance(nodes, Sequence):
  1235. nodes = [nodes]
  1236. self.end_points = nodes
  1237. graphs = list(self.argdef_graph_map.values())
  1238. for n in nodes:
  1239. assert n.top_graph in graphs
  1240. n.top_graph._end_point.append(n)
  1241. def clear_end_points(self):
  1242. for n in self.end_points:
  1243. n.top_graph._end_point = []
  1244. self.end_points = []
  1245. @property
  1246. def graph(self) -> InternalGraph:
  1247. if self._is_top:
  1248. self._update_ref()
  1249. assert len(self.argdef_graph_map) == 1
  1250. return list(self.argdef_graph_map.values())[0]
  1251. def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None):
  1252. for inp_def, graph in self.argdef_graph_map.items():
  1253. if top_graph is not None:
  1254. graph._top_graph = weakref.ref(top_graph)
  1255. for n in graph._inputs + graph.outputs:
  1256. n._top_graph = weakref.ref(graph)
  1257. graph._inputs[0]._owner = weakref.ref(self)
  1258. for i, n in enumerate(graph._inputs):
  1259. n.actual_node = []
  1260. if actual_node_map is not None and inp_def in actual_node_map.keys():
  1261. n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i])
  1262. node2obj = {}
  1263. next_actual_node_map = collections.defaultdict(
  1264. lambda: collections.defaultdict(list)
  1265. )
  1266. node2obj[graph._inputs[0]] = self
  1267. for expr in graph._exprs:
  1268. for n in expr.inputs + expr.outputs:
  1269. n._top_graph = weakref.ref(graph)
  1270. expr._top_graph = weakref.ref(graph)
  1271. if isinstance(expr, GetAttr) and isinstance(
  1272. expr.outputs[0], ModuleNode
  1273. ):
  1274. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  1275. expr.outputs[0]._owner = weakref.ref(obj)
  1276. node2obj[expr.outputs[0]] = obj
  1277. if isinstance(expr, Constant) and isinstance(
  1278. expr.outputs[0], ModuleNode
  1279. ):
  1280. obj = expr.value
  1281. expr.outputs[0]._owner = weakref.ref(obj)
  1282. node2obj[expr.outputs[0]] = obj
  1283. if (
  1284. isinstance(expr, CallMethod)
  1285. and expr.method == "__call__"
  1286. and isinstance(expr.inputs[0], ModuleNode)
  1287. ):
  1288. obj = node2obj[expr.inputs[0]]
  1289. if expr.arg_def is not None:
  1290. next_actual_node_map[obj][expr.arg_def].append(expr.inputs)
  1291. for obj in node2obj.values():
  1292. if obj is self:
  1293. continue
  1294. mnode_map = None
  1295. if obj in next_actual_node_map.keys():
  1296. mnode_map = next_actual_node_map[obj]
  1297. if isinstance(obj, TracedModule):
  1298. obj._update_ref(mnode_map, graph)
  1299. def flatten(self):
  1300. """
  1301. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  1302. :return: :class:`TracedModule`
  1303. """
  1304. new_module = copy.deepcopy(self)
  1305. assert active_module_tracer() is None
  1306. id2name = _init_id2name(new_module, "self")
  1307. set_active_module_tracer(module_tracer(lambda x: x, {}))
  1308. active_module_tracer().push_scope(new_module.graph)
  1309. def _flatten_subgraph(
  1310. graph: InternalGraph,
  1311. module: Module,
  1312. call=None,
  1313. prefix_name="",
  1314. module_name="",
  1315. ):
  1316. if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_":
  1317. prefix_name += "_"
  1318. if isinstance(module_name, str) and module_name:
  1319. module_name += "."
  1320. if graph is None or module.is_qat:
  1321. assert not isinstance(module, TracedModule) or module.is_qat
  1322. const = Constant(module, id2name[id(module)])
  1323. m_node = call.inputs[0]
  1324. if m_node.top_graph != active_module_tracer().current_scope():
  1325. m_node._name = (
  1326. active_module_tracer()
  1327. .current_scope()
  1328. ._create_unique_name(prefix_name)
  1329. )
  1330. m_node._orig_name = id2name[id(module)][5:]
  1331. const.outputs[0] = m_node
  1332. const.outputs[0].expr = const
  1333. return [const, call]
  1334. if call is not None:
  1335. graph = copy.deepcopy(graph)
  1336. exprs = []
  1337. node2obj = {}
  1338. node2obj[graph._inputs[0]] = module
  1339. if call:
  1340. node2obj[call.inputs[0]] = module
  1341. # replace inputs for submodule's exprx
  1342. if call:
  1343. repl_dict = dict(zip(graph._inputs, call.inputs))
  1344. for ind, out in enumerate(graph.outputs):
  1345. if isinstance(out.expr, Input):
  1346. assert out in repl_dict
  1347. call_out = call.outputs[ind]
  1348. for expr in call.outputs[ind].users:
  1349. for index, inp in enumerate(expr.inputs):
  1350. if inp is call_out:
  1351. expr.inputs[index] = repl_dict[out]
  1352. continue
  1353. repl_dict[out] = call.outputs[ind]
  1354. graph._replace_inputs_outputs(repl_dict, prefix_name, module_name)
  1355. for expr in graph._exprs:
  1356. if isinstance(expr, GetAttr):
  1357. # replace GetAttr with Constant
  1358. if isinstance(expr.outputs[0], TensorNode):
  1359. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  1360. const.outputs = expr.outputs
  1361. const.outputs[0].expr = const
  1362. exprs.append(const)
  1363. elif isinstance(expr.outputs[0], ModuleNode):
  1364. node2obj[expr.outputs[0]] = getattr(
  1365. node2obj[expr.inputs[0]], expr.name
  1366. )
  1367. elif isinstance(expr, CallMethod):
  1368. obj_node = expr.inputs[0]
  1369. if isinstance(obj_node, ModuleNode):
  1370. pre_expr = expr.inputs[0].expr
  1371. if isinstance(pre_expr, GetAttr):
  1372. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  1373. expr_graph = (
  1374. obj.argdef_graph_map[expr.arg_def]
  1375. if hasattr(obj, "argdef_graph_map")
  1376. else None
  1377. )
  1378. exprs.extend(
  1379. _flatten_subgraph(
  1380. expr_graph,
  1381. obj,
  1382. expr,
  1383. prefix_name + obj_node._name.lstrip("_"),
  1384. module_name + obj_node._orig_name,
  1385. )
  1386. )
  1387. else:
  1388. # module has been replaced.
  1389. assert isinstance(pre_expr, Constant)
  1390. exprs.append(expr)
  1391. else:
  1392. exprs.append(expr)
  1393. else:
  1394. exprs.append(expr)
  1395. if call is not None:
  1396. for i in call.inputs:
  1397. i.users.remove(call)
  1398. return exprs
  1399. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  1400. new_module.graph.compile()
  1401. set_active_module_tracer(None)
  1402. for _id, expr in enumerate(new_module.graph._exprs):
  1403. expr._id = _id
  1404. total_node_id = 0
  1405. for i in new_module.graph._inputs:
  1406. i._id = total_node_id
  1407. total_node_id += 1
  1408. for expr in new_module.graph._exprs:
  1409. for o in expr.outputs:
  1410. o._id = total_node_id
  1411. total_node_id += 1
  1412. return new_module
  1413. def __getstate__(self):
  1414. d = self.__dict__
  1415. for k in Module.__dict__:
  1416. d.pop(k, None)
  1417. return d
  1418. def cpp_apply_module_trace(opdef, *args):
  1419. return Apply.apply_module_trace_hook(opdef, *args)
  1420. def register_as_builtin(mod_cls: Type[Module]) -> None:
  1421. """
  1422. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  1423. param mod_cls: the Module class which will be threated as builtin module in tracing
  1424. """
  1425. module_tracer.register_as_builtin(mod_cls)
  1426. def wrap(func: Callable):
  1427. """
  1428. Call this function to register func as a builtin function.
  1429. """
  1430. assert callable(func), "func must be a callable"
  1431. assert hasattr(func, "__code__")
  1432. fn_name = func.__code__.co_name
  1433. currentframe = inspect.currentframe()
  1434. assert currentframe is not None
  1435. f = currentframe.f_back
  1436. assert f is not None
  1437. assert (
  1438. f.f_code.co_name == "<module>"
  1439. ), "wrap must be called at the top level of a module"
  1440. Patcher._builtin_functions.append((f.f_globals, fn_name))
  1441. return func
  1442. def _register_all_builtin_module():
  1443. for sub_mod in [M, M.qat, M.quantized]:
  1444. for m in getmembers(sub_mod):
  1445. if (
  1446. isclass(m[1])
  1447. and issubclass(m[1], M.Module)
  1448. and m[1] is not M.Sequential
  1449. ):
  1450. module_tracer.register_as_builtin(m[1])
  1451. module_tracer.register_as_builtin(Observer)
  1452. module_tracer.register_as_builtin(MinMaxObserver)
  1453. module_tracer.register_as_builtin(SyncMinMaxObserver)
  1454. module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
  1455. module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
  1456. module_tracer.register_as_builtin(HistogramObserver)
  1457. module_tracer.register_as_builtin(PassiveObserver)
  1458. module_tracer.register_as_builtin(LSQ)
  1459. module_tracer.register_as_builtin(TQT)
  1460. module_tracer.register_as_builtin(FakeQuantize)
  1461. module_tracer.register_as_builtin(TM_FakeQuant)
  1462. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  1463. """
  1464. Traces module ``mod`` and returns corresponding TracedModule.
  1465. param mod: the module will be converted to TracedModule
  1466. param input: the positional arguments passed to forward method of ``mod``
  1467. param kwargs: the keyword arguments passed to forward method of ``mod``
  1468. """
  1469. assert active_module_tracer() is None
  1470. assert isinstance(mod, Module)
  1471. try:
  1472. use_sym_shape = set_symbolic_shape(True)
  1473. set_module_tracing()
  1474. set_active_module_tracer(
  1475. module_tracer(_wrapped_function, _init_id2name(mod, "self"))
  1476. )
  1477. with active_module_tracer().patcher:
  1478. global_scope = InternalGraph(name="")
  1479. active_module_tracer().push_scope(global_scope)
  1480. builder = TracedModuleBuilder(mod, True)
  1481. name = mod._name if mod._name else mod.__class__.__name__
  1482. NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self"))
  1483. inputs, _ = tree_flatten((args, kwargs))
  1484. for _, i in enumerate(inputs):
  1485. # assert isinstance(i, Tensor), "not support "
  1486. if isinstance(i, RawTensor):
  1487. NodeMixin.wrap_safe(
  1488. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  1489. )
  1490. builder(*args, **kwargs)
  1491. active_module_tracer().pop_scope()
  1492. return builder.build()
  1493. finally:
  1494. set_symbolic_shape(use_sym_shape)
  1495. set_active_module_tracer(None)
  1496. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台