You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936
  1. import collections
  2. import contextlib
  3. import functools
  4. import itertools
  5. import json
  6. import typing
  7. import warnings
  8. import weakref
  9. import numpy as np
  10. from ..core._imperative_rt import GraphProfiler
  11. from ..core._imperative_rt.ops import OprAttr
  12. from ..core._trace_option import set_tensor_shape
  13. from ..core.ops.special import Const
  14. from ..core.tensor import megbrain_graph as G
  15. from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
  16. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  17. from ..core.tensor.tensor import Tensor
  18. from .sublinear_memory_config import SublinearMemoryConfig
  19. class TraceMismatchError(RuntimeError):
  20. pass
  21. active_trace = None
  22. skip_tracing = False
  23. @contextlib.contextmanager
  24. def exclude_from_trace():
  25. global skip_tracing
  26. if skip_tracing:
  27. yield
  28. return
  29. try:
  30. skip_tracing = True
  31. if active_trace is not None:
  32. active_trace._begin_excluded_region()
  33. yield
  34. finally:
  35. skip_tracing = False
  36. class TensorInfo:
  37. __slots__ = (
  38. # collected attributes
  39. "external",
  40. "exported",
  41. "data_read",
  42. "shape_read",
  43. "value_read",
  44. "device",
  45. "dtype",
  46. "shape",
  47. "bound_data",
  48. # resources for execution
  49. "varnode",
  50. "data_setter",
  51. "shape_reader",
  52. "value_reader",
  53. "data_reader",
  54. )
  55. def __init__(self):
  56. self.exported = None
  57. self.data_read = None
  58. self.shape_read = None
  59. self.value_read = None
  60. self.bound_data = None
  61. self.data_setter = None
  62. self.shape_reader = None
  63. self.value_reader = None
  64. self.data_reader = None
  65. class trace:
  66. """
  67. Wraps a callable and provide:
  68. * tracing via :meth:`.trace` and :meth:`.dump`
  69. * accelerated evalutaion via :meth:`.__call__`
  70. :param function: the function will be traced.
  71. :param symbolic: whether to apply symbolic execution for tracing. Default: False
  72. :param capture_as_const: capture global vars or closures as const value. Default: False
  73. :param sublinear_memory_config: configuration for sublinear memory optimization.
  74. If not None, it enables sublinear memory optimization with given setting.
  75. :param profiling: whether to profile compiled trace. Default: False
  76. :param opt_level: optimization level for compiling trace.
  77. :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
  78. """
  79. def __new__(cls, *args, **kwargs):
  80. if not args:
  81. return functools.partial(cls, **kwargs)
  82. return super().__new__(cls)
  83. def __init__(
  84. self,
  85. function,
  86. symbolic=False,
  87. capture_as_const=False,
  88. sublinear_memory_config: SublinearMemoryConfig = None,
  89. profiling: bool = False,
  90. opt_level: int = None,
  91. tensor_shape: bool = True,
  92. ):
  93. self.__wrapped__ = function
  94. self._symbolic = symbolic
  95. self._capture_as_const = capture_as_const
  96. self._sublinear_memory_config = sublinear_memory_config
  97. self._profiling = profiling
  98. self._profiler = None
  99. self._graph_opt_level = opt_level
  100. self._tensor_shape = tensor_shape
  101. self._untraced = True
  102. self._tinfo = [] # handle -> TensorInfo
  103. self._seq = []
  104. self._pc = 0
  105. self._graph = None
  106. self._need_reset_nodes = None
  107. self._lazy_eval_graph = None
  108. self._lazy_eval_tensors = []
  109. self._lazy_eval_tensor_count = 0
  110. self._active_tensors = weakref.WeakSet()
  111. self._tensor_remaps = None
  112. self._inputs_to_restore = None
  113. self._arg_bindings = None
  114. self._kwarg_bindings = None
  115. self._output_bindings = None
  116. self._output_names = None
  117. set_tensor_shape(self._tensor_shape)
  118. def _new_handle(self):
  119. handle = len(self._tinfo)
  120. info = TensorInfo()
  121. self._tinfo.append(info)
  122. return handle, info
  123. def _apply_op(self, op, args):
  124. assert not self._untraced
  125. # check against trace
  126. if self._pc >= len(self._seq):
  127. raise TraceMismatchError("trace should end here, but more op observed")
  128. record = self._seq[self._pc]
  129. op_, ihandles, ohandles = record
  130. if op != op_:
  131. # FIXME: will be removed once better rng implementation is done
  132. if isinstance(op, OprAttr) and (
  133. op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type
  134. ):
  135. if op.param[8:] != op_.param[8:]:
  136. raise TraceMismatchError("op different from last time")
  137. else:
  138. raise TraceMismatchError("op different from last time")
  139. if len(ihandles) != len(args):
  140. raise TraceMismatchError("op input size different from last time")
  141. for h, x in zip(ihandles, args):
  142. info = self._tinfo[h]
  143. if info.external:
  144. if (
  145. x.__class__ is CompiledTensorProxy
  146. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  147. ):
  148. raise TraceMismatchError(
  149. "failed to capture: input was an external tensor "
  150. "last time, got an internal tensor this time"
  151. )
  152. if info.bound_data:
  153. if x.__class__ is CompiledTensorProxy:
  154. raise TraceMismatchError(
  155. "const capture violated: was an external tensor "
  156. "last time, got an internal tensor this time"
  157. )
  158. if x._handle != info.bound_data._handle:
  159. if not np.array_equal(x.numpy(), info.bound_data.numpy()):
  160. raise TraceMismatchError(
  161. "const capture violated: got "
  162. "a different tensor this time"
  163. )
  164. else:
  165. if info.dtype != x.dtype:
  166. raise TraceMismatchError(
  167. "failed to capture: different dtype from last time"
  168. )
  169. if info.device != x.device:
  170. raise TraceMismatchError(
  171. "failed to capture: different device from last time"
  172. )
  173. info.data_setter.set_value(x._dev_tensor())
  174. else:
  175. if x.__class__ is not CompiledTensorProxy:
  176. if x not in self._tensor_remaps:
  177. raise TraceMismatchError(
  178. "unexpected capture: trying to use an external tensor as "
  179. "input, but that input was an internal tensor last time"
  180. )
  181. else:
  182. x = self._tensor_remaps[x]
  183. if x._CompiledTensorProxy__handle != h:
  184. raise TraceMismatchError(
  185. "mis-wiring: input edge to an data flow "
  186. "graph node is different from last time"
  187. )
  188. self._pc += 1
  189. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  190. self._active_tensors.update(outputs)
  191. return outputs
  192. def _record_op(self, op, inputs, outputs):
  193. if skip_tracing:
  194. for x in inputs:
  195. h = getattr(x, "_TraceMixin__handle", None)
  196. if h is not None:
  197. self._tinfo[h].data_read = True
  198. return
  199. ihandles = []
  200. for x in inputs:
  201. h = getattr(x, "_TraceMixin__handle", None)
  202. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  203. h, info = self._new_handle()
  204. info.external = True
  205. info.device = x.device
  206. info.dtype = x.dtype
  207. info.shape = x.shape
  208. if self._capture_as_const:
  209. info.bound_data = x
  210. ihandles.append(h)
  211. ohandles = []
  212. for x in outputs:
  213. h, info = self._new_handle()
  214. ohandles.append(h)
  215. info.external = False
  216. TraceMixin._TraceMixin__inject(x, h)
  217. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  218. self._active_tensors.update(outputs)
  219. def _record_const(self, op, outputs):
  220. pass
  221. @contextlib.contextmanager
  222. def _setup(self):
  223. global active_trace
  224. if active_trace:
  225. raise NotImplementedError("sorry, not implemented: nested trace")
  226. active_trace = self
  227. if self._untraced:
  228. apply.enable(apply_with_tracing)
  229. apply.enable(apply_const_with_tracing)
  230. if self._symbolic:
  231. apply.enable(apply_symbolic_mode)
  232. apply.enable(apply_const_symbolic_mode)
  233. self._lazy_eval_graph = G.Graph()
  234. else:
  235. apply.enable(apply_compiled_mode)
  236. if self._graph is None:
  237. self._compile()
  238. self._graph.execute()
  239. yield
  240. escaped_tensors = tuple(self._active_tensors)
  241. self._active_tensors.clear()
  242. if self._untraced:
  243. for x in escaped_tensors:
  244. info = self._tinfo[x._TraceMixin__handle]
  245. info.data_read = True
  246. x._TraceMixin__restore()
  247. if self._inputs_to_restore:
  248. for x in self._inputs_to_restore:
  249. x._TraceMixin__restore()
  250. if self._symbolic:
  251. # eval lazy eval tensors
  252. if self._lazy_eval_tensors:
  253. lazy_eval_tensors = []
  254. visited = set()
  255. readers = []
  256. for x in self._lazy_eval_tensors:
  257. x = x()
  258. if x is None or x in visited:
  259. continue
  260. reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  261. readers.append(reader)
  262. lazy_eval_tensors.append(x)
  263. visited.add(x)
  264. self._apply_graph_options(self._lazy_eval_graph)
  265. self._lazy_eval_graph.compile(*readers)
  266. self._lazy_eval_graph()
  267. for r, x in zip(readers, lazy_eval_tensors):
  268. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  269. self._lazy_eval_graph = None
  270. self._lazy_eval_tensors = None
  271. self._untraced = False
  272. else:
  273. if self._pc != len(self._seq):
  274. raise TraceMismatchError("premature end")
  275. for x in escaped_tensors:
  276. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  277. self._graph.wait()
  278. self._reset_exec_env()
  279. self._pc = 0
  280. self._tensor_remaps = None
  281. apply.disable(apply_with_tracing)
  282. apply.disable(apply_const_with_tracing)
  283. apply.disable(apply_symbolic_mode)
  284. apply.disable(apply_const_symbolic_mode)
  285. apply.disable(apply_compiled_mode)
  286. active_trace = None
  287. def _begin_excluded_region(self):
  288. if self._capture_as_const:
  289. raise RuntimeError(
  290. "exclude_from_trace cannot be used with capture_as_const"
  291. )
  292. if self._untraced:
  293. # conditionally reading a compiled tensor in excluded region
  294. # is permitted, so we have to assume every tensor might be read
  295. for x in self._active_tensors:
  296. info = self._tinfo[x._TraceMixin__handle]
  297. info.exported = True
  298. info.data_read = True
  299. def _apply_graph_options(self, graph):
  300. graph.options.seq_opt.enable_seq_comp_node_opt = False
  301. # graph opt level
  302. if self._graph_opt_level is not None:
  303. graph.options.graph_opt_level = self._graph_opt_level
  304. # sublinear
  305. if self._sublinear_memory_config is not None:
  306. graph.options.enable_sublinear_memory_opt = True
  307. sublinear_config = graph.options.sublinear_mem_config
  308. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  309. sublinear_config.genetic_nr_iter = (
  310. self._sublinear_memory_config.genetic_nr_iter
  311. )
  312. sublinear_config.genetic_pool_size = (
  313. self._sublinear_memory_config.genetic_pool_size
  314. )
  315. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  316. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  317. # profile
  318. if self._profiling:
  319. self._profiler = GraphProfiler(graph)
  320. def _compile(self):
  321. graph = self._graph = G.Graph()
  322. graph.options.no_force_inplace = True
  323. self._apply_graph_options(graph)
  324. # graph.options.graph_opt_level = 0
  325. need_reset_nodes = self._need_reset_nodes = []
  326. # links enforce ordering of I/O nodes
  327. links = ()
  328. readers = []
  329. if self._capture_as_const:
  330. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  331. info = self._tinfo[h]
  332. opnode = info.data_setter = G.InputNode(
  333. device=info.device, dtype=info.dtype, shape=info.shape, graph=graph
  334. )
  335. need_reset_nodes.append(opnode)
  336. info.varnode = opnode.outputs[0]
  337. links += opnode.outputs[1:]
  338. for op, ihandles, ohandles in self._seq:
  339. ivars = []
  340. for h in ihandles:
  341. info = self._tinfo[h]
  342. if not hasattr(info, "varnode"):
  343. assert info.external
  344. if info.bound_data:
  345. info.varnode = graph.make_const(info.bound_data._dev_tensor())
  346. else:
  347. opnode = info.data_setter = G.InputNode(
  348. *links,
  349. device=info.device,
  350. dtype=info.dtype,
  351. shape=info.shape,
  352. graph=graph,
  353. )
  354. need_reset_nodes.append(opnode)
  355. info.varnode, *links = opnode.outputs
  356. ivars.append(info.varnode)
  357. ovars = apply(op, *ivars)
  358. assert len(ovars) == len(ohandles)
  359. for h, v in zip(ohandles, ovars):
  360. info = self._tinfo[h]
  361. info.varnode = v
  362. def add_reader(opnode):
  363. nonlocal links
  364. need_reset_nodes.append(opnode)
  365. readers.append(opnode.outputs[0])
  366. links = opnode.outputs
  367. if info.data_read:
  368. # Shape can be obtained from data so doesn't need its own
  369. # output node. On the other hand, value is read separately
  370. # to leverage eager h2d copy
  371. info.shape_read = False
  372. opnode = info.data_reader = G.OutputNode(v, *links)
  373. add_reader(opnode)
  374. if info.value_read:
  375. opnode = info.value_reader = G.ValueOutputNode(v, *links)
  376. add_reader(opnode)
  377. if info.shape_read:
  378. opnode = info.shape_reader = G.AttrOutputNode(v, *links)
  379. add_reader(opnode)
  380. graph.compile(*readers)
  381. def _reset_exec_env(self):
  382. for opnode in self._need_reset_nodes:
  383. opnode.reset()
  384. def _require_shape(self, handle):
  385. info = self._tinfo[handle]
  386. info.shape_read = True
  387. def _require_value(self, handle):
  388. info = self._tinfo[handle]
  389. info.value_read = True
  390. def _require_data(self, handle):
  391. info = self._tinfo[handle]
  392. info.data_read = True
  393. def __call__(self, *args, **kwargs):
  394. with self._setup():
  395. if self._capture_as_const:
  396. self._process_inputs(*args, **kwargs)
  397. outputs = self.__wrapped__(*args, **kwargs)
  398. if self._capture_as_const:
  399. self._process_outputs(outputs)
  400. return outputs
  401. def dump(self, file, *, arg_names=None, output_names=None, append=False, **kwargs):
  402. r"""Serializes trace to file system.
  403. :param file: output file, could be file object or filename.
  404. :param arg_names: names of the input tensors in the traced function.
  405. :param output_names: names of the output tensors in the traced function,
  406. use the default name if not specified.
  407. :param append: whether output is appended to ``file``.
  408. Only works when ``file`` is str.
  409. :Keyword Arguments:
  410. * enable_io16xc32 --
  411. whether to use float16 for I/O between oprs and use
  412. float32 as internal computation precision. Note the output var would be
  413. changed to float16.
  414. * enable_ioc16 --
  415. whether to use float16 for both I/O and computation
  416. precision.
  417. * enable_hwcd4 --
  418. whether to use NHWCD4 data layout. This is faster on some
  419. OpenCL backend.
  420. * enable_nchw88 --
  421. whether to use NCHW88 data layout, currently
  422. used in X86 AVX backend.
  423. * enable_nchw44 --
  424. whether to use NCHW44 data layout, currently
  425. used in arm backend.
  426. * enable_nchw44_dot --
  427. whether to use NCHW44_dot data layout, currently
  428. used in armv8.2+dotprod backend.
  429. * enable_nchw4 --
  430. whether to use NCHW4 data layout, currently
  431. used in nvidia backend(based on cudnn).
  432. * enable_nchw32 --
  433. whether to use NCHW32 data layout, currently
  434. used in nvidia backend with tensorcore(based on cudnn).
  435. * enable_chwn4 --
  436. whether to use CHWN4 data layout, currently
  437. used in nvidia backend with tensorcore.
  438. * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
  439. into one opr.
  440. * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
  441. input for inference on nvidia backend(this optimization pass will
  442. result in mismatch of the precision of output of training and
  443. inference)
  444. """
  445. if not self._capture_as_const:
  446. raise ValueError(
  447. "you must specify capture_as_const=True at __init__ to use dump"
  448. )
  449. if self._untraced:
  450. raise RuntimeError("should run at least once before calling dump")
  451. if self._output_names and output_names:
  452. raise TypeError(
  453. "cannot specify output_names when output is already in dict format"
  454. )
  455. if output_names and not isinstance(output_names, collections.abc.Sequence):
  456. output_names = (output_names,)
  457. if output_names and len(output_names) != len(self._output_bindings):
  458. raise ValueError(
  459. "wrong number of output_names, should be {} values".format(
  460. len(self._output_bindings)
  461. )
  462. )
  463. if arg_names and not isinstance(arg_names, collections.abc.Sequence):
  464. arg_names = (arg_names,)
  465. if arg_names and len(arg_names) != len(self._arg_bindings):
  466. raise ValueError(
  467. "wrong number of arg_names, should be {} values".format(
  468. len(self._arg_bindings)
  469. )
  470. )
  471. output_names = output_names or self._output_names
  472. h2v = {}
  473. graph = G.Graph()
  474. for i, h in enumerate(self._arg_bindings):
  475. info = self._tinfo[h]
  476. h2v[h] = graph.make_h2d(
  477. dtype=info.dtype,
  478. device=info.device,
  479. shape=info.shape,
  480. name=arg_names[i] if arg_names else None,
  481. )
  482. for k, h in self._kwarg_bindings.items():
  483. info = self._tinfo[h]
  484. h2v[h] = graph.make_h2d(
  485. dtype=info.dtype, device=info.device, shape=info.shape, name=k
  486. )
  487. for op, ihandles, ohandles in self._seq:
  488. ivars = []
  489. for h in ihandles:
  490. info = self._tinfo[h]
  491. if h not in h2v:
  492. assert info.external
  493. assert info.bound_data
  494. h2v[h] = graph.make_const(info.bound_data._dev_tensor())
  495. ivars.append(h2v[h])
  496. ovars = apply(op, *ivars)
  497. assert len(ovars) == len(ohandles)
  498. h2v.update(zip(ohandles, ovars))
  499. dest_vars = []
  500. for i, h in enumerate(self._output_bindings):
  501. v = h2v[h]
  502. if output_names:
  503. v.name = output_names[i]
  504. dest_vars.append(v)
  505. dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
  506. if isinstance(file, str):
  507. permission = "wb" if append == False else "ab"
  508. file = open(file, permission)
  509. file.write(G.dump_graph(*dest_vars))
  510. def _process_inputs(self, *args, **kwargs):
  511. if self._untraced:
  512. self._inputs_to_restore = []
  513. def record_input(x):
  514. if x is None:
  515. return
  516. h, info = self._new_handle()
  517. info.external = False
  518. info.device = x.device
  519. info.dtype = x.dtype
  520. info.shape = x.shape
  521. TraceMixin._TraceMixin__inject(x, h)
  522. self._inputs_to_restore.append(x)
  523. return h
  524. self._arg_bindings = []
  525. for i, x in enumerate(args):
  526. x = find_raw_tensor(x)
  527. if x is None:
  528. raise TypeError(
  529. "positional arguments should all be tensor "
  530. "but args[%d] cannot be recognized as one" % i
  531. )
  532. self._arg_bindings.append(record_input(x))
  533. self._kwarg_bindings = {}
  534. for k, x in kwargs.items():
  535. x = find_raw_tensor(x)
  536. if x is not None:
  537. self._kwarg_bindings[k] = record_input(x)
  538. else:
  539. if len(args) != len(self._arg_bindings):
  540. raise TraceMismatchError("positional argument length mismatch")
  541. self._tensor_remaps = {}
  542. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  543. x = find_raw_tensor(x)
  544. if x is None:
  545. raise TypeError(
  546. "positional arguments should all be tensor "
  547. "but args[%d] cannot be recognized as one" % i
  548. )
  549. info = self._tinfo[h]
  550. if x.dtype != info.dtype:
  551. raise TypeError("args[%d].dtype different from last time" % i)
  552. if x.device != info.device:
  553. raise TypeError("args[%d].device different from last time" % i)
  554. info.data_setter.set_value(x._dev_tensor())
  555. self._tensor_remaps[x] = CompiledTensorProxy(h)
  556. kwargs_tensors = {}
  557. for k, x in kwargs.items():
  558. x = find_raw_tensor(x)
  559. if x is not None:
  560. kwargs_tensors[k] = x
  561. if set(kwargs_tensors) != set(self._kwarg_bindings):
  562. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  563. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  564. if too_many:
  565. raise TraceMismatchError(
  566. "keyword arguments found to be tensor this time "
  567. "but were non-tensor previously: %s" % " ".join(too_many)
  568. )
  569. if too_few:
  570. raise TraceMismatchError(
  571. "keyword arguments found to be non-tensor this time "
  572. "but were tensor previously: %s" % " ".join(too_few)
  573. )
  574. for k, h in self._kwarg_bindings.items():
  575. x = kwargs_tensors[k]
  576. info = self._tinfo[h]
  577. if x.dtype != info.dtype:
  578. raise TypeError("kwargs[%s].dtype different from last time" % k)
  579. if x.device != info.device:
  580. raise TypeError("kwargs[%s].device different from last time" % k)
  581. info.data_setter.set_value(x._dev_tensor())
  582. self._tensor_remaps[x] = CompiledTensorProxy(h)
  583. def _process_outputs(self, outputs):
  584. output_names = None
  585. if isinstance(outputs, collections.abc.Mapping):
  586. output_names, outputs = zip(*sorted(outputs.items()))
  587. elif not isinstance(outputs, collections.abc.Sequence):
  588. outputs = (outputs,)
  589. if not self._untraced:
  590. if output_names != self._output_names:
  591. too_many = set(output_names) - set(self._output_names)
  592. too_few = set(self._output_names) - set(output_names)
  593. if too_many:
  594. raise TraceMismatchError(
  595. "output has more keys than last time: %s" % " ".join(too_many)
  596. )
  597. if too_few:
  598. raise TraceMismatchError(
  599. "output has less keys than last time: %s" % " ".join(too_few)
  600. )
  601. if len(outputs) != len(self._output_bindings):
  602. raise TraceMismatchError("output size differs from last time")
  603. else:
  604. self._output_names = output_names
  605. self._output_bindings = []
  606. for i, x in enumerate(outputs):
  607. x = find_raw_tensor(x)
  608. if x is None:
  609. raise TypeError("every item of return value should be tensor")
  610. if self._untraced:
  611. if not isinstance(x, TraceMixin):
  612. raise RuntimeError("output is not computed from inputs")
  613. h = x._TraceMixin__handle
  614. self._output_bindings.append(h)
  615. else:
  616. if not isinstance(x, CompiledTensorProxy):
  617. raise RuntimeError("output is not computed from inputs")
  618. h = x._CompiledTensorProxy__handle
  619. if h != self._output_bindings[i]:
  620. raise TraceMismatchError(
  621. "retval[%s] is a different tensor than last time"
  622. % (output_names and output_names[i] or i)
  623. )
  624. def get_profile(self):
  625. """
  626. Get profiling result for compiled trace.
  627. :return: a json compatible object.
  628. """
  629. if not self._profiler:
  630. raise RuntimeError("trace is not set with profiling=True")
  631. return json.loads(self._profiler.get())
  632. class CompiledTensorProxy(RawTensor):
  633. """
  634. Duck-typed RawTensor
  635. """
  636. def __init__(self, handle):
  637. self.__handle = handle
  638. self.__info = active_trace._tinfo[handle]
  639. self.__shape = None
  640. self.__data = None
  641. self.__value = None
  642. @property
  643. def dtype(self):
  644. return self.__info.varnode.dtype
  645. @property
  646. def device(self):
  647. return self.__info.varnode.device
  648. @property
  649. def shape(self):
  650. if self.__shape is None:
  651. if self.__info.shape_read:
  652. self.__shape = self.__info.shape_reader.get_value().shape
  653. elif self.__info.data_read:
  654. self.__shape = self._dev_tensor().shape
  655. else:
  656. raise TraceMismatchError("shape of this tensor is not read in trace")
  657. return self.__shape
  658. def numpy(self):
  659. if self.__value is None:
  660. if self.__info.value_read:
  661. self.__value = self.__info.value_reader.get_value()
  662. elif self.__info.data_read:
  663. self.__value = self._dev_tensor().numpy()
  664. else:
  665. raise TraceMismatchError("value of this tensor is not read in trace")
  666. return self.__value
  667. def _dev_tensor(self):
  668. if self.__data is None:
  669. if not self.__info.data_read:
  670. raise TraceMismatchError("raw data of this tensor is not read in trace")
  671. self.__data = self.__info.data_reader.get_value()
  672. return self.__data
  673. def __del__(self):
  674. if self.__info.shape_read and self.__shape is not None:
  675. self.__info.shape_reader.drop_value()
  676. if self.__info.value_read and self.__value is not None:
  677. self.__info.value_reader.drop_value()
  678. if self.__info.data_read and self.__data is not None:
  679. self.__info.data_reader.drop_value()
  680. class LazyEvalTensor(RawTensor):
  681. def __init__(self, varnode):
  682. self.__varnode = varnode
  683. @property
  684. def dtype(self):
  685. return self.__varnode.dtype
  686. @property
  687. def device(self):
  688. return self.__varnode.device
  689. @property
  690. def shape(self):
  691. return self.__varnode.shape
  692. def numpy(self):
  693. return self.__varnode.value
  694. def _dev_tensor(self):
  695. raise RuntimeError("cannot access data during symbolic tracing")
  696. class TraceMixin:
  697. __subclass_cache = {}
  698. def __inject(self, handle):
  699. cache = __class__.__subclass_cache
  700. cls = self.__class__
  701. subcls = cache.get(cls)
  702. if subcls is None:
  703. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  704. self.__class__ = subcls
  705. self.__handle = handle
  706. self.__cls = cls
  707. return self
  708. def __restore(self):
  709. cls = self.__cls
  710. del self.__handle
  711. del self.__cls
  712. self.__class__ = cls
  713. return self
  714. @property
  715. def shape(self):
  716. if not skip_tracing:
  717. active_trace._require_shape(self.__handle)
  718. return super().shape
  719. def numpy(self):
  720. if not skip_tracing:
  721. active_trace._require_value(self.__handle)
  722. return super().numpy()
  723. def _dev_tensor(self):
  724. if not skip_tracing:
  725. active_trace._require_data(self.__handle)
  726. return super()._dev_tensor()
  727. class TracedRawTensor(TraceMixin, RawTensor):
  728. pass
  729. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  730. pass
  731. def assign_raw_tensor(lhs, rhs):
  732. handle = rhs._handle
  733. rhs.__dict__.clear()
  734. lhs.__dict__.clear()
  735. lhs.__class__ = RawTensor
  736. lhs.__init__(handle)
  737. # this hook turns RawTensor into LazyEvalTensor
  738. @apply.register()
  739. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  740. graph = active_trace._lazy_eval_graph
  741. ivars = [
  742. getattr(x, "_LazyEvalTensor__varnode", None)
  743. or graph.make_const(x._dev_tensor())
  744. for x in args
  745. ]
  746. ovars = apply(op, *ivars)
  747. outputs = [LazyEvalTensor(v) for v in ovars]
  748. active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs)
  749. return outputs
  750. apply.disable(apply_symbolic_mode)
  751. @apply.register()
  752. def apply_const_symbolic_mode(op: Const, *args: RawTensor):
  753. graph = active_trace._lazy_eval_graph
  754. ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
  755. active_trace._lazy_eval_tensors.append(weakref.ref(ret))
  756. return (ret,)
  757. apply.disable(apply_const_symbolic_mode)
  758. @apply.register()
  759. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  760. if skip_tracing:
  761. args = [
  762. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  763. for x in args
  764. ]
  765. return apply.super(op, *args)
  766. return active_trace._apply_op(op, args)
  767. apply.disable(apply_compiled_mode)
  768. # this hook injects TraceMixin
  769. @apply.register()
  770. def apply_with_tracing(op: OpDef, *args: RawTensor):
  771. outputs = apply.super(op, *args)
  772. active_trace._record_op(op, args, outputs)
  773. return outputs
  774. apply.disable(apply_with_tracing)
  775. @apply.register()
  776. def apply_const_with_tracing(op: Const, *args: RawTensor):
  777. outputs = apply.super(op, *args)
  778. active_trace._record_const(op, outputs)
  779. return outputs
  780. apply.disable(apply_const_with_tracing)
  781. class BrokenRawTensor(RawTensor):
  782. def __getattribute__(self, _):
  783. raise RuntimeError("broken due to misuse of tracing")
  784. def __setattr__(self, *_):
  785. raise RuntimeError("broken due to misuse of tracing")
  786. @functools.singledispatch
  787. def find_raw_tensor(x):
  788. return None
  789. @find_raw_tensor.register(RawTensor)
  790. def _(x):
  791. return x
  792. @find_raw_tensor.register(TensorWrapperBase)
  793. def _(x):
  794. x = getattr(x, "__wrapped__", None)
  795. if x is not None:
  796. return find_raw_tensor(x)
  797. @find_raw_tensor.register(Tensor)
  798. def _(x):
  799. x = getattr(x, "_data", None)
  800. if x is not None:
  801. return find_raw_tensor(x)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台