You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. import weakref
  13. from inspect import getmembers, isclass, ismethod
  14. from typing import Callable, Dict, Iterable, List, Sequence, Type
  15. import numpy as np
  16. from numpy.lib.arraysetops import isin
  17. from ... import functional as F
  18. from ... import get_logger
  19. from ... import module as M
  20. from ...core._imperative_rt.core2 import Tensor as RawTensor
  21. from ...core._imperative_rt.core2 import (
  22. is_tracing_module,
  23. set_module_tracing,
  24. unset_module_tracing,
  25. )
  26. from ...core._trace_option import set_symbolic_shape
  27. from ...core.tensor.array_method import ArrayMethodMixin
  28. from ...module import Module
  29. from ...tensor import Tensor
  30. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  31. from .module_tracer import (
  32. Patcher,
  33. active_module_tracer,
  34. module_tracer,
  35. set_active_module_tracer,
  36. )
  37. from .node import ModuleNode, Node, NodeMixin, TensorNode
  38. from .pytree import tree_flatten
  39. logger = get_logger(__name__)
  40. def _leaf_type(node):
  41. if isinstance(node, RawTensor):
  42. return (Tensor, TensorNode)
  43. elif isinstance(node, (NodeMixin, Module)):
  44. return (Module, ModuleNode, NodeMixin)
  45. else:
  46. return type(node)
  47. def _is_leaf(node):
  48. assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
  49. type(node)
  50. )
  51. return isinstance(node, RawTensor)
  52. def _is_const_leaf(node):
  53. if isinstance(node, (RawTensor, NodeMixin, Module)):
  54. return False
  55. return True
  56. class InternalGraph:
  57. """
  58. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  59. Attributes:
  60. _exprs: List of Exprs in order of execution
  61. _inputs: Input Nodes of InternalGraph
  62. _outputs: Output Nodes of InternalGraph
  63. """
  64. _exprs = None # type: List[Expr]
  65. _inputs = None # type: List[Node]
  66. _outputs = None # type: List[Node]
  67. def __init__(self):
  68. self._exprs = []
  69. self._inputs = []
  70. self._outputs = []
  71. def insert(self, expr):
  72. self._exprs.append(expr)
  73. @property
  74. def inputs(self):
  75. return self._inputs
  76. @property
  77. def outputs(self):
  78. return self._outputs
  79. @property
  80. def exprs(self):
  81. return ExprFilter(_expr_iter(self))
  82. def get_call_function(self, func: Callable = None):
  83. return self.exprs.call_function(func)
  84. def get_call_method(self, method: str = None):
  85. return self.exprs.call_method(method)
  86. def add_input(self, i):
  87. self._inputs.append(i)
  88. def add_output(self, o):
  89. self._outputs.append(o)
  90. def _replace_inputs_outputs(self, repl_dict):
  91. for node, repl_node in repl_dict.items():
  92. assert node in self._inputs or node in self._outputs
  93. for i in node.users:
  94. if i not in repl_node.users:
  95. repl_node.users.append(i)
  96. for idx, i in enumerate(self._inputs):
  97. if i in repl_dict:
  98. self._inputs[idx] = repl_dict[i]
  99. for idx, o in enumerate(self._outputs):
  100. if o in repl_dict:
  101. self._outputs[idx] = repl_dict[o]
  102. self._outputs[idx].expr = node.expr
  103. for expr in self._exprs:
  104. for idx, i in enumerate(expr.inputs):
  105. if i in repl_dict:
  106. expr.inputs[idx] = repl_dict[i]
  107. for idx, o in enumerate(expr.outputs):
  108. if o in repl_dict:
  109. expr.outputs[idx] = repl_dict[o]
  110. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  111. if not isinstance(nodes, Sequence):
  112. nodes = (nodes,)
  113. ret = list()
  114. queue = list(nodes)
  115. while queue:
  116. node = queue.pop()
  117. expr = node.expr
  118. if expr not in ret:
  119. ret.append(expr)
  120. for i in expr.inputs:
  121. if i not in queue:
  122. queue.append(i)
  123. return ret
  124. def insert_call_function(self, func: Callable, nodes: Sequence[Node]):
  125. if not isinstance(nodes, Sequence):
  126. nodes = [nodes]
  127. assert isinstance(func, Callable)
  128. for i in nodes:
  129. assert isinstance(
  130. i, TensorNode
  131. ), "CallFunction only accept TensorNode as inputs"
  132. expr = CallFunction(func)
  133. expr.inputs = nodes
  134. for i in nodes:
  135. i.users.append(expr)
  136. idx = max(self._exprs.index(i.expr) for i in nodes) + 1
  137. self._exprs.insert(idx, expr)
  138. fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes)
  139. fake_out_val = func(*fake_inp_val)
  140. def create_node(val: Tensor):
  141. node = TensorNode(expr)
  142. node.shape = val.shape
  143. node.dtype = val.dtype
  144. return node
  145. out_nodes = list(create_node(i) for i in fake_out_val)
  146. expr.outputs = out_nodes
  147. return out_nodes
  148. def insert_call_method(self, target, method, args):
  149. if not isinstance(args, Sequence):
  150. args = [args]
  151. assert isinstance(target, (TensorNode, ModuleNode))
  152. assert isinstance(method, str)
  153. for i in args:
  154. assert isinstance(i, TensorNode)
  155. expr = CallMethod(method)
  156. expr.inputs = [target, *args]
  157. if isinstance(target, TensorNode):
  158. fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype)
  159. fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args)
  160. fake_out_val = getattr(fake_target_val, method)(fake_inp_val)
  161. def create_node(val: Tensor):
  162. node = TensorNode(expr)
  163. node.shape = val.shape
  164. node.dtype = val.dtype
  165. return node
  166. out_nodes = list(create_node(i) for i in fake_out_val)
  167. expr.outputs = out_nodes
  168. else:
  169. raise NotImplementedError()
  170. return out_nodes
  171. def replace_node(self, repl_dict: Dict[Node, Node]):
  172. while repl_dict:
  173. node, repl_node = repl_dict.popitem()
  174. # check graph inputs and outputs
  175. assert node not in self.inputs, "Cannot replace inputs"
  176. for i, n in enumerate(self.outputs):
  177. if n is node:
  178. self.outputs[i] = repl_node
  179. # update users of node and repl_node
  180. # update inputs of expr in node.users
  181. dep_exprs = self.get_dep_exprs(repl_node)
  182. i = 0
  183. while i < len(node.users):
  184. n = node.users[i]
  185. if n in dep_exprs:
  186. logger.info("Find a loop: ignore this replacement once")
  187. logger.info("node: %s" % node.__repr__())
  188. logger.info("repl_node: %s" % repl_node.__repr__())
  189. i += 1
  190. continue
  191. repl_node.users.append(n)
  192. node.users.pop(i)
  193. idx = n.inputs.index(node)
  194. n.inputs[idx] = repl_node
  195. def compile(self):
  196. """
  197. Delete unused expr.
  198. """
  199. dep_exprs = self.get_dep_exprs(self.outputs)
  200. i = 0
  201. while i < len(self._exprs):
  202. expr = self._exprs[i]
  203. if expr in dep_exprs:
  204. i += 1
  205. continue
  206. for n in expr.inputs:
  207. n.users.remove(expr)
  208. self._exprs.remove(expr)
  209. def interpret(self, *inputs):
  210. node2value = {}
  211. for n, v in zip(self._inputs, inputs):
  212. node2value[n] = v
  213. for expr in self._exprs:
  214. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  215. if values is not None:
  216. for n, v in zip(expr.outputs, values):
  217. node2value[n] = v
  218. return list(node2value[i] for i in self._outputs)
  219. def __repr__(self):
  220. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  221. ", ".join(str(i) for i in self._inputs),
  222. "\n\t".join(str(i) for i in self._exprs),
  223. ", ".join(str(i) for i in self._outputs),
  224. )
  225. def _get_meth_name(obj, func):
  226. tp = obj if isinstance(obj, type) else type(obj)
  227. for cls in tp.mro():
  228. for k, v in cls.__dict__.items():
  229. if v == func:
  230. return k
  231. return None
  232. def _wrapped_function(orig_func):
  233. @functools.wraps(orig_func)
  234. def wrapped_fn(*args, **kwargs):
  235. if is_tracing_module():
  236. unset_module_tracing()
  237. inputs, tree_def = tree_flatten(
  238. (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  239. )
  240. for i in inputs:
  241. if not NodeMixin.get(i, None):
  242. if isinstance(i, (RawTensor, NodeMixin)):
  243. NodeMixin.wrap_safe(i, Constant.make(i))
  244. meth_name = _get_meth_name(args[0], wrapped_fn)
  245. if meth_name:
  246. self = inputs[0]
  247. if meth_name == "__new__":
  248. if all([not isinstance(i, RawTensor) for i in inputs]):
  249. # only trace Tensor.__new__() when there are tensors in args
  250. set_module_tracing()
  251. return orig_func(*args, **kwargs)
  252. if isinstance(args[1], RawTensor):
  253. node = NodeMixin.get(inputs[1])
  254. inputs[1] = copy.copy(inputs[1])
  255. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  256. NodeMixin.wrap_safe(inputs[1], node)
  257. args, kwargs = tree_def.unflatten(inputs)
  258. call_node = CallMethod.make(self, meth_name)
  259. else:
  260. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  261. call_node.add_inputs(inputs[1:])
  262. else:
  263. call_node = CallFunction.make(orig_func)
  264. call_node.add_inputs(inputs)
  265. call_node.arg_def = tree_def
  266. outputs = orig_func(*args, **kwargs)
  267. call_node.add_outputs(outputs)
  268. set_module_tracing()
  269. return outputs
  270. return orig_func(*args, **kwargs)
  271. return wrapped_fn
  272. class TracedModuleBuilder(NodeMixin):
  273. _mod = None # type: Module
  274. _body = None # type: InternalGraph
  275. _is_builtin = None # type: bool
  276. _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
  277. _argdef_outdef_map = None # type: Dict[Treedef, Treedef]
  278. nodes = None
  279. __builder_attributes__ = [
  280. "_mod",
  281. "_body",
  282. "_NodeMixin__node",
  283. "_is_builtin",
  284. "build",
  285. "_argdef_graph_map",
  286. "_argdef_outdef_map",
  287. "nodes",
  288. ]
  289. def __init__(self, mod, is_top_module=False):
  290. super(TracedModuleBuilder, self).__init__()
  291. self._mod = mod
  292. self._body = None
  293. self._is_builtin = module_tracer.is_builtin(mod)
  294. self._argdef_graph_map = {}
  295. self._argdef_outdef_map = {}
  296. self.nodes = set()
  297. def build(self):
  298. if self._is_builtin:
  299. for node in self.nodes:
  300. node.module_type = type(self._mod)
  301. # node._owner = weakref.ref(self._mod)
  302. return self._mod
  303. else:
  304. traced_module = TracedModule(
  305. self._argdef_graph_map, self._argdef_outdef_map
  306. )
  307. for _, g in self._argdef_graph_map.items():
  308. g.compile()
  309. # for node in self.nodes:
  310. # node._owner = weakref.ref(traced_module)
  311. for k, v in self.__dict__.items():
  312. if k not in TracedModuleBuilder.__builder_attributes__:
  313. if isinstance(v, TracedModuleBuilder):
  314. v = v.build()
  315. setattr(traced_module, k, v)
  316. return traced_module
  317. def _record_wrapped_nodes(self, node):
  318. self.nodes.add(node)
  319. def __call__(self, *args, **kwargs):
  320. assert isinstance(self._mod, Module)
  321. # prepare args and kwargs for inner graph
  322. def mark_constant(x):
  323. node = NodeMixin.get(x, None)
  324. if node is None: # capture as constant
  325. NodeMixin.wrap(x, lambda: Constant.make(x))
  326. inputs, tree_def = tree_flatten(
  327. ((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  328. )
  329. for i in inputs:
  330. mark_constant(i)
  331. callnode = CallMethod.make(NodeMixin.get(self))
  332. callnode.add_inputs(inputs[1:])
  333. callnode.arg_def = tree_def
  334. if self._is_builtin:
  335. unset_module_tracing()
  336. rst = self._mod(*args, **kwargs)
  337. outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
  338. set_module_tracing()
  339. if self._is_builtin:
  340. self._body = None
  341. else:
  342. self_node = None
  343. if self._body:
  344. self_node = self._body.inputs[0]
  345. self._body = InternalGraph()
  346. active_module_tracer().push_scope(self._body)
  347. # rebind self to new input node
  348. orig_self = NodeMixin.get(self)
  349. if self_node:
  350. NodeMixin.wrap_safe(self, self_node)
  351. active_module_tracer().current_scope().add_input(self_node)
  352. else:
  353. NodeMixin.wrap_safe(
  354. self,
  355. self_node
  356. if self_node
  357. else Input.make("self", NodeMixin.get_wrapped_type(self)),
  358. )
  359. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  360. # prepare args and kwargs for inner graph
  361. def wrap(x):
  362. if isinstance(x, (RawTensor, NodeMixin)):
  363. NodeMixin.wrap(
  364. x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
  365. )
  366. return x
  367. args = [self]
  368. for i in inputs[1:]:
  369. args.append(wrap(i))
  370. args, kwargs = tree_def.unflatten(args)
  371. active_module_tracer().patcher.auto_patch(
  372. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  373. )
  374. rst = type(self._mod).forward(*args, **kwargs)
  375. outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
  376. for i in (
  377. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  378. ):
  379. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  380. NodeMixin.wrap_safe(self, orig_self)
  381. for arg, node in zip(inputs[1:], origin_inp_node):
  382. if node:
  383. NodeMixin.wrap_safe(arg, node)
  384. active_module_tracer().pop_scope()
  385. # rebind output to outer graph
  386. callnode.add_outputs(outputs)
  387. self._argdef_graph_map[callnode.arg_def] = self._body
  388. self._argdef_outdef_map[callnode.arg_def] = out_def
  389. return rst
  390. def __getattr__(self, name):
  391. if name not in self._mod.__dict__:
  392. attr = getattr(type(self._mod), name).__get__(self, type(self))
  393. else:
  394. attr = getattr(self._mod, name)
  395. if isinstance(attr, Module):
  396. attr = TracedModuleBuilder(attr)
  397. setattr(self, name, attr)
  398. NodeMixin.wrap(
  399. attr,
  400. lambda: GetAttr.make(
  401. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  402. ),
  403. )
  404. return attr
  405. def __getattribute__(self, name):
  406. if name in TracedModuleBuilder.__builder_attributes__:
  407. return super().__getattribute__(name)
  408. else:
  409. wrapped = super().__getattribute__(name)
  410. if name in self._mod.__dict__:
  411. assert not self._is_builtin
  412. if isinstance(wrapped, (NodeMixin, RawTensor)):
  413. NodeMixin.wrap(
  414. wrapped,
  415. lambda: GetAttr.make(
  416. NodeMixin.get(self),
  417. name,
  418. type=NodeMixin.get_wrapped_type(wrapped),
  419. ),
  420. )
  421. """
  422. else:
  423. node = NodeMixin.get(wrapped)
  424. expr = node.expr
  425. assert isinstance(expr, GetAttr)
  426. if expr not in active_module_tracer().current_scope()._exprs:
  427. active_module_tracer().current_scope().insert(expr)
  428. """
  429. return wrapped
  430. class _expr_iter:
  431. def __init__(self, graph: InternalGraph):
  432. self.graph = graph
  433. def __iter__(self):
  434. for expr in self.graph._exprs:
  435. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  436. yield expr
  437. if expr.graph is not None:
  438. yield from expr.graph.exprs
  439. else:
  440. yield expr
  441. class ExprFilter:
  442. def __init__(self, expr_iter: Iterable):
  443. self._iter = expr_iter
  444. def __iter__(self):
  445. return iter(self._iter)
  446. def call_function(self, func):
  447. return ExprFilterCallFunction(self, func)
  448. def call_method(self, method):
  449. return ExprFilterCallMethod(self, method)
  450. def as_list(self):
  451. return list(self)
  452. def as_dict(self):
  453. raise NotImplementedError("need key")
  454. def as_unique(self):
  455. (expr,) = self
  456. return expr
  457. def as_count(self):
  458. return sum(1 for _ in self)
  459. class ExprFilterCallFunction(ExprFilter):
  460. def __init__(self, expr_iter, func: Callable = None):
  461. super().__init__(expr_iter)
  462. self.func = func
  463. def __iter__(self):
  464. for i in self._iter:
  465. if not isinstance(i, CallFunction):
  466. continue
  467. if self.func is None or i.func == self.func:
  468. yield i
  469. class ExprFilterCallMethod(ExprFilter):
  470. def __init__(self, expr_iter, method: str = None):
  471. super().__init__(expr_iter)
  472. self.method = method
  473. def __iter__(self):
  474. for i in self._iter:
  475. if not isinstance(i, CallMethod):
  476. continue
  477. if self.method is None or i.method == self.method:
  478. yield i
  479. class TracedModule(Module):
  480. """
  481. `TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
  482. """
  483. # m_node = None # type: ModuleNode
  484. argdef_graph_map = None
  485. argdef_outdef_map = None
  486. def __init__(self, argdef_graph_map, argdef_outdef_map):
  487. super(TracedModule, self).__init__()
  488. self.argdef_graph_map = argdef_graph_map
  489. self.argdef_outdef_map = argdef_outdef_map
  490. def forward(self, *args, **kwargs):
  491. inputs, treedef = tree_flatten(
  492. ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
  493. )
  494. assert treedef in self.argdef_graph_map
  495. inputs = filter(
  496. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  497. ) # allow TracedModuleBuilder for retrace.
  498. outputs = self.argdef_graph_map[treedef].interpret(*inputs)
  499. out_def = self.argdef_outdef_map[treedef]
  500. outputs = out_def.unflatten(outputs)
  501. return outputs
  502. @property
  503. def graph(self):
  504. self._update_modulenode_ref()
  505. assert len(self.argdef_graph_map) == 1
  506. return list(self.argdef_graph_map.values())[0]
  507. def _update_modulenode_ref(self):
  508. for _, graph in self.argdef_graph_map.items():
  509. graph._inputs[0]._owner = weakref.ref(self)
  510. node2obj = {}
  511. node2obj[graph._inputs[0]] = self
  512. for expr in graph._exprs:
  513. if isinstance(expr, GetAttr) and isinstance(
  514. expr.outputs[0], ModuleNode
  515. ):
  516. obj = getattr(node2obj[expr.inputs[0]], expr.name)
  517. expr.outputs[0]._owner = weakref.ref(obj)
  518. node2obj[expr.outputs[0]] = obj
  519. if isinstance(obj, TracedModule):
  520. obj._update_modulenode_ref()
  521. @property
  522. def exprs(self):
  523. return self.graph.exprs
  524. def flatten(self):
  525. """
  526. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  527. :return: :class:`TracedModule`
  528. """
  529. new_module = copy.deepcopy(self)
  530. def _flatten_subgraph(graph, module, call=None):
  531. if graph is None:
  532. assert not isinstance(module, TracedModule)
  533. const = Constant(module)
  534. const.outputs[0] = call.inputs[0]
  535. const.outputs[0].expr = const
  536. return [const, call]
  537. if call is not None:
  538. graph = copy.deepcopy(graph)
  539. exprs = []
  540. node2obj = {}
  541. node2obj[graph._inputs[0]] = module
  542. if call:
  543. node2obj[call.inputs[0]] = module
  544. for expr in graph._exprs:
  545. # replace inputs for submodule's exprx
  546. if call:
  547. repl_dict = dict(
  548. zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
  549. )
  550. graph._replace_inputs_outputs(repl_dict)
  551. if isinstance(expr, GetAttr):
  552. # replace GetAttr with Constant
  553. if isinstance(expr.outputs[0], TensorNode):
  554. const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
  555. const.outputs = expr.outputs
  556. const.outputs[0].expr = const
  557. exprs.append(const)
  558. elif isinstance(expr.outputs[0], ModuleNode):
  559. node2obj[expr.outputs[0]] = getattr(
  560. node2obj[expr.inputs[0]], expr.name
  561. )
  562. elif isinstance(expr, CallMethod):
  563. obj_node = expr.inputs[0]
  564. if isinstance(obj_node, ModuleNode):
  565. pre_expr = expr.inputs[0].expr
  566. if isinstance(pre_expr, GetAttr):
  567. (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
  568. expr_graph = (
  569. obj.argdef_graph_map[expr.arg_def]
  570. if hasattr(obj, "argdef_graph_map")
  571. else None
  572. )
  573. exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
  574. else:
  575. # module has been replaced.
  576. assert isinstance(pre_expr, Constant)
  577. exprs.append(expr)
  578. else:
  579. exprs.append(expr)
  580. else:
  581. exprs.append(expr)
  582. if call is not None:
  583. for i in call.inputs:
  584. i.users.remove(call)
  585. return exprs
  586. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  587. return new_module
  588. def __getstate__(self):
  589. d = self.__dict__
  590. for k in Module.__dict__:
  591. d.pop(k, None)
  592. return d
  593. def cpp_apply_module_trace(opdef, *args):
  594. return Apply.apply_module_trace_hook(opdef, *args)
  595. def register_as_builtin(mod_cls: Type[Module]) -> None:
  596. """
  597. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  598. param mod_cls: the Module class which will be threated as builtin module in tracing
  599. """
  600. module_tracer.register_as_builtin(mod_cls)
  601. def _register_all_builtin_module():
  602. for sub_mod in [M, M.qat, M.quantized]:
  603. for m in getmembers(sub_mod):
  604. if (
  605. isclass(m[1])
  606. and issubclass(m[1], M.Module)
  607. and m[1] is not M.Sequential
  608. ):
  609. module_tracer.register_as_builtin(m[1])
  610. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  611. """
  612. Traces module ``mod`` and returns corresponding TracedModule.
  613. param mod: the module will be converted to TracedModule
  614. param input: the positional arguments passed to forward method of ``mod``
  615. param kwargs: the keyword arguments passed to forward method of ``mod``
  616. """
  617. assert active_module_tracer() is None
  618. try:
  619. use_sym_shape = set_symbolic_shape(True)
  620. set_module_tracing()
  621. set_active_module_tracer(module_tracer(_wrapped_function))
  622. with active_module_tracer().patcher:
  623. global_scope = InternalGraph()
  624. active_module_tracer().push_scope(global_scope)
  625. builder = TracedModuleBuilder(mod, True)
  626. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  627. inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
  628. for _, i in enumerate(inputs):
  629. if isinstance(i, RawTensor):
  630. NodeMixin.wrap_safe(
  631. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  632. )
  633. builder(*args, **kwargs)
  634. active_module_tracer().pop_scope()
  635. return builder.build()
  636. finally:
  637. set_symbolic_shape(use_sym_shape)
  638. set_active_module_tracer(None)
  639. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台