You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tracing.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. import contextlib
  2. import functools
  3. import typing
  4. import weakref
  5. from ..core.ops.special import Const
  6. from ..core.tensor import megbrain_graph as G
  7. from ..core.tensor.core import OpBase, apply
  8. from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
  9. class TraceMismatchError(RuntimeError):
  10. pass
  11. active_trace = None
  12. skip_tracing = False
  13. @contextlib.contextmanager
  14. def exclude_from_trace():
  15. global skip_tracing
  16. if skip_tracing:
  17. yield
  18. return
  19. try:
  20. skip_tracing = True
  21. if active_trace is not None:
  22. active_trace._begin_excluded_region()
  23. yield
  24. finally:
  25. skip_tracing = False
  26. class TensorInfo:
  27. __slots__ = (
  28. # collected attributes
  29. "external",
  30. "exported",
  31. "data_read",
  32. "shape_read",
  33. "value_read",
  34. "device",
  35. "dtype",
  36. "bound_data",
  37. # resources for execution
  38. "varnode",
  39. "data_setter",
  40. "shape_reader",
  41. "value_reader",
  42. "data_reader",
  43. )
  44. def __init__(self):
  45. self.exported = None
  46. self.data_read = None
  47. self.shape_read = None
  48. self.value_read = None
  49. self.bound_data = None
  50. self.data_setter = None
  51. self.shape_reader = None
  52. self.value_reader = None
  53. self.data_reader = None
  54. class trace:
  55. def __new__(cls, *args, **kwargs):
  56. if not args:
  57. return functools.partial(cls, **kwargs)
  58. self = super().__new__(cls)
  59. self.__init__(*args, **kwargs)
  60. return self
  61. def __init__(self, function, symbolic=False, capture_as_const=False):
  62. self.__wrapped__ = function
  63. self._symbolic = symbolic
  64. self._capture_as_const = capture_as_const
  65. self._capture_static_shape = False
  66. self._untraced = True
  67. self._tinfo = [] # handle -> TensorInfo
  68. self._seq = []
  69. self._pc = 0
  70. self._graph = None
  71. self._need_reset_nodes = None
  72. self._lazy_eval_graph = None
  73. self._lazy_eval_tensors = weakref.WeakSet()
  74. self._active_tensors = weakref.WeakSet()
  75. def _new_handle(self):
  76. handle = len(self._tinfo)
  77. info = TensorInfo()
  78. self._tinfo.append(info)
  79. return handle, info
  80. def _apply_op(self, op, args):
  81. assert not self._untraced
  82. # check against trace
  83. if self._pc >= len(self._seq):
  84. raise TraceMismatchError("trace should end here, but more op observed")
  85. record = self._seq[self._pc]
  86. op_, ihandles, ohandles = record
  87. if op != op_:
  88. raise TraceMismatchError("op different from last time")
  89. if len(ihandles) != len(args):
  90. raise TraceMismatchError("op input size different from last time")
  91. for h, x in zip(ihandles, args):
  92. info = self._tinfo[h]
  93. if info.external:
  94. if (
  95. x.__class__ is CompiledTensorProxy
  96. and not self._tinfo[x._CompiledTensorProxy__handle].exported
  97. ):
  98. raise TraceMismatchError(
  99. "failed to capture: input was an external tensor "
  100. "last time, got an internal tensor this time"
  101. )
  102. if info.bound_data:
  103. if x.__class__ is CompiledTensorProxy:
  104. raise TraceMismatchError(
  105. "const capture violated: was an external tensor "
  106. "last time, got an internal tensor this time"
  107. )
  108. if x._handle != info.bound_data._handle:
  109. raise TraceMismatchError(
  110. "const capture violated: got "
  111. "a different tensor this time"
  112. )
  113. else:
  114. if info.dtype != x.dtype:
  115. raise TraceMismatchError(
  116. "failed to capture: different dtype from last time"
  117. )
  118. if info.device != x.device:
  119. raise TraceMismatchError(
  120. "failed to capture: different device from last time"
  121. )
  122. info.data_setter.set_value(x._dev_tensor())
  123. else:
  124. if x.__class__ is not CompiledTensorProxy:
  125. raise TraceMismatchError(
  126. "unexpected capture: trying to use an external tensor as input, "
  127. "but that input was an internal tensor last time"
  128. )
  129. if x._CompiledTensorProxy__handle != h:
  130. raise TraceMismatchError(
  131. "mis-wiring: input edge to an data flow "
  132. "graph node is different from last time"
  133. )
  134. self._pc += 1
  135. outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
  136. self._active_tensors.update(outputs)
  137. return outputs
  138. def _record_op(self, op, inputs, outputs):
  139. if skip_tracing:
  140. for x in inputs:
  141. h = getattr(x, "_TraceMixin__handle", None)
  142. if h is not None:
  143. self._tinfo[h].data_read = True
  144. return
  145. ihandles = []
  146. for x in inputs:
  147. h = getattr(x, "_TraceMixin__handle", None)
  148. if h is None or (not self._capture_as_const and self._tinfo[h].exported):
  149. h, info = self._new_handle()
  150. info.external = True
  151. info.device = x.device
  152. info.dtype = x.dtype
  153. if self._capture_as_const:
  154. info.bound_data = x
  155. ihandles.append(h)
  156. ohandles = []
  157. for x in outputs:
  158. h, info = self._new_handle()
  159. ohandles.append(h)
  160. info.external = False
  161. TraceMixin._TraceMixin__inject(x, h)
  162. self._seq.append((op, tuple(ihandles), tuple(ohandles)))
  163. self._active_tensors.update(outputs)
  164. @contextlib.contextmanager
  165. def _setup(self):
  166. global active_trace
  167. if active_trace:
  168. raise NotImplementedError("sorry, not implemented: nested trace")
  169. active_trace = self
  170. if self._untraced:
  171. apply.enable(apply_with_tracing)
  172. if self._symbolic:
  173. apply.enable(apply_symbolic_mode)
  174. self._lazy_eval_graph = G.Graph()
  175. else:
  176. apply.enable(apply_compiled_mode)
  177. if self._graph is None:
  178. self._compile()
  179. self._graph.execute()
  180. yield
  181. escaped_tensors = tuple(self._active_tensors)
  182. self._active_tensors.clear()
  183. if self._untraced:
  184. for x in escaped_tensors:
  185. info = self._tinfo[x._TraceMixin__handle]
  186. info.data_read = True
  187. x._TraceMixin__restore()
  188. if self._symbolic:
  189. # eval lazy eval tensors
  190. lazy_eval_tensors = tuple(self._lazy_eval_tensors)
  191. if lazy_eval_tensors:
  192. readers = [
  193. G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
  194. for x in lazy_eval_tensors
  195. ]
  196. self._lazy_eval_graph.compile(*readers)
  197. self._lazy_eval_graph()
  198. for r, x in zip(readers, lazy_eval_tensors):
  199. assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
  200. self._lazy_eval_graph = None
  201. self._lazy_eval_tensors = None
  202. self._untraced = False
  203. else:
  204. if self._pc != len(self._seq):
  205. raise TraceMismatchError("premature end")
  206. for x in escaped_tensors:
  207. assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
  208. self._graph.wait()
  209. self._reset_exec_env()
  210. self._pc = 0
  211. apply.disable(apply_with_tracing)
  212. apply.disable(apply_symbolic_mode)
  213. apply.disable(apply_compiled_mode)
  214. active_trace = None
  215. def _begin_excluded_region(self):
  216. if self._untraced:
  217. # conditionally reading a compiled tensor in excluded region
  218. # is permitted, so we have to assume every tensor might be read
  219. for x in self._active_tensors:
  220. info = self._tinfo[x._TraceMixin__handle]
  221. info.exported = True
  222. info.data_read = True
  223. def _compile(self):
  224. graph = self._graph = G.Graph()
  225. graph.options.no_force_inplace = True
  226. # graph.options.graph_opt_level = 0
  227. need_reset_nodes = self._need_reset_nodes = []
  228. # links enforce ordering of I/O nodes
  229. links = ()
  230. for op, ihandles, ohandles in self._seq:
  231. ivars = []
  232. readers = []
  233. for h in ihandles:
  234. info = self._tinfo[h]
  235. if not hasattr(info, "varnode"):
  236. assert info.external
  237. if info.bound_data:
  238. info.varnode = graph.make_const(info.bound_data._dev_tensor())
  239. else:
  240. opnode = info.data_setter = G.InputNode(
  241. *links, device=info.device, dtype=info.dtype, graph=graph
  242. )
  243. need_reset_nodes.append(opnode)
  244. info.varnode, *links = opnode.outputs
  245. ivars.append(info.varnode)
  246. ovars = apply(op, *ivars)
  247. assert len(ovars) == len(ohandles)
  248. for h, v in zip(ohandles, ovars):
  249. info = self._tinfo[h]
  250. info.varnode = v
  251. def add_reader(opnode):
  252. nonlocal links
  253. need_reset_nodes.append(opnode)
  254. readers.append(opnode.outputs[0])
  255. links = opnode.outputs
  256. if info.data_read:
  257. # Shape can be obtained from data so doesn't need its own
  258. # output node. On the other hand, value is read separately
  259. # to leverage eager h2d copy
  260. info.shape_read = False
  261. opnode = info.data_reader = G.OutputNode(v, *links)
  262. add_reader(opnode)
  263. if info.value_read:
  264. opnode = info.value_reader = G.ValueOutputNode(v, *links)
  265. add_reader(opnode)
  266. if info.shape_read:
  267. opnode = info.shape_reader = G.AttrOutputNode(v, *links)
  268. add_reader(opnode)
  269. graph.compile(*readers)
  270. def _reset_exec_env(self):
  271. for opnode in self._need_reset_nodes:
  272. opnode.reset()
  273. def _require_shape(self, handle):
  274. info = self._tinfo[handle]
  275. info.shape_read = True
  276. def _require_value(self, handle):
  277. info = self._tinfo[handle]
  278. info.value_read = True
  279. def _require_data(self, handle):
  280. info = self._tinfo[handle]
  281. info.data_read = True
  282. def __call__(self, *args, **kwargs):
  283. with self._setup():
  284. return self.__wrapped__(*args, **kwargs)
  285. class CompiledTensorProxy(RawTensor):
  286. """
  287. Duck-typed RawTensor
  288. """
  289. def __init__(self, handle):
  290. self.__handle = handle
  291. self.__info = active_trace._tinfo[handle]
  292. self.__shape = None
  293. self.__data = None
  294. self.__value = None
  295. @property
  296. def dtype(self):
  297. return self.__info.varnode.dtype
  298. @property
  299. def device(self):
  300. return self.__info.varnode.device
  301. @property
  302. def shape(self):
  303. if self.__shape is None:
  304. if self.__info.shape_read:
  305. self.__shape = self.__info.shape_reader.get_value().shape
  306. elif self.__info.data_read:
  307. self.__shape = self._dev_tensor().shape
  308. else:
  309. raise TraceMismatchError("shape of this tensor is not read in trace")
  310. return self.__shape
  311. def numpy(self):
  312. if self.__value is None:
  313. if self.__info.value_read:
  314. self.__value = self.__info.value_reader.get_value()
  315. elif self.__info.data_read:
  316. self.__value = self._dev_tensor().numpy()
  317. else:
  318. raise TraceMismatchError("value of this tensor is not read in trace")
  319. return self.__value
  320. def _dev_tensor(self):
  321. if self.__data is None:
  322. if not self.__info.data_read:
  323. raise TraceMismatchError("raw data of this tensor is not read in trace")
  324. self.__data = self.__info.data_reader.get_value()
  325. return self.__data
  326. def __del__(self):
  327. if self.__info.shape_read and self.__shape is not None:
  328. self.__info.shape_reader.drop_value()
  329. if self.__info.value_read and self.__value is not None:
  330. self.__info.value_reader.drop_value()
  331. if self.__info.data_read and self.__data is not None:
  332. self.__info.data_reader.drop_value()
  333. class LazyEvalTensor(RawTensor):
  334. def __init__(self, varnode):
  335. self.__varnode = varnode
  336. @property
  337. def dtype(self):
  338. return self.__varnode.dtype
  339. @property
  340. def device(self):
  341. return self.__varnode.device
  342. @property
  343. def shape(self):
  344. return self.__varnode.shape
  345. def numpy(self):
  346. return self.__varnode.value
  347. def _dev_tensor(self):
  348. raise RuntimeError("cannot access data during symbolic tracing")
  349. class TraceMixin:
  350. __subclass_cache = {}
  351. def __inject(self, handle):
  352. cache = __class__.__subclass_cache
  353. cls = self.__class__
  354. subcls = cache.get(cls)
  355. if subcls is None:
  356. subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
  357. self.__class__ = subcls
  358. self.__handle = handle
  359. self.__cls = cls
  360. return self
  361. def __restore(self):
  362. cls = self.__cls
  363. del self.__handle
  364. del self.__cls
  365. self.__class__ = cls
  366. return self
  367. @property
  368. def shape(self):
  369. if not skip_tracing:
  370. active_trace._require_shape(self.__handle)
  371. return super().shape
  372. def numpy(self):
  373. if not skip_tracing:
  374. active_trace._require_value(self.__handle)
  375. return super().numpy()
  376. def _dev_tensor(self):
  377. if not skip_tracing:
  378. active_trace._require_data(self.__handle)
  379. return super()._dev_tensor()
  380. class TracedRawTensor(TraceMixin, RawTensor):
  381. pass
  382. class TracedLazyTensor(TraceMixin, LazyEvalTensor):
  383. pass
  384. def assign_raw_tensor(lhs, rhs):
  385. handle = rhs._handle
  386. rhs.__dict__.clear()
  387. lhs.__dict__.clear()
  388. lhs.__class__ = RawTensor
  389. lhs.__init__(handle)
  390. # this hook turns RawTensor into LazyEvalTensor
  391. @apply.register()
  392. def apply_symbolic_mode(op: OpDef, *args: RawTensor):
  393. graph = active_trace._lazy_eval_graph
  394. ivars = [
  395. getattr(x, "_LazyEvalTensor__varnode", None)
  396. or graph.make_const(x._dev_tensor())
  397. for x in args
  398. ]
  399. ovars = apply(op, *ivars)
  400. outputs = [LazyEvalTensor(v) for v in ovars]
  401. active_trace._lazy_eval_tensors.update(outputs)
  402. return outputs
  403. apply.disable(apply_symbolic_mode)
  404. @apply.register()
  405. def apply_compiled_mode(op: OpDef, *args: RawTensor):
  406. if skip_tracing:
  407. args = [
  408. as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
  409. for x in args
  410. ]
  411. return apply.super(op, *args)
  412. return active_trace._apply_op(op, args)
  413. apply.disable(apply_compiled_mode)
  414. # this hook injects TraceMixin
  415. @apply.register()
  416. def apply_with_tracing(op: OpDef, *args: RawTensor):
  417. outputs = apply.super(op, *args)
  418. active_trace._record_op(op, args, outputs)
  419. return outputs
  420. apply.disable(apply_with_tracing)
  421. # @apply.register()
  422. # def _(op: Const, *args: RawTensor):
  423. # return active_trace._apply_const(op, args)
  424. class BrokenRawTensor(RawTensor):
  425. def __getattribute__(self, _):
  426. raise RuntimeError("broken due to misuse of tracing")
  427. def __setattr__(self, *_):
  428. raise RuntimeError("broken due to misuse of tracing")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台