You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. import collections
  2. import contextlib
  3. import functools
  4. import itertools
  5. import json
  6. import typing
  7. import warnings
  8. import weakref
  9. import numpy as np
  10. from ..core._imperative_rt import GraphProfiler
  11. from ..core._imperative_rt.ops import OprAttr
  12. from ..core.ops.special import Const
  13. from ..core.tensor import megbrain_graph as G
  14. from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
  15. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  16. from ..core.tensor.tensor import Tensor
  17. from .sublinear_memory_config import SublinearMemoryConfig
  18. class TraceMismatchError(RuntimeError):
  19. pass
  20. active_trace = None
  21. skip_tracing = False
  22. @contextlib.contextmanager
  23. def exclude_from_trace():
  24. global skip_tracing
  25. if skip_tracing:
  26. yield
  27. return
  28. try:
  29. skip_tracing = True
  30. if active_trace is not None:
  31. active_trace._begin_excluded_region()
  32. yield
  33. finally:
  34. skip_tracing = False
  35. class TensorInfo:
  36. __slots__ = (
  37. # collected attributes
  38. "external",
  39. "exported",
  40. "data_read",
  41. "shape_read",
  42. "value_read",
  43. "device",
  44. "dtype",
  45. "shape",
  46. "bound_data",
  47. # resources for execution
  48. "varnode",
  49. "data_setter",
  50. "shape_reader",
  51. "value_reader",
  52. "data_reader",
  53. )
  54. def __init__(self):
  55. self.exported = None
  56. self.data_read = None
  57. self.shape_read = None
  58. self.value_read = None
  59. self.bound_data = None
  60. self.data_setter = None
  61. self.shape_reader = None
  62. self.value_reader = None
  63. self.data_reader = None
  64. class trace:
  65. def __new__(cls, *args, **kwargs):
  66. if not args:
  67. return functools.partial(cls, **kwargs)
  68. return super().__new__(cls)
  69. def __init__(
  70. self,
  71. function,
  72. symbolic=False,
  73. capture_as_const=False,
  74. sublinear_memory_config: SublinearMemoryConfig = None,
  75. profiling: bool = False,
  76. ):
  77. self.__wrapped__ = function
  78. self._symbolic = symbolic
  79. self._capture_as_const = capture_as_const
  80. self._sublinear_memory_config = sublinear_memory_config
  81. self._profiling = profiling
  82. self._profiler = None
  83. self._untraced = True
  84. self._tinfo = [] # handle -> TensorInfo
  85. self._seq = []
  86. self._pc = 0
  87. self._graph = None
  88. self._need_reset_nodes = None
  89. self._lazy_eval_graph = None
  90. self._lazy_eval_tensors = weakref.WeakSet()
  91. self._active_tensors = weakref.WeakSet()
  92. self._tensor_remaps = None
  93. self._inputs_to_restore = None
  94. self._arg_bindings = None
  95. self._kwarg_bindings = None
  96. self._output_bindings = None
  97. self._output_names = None
  98. def _new_handle(self):
  99. handle = len(self._tinfo)
  100. info = TensorInfo()
  101. self._tinfo.append(info)
  102. return handle, info
  103. def _apply_op(self, op, args):
  104. assert not self._untraced
  105. # check against trace
  106. if self._pc >= len(self._seq):
  107. raise TraceMismatchError("trace should end here, but more op observed")
  108. record = self._seq[self._pc]
  109. op_, ihandles, ohandles = record
  110. if op != op_:
  111. # FIXME: will be removed once better rng implementation is done
  112. if isinstance(op, OprAttr) and (
  113. op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type
  114. ):
  115. if op.param[8:] != op_.param[8:]:
  116. raise TraceMismatchError("op different from last time")
  117. else:
  118. raise TraceMismatchError("op different from last time")
  119. if len(ihandles) != len(args):
  120. raise TraceMismatchError("op input size different from last time")
  121. for h, x in zip(ihandles, args):
  122. info = self._tinfo[h]
  123. if info.external:
  124. if (
  125. x.__class__ is CompiledTensorProxy
  126. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  127. ):
  128. raise TraceMismatchError(
  129. "failed to capture: input was an external tensor "
  130. "last time, got an internal tensor this time"
  131. )
  132. if info.bound_data:
  133. if x.__class__ is CompiledTensorProxy:
  134. raise TraceMismatchError(
  135. "const capture violated: was an external tensor "
  136. "last time, got an internal tensor this time"
  137. )
  138. if x._handle != info.bound_data._handle:
  139. if not np.array_equal(
  140. x.numpy(), info.bound_data.numpy(), equal_nan=True
  141. ):
  142. raise TraceMismatchError(
  143. "const capture violated: got "
  144. "a different tensor this time"
  145. )
  146. else:
  147. if info.dtype != x.dtype:
  148. raise TraceMismatchError(
  149. "failed to capture: different dtype from last time"
  150. )
  151. if info.device != x.device:
  152. raise TraceMismatchError(
  153. "failed to capture: different device from last time"
  154. )
  155. info.data_setter.set_value(x._dev_tensor())
  156. else:
  157. if x.__class__ is not CompiledTensorProxy:
  158. if x not in self._tensor_remaps:
  159. raise TraceMismatchError(
  160. "unexpected capture: trying to use an external tensor as "
  161. "input, but that input was an internal tensor last time"
  162. )
  163. else:
  164. x = self._tensor_remaps[x]
  165. if x._CompiledTensorProxy__handle != h:
  166. raise TraceMismatchError(
  167. "mis-wiring: input edge to an data flow "
  168. "graph node is different from last time"
  169. )
  170. self._pc += 1
  171. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  172. self._active_tensors.update(outputs)
  173. return outputs
  174. def _record_op(self, op, inputs, outputs):
  175. if skip_tracing:
  176. for x in inputs:
  177. h = getattr(x, "_TraceMixin__handle", None)
  178. if h is not None:
  179. self._tinfo[h].data_read = True
  180. return
  181. ihandles = []
  182. for x in inputs:
  183. h = getattr(x, "_TraceMixin__handle", None)
  184. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  185. h, info = self._new_handle()
  186. info.external = True
  187. info.device = x.device
  188. info.dtype = x.dtype
  189. info.shape = x.shape
  190. if self._capture_as_const:
  191. info.bound_data = x
  192. ihandles.append(h)
  193. ohandles = []
  194. for x in outputs:
  195. h, info = self._new_handle()
  196. ohandles.append(h)
  197. info.external = False
  198. TraceMixin._TraceMixin__inject(x, h)
  199. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  200. self._active_tensors.update(outputs)
  201. def _record_const(self, op, outputs):
  202. pass
  203. @contextlib.contextmanager
  204. def _setup(self):
  205. global active_trace
  206. if active_trace:
  207. raise NotImplementedError("sorry, not implemented: nested trace")
  208. active_trace = self
  209. if self._untraced:
  210. apply.enable(apply_with_tracing)
  211. apply.enable(apply_const_with_tracing)
  212. if self._symbolic:
  213. apply.enable(apply_symbolic_mode)
  214. apply.enable(apply_const_symbolic_mode)
  215. self._lazy_eval_graph = G.Graph()
  216. else:
  217. apply.enable(apply_compiled_mode)
  218. if self._graph is None:
  219. self._compile()
  220. self._graph.execute()
  221. yield
  222. escaped_tensors = tuple(self._active_tensors)
  223. self._active_tensors.clear()
  224. if self._untraced:
  225. for x in escaped_tensors:
  226. info = self._tinfo[x._TraceMixin__handle]
  227. info.data_read = True
  228. x._TraceMixin__restore()
  229. if self._inputs_to_restore:
  230. for x in self._inputs_to_restore:
  231. x._TraceMixin__restore()
  232. if self._symbolic:
  233. # eval lazy eval tensors
  234. lazy_eval_tensors = tuple(self._lazy_eval_tensors)
  235. if lazy_eval_tensors:
  236. readers = [
  237. G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  238. for x in lazy_eval_tensors
  239. ]
  240. self._apply_graph_options(self._lazy_eval_graph)
  241. self._lazy_eval_graph.compile(*readers)
  242. self._lazy_eval_graph()
  243. for r, x in zip(readers, lazy_eval_tensors):
  244. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  245. self._lazy_eval_graph = None
  246. self._lazy_eval_tensors = None
  247. self._untraced = False
  248. else:
  249. if self._pc != len(self._seq):
  250. raise TraceMismatchError("premature end")
  251. for x in escaped_tensors:
  252. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  253. self._graph.wait()
  254. self._reset_exec_env()
  255. self._pc = 0
  256. self._tensor_remaps = None
  257. apply.disable(apply_with_tracing)
  258. apply.disable(apply_const_with_tracing)
  259. apply.disable(apply_symbolic_mode)
  260. apply.disable(apply_const_symbolic_mode)
  261. apply.disable(apply_compiled_mode)
  262. active_trace = None
  263. def _begin_excluded_region(self):
  264. if self._capture_as_const:
  265. raise RuntimeError(
  266. "exclude_from_trace cannot be used with capture_as_const"
  267. )
  268. if self._untraced:
  269. # conditionally reading a compiled tensor in excluded region
  270. # is permitted, so we have to assume every tensor might be read
  271. for x in self._active_tensors:
  272. info = self._tinfo[x._TraceMixin__handle]
  273. info.exported = True
  274. info.data_read = True
  275. def _apply_graph_options(self, graph):
  276. graph.options.seq_opt.enable_seq_comp_node_opt = False
  277. # sublinear
  278. if self._sublinear_memory_config is not None:
  279. graph.options.enable_sublinear_memory_opt = True
  280. sublinear_config = graph.options.sublinear_mem_config
  281. sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
  282. sublinear_config.genetic_nr_iter = (
  283. self._sublinear_memory_config.genetic_nr_iter
  284. )
  285. sublinear_config.genetic_pool_size = (
  286. self._sublinear_memory_config.genetic_pool_size
  287. )
  288. sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
  289. sublinear_config.num_worker = self._sublinear_memory_config.num_worker
  290. if self._profiling:
  291. self._profiler = GraphProfiler(graph)
  292. def _compile(self):
  293. graph = self._graph = G.Graph()
  294. graph.options.no_force_inplace = True
  295. self._apply_graph_options(graph)
  296. # graph.options.graph_opt_level = 0
  297. need_reset_nodes = self._need_reset_nodes = []
  298. # links enforce ordering of I/O nodes
  299. links = ()
  300. readers = []
  301. if self._capture_as_const:
  302. for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
  303. info = self._tinfo[h]
  304. opnode = info.data_setter = G.InputNode(
  305. device=info.device, dtype=info.dtype, shape=info.shape, graph=graph
  306. )
  307. need_reset_nodes.append(opnode)
  308. info.varnode = opnode.outputs[0]
  309. links += opnode.outputs[1:]
  310. for op, ihandles, ohandles in self._seq:
  311. ivars = []
  312. for h in ihandles:
  313. info = self._tinfo[h]
  314. if not hasattr(info, "varnode"):
  315. assert info.external
  316. if info.bound_data:
  317. info.varnode = graph.make_const(info.bound_data._dev_tensor())
  318. else:
  319. opnode = info.data_setter = G.InputNode(
  320. *links,
  321. device=info.device,
  322. dtype=info.dtype,
  323. shape=info.shape,
  324. graph=graph,
  325. )
  326. need_reset_nodes.append(opnode)
  327. info.varnode, *links = opnode.outputs
  328. ivars.append(info.varnode)
  329. ovars = apply(op, *ivars)
  330. assert len(ovars) == len(ohandles)
  331. for h, v in zip(ohandles, ovars):
  332. info = self._tinfo[h]
  333. info.varnode = v
  334. def add_reader(opnode):
  335. nonlocal links
  336. need_reset_nodes.append(opnode)
  337. readers.append(opnode.outputs[0])
  338. links = opnode.outputs
  339. if info.data_read:
  340. # Shape can be obtained from data so doesn't need its own
  341. # output node. On the other hand, value is read separately
  342. # to leverage eager h2d copy
  343. info.shape_read = False
  344. opnode = info.data_reader = G.OutputNode(v, *links)
  345. add_reader(opnode)
  346. if info.value_read:
  347. opnode = info.value_reader = G.ValueOutputNode(v, *links)
  348. add_reader(opnode)
  349. if info.shape_read:
  350. opnode = info.shape_reader = G.AttrOutputNode(v, *links)
  351. add_reader(opnode)
  352. graph.compile(*readers)
  353. def _reset_exec_env(self):
  354. for opnode in self._need_reset_nodes:
  355. opnode.reset()
  356. def _require_shape(self, handle):
  357. info = self._tinfo[handle]
  358. info.shape_read = True
  359. def _require_value(self, handle):
  360. info = self._tinfo[handle]
  361. info.value_read = True
  362. def _require_data(self, handle):
  363. info = self._tinfo[handle]
  364. info.data_read = True
  365. def __call__(self, *args, **kwargs):
  366. with self._setup():
  367. if self._capture_as_const:
  368. self._process_inputs(*args, **kwargs)
  369. outputs = self.__wrapped__(*args, **kwargs)
  370. if self._capture_as_const:
  371. self._process_outputs(outputs)
  372. return outputs
  373. def dump(self, file, *, arg_names=None, output_names=None):
  374. if not self._capture_as_const:
  375. raise ValueError(
  376. "you must specify capture_as_const=True at __init__ to use dump"
  377. )
  378. if self._untraced:
  379. raise RuntimeError("should run at least once before calling dump")
  380. if self._output_names and output_names:
  381. raise TypeError(
  382. "cannot specify output_names when output is already in dict format"
  383. )
  384. if output_names and not isinstance(output_names, collections.Sequence):
  385. output_names = (output_names,)
  386. if output_names and len(output_names) != len(self._output_bindings):
  387. raise ValueError(
  388. "wrong number of output_names, should be {} values".format(
  389. len(self._output_bindings)
  390. )
  391. )
  392. if arg_names and not isinstance(arg_names, collections.Sequence):
  393. arg_names = (arg_names,)
  394. if arg_names and len(arg_names) != len(self._arg_bindings):
  395. raise ValueError(
  396. "wrong number of arg_names, should be {} values".format(
  397. len(self._arg_bindings)
  398. )
  399. )
  400. output_names = output_names or self._output_names
  401. h2v = {}
  402. graph = G.Graph()
  403. for i, h in enumerate(self._arg_bindings):
  404. info = self._tinfo[h]
  405. h2v[h] = graph.make_h2d(
  406. dtype=info.dtype,
  407. device=info.device,
  408. shape=info.shape,
  409. name=arg_names[i] if arg_names else None,
  410. )
  411. for k, h in self._kwarg_bindings.items():
  412. info = self._tinfo[h]
  413. h2v[h] = graph.make_h2d(
  414. dtype=info.dtype, device=info.device, shape=info.shape, name=k
  415. )
  416. for op, ihandles, ohandles in self._seq:
  417. ivars = []
  418. for h in ihandles:
  419. info = self._tinfo[h]
  420. if h not in h2v:
  421. assert info.external
  422. assert info.bound_data
  423. h2v[h] = graph.make_const(info.bound_data._dev_tensor())
  424. ivars.append(h2v[h])
  425. ovars = apply(op, *ivars)
  426. assert len(ovars) == len(ohandles)
  427. h2v.update(zip(ohandles, ovars))
  428. dest_vars = []
  429. for i, h in enumerate(self._output_bindings):
  430. v = h2v[h]
  431. if output_names:
  432. v.name = output_names[i]
  433. dest_vars.append(v)
  434. if isinstance(file, str):
  435. file = open(file, "wb")
  436. file.write(G.dump(*dest_vars))
  437. def _process_inputs(self, *args, **kwargs):
  438. if self._untraced:
  439. self._inputs_to_restore = []
  440. def record_input(x):
  441. if x is None:
  442. return
  443. h, info = self._new_handle()
  444. info.external = False
  445. info.device = x.device
  446. info.dtype = x.dtype
  447. info.shape = x.shape
  448. TraceMixin._TraceMixin__inject(x, h)
  449. self._inputs_to_restore.append(x)
  450. return h
  451. self._arg_bindings = []
  452. for i, x in enumerate(args):
  453. x = find_raw_tensor(x)
  454. if x is None:
  455. raise TypeError(
  456. "positional arguments should all be tensor "
  457. "but args[%d] cannot be recognized as one" % i
  458. )
  459. self._arg_bindings.append(record_input(x))
  460. self._kwarg_bindings = {}
  461. for k, x in kwargs.items():
  462. x = find_raw_tensor(x)
  463. if x is not None:
  464. self._kwarg_bindings[k] = record_input(x)
  465. else:
  466. if len(args) != len(self._arg_bindings):
  467. raise TraceMismatchError("positional argument length mismatch")
  468. self._tensor_remaps = {}
  469. for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
  470. x = find_raw_tensor(x)
  471. if x is None:
  472. raise TypeError(
  473. "positional arguments should all be tensor "
  474. "but args[%d] cannot be recognized as one" % i
  475. )
  476. info = self._tinfo[h]
  477. if x.dtype != info.dtype:
  478. raise TypeError("args[%d].dtype different from last time" % i)
  479. if x.device != info.device:
  480. raise TypeError("args[%d].device different from last time" % i)
  481. info.data_setter.set_value(x._dev_tensor())
  482. self._tensor_remaps[x] = CompiledTensorProxy(h)
  483. kwargs_tensors = {}
  484. for k, x in kwargs.items():
  485. x = find_raw_tensor(x)
  486. if x is not None:
  487. kwargs_tensors[k] = x
  488. if set(kwargs_tensors) != set(self._kwarg_bindings):
  489. too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
  490. too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
  491. if too_many:
  492. raise TraceMismatchError(
  493. "keyword arguments found to be tensor this time "
  494. "but were non-tensor previously: %s" % " ".join(too_many)
  495. )
  496. if too_few:
  497. raise TraceMismatchError(
  498. "keyword arguments found to be non-tensor this time "
  499. "but were tensor previously: %s" % " ".join(too_few)
  500. )
  501. for k, h in self._kwarg_bindings.items():
  502. x = kwargs_tensors[k]
  503. info = self._tinfo[h]
  504. if x.dtype != info.dtype:
  505. raise TypeError("kwargs[%s].dtype different from last time" % k)
  506. if x.device != info.device:
  507. raise TypeError("kwargs[%s].device different from last time" % k)
  508. info.data_setter.set_value(x._dev_tensor())
  509. self._tensor_remaps[x] = CompiledTensorProxy(h)
  510. def _process_outputs(self, outputs):
  511. output_names = None
  512. if isinstance(outputs, collections.Mapping):
  513. output_names, outputs = zip(*sorted(outputs.items()))
  514. elif not isinstance(outputs, collections.Sequence):
  515. outputs = (outputs,)
  516. if not self._untraced:
  517. if output_names != self._output_names:
  518. too_many = set(output_names) - set(self._output_names)
  519. too_few = set(self._output_names) - set(output_names)
  520. if too_many:
  521. raise TraceMismatchError(
  522. "output has more keys than last time: %s" % " ".join(too_many)
  523. )
  524. if too_few:
  525. raise TraceMismatchError(
  526. "output has less keys than last time: %s" % " ".join(too_few)
  527. )
  528. if len(outputs) != len(self._output_bindings):
  529. raise TraceMismatchError("output size differs from last time")
  530. else:
  531. self._output_names = output_names
  532. self._output_bindings = []
  533. for i, x in enumerate(outputs):
  534. x = find_raw_tensor(x)
  535. if x is None:
  536. raise TypeError("every item of return value should be tensor")
  537. if self._untraced:
  538. if not isinstance(x, TraceMixin):
  539. raise RuntimeError("output is not computed from inputs")
  540. h = x._TraceMixin__handle
  541. self._output_bindings.append(h)
  542. else:
  543. if not isinstance(x, CompiledTensorProxy):
  544. raise RuntimeError("output is not computed from inputs")
  545. h = x._CompiledTensorProxy__handle
  546. if h != self._output_bindings[i]:
  547. raise TraceMismatchError(
  548. "retval[%s] is a different tensor than last time"
  549. % (output_names and output_names[i] or i)
  550. )
  551. def get_profile(self):
  552. """
  553. Get profiling result for compiled trace.
  554. :return: a json compatible object.
  555. """
  556. if not self._profiler:
  557. raise RuntimeError("trace is not set with profiling=True")
  558. return json.loads(self._profiler.get())
  559. class CompiledTensorProxy(RawTensor):
  560. """
  561. Duck-typed RawTensor
  562. """
  563. def __init__(self, handle):
  564. self.__handle = handle
  565. self.__info = active_trace._tinfo[handle]
  566. self.__shape = None
  567. self.__data = None
  568. self.__value = None
  569. @property
  570. def dtype(self):
  571. return self.__info.varnode.dtype
  572. @property
  573. def device(self):
  574. return self.__info.varnode.device
  575. @property
  576. def shape(self):
  577. if self.__shape is None:
  578. if self.__info.shape_read:
  579. self.__shape = self.__info.shape_reader.get_value().shape
  580. elif self.__info.data_read:
  581. self.__shape = self._dev_tensor().shape
  582. else:
  583. raise TraceMismatchError("shape of this tensor is not read in trace")
  584. return self.__shape
  585. def numpy(self):
  586. if self.__value is None:
  587. if self.__info.value_read:
  588. self.__value = self.__info.value_reader.get_value()
  589. elif self.__info.data_read:
  590. self.__value = self._dev_tensor().numpy()
  591. else:
  592. raise TraceMismatchError("value of this tensor is not read in trace")
  593. return self.__value
  594. def _dev_tensor(self):
  595. if self.__data is None:
  596. if not self.__info.data_read:
  597. raise TraceMismatchError("raw data of this tensor is not read in trace")
  598. self.__data = self.__info.data_reader.get_value()
  599. return self.__data
  600. def __del__(self):
  601. if self.__info.shape_read and self.__shape is not None:
  602. self.__info.shape_reader.drop_value()
  603. if self.__info.value_read and self.__value is not None:
  604. self.__info.value_reader.drop_value()
  605. if self.__info.data_read and self.__data is not None:
  606. self.__info.data_reader.drop_value()
  607. class LazyEvalTensor(RawTensor):
  608. def __init__(self, varnode):
  609. self.__varnode = varnode
  610. @property
  611. def dtype(self):
  612. return self.__varnode.dtype
  613. @property
  614. def device(self):
  615. return self.__varnode.device
  616. @property
  617. def shape(self):
  618. return self.__varnode.shape
  619. def numpy(self):
  620. return self.__varnode.value
  621. def _dev_tensor(self):
  622. raise RuntimeError("cannot access data during symbolic tracing")
  623. class TraceMixin:
  624. __subclass_cache = {}
  625. def __inject(self, handle):
  626. cache = __class__.__subclass_cache
  627. cls = self.__class__
  628. subcls = cache.get(cls)
  629. if subcls is None:
  630. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  631. self.__class__ = subcls
  632. self.__handle = handle
  633. self.__cls = cls
  634. return self
  635. def __restore(self):
  636. cls = self.__cls
  637. del self.__handle
  638. del self.__cls
  639. self.__class__ = cls
  640. return self
  641. @property
  642. def shape(self):
  643. if not skip_tracing:
  644. active_trace._require_shape(self.__handle)
  645. return super().shape
  646. def numpy(self):
  647. if not skip_tracing:
  648. active_trace._require_value(self.__handle)
  649. return super().numpy()
  650. def _dev_tensor(self):
  651. if not skip_tracing:
  652. active_trace._require_data(self.__handle)
  653. return super()._dev_tensor()
  654. class TracedRawTensor(TraceMixin, RawTensor):
  655. pass
  656. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  657. pass
  658. def assign_raw_tensor(lhs, rhs):
  659. handle = rhs._handle
  660. rhs.__dict__.clear()
  661. lhs.__dict__.clear()
  662. lhs.__class__ = RawTensor
  663. lhs.__init__(handle)
  664. # this hook turns RawTensor into LazyEvalTensor
  665. @apply.register()
  666. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  667. graph = active_trace._lazy_eval_graph
  668. ivars = [
  669. getattr(x, "_LazyEvalTensor__varnode", None)
  670. or graph.make_const(x._dev_tensor())
  671. for x in args
  672. ]
  673. ovars = apply(op, *ivars)
  674. outputs = [LazyEvalTensor(v) for v in ovars]
  675. active_trace._lazy_eval_tensors.update(outputs)
  676. return outputs
  677. apply.disable(apply_symbolic_mode)
  678. @apply.register()
  679. def apply_const_symbolic_mode(op: Const, *args: RawTensor):
  680. graph = active_trace._lazy_eval_graph
  681. ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
  682. active_trace._lazy_eval_tensors.add(ret)
  683. return (ret,)
  684. apply.disable(apply_const_symbolic_mode)
  685. @apply.register()
  686. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  687. if skip_tracing:
  688. args = [
  689. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  690. for x in args
  691. ]
  692. return apply.super(op, *args)
  693. return active_trace._apply_op(op, args)
  694. apply.disable(apply_compiled_mode)
  695. # this hook injects TraceMixin
  696. @apply.register()
  697. def apply_with_tracing(op: OpDef, *args: RawTensor):
  698. outputs = apply.super(op, *args)
  699. active_trace._record_op(op, args, outputs)
  700. return outputs
  701. apply.disable(apply_with_tracing)
  702. @apply.register()
  703. def apply_const_with_tracing(op: Const, *args: RawTensor):
  704. outputs = apply.super(op, *args)
  705. active_trace._record_const(op, outputs)
  706. return outputs
  707. apply.disable(apply_const_with_tracing)
  708. class BrokenRawTensor(RawTensor):
  709. def __getattribute__(self, _):
  710. raise RuntimeError("broken due to misuse of tracing")
  711. def __setattr__(self, *_):
  712. raise RuntimeError("broken due to misuse of tracing")
  713. @functools.singledispatch
  714. def find_raw_tensor(x):
  715. return None
  716. @find_raw_tensor.register(RawTensor)
  717. def _(x):
  718. return x
  719. @find_raw_tensor.register(TensorWrapperBase)
  720. def _(x):
  721. x = getattr(x, "__wrapped__", None)
  722. if x is not None:
  723. return find_raw_tensor(x)
  724. @find_raw_tensor.register(Tensor)
  725. def _(x):
  726. x = getattr(x, "_data", None)
  727. if x is not None:
  728. return find_raw_tensor(x)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台