You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

traced_module.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import copy
  11. import functools
  12. from inspect import getmembers, isclass, ismethod
  13. from typing import Callable, Dict, Iterable, List, Sequence, Type
  14. import numpy as np
  15. from numpy.lib.arraysetops import isin
  16. from ... import functional as F
  17. from ... import get_logger
  18. from ... import module as M
  19. from ...core._imperative_rt.core2 import Tensor as RawTensor
  20. from ...core._imperative_rt.core2 import (
  21. is_tracing_module,
  22. set_module_tracing,
  23. unset_module_tracing,
  24. )
  25. from ...core._trace_option import set_symbolic_shape
  26. from ...core.tensor.array_method import ArrayMethodMixin
  27. from ...module import Module
  28. from ...tensor import Tensor
  29. from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
  30. from .module_tracer import (
  31. Patcher,
  32. active_module_tracer,
  33. module_tracer,
  34. set_active_module_tracer,
  35. )
  36. from .node import ModuleNode, Node, NodeMixin, TensorNode
  37. from .pytree import tree_flatten
  38. logger = get_logger(__name__)
  39. def _leaf_type(node):
  40. if isinstance(node, RawTensor):
  41. return (Tensor, TensorNode)
  42. elif isinstance(node, (NodeMixin, Module)):
  43. return (Module, ModuleNode, NodeMixin)
  44. else:
  45. return type(node)
  46. def _is_leaf(node):
  47. assert isinstance(node, RawTensor), type(node)
  48. return isinstance(node, RawTensor)
  49. def _is_const_leaf(node):
  50. if isinstance(node, (RawTensor, NodeMixin, Module)):
  51. return False
  52. return True
  53. class InternalGraph:
  54. """
  55. ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
  56. Attributes:
  57. _exprs: List of Exprs in order of execution
  58. _inputs: Input Nodes of InternalGraph
  59. _outputs: Output Nodes of InternalGraph
  60. """
  61. _exprs = None # type: List[Expr]
  62. _inputs = None # type: List[Node]
  63. _outputs = None # type: List[Node]
  64. def __init__(self):
  65. self._exprs = []
  66. self._inputs = []
  67. self._outputs = []
  68. def insert(self, expr):
  69. self._exprs.append(expr)
  70. @property
  71. def inputs(self):
  72. return self._inputs
  73. @property
  74. def outputs(self):
  75. return self._outputs
  76. @property
  77. def exprs(self):
  78. return ExprFilter(_expr_iter(self))
  79. def get_call_function(self, func: Callable = None):
  80. return self.exprs.call_function(func)
  81. def get_call_method(self, method: str = None):
  82. return self.exprs.call_method(method)
  83. def add_input(self, i):
  84. self._inputs.append(i)
  85. def add_output(self, o):
  86. self._outputs.append(o)
  87. def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
  88. if not isinstance(nodes, Sequence):
  89. nodes = (nodes,)
  90. ret = list()
  91. queue = list(nodes)
  92. while queue:
  93. node = queue.pop()
  94. expr = node.expr
  95. if expr not in ret:
  96. ret.append(expr)
  97. for i in expr.inputs:
  98. if i not in queue:
  99. queue.append(i)
  100. return ret
  101. def insert_call_function(self, func: Callable, nodes: Sequence[Node]):
  102. if not isinstance(nodes, Sequence):
  103. nodes = [nodes]
  104. assert isinstance(func, Callable)
  105. for i in nodes:
  106. assert isinstance(
  107. i, TensorNode
  108. ), "CallFunction only accept TensorNode as inputs"
  109. expr = CallFunction(func)
  110. expr.inputs = nodes
  111. for i in nodes:
  112. i.users.append(expr)
  113. idx = max(self._exprs.index(i.expr) for i in nodes) + 1
  114. self._exprs.insert(idx, expr)
  115. fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in nodes)
  116. fake_out_val = func(*fake_inp_val)
  117. def create_node(val: Tensor):
  118. node = TensorNode(expr)
  119. node.shape = val.shape
  120. node.dtype = val.dtype
  121. return node
  122. out_nodes = list(create_node(i) for i in fake_out_val)
  123. expr.outputs = out_nodes
  124. return out_nodes
  125. def insert_call_method(self, target, method, args):
  126. if not isinstance(args, Sequence):
  127. args = [args]
  128. assert isinstance(target, (TensorNode, ModuleNode))
  129. assert isinstance(method, str)
  130. for i in args:
  131. assert isinstance(i, TensorNode)
  132. expr = CallMethod(method)
  133. expr.inputs = [target, *args]
  134. if isinstance(target, TensorNode):
  135. fake_target_val = F.zeros(shape=target.shape, dtype=target.dtype)
  136. fake_inp_val = tuple(F.zeros(shape=i.shape, dtype=i.dtype) for i in args)
  137. fake_out_val = getattr(fake_target_val, method)(fake_inp_val)
  138. def create_node(val: Tensor):
  139. node = TensorNode(expr)
  140. node.shape = val.shape
  141. node.dtype = val.dtype
  142. return node
  143. out_nodes = list(create_node(i) for i in fake_out_val)
  144. expr.outputs = out_nodes
  145. else:
  146. raise NotImplementedError()
  147. return out_nodes
  148. def replace_node(self, repl_dict: Dict[Node, Node]):
  149. while repl_dict:
  150. node, repl_node = repl_dict.popitem()
  151. # check graph inputs and outputs
  152. assert node not in self.inputs, "Cannot replace inputs"
  153. for i, n in enumerate(self.outputs):
  154. if n is node:
  155. self.outputs[i] = repl_node
  156. # update users of node and repl_node
  157. # update inputs of expr in node.users
  158. dep_exprs = self.get_dep_exprs(repl_node)
  159. i = 0
  160. while i < len(node.users):
  161. n = node.users[i]
  162. if n in dep_exprs:
  163. logger.info("Find a loop: ignore this replacement once")
  164. logger.info("node: %s" % node.__repr__())
  165. logger.info("repl_node: %s" % repl_node.__repr__())
  166. i += 1
  167. continue
  168. repl_node.users.append(n)
  169. node.users.pop(i)
  170. idx = n.inputs.index(node)
  171. n.inputs[idx] = repl_node
  172. def compile(self):
  173. """
  174. Delete unused expr.
  175. """
  176. dep_exprs = self.get_dep_exprs(self.outputs)
  177. i = 0
  178. while i < len(self._exprs):
  179. expr = self._exprs[i]
  180. if expr in dep_exprs:
  181. i += 1
  182. continue
  183. for n in expr.inputs:
  184. n.users.remove(expr)
  185. self._exprs.remove(expr)
  186. def interpret(self, *inputs):
  187. node2value = {}
  188. for n, v in zip(self._inputs, inputs):
  189. node2value[n] = v
  190. for expr in self._exprs:
  191. values = expr.interpret(*list(node2value[i] for i in expr.inputs))
  192. if values is not None:
  193. for n, v in zip(expr.outputs, values):
  194. node2value[n] = v
  195. return list(node2value[i] for i in self._outputs)
  196. def __repr__(self):
  197. return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
  198. ", ".join(str(i) for i in self._inputs),
  199. "\n\t".join(str(i) for i in self._exprs),
  200. ", ".join(str(i) for i in self._outputs),
  201. )
  202. def _get_meth_name(obj, func):
  203. tp = obj if isinstance(obj, type) else type(obj)
  204. for cls in tp.mro():
  205. for k, v in cls.__dict__.items():
  206. if v == func:
  207. return k
  208. return None
  209. def _wrapped_function(orig_func):
  210. @functools.wraps(orig_func)
  211. def wrapped_fn(*args, **kwargs):
  212. if is_tracing_module():
  213. unset_module_tracing()
  214. inputs, tree_def = tree_flatten(
  215. (args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  216. )
  217. for i in inputs:
  218. if not NodeMixin.get(i, None):
  219. if isinstance(i, (RawTensor, NodeMixin)):
  220. NodeMixin.wrap_safe(i, Constant.make(i))
  221. meth_name = _get_meth_name(args[0], wrapped_fn)
  222. if meth_name:
  223. self = inputs[0]
  224. if meth_name == "__new__":
  225. if all([not isinstance(i, RawTensor) for i in inputs]):
  226. # only trace Tensor.__new__() when there are tensors in args
  227. set_module_tracing()
  228. return orig_func(*args, **kwargs)
  229. if isinstance(args[1], RawTensor):
  230. node = NodeMixin.get(inputs[1])
  231. inputs[1] = copy.copy(inputs[1])
  232. # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
  233. NodeMixin.wrap_safe(inputs[1], node)
  234. args, kwargs = tree_def.unflatten(inputs)
  235. call_node = CallMethod.make(self, meth_name)
  236. else:
  237. call_node = CallMethod.make(NodeMixin.get(self), meth_name)
  238. call_node.add_inputs(inputs[1:])
  239. else:
  240. call_node = CallFunction.make(orig_func)
  241. call_node.add_inputs(inputs)
  242. call_node.arg_def = tree_def
  243. outputs = orig_func(*args, **kwargs)
  244. if meth_name == "__new__":
  245. call_node.add_outputs(outputs, False)
  246. else:
  247. call_node.add_outputs(outputs)
  248. set_module_tracing()
  249. return outputs
  250. return orig_func(*args, **kwargs)
  251. return wrapped_fn
  252. class TracedModuleBuilder(NodeMixin):
  253. _mod = None # type: Module
  254. _body = None # type: InternalGraph
  255. _is_builtin = None # type: bool
  256. __builder_attributes__ = [
  257. "_mod",
  258. "_body",
  259. "_NodeMixin__node",
  260. "_is_builtin",
  261. "build",
  262. ]
  263. def __init__(self, mod, is_top_module=False):
  264. super(TracedModuleBuilder, self).__init__()
  265. self._mod = mod
  266. self._body = None
  267. self._is_builtin = module_tracer.is_builtin(mod)
  268. def build(self):
  269. if self._is_builtin:
  270. node = NodeMixin.get(self)
  271. node.module_type = type(self._mod)
  272. return self._mod
  273. else:
  274. node = NodeMixin.get(self)
  275. traced_module = TracedModule(node)
  276. for k, v in self.__dict__.items():
  277. if k not in TracedModuleBuilder.__builder_attributes__:
  278. if isinstance(v, TracedModuleBuilder):
  279. v = v.build()
  280. setattr(traced_module, k, v)
  281. traced_module.m_node.attr_type_map[k] = type(v)
  282. return traced_module
  283. def __call__(self, *args, **kwargs):
  284. assert isinstance(self._mod, Module)
  285. # prepare args and kwargs for inner graph
  286. def mark_constant(x):
  287. node = NodeMixin.get(x, None)
  288. if node is None: # capture as constant
  289. NodeMixin.wrap(x, lambda: Constant.make(x))
  290. inputs, tree_def = tree_flatten(
  291. ((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
  292. )
  293. for i in inputs:
  294. mark_constant(i)
  295. callnode = CallMethod.make(NodeMixin.get(self))
  296. callnode.add_inputs(inputs[1:])
  297. callnode.arg_def = tree_def
  298. if self._is_builtin:
  299. unset_module_tracing()
  300. rst = self._mod(*args, **kwargs)
  301. outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
  302. set_module_tracing()
  303. if self._is_builtin:
  304. self._body = None
  305. else:
  306. self._body = InternalGraph()
  307. active_module_tracer().push_scope(self._body)
  308. # rebind self to new input node
  309. orig_self = NodeMixin.get(self)
  310. NodeMixin.wrap_safe(
  311. self, Input.make("self", NodeMixin.get_wrapped_type(self))
  312. )
  313. origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
  314. # prepare args and kwargs for inner graph
  315. def wrap(x):
  316. NodeMixin.wrap(
  317. x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
  318. )
  319. return x
  320. args = [self]
  321. for i in inputs[1:]:
  322. args.append(wrap(i))
  323. args, kwargs = tree_def.unflatten(args)
  324. active_module_tracer().patcher.auto_patch(
  325. getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
  326. )
  327. rst = type(self._mod).forward(*args, **kwargs)
  328. outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
  329. for i in (
  330. outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
  331. ):
  332. active_module_tracer().current_scope().add_output(NodeMixin.get(i))
  333. NodeMixin.wrap_safe(self, orig_self)
  334. for arg, node in zip(inputs[1:], origin_inp_node):
  335. if node:
  336. NodeMixin.wrap_safe(arg, node)
  337. active_module_tracer().pop_scope()
  338. # rebind output to outer graph
  339. callnode.add_outputs(outputs)
  340. self_node = NodeMixin.get(self)
  341. self_node.argdef_graph_map[callnode.arg_def] = self._body
  342. self_node.argdef_outdef_map[callnode.arg_def] = out_def
  343. return rst
  344. def __getattr__(self, name):
  345. if name not in self._mod.__dict__:
  346. attr = getattr(type(self._mod), name).__get__(self, type(self))
  347. else:
  348. attr = getattr(self._mod, name)
  349. if isinstance(attr, Module):
  350. attr = TracedModuleBuilder(attr)
  351. setattr(self, name, attr)
  352. NodeMixin.wrap(
  353. attr,
  354. lambda: GetAttr.make(
  355. NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
  356. ),
  357. )
  358. return attr
  359. def __getattribute__(self, name):
  360. if name in TracedModuleBuilder.__builder_attributes__:
  361. return super().__getattribute__(name)
  362. else:
  363. wrapped = super().__getattribute__(name)
  364. if name in self._mod.__dict__:
  365. if not NodeMixin.get(wrapped, None):
  366. assert not self._is_builtin
  367. NodeMixin.wrap(
  368. wrapped,
  369. lambda: GetAttr.make(
  370. NodeMixin.get(self),
  371. name,
  372. type=NodeMixin.get_wrapped_type(wrapped),
  373. ),
  374. )
  375. else:
  376. node = NodeMixin.get(wrapped)
  377. expr = GetAttr.make(
  378. NodeMixin.get(self),
  379. name,
  380. type=NodeMixin.get_wrapped_type(wrapped),
  381. ).expr
  382. expr.outputs[0] = node
  383. return wrapped
  384. class _expr_iter:
  385. def __init__(self, graph: InternalGraph):
  386. self.graph = graph
  387. def __iter__(self):
  388. for expr in self.graph._exprs:
  389. if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
  390. yield expr
  391. if expr.graph is not None:
  392. yield from expr.graph.exprs
  393. else:
  394. yield expr
  395. class ExprFilter:
  396. def __init__(self, expr_iter: Iterable):
  397. self._iter = expr_iter
  398. def __iter__(self):
  399. return iter(self._iter)
  400. def call_function(self, func):
  401. return ExprFilterCallFunction(self, func)
  402. def call_method(self, method):
  403. return ExprFilterCallMethod(self, method)
  404. def as_list(self):
  405. return list(self)
  406. def as_dict(self):
  407. raise NotImplementedError("need key")
  408. def as_unique(self):
  409. (expr,) = self
  410. return expr
  411. def as_count(self):
  412. return sum(1 for _ in self)
  413. class ExprFilterCallFunction(ExprFilter):
  414. def __init__(self, expr_iter, func: Callable = None):
  415. super().__init__(expr_iter)
  416. self.func = func
  417. def __iter__(self):
  418. for i in self._iter:
  419. if not isinstance(i, CallFunction):
  420. continue
  421. if self.func is None or i.func == self.func:
  422. yield i
  423. class ExprFilterCallMethod(ExprFilter):
  424. def __init__(self, expr_iter, method: str = None):
  425. super().__init__(expr_iter)
  426. self.method = method
  427. def __iter__(self):
  428. for i in self._iter:
  429. if not isinstance(i, CallMethod):
  430. continue
  431. if self.method is None or i.method == self.method:
  432. yield i
  433. class TracedModule(Module):
  434. """
  435. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
  436. interpreted by CallMethod Expr.
  437. """
  438. m_node = None # type: ModuleNode
  439. def __init__(self, node):
  440. super(TracedModule, self).__init__()
  441. self.m_node = node
  442. def forward(self, *args, **kwargs):
  443. inputs, treedef = tree_flatten(
  444. ((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
  445. )
  446. assert treedef in self.m_node.argdef_graph_map
  447. inputs = filter(
  448. lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
  449. ) # allow TracedModuleBuilder for retrace.
  450. outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
  451. out_def = self.m_node.argdef_outdef_map[treedef]
  452. outputs = out_def.unflatten(outputs)
  453. return outputs
  454. @property
  455. def graph(self):
  456. assert len(self.m_node.argdef_graph_map) == 1
  457. return list(self.m_node.argdef_graph_map.values())[0]
  458. @property
  459. def exprs(self):
  460. return self.graph.exprs
  461. def flatten(self):
  462. """
  463. Get a new module, which eliminates ``GetAttr`` and has no hierarchy.
  464. :return: :class:`TracedModule`
  465. """
  466. new_module = copy.deepcopy(self)
  467. def _flatten_subgraph(graph, module, call=None):
  468. if graph is None:
  469. assert not isinstance(module, TracedModule)
  470. const = Constant(module)
  471. const.outputs[0] = call.inputs[0]
  472. const.outputs[0].expr = const
  473. return [const, call]
  474. exprs = []
  475. for expr in graph._exprs:
  476. # replace inputs for submodule's expr
  477. for idx, inp in enumerate(expr.inputs):
  478. if call and inp in graph._inputs:
  479. inp_idx = graph._inputs.index(inp)
  480. expr.inputs[idx] = call.inputs[inp_idx]
  481. call.inputs[inp_idx].users.append(expr)
  482. # replace outputs for submodule's expr
  483. for idx, outp in enumerate(expr.outputs):
  484. if call and outp in graph._outputs:
  485. oup_idx = graph._outputs.index(outp)
  486. expr.outputs[idx] = call.outputs[oup_idx]
  487. call.outputs[oup_idx].expr = expr
  488. if isinstance(expr, GetAttr):
  489. # replace GetAttr with Constant
  490. if isinstance(expr.outputs[0], TensorNode):
  491. const = Constant(getattr(module, expr.name))
  492. const.outputs = expr.outputs
  493. const.outputs[0].expr = const
  494. exprs.append(const)
  495. elif isinstance(expr, CallMethod):
  496. obj_node = expr.inputs[0]
  497. if isinstance(obj_node, ModuleNode):
  498. pre_expr = expr.inputs[0].expr
  499. if isinstance(pre_expr, GetAttr):
  500. (obj,) = expr.inputs[0].expr.interpret(module)
  501. exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
  502. else:
  503. # module has been replaced.
  504. assert isinstance(pre_expr, Constant)
  505. else:
  506. exprs.append(expr)
  507. else:
  508. exprs.append(expr)
  509. if call is not None:
  510. for i in call.inputs:
  511. i.users.remove(call)
  512. return exprs
  513. new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
  514. return new_module
  515. def __getstate__(self):
  516. d = self.__dict__
  517. for k in Module.__dict__:
  518. d.pop(k, None)
  519. return d
  520. def cpp_apply_module_trace(opdef, *args):
  521. return Apply.apply_module_trace_hook(opdef, *args)
  522. def register_as_builtin(mod_cls: Type[Module]) -> None:
  523. """
  524. Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
  525. param mod_cls: the Module class which will be threated as builtin module in tracing
  526. """
  527. module_tracer.register_as_builtin(mod_cls)
  528. def _register_all_builtin_module():
  529. for sub_mod in [M, M.qat, M.quantized]:
  530. for m in getmembers(sub_mod):
  531. if (
  532. isclass(m[1])
  533. and issubclass(m[1], M.Module)
  534. and m[1] is not M.Sequential
  535. ):
  536. module_tracer.register_as_builtin(m[1])
  537. def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
  538. """
  539. Traces module ``mod`` and returns corresponding TracedModule.
  540. param mod: the module will be converted to TracedModule
  541. param input: the positional arguments passed to forward method of ``mod``
  542. param kwargs: the keyword arguments passed to forward method of ``mod``
  543. """
  544. assert active_module_tracer() is None
  545. try:
  546. use_sym_shape = set_symbolic_shape(True)
  547. set_module_tracing()
  548. set_active_module_tracer(module_tracer(_wrapped_function))
  549. with active_module_tracer().patcher:
  550. global_scope = InternalGraph()
  551. active_module_tracer().push_scope(global_scope)
  552. builder = TracedModuleBuilder(mod, True)
  553. NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
  554. inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf)
  555. for _, i in enumerate(inputs):
  556. if isinstance(i, RawTensor):
  557. NodeMixin.wrap_safe(
  558. i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
  559. )
  560. builder(*args, **kwargs)
  561. active_module_tracer().pop_scope()
  562. return builder.build()
  563. finally:
  564. set_symbolic_shape(use_sym_shape)
  565. set_active_module_tracer(None)
  566. unset_module_tracing()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台