You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_template.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # -*- coding: utf-8 -*-
  2. # This file is part of MegBrain.
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. """This python module contains functions to apply the operators defined by
  5. megbrain.
  6. .. note::
  7. Most of the functions are automatically generated, and their signature have
  8. the form contain a ``param`` argument (or more than one arguments such as
  9. :func:`convolution` that has ``param`` and ``execution_polity``) and also
  10. accept keyword arguments. In such case, it can be called by either
  11. providing a param object of appropriate type, or by passing the arguments
  12. needed by the constructor of param object to the keyword arguments.
  13. Furthermore, for a param that needs an enumeration member, the enum name
  14. can be used to refer to the enum object.
  15. For example, the following statements are equivalent::
  16. elemwise([a, b], mode='max')
  17. elemwise([a, b], mode=opr_param_defs.Elemwise.Mode.MAX)
  18. elemwise([a, b], param=opr_param_defs.Elemwise('max'))
  19. """
  20. from . import mgb as _mgb
  21. from . import helper as _helper
  22. from . import opr_param_defs as _opr_param_defs
  23. import sys
  24. import enum
  25. import collections
  26. import json
  27. __git_commit__ = "{%git_commit%}"
  28. {%body%}
  29. class _ElemMeta(type):
  30. def __getattr__(self, name):
  31. def run(*inputs, **kwargs):
  32. return elemwise(inputs, mode=name, **kwargs)
  33. if name.startswith('__'):
  34. return
  35. return run
  36. class elem(metaclass=_ElemMeta):
  37. """
  38. Helper class for easily applying element-wise operator. Request for getting
  39. member method would be translated to a call to :func:`elemwise` with mode
  40. set to the method name. Example::
  41. elem.exp(a) # elemwise(a, mode='exp')
  42. elem.max(a, b) # elemwise([a, b], mode='max')
  43. """
  44. def add_update(
  45. dest, delta,
  46. alpha=_mgb.SharedScalar(1), beta=_mgb.SharedScalar(1),
  47. bias=_mgb.SharedScalar(0), disable=_mgb.SharedScalar(0), *,
  48. name=None, comp_node=None, config=None, comp_graph=None):
  49. """update *dest* by `dest := dest*alpha + delta*beta + bias`
  50. :param dest: target var to be updated; must be created from
  51. :func:`make_shared`
  52. :type dest: :class:`.SymbolVar`
  53. :param disable: AddUpdate will not be executed if disable is set to 1,
  54. this is used for dynamic param-updating. The value of this SharedScalar
  55. can only be set to 0/1 of type `int`
  56. :type disable: :class:`.SharedScalar`
  57. """
  58. def as_ss(x):
  59. if not isinstance(x, _mgb.SharedScalar):
  60. x = _mgb.SharedScalar(x)
  61. return x
  62. assert isinstance(dest, _mgb.SymbolVar)
  63. config = _helper.gen_config(name, comp_node, config)
  64. dest, delta = _helper.canonize_input_vars(
  65. [dest, delta], comp_graph=comp_graph, config=config)
  66. assert isinstance(disable, _mgb.SharedScalar)
  67. alpha, beta, bias = map(as_ss, [alpha, beta, bias])
  68. return _mgb._Opr.add_update(dest, delta, alpha, beta, bias, disable, config)
  69. def reduce_(src, mode, axis=None, keepdims=False, *,
  70. name=None, comp_node=None, config=None, comp_graph=None):
  71. """reduce along given axis; if axis is None, reduce to scalar
  72. :param mode: reduction mode
  73. :type mode: :class:`~megengine._internal.opr_param_defs.Reduce.Mode` compatible
  74. :param axis: axis along which to reduce input var
  75. :type axis: int
  76. :param keepdims: whether to keep an axis of shape 1 in the result
  77. :type keepdims: False
  78. """
  79. assert isinstance(src, _mgb.SymbolVar)
  80. config = _helper.gen_config(name, comp_node, config)
  81. inputs = [src]
  82. kwargs = {'mode': mode}
  83. remove_axis = False
  84. if axis is None:
  85. inputs.append(1)
  86. assert not keepdims, 'can not set axis=None and keepdims=True'
  87. else:
  88. remove_axis = not keepdims
  89. kwargs['axis'] = axis
  90. ret = reduce_general(inputs, config=config, comp_graph=comp_graph,
  91. **kwargs)
  92. if remove_axis:
  93. ret = _mgb._Opr.remove_axis(ret, axis, _mgb.make_opr_config())
  94. return _helper.cvt_opr_result(ret)
  95. def _reduce_like(impl, src, axis, keepdims,
  96. name, comp_node, config, comp_graph):
  97. config = _helper.gen_config(name, comp_node, config)
  98. remove_axis = False
  99. if axis is None:
  100. assert not keepdims, 'can not set axis=None and keepdims=True'
  101. src = src.flatten()
  102. axis = 0
  103. else:
  104. assert isinstance(axis, int) and axis >= 0, (
  105. 'bad axis: {!r}'.format(axis))
  106. remove_axis = not keepdims
  107. ret = impl(src, axis=axis, config=config, comp_graph=comp_graph)
  108. if remove_axis:
  109. ret = _mgb._Opr.remove_axis(ret, axis, _mgb.make_opr_config())
  110. return _helper.cvt_opr_result(ret)
  111. def dimshuffle(src, pattern, ndim=0, *,
  112. name=None, comp_node=None, config=None):
  113. """swap shapes and strides according to given pattern
  114. :param pattern: a list of integers, where each element is the input axis of
  115. that output axis. An element could also be 'x' for creating a new axis
  116. with shape 1
  117. :param ndim: number of input dimensions; 0 to be inferred from pattern;
  118. this is required only for grad
  119. """
  120. config = _helper.gen_config(name, comp_node, config)
  121. if not isinstance(pattern, (list, tuple)):
  122. raise TypeError('could not convert {} to dimshuffle pattern'.format(
  123. pattern))
  124. pattern_mgb = _mgb._VectorInt()
  125. for i in pattern:
  126. if i == 'x':
  127. pattern_mgb.push_back(-1)
  128. else:
  129. i = int(i)
  130. assert i >= 0
  131. pattern_mgb.push_back(i)
  132. return _mgb._Opr.dimshuffle(src, pattern_mgb, int(ndim), config)
  133. def param_pack_split(src, shapes, *,
  134. name=None, comp_node=None, config=None):
  135. """
  136. split param into a list of tensor for given shape
  137. ParamPackSplit operator has a input: ``src`` and would
  138. have a ``output``. output[i] indicates the address of tensor which part of
  139. ``src`` would transfer its elements into.
  140. Example: a input tensor with size 32, the shapes: ``[(1, 2, 4), (4, 2, 2),
  141. (4, 2, 1)]``, the output tensor would be a list of address with size 3.
  142. output[0] indicates the address of tensor with shapes[0]:(1, 2, 4),
  143. output[1] indicates the address of tensor with shapes[1]:(4, 2, 2),
  144. output[2] indicates the address of tensor with shapes[2]:(4, 2, 1).
  145. :param src: The concatenated input tensor.
  146. :type src: :class:`SymbolVar`
  147. :param shapes: Shapes of output tensors
  148. :type shapes: list of list of int
  149. """
  150. config = _helper.gen_config(name, comp_node, config)
  151. if not isinstance(shapes, (list, tuple)):
  152. raise TypeError('could not convert {} to tensor shapes'.format(
  153. shapes))
  154. shapes_mgb = _mgb._VectorTensorShape()
  155. for s in shapes:
  156. s = tuple(map(int, s))
  157. assert min(s) > 0
  158. shapes_mgb.push_back(s)
  159. return _mgb._Opr.param_pack_split(src, shapes_mgb, config)
  160. class _modify_subtensor_helper:
  161. def __init__(self, dest, val, *, name=None, comp_node=None, config=None):
  162. self.dest = dest
  163. self.val = val
  164. self.config = _helper.gen_config(name, comp_node, config)
  165. def __getitem__(self, idx):
  166. inp = _mgb._VectorSymbolVar()
  167. dest, desc = _helper.cvt_getitem_to_idx_desc(
  168. self.dest, idx, allow_newaxis=False)
  169. assert desc is not None, 'no __getitem__ entries given'
  170. inp.push_back(dest)
  171. inp.push_back(self.val)
  172. return _mgb._create_subtensor_like_opr(
  173. self._opr_name, inp, desc, self.config)
  174. class set_subtensor(_modify_subtensor_helper):
  175. """a proxy object which supports ``__getitem__`` to set subtensor.
  176. ``c = set_subtensor(a, b)[idx]`` is equivalent to the numpy
  177. expression::
  178. c = a.copy()
  179. c[idx] = b
  180. """
  181. _opr_name = 'set_subtensor'
  182. class incr_subtensor(_modify_subtensor_helper):
  183. """a proxy object which supports ``__getitem__`` to increase subtensor.
  184. ``c = incr_subtensor(a, b)[idx]`` is equivalent to the numpy
  185. expression::
  186. c = a.copy()
  187. c[idx] += b
  188. """
  189. _opr_name = 'incr_subtensor'
  190. class mesh_indexing:
  191. """ Extract elements from given tensor by the coordinates which is
  192. Cartesian product of given index; example::
  193. mesh_indexing(x)[:, [2, 3], :, [2, 3, 4]]
  194. """
  195. def __init__(self, src, *, name=None, comp_node=None, config=None):
  196. self.src = src
  197. self.config = _helper.gen_config(name, comp_node, config)
  198. def __getitem__(self, idx):
  199. inp, desc = _helper.cvt_getitem_to_idx_desc(self.src, idx)
  200. if desc is None:
  201. return inp
  202. return _mgb._create_subtensor_like_opr(
  203. 'mesh_indexing', [inp], desc, self.config)
  204. class batched_mesh_indexing:
  205. """ Similar to :class:`mesh_indexing`, while the k-th position of
  206. slices is a 2-dim matrix `matrix[k]`.
  207. The `matrix[k] is a list of index. The i-th row `matrix[k][i]`
  208. represents the index of the associated k-th position slice when
  209. `batch_idx == i` ; example::
  210. batched_mesh_indexing(x)[:, [[1, 2], [2, 3]], 1:-1:-1]
  211. .. warning::
  212. The first dimension of slices must be (start, stop, step) like,
  213. cannot be any of SymbolVar, numpy.array, Python list.
  214. And the shape of other indexs must be (n, x) while n is the length
  215. of first dimension of tensor after applying [start:stop:step]
  216. """
  217. def __init__(self, src, *, name=None, comp_node=None, config=None):
  218. self.src = src
  219. self.config = _helper.gen_config(name, comp_node, config)
  220. def __getitem__(self, idx):
  221. inp, desc = _helper.cvt_getitem_to_idx_desc(self.src, idx)
  222. if desc is None:
  223. return inp
  224. return _mgb._create_subtensor_like_opr(
  225. 'batched_mesh_indexing', [inp], desc, self.config)
  226. class incr_mesh_indexing(_modify_subtensor_helper):
  227. _opr_name = 'incr_mesh_indexing'
  228. class set_mesh_indexing(_modify_subtensor_helper):
  229. _opr_name = 'set_mesh_indexing'
  230. class batched_incr_mesh_indexing(_modify_subtensor_helper):
  231. _opr_name = 'batched_incr_mesh_indexing'
  232. class batched_set_mesh_indexing(_modify_subtensor_helper):
  233. _opr_name = 'batched_set_mesh_indexing'
  234. class advanced_indexing:
  235. """wrapper for numpy-like advanced indexing, where a non-slice index can be
  236. a vector; example::
  237. advanced_indexing(x)[:, [2, 3]]
  238. """
  239. def __init__(self, src, *, name=None, comp_node=None, config=None):
  240. self.src = src
  241. self.config = _helper.gen_config(name, comp_node, config)
  242. def __getitem__(self, idx):
  243. inp, desc = _helper.cvt_getitem_to_idx_desc(self.src, idx)
  244. if desc is None:
  245. return inp
  246. return _mgb._create_subtensor_like_opr(
  247. 'mavi', [inp], desc, self.config)
  248. class set_advanced_indexing(_modify_subtensor_helper):
  249. """:class:`set_subtensor` equivalent with advanced-indexing support"""
  250. _opr_name = 'set_mavi'
  251. class incr_advanced_indexing(_modify_subtensor_helper):
  252. """:class:`incr_subtensor` equivalent with advanced-indexing support"""
  253. _opr_name = 'incr_mavi'
  254. def mean(inp, axis, keepdims):
  255. """average value along an axis"""
  256. if hasattr(inp.dtype, 'metadata'):
  257. return reduce_(inp, 'MEAN', axis, keepdims)
  258. else:
  259. s = reduce_(inp, 'SUM', axis, keepdims)
  260. if axis is None:
  261. cnt = inp.shape.prod()
  262. else:
  263. cnt = inp.axis_shape(axis)
  264. return s / cnt
  265. def square(inp):
  266. """*inp* squared"""
  267. return inp ** 2
  268. def sqrt(inp):
  269. """square root"""
  270. return inp ** 0.5
  271. class _LoopDescMakerCallback(_mgb._LoopDescMakerCallback):
  272. def __init__(self, func):
  273. super().__init__()
  274. assert isinstance(func, collections.Callable)
  275. self._func = func
  276. self.__disown__()
  277. def call(self, desc):
  278. self._func(desc)
  279. def make_loop(desc_maker, *,
  280. swap_interval=-5, name=None, comp_node=None, config=None):
  281. """Create a loop operator. The loop operator works in the following way:
  282. 1. Copy variables specified by :meth:`.LoopDesc.add_input` from the parent
  283. graph into the sub graph.
  284. 2. Evaluates the loop condition.
  285. 3. If the absolute value of the loop condition is no more than 1e-6, go to
  286. 5.
  287. 4. Update variables in the sub graph using rules specified by
  288. :meth:`.LoopDesc.assign` and then go to 2 again.
  289. 5. Copy values of output variables given by :meth:`.LoopDesc.add_output`
  290. into the parent graph and exit.
  291. The loop operator could be thought of as a digital circuit, where the sub
  292. graph (which must be purely functional) is the combinational logic part and
  293. the :meth:`.LoopDesc.assign` rules serve as the flip-flops.
  294. :type desc_maker: callable
  295. :param desc_maker: a function to create the loop descriptor; it would
  296. receive a :class:`.LoopDesc` object and should call methods on it to
  297. describe the sub graph. This function may be called multiple times, and
  298. it should behave exactly the same in every call.
  299. :type swap_interval: int
  300. :param swap_interval: number of loop executions between swapping saved
  301. mutable states to host; larger *swap_interval* requires more memory and
  302. less copy stall. If *swap_interval* is negative, then statically
  303. inferred loop time would be used if possible; otherwise its absolute
  304. value would be used as swap interval.
  305. :rtype: list of :class:`.SymbolVar`
  306. :return: the output vars, corresponding to each
  307. :meth:`.LoopDesc.add_output` call.
  308. """
  309. config = _helper.gen_config(name, comp_node, config)
  310. return _mgb._make_loop(_LoopDescMakerCallback(desc_maker), swap_interval,
  311. config)
  312. def symvar_from_shared_nd(sv, comp_graph, name=None):
  313. """get a symbol var in a computing graph that represents a shared (i.e.
  314. pre-allocated) value on device
  315. :param sv: the shared value
  316. :type sv: :class:`.SharedND`
  317. :param comp_graph: the computing graph to which this symvar should belong
  318. :type graph: :class:`.CompGraph`
  319. :param name: the name of resulting symvar
  320. :type name: str or None
  321. :rtype: :class:`.SymbolVar`
  322. """
  323. assert isinstance(sv, _mgb.SharedND)
  324. return sv.symvar(comp_graph, name)
  325. def zero_grad(sv, **kwargs):
  326. return set_grad(sv, None, **kwargs)
  327. # for backward pickle compatiblility
  328. def _make_enum_unpickle(new_enum):
  329. """create a class that can be used for unpickling old enum values"""
  330. class OldEnum:
  331. def __new__(cls, value):
  332. return new_enum[value]
  333. return OldEnum
  334. ConvMode = _make_enum_unpickle(_opr_param_defs.Convolution.Mode)
  335. PoolingMode = _make_enum_unpickle(_opr_param_defs.Pooling.Mode)
  336. ROIPoolingMode = _make_enum_unpickle(_opr_param_defs.ROIPooling.Mode)
  337. WarpPerspectiveBorderMode = _make_enum_unpickle(
  338. _opr_param_defs.WarpPerspective.BorderMode)
  339. WarpPerspectiveInterpMode = _make_enum_unpickle(
  340. _opr_param_defs.WarpPerspective.InterpolationMode)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台