You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.i 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. /*
  2. * $File: callback.i
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * $Copyright: Copyright (c) 2014-2017 Megvii Inc. All rights reserved.
  7. */
  8. %feature("autodoc",
  9. """It is used to be passed as arguments to callbacks (used in
  10. :meth:`.CompGraph.compile`, :func:`.callback_injector`, and
  11. :meth:`.CraniotomeBase.execute`). Object of this type could also be directly
  12. passed to :meth:`.SharedND.set_value`, to bypass some host and device
  13. communication. Note that the underlying buffer may be reused after the callback
  14. returns, so reference to this object should not be passed outside of the
  15. callback, and :meth:`get_value` should be called immediately if the numerical
  16. value is needed.""")
  17. CompGraphCallbackValueProxy;
  18. class CompGraphCallbackValueProxy {
  19. public:
  20. PyObject* _get_npyarr();
  21. PyObject* _get_dtype();
  22. std::vector<size_t> _get_shape();
  23. uintptr_t _pubapi_dev_tensor_ptr(int version);
  24. CompNode _get_comp_node();
  25. %pythoncode{
  26. @property
  27. def shape(self):
  28. """get shape of the var
  29. :type: tuple of int
  30. """
  31. return tuple(map(int, self._get_shape()))
  32. @property
  33. def comp_node(self):
  34. """get comp node of the var
  35. :type: :class:`.CompNode`
  36. """
  37. return self._get_comp_node()
  38. @property
  39. def dtype(self):
  40. """get data type of the var
  41. :type: :class:`.numpy.dtype`
  42. """
  43. return self._get_dtype()
  44. def get_value(self, *, borrow_mem=False):
  45. """get value as numpy array
  46. :param borrow_mem: whether to forward internal buffer with
  47. zero-copy; if True, the content in returned buffer would be
  48. modified directly by asynchronous graph execution.
  49. """
  50. ret = self._get_npyarr()
  51. if not borrow_mem:
  52. ret = ret.copy()
  53. return ret
  54. @property
  55. def dev_ptr(self):
  56. """this method is DEPRECATED; use :meth:`pubapi_dev_tensor_ptr`
  57. instead"""
  58. return self._pubapi_dev_tensor_ptr(0)
  59. @property
  60. def pubapi_dev_tensor_ptr(self):
  61. """get a pointer to the corresponding mgb::pubapi::DeviceTensor object
  62. :rtype: int
  63. :return: the address as an integer
  64. """
  65. return self._pubapi_dev_tensor_ptr(1)
  66. }
  67. };
  68. %template(_VectorCompGraphCallbackValueProxy)
  69. std::vector<CompGraphCallbackValueProxy>;
  70. %feature("director") _CompGraphCallback;
  71. class _CompGraphCallback {
  72. public:
  73. _CompGraphCallback();
  74. void set_eager_copy(bool flag);
  75. virtual ~_CompGraphCallback();
  76. virtual void call(std::vector<CompGraphCallbackValueProxy> &value) = 0;
  77. };
  78. %feature("director") _SplitPartCallback;
  79. class _SplitPartCallback {
  80. public:
  81. _SplitPartCallback();
  82. virtual ~_SplitPartCallback();
  83. virtual std::vector<size_t> call(size_t tot_size) = 0;
  84. };
  85. %feature("director") _SetGradCallback;
  86. class _SetGradCallback {
  87. public:
  88. _SetGradCallback();
  89. virtual ~_SetGradCallback();
  90. virtual SymbolVar call(CompGraph &graph) = 0;
  91. virtual bool empty() = 0;
  92. };
  93. %feature("director") _TimeoutCallback;
  94. class _TimeoutCallback {
  95. public:
  96. _TimeoutCallback();
  97. virtual ~_TimeoutCallback();
  98. virtual bool call() = 0;
  99. };
  100. %pythoncode{
  101. import collections
  102. import inspect
  103. from .mgb_helper import callback_lazycopy
  104. class _CompGraphCallbackPyWrapper(_CompGraphCallback):
  105. """wraps around a callable to be used as comp graph callback"""
  106. def __init__(self, f):
  107. super().__init__()
  108. if isinstance(f, callback_lazycopy):
  109. f = f.func
  110. self.set_eager_copy(False)
  111. else:
  112. self.set_eager_copy(True)
  113. assert isinstance(f, collections.Callable)
  114. self._func = f
  115. self.__disown__()
  116. def call(self, value):
  117. if value.size() == 1:
  118. self._func(value[0])
  119. else:
  120. self._func(value)
  121. _CompGraphCallbackPyWrapperNoEager = lambda f: (
  122. _CompGraphCallbackPyWrapper(callback_lazycopy(f)))
  123. class _SplitPartCallbackPyWrapper(_SplitPartCallback):
  124. def __init__(self, f):
  125. super().__init__()
  126. assert isinstance(f, collections.Callable)
  127. self._func = f
  128. self.__disown__()
  129. def call(self, size):
  130. return tuple(map(int, self._func(size)))
  131. class _SetGradCallbackPyWrapper(_SetGradCallback):
  132. def __init__(self, f):
  133. super().__init__()
  134. if f is None:
  135. self._func = None
  136. else:
  137. assert isinstance(f, collections.Callable)
  138. nr_arg = len(list(filter(
  139. lambda x: (
  140. x.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
  141. x.default == inspect.Parameter.empty),
  142. inspect.signature(f).parameters.values())))
  143. if not nr_arg:
  144. f = lambda graph, f0=f: f0()
  145. else:
  146. assert nr_arg == 1, 'bad callback for SetGrad: {}'.format(f)
  147. self._func = f
  148. self.__disown__()
  149. def call(self, graph):
  150. if self._func is None:
  151. return SymbolVar()
  152. ret = self._func(graph)
  153. if ret is None:
  154. ret = SymbolVar()
  155. else:
  156. assert isinstance(ret, SymbolVar), (
  157. 'bad return value for var maker: {!r}'.format(ret))
  158. return ret
  159. def empty(self):
  160. return self._func is None
  161. class _TimeoutCallbackPyWrapper(_TimeoutCallback):
  162. def __init__(self, f):
  163. super().__init__()
  164. assert isinstance(f, collections.Callable)
  165. self._func = f
  166. self.__disown__()
  167. def call(self):
  168. return bool(self._func())
  169. } // %pythoncode
  170. // vim: ft=swig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台