You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. import copy
  10. import functools
  11. import itertools
  12. import weakref
  13. from typing import Callable, Tuple, Union
  14. import numpy as np
  15. import megengine._internal as mgb
  16. from .graph import _use_default_if_none, get_default_graph
  17. def wrap_io_tensor(func):
  18. r"""A wrapper to make ``func`` compatible with functions in ``_internal.opr``.
  19. """
  20. @functools.wraps(func)
  21. def wrapper(*args, **kwargs):
  22. comp_graph = None
  23. for i in itertools.chain(args, kwargs.values()):
  24. if isinstance(i, Tensor) and i._comp_graph:
  25. comp_graph = i._comp_graph
  26. break
  27. else:
  28. comp_graph = get_default_graph()
  29. new_args = (
  30. arg._attach(comp_graph) if isinstance(arg, Tensor) else arg for arg in args
  31. )
  32. new_kwargs = {
  33. k: v._attach(comp_graph) if isinstance(v, Tensor) else v
  34. for k, v in kwargs.items()
  35. }
  36. ret = func(*new_args, **new_kwargs)
  37. if isinstance(ret, mgb.SymbolVar):
  38. ret = Tensor(ret)
  39. elif isinstance(ret, list):
  40. ret = [Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret]
  41. elif isinstance(ret, tuple):
  42. ret = tuple(Tensor(t) if isinstance(t, mgb.SymbolVar) else t for t in ret)
  43. return ret
  44. return wrapper
  45. def _wrap_symbolvar_binary_op(f):
  46. @functools.wraps(f)
  47. def wrapped(self, other):
  48. comp_graph = (
  49. isinstance(other, Tensor)
  50. and other._comp_graph
  51. or self._comp_graph
  52. or get_default_graph()
  53. )
  54. if isinstance(other, Tensor):
  55. other = other._attach(comp_graph)
  56. return Tensor(f(self._attach(comp_graph), other))
  57. return wrapped
  58. def _wrap_slice(inp: slice):
  59. r"""
  60. A wrapper to handle Tensor values in ``inp`` slice.
  61. """
  62. start = inp.start._symvar if isinstance(inp.start, Tensor) else inp.start
  63. stop = inp.stop._symvar if isinstance(inp.stop, Tensor) else inp.stop
  64. step = inp.step._symvar if isinstance(inp.step, Tensor) else inp.step
  65. return slice(start, stop, step)
  66. def _wrap_idx(idx: Tuple[Union[int, "Tensor"]]):
  67. r"""
  68. A wrapper to handle Tensor values in ``idx``.
  69. """
  70. if not isinstance(idx, tuple):
  71. idx = (idx,)
  72. idx = tuple(i._symvar if isinstance(i, Tensor) else i for i in idx)
  73. idx = tuple(_wrap_slice(i) if isinstance(i, slice) else i for i in idx)
  74. return idx
  75. class _MGBIndexWrapper:
  76. r"""
  77. A wrapper class to handle ``__getitem__`` for index containing Tensor values.
  78. :param dest: a destination Tensor to do indexing on.
  79. :param mgb_index: an ``_internal`` helper function indicating how to index.
  80. :param val: a optional Tensor parameter used for ``mgb_index``.
  81. """
  82. def __init__(self, dest: "Tensor", mgb_index: Callable, val=None):
  83. self.dest = dest
  84. self.val = val
  85. self.mgb_index = mgb_index
  86. def __getitem__(self, idx):
  87. if self.val is None:
  88. return wrap_io_tensor(self.mgb_index(self.dest._symvar).__getitem__)(
  89. _wrap_idx(idx)
  90. )
  91. else:
  92. return wrap_io_tensor(
  93. self.mgb_index(self.dest._symvar, self.val._symvar).__getitem__
  94. )(_wrap_idx(idx))
  95. class _Guard:
  96. r"""
  97. A wrapper class with custom ``__del__`` method calling ``deleter``.
  98. :param deleter: a function to be called in ``__del__``.
  99. """
  100. def __init__(self, deleter: Callable):
  101. self.deleter = deleter
  102. def __del__(self):
  103. self.deleter()
  104. class Tensor:
  105. r"""The main data container in MegEngine.
  106. Use :func:`~.tensor` to create a Tensor with existed data.
  107. """
  108. requires_grad = False
  109. grad = None
  110. def __init__(self, val=None, *, requires_grad=None):
  111. self._reset(val, requires_grad=requires_grad)
  112. def _reset(self, val=None, *, requires_grad=None):
  113. self.__sym_override = None
  114. if val is None:
  115. self.__val = None
  116. self.__sym = None
  117. elif isinstance(val, mgb.SharedND):
  118. self.__val = val
  119. self.__sym = None
  120. elif isinstance(val, mgb.SymbolVar):
  121. self.__val = None
  122. self.__sym = val
  123. else:
  124. raise TypeError("must be initialized with SymbolVar or SharedND")
  125. self.requires_grad = requires_grad
  126. def _as_tensor(self, obj):
  127. r"""Convert the data into a ``Tensor``. If the data is already a Tensor
  128. with the same dtype and device, no copy will be performed. Otherwise a
  129. new Tensor will be returned with computational graph retained.
  130. """
  131. if isinstance(obj, Tensor):
  132. return obj
  133. if isinstance(obj, mgb.SymbolVar):
  134. return Tensor(obj)
  135. if isinstance(obj, mgb.SharedScalar):
  136. return Tensor(obj._as_sym_var(self._comp_graph, self._comp_node))
  137. return tensor(data=obj, device=self.device)
  138. def numpy(self):
  139. r"""Return the tensor value in numpy.ndarray format.
  140. """
  141. if self.__val is not None:
  142. assert self.__sym is None
  143. return self.__val.get_value()
  144. if self.__sym is None:
  145. raise ValueError("uninitialized")
  146. if self.__sym.eager_val is not None:
  147. return self.__sym.eager_val.get_value()
  148. return self.__sym.inferred_value
  149. def item(self):
  150. r"""If tensor only has only one value, return it."""
  151. return self.numpy().item()
  152. def _attach(self, comp_graph, *, volatile=True):
  153. sym = self.__sym_override or self.__sym
  154. if sym:
  155. if sym.owner_graph != comp_graph:
  156. raise RuntimeError("internal error")
  157. return sym
  158. if self.__val:
  159. return self.__val.symvar(comp_graph, volatile=volatile)
  160. else:
  161. raise ValueError("uninitialized")
  162. @property
  163. def _symvar(self):
  164. if self.__sym_override:
  165. return self.__sym_override
  166. if self.__sym:
  167. assert not self.__val
  168. return self.__sym
  169. if not self.__val:
  170. raise ValueError("uninitialized")
  171. return self._attach(get_default_graph())
  172. def __mgb_symvar__(self, comp_graph=None, **_):
  173. if self.__sym_override:
  174. return self.__sym_override
  175. if self.__val and comp_graph:
  176. return self._attach(comp_graph)
  177. return self._symvar # read by mgb.opr
  178. def _override_symvar_during_trace(self, trace, symvar):
  179. assert self.__val and not self.__sym
  180. assert trace is type(trace)._active_instance
  181. deleters = trace._user_cache.setdefault(Tensor, set())
  182. self_ref = weakref.ref(self)
  183. def restore():
  184. self = self_ref()
  185. if self is not None:
  186. self.__sym_override = None
  187. deleters.add(_Guard(restore))
  188. self.__sym_override = symvar
  189. @property
  190. def dtype(self):
  191. r"""Return the data type of the tensor.
  192. """
  193. if self.__val is not None:
  194. return self.__val.dtype
  195. return self._symvar.dtype
  196. @dtype.setter
  197. def dtype(self, dtype: str = None):
  198. r"""Set the data type of the tensor.
  199. """
  200. if self.__val is not None:
  201. self.__val = mgb.make_shared(self.device, value=self.astype(dtype).numpy())
  202. elif self.__sym_override is not None:
  203. self.__sym_override = self.__sym_override.astype(dtype)
  204. elif self.__sym is not None:
  205. self.__sym = self.__sym.astype(dtype)
  206. @property
  207. def name(self):
  208. r"""Get the tensor name, does not support Parameter and Buffer.
  209. """
  210. return self._symvar.name
  211. @name.setter
  212. def name(self, name: str = None):
  213. r"""Set the tensor name, does not support Parameter and Buffer.
  214. """
  215. if self.__val is not None:
  216. raise ValueError("name setting is not available for Parameter or Buffer.")
  217. if self.__sym_override is not None:
  218. self.__sym_override = self.__sym_override.rename(name)
  219. if self.__sym is not None:
  220. assert not self.__val
  221. self.__sym = self.__sym.rename(name)
  222. @property
  223. def _comp_node(self):
  224. if self.__val is not None:
  225. return self.__val.comp_node
  226. return self._symvar.comp_node
  227. device = _comp_node
  228. @property
  229. def _comp_graph(self):
  230. if self.__sym is not None:
  231. return self.__sym.owner_graph
  232. return None
  233. @property
  234. def shape(self):
  235. r"""Return an int tuple that is the shape/layout of the tensor.
  236. Could be invalid in static graph mode.
  237. """
  238. from ..jit import trace
  239. if trace._active_instance: # pylint: disable=protected-access
  240. # NOTE: this is an hack
  241. shape = mgb.opr.get_var_shape(self._symvar)
  242. return tuple(Tensor(shape[i]) for i in range(self.ndim))
  243. return self._symvar.imm_shape
  244. def set_value(self, value, *, sync=True, inplace=False, share=False):
  245. r"""Set value to the tensor.
  246. """
  247. if not self.__val:
  248. raise ValueError("not detached")
  249. if isinstance(value, Tensor):
  250. value = value.__val or value.__sym.eager_val
  251. self.__val.set_value(value, sync=sync, inplace=inplace, share=share)
  252. def fill(self, value):
  253. r"""Fills the tensor with the specified value.
  254. """
  255. self.set_value(np.full(self.shape, value, dtype=self.dtype))
  256. def reset_zero(self):
  257. r"""Reset the tensor and fills with zeros.
  258. """
  259. if not self.__val:
  260. raise ValueError("not detached")
  261. self.__val.reset_zero()
  262. def to(self, device):
  263. r"""Performs Tensor device conversion, returns Tensor with the specified device.
  264. """
  265. return wrap_io_tensor(mgb.opr.copy)(self, comp_node=device)
  266. # https://docs.python.org/3/reference/datamodel.html#object.__hash__
  267. # > If a class does not define an __eq__() method it should not define a
  268. # > __hash__() operation either
  269. __hash__ = None # type: ignore[assignment]
  270. def __eq__(self, rhs):
  271. rhs = self._as_tensor(rhs)
  272. return Tensor(self._symvar._binary_opr("EQ", rhs._symvar))
  273. def __ne__(self, rhs):
  274. return 1 - self.__eq__(rhs)
  275. def __len__(self):
  276. if self._symvar.eager_val is not None:
  277. return self._symvar.eager_val.shape[0]
  278. raise TypeError(
  279. "__len__ and __iter__ is not available for tensors on non eager graph."
  280. )
  281. __add__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__add__)
  282. __radd__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__radd__)
  283. __sub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__sub__)
  284. __rsub__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rsub__)
  285. __mul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mul__)
  286. __rmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmul__)
  287. __matmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__matmul__)
  288. __rmatmul__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmatmul__)
  289. __lshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lshift__)
  290. __rshift__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rshift__)
  291. __truediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__truediv__)
  292. __rtruediv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rtruediv__)
  293. __floordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__floordiv__)
  294. __rfloordiv__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rfloordiv__)
  295. __mod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__mod__)
  296. __rmod__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rmod__)
  297. __pow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__pow__)
  298. __rpow__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__rpow__)
  299. __lt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__lt__)
  300. __gt__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__gt__)
  301. __le__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__le__)
  302. __ge__ = _wrap_symbolvar_binary_op(mgb.SymbolVar.__ge__)
  303. __neg__ = wrap_io_tensor(mgb.SymbolVar.__neg__)
  304. sum = wrap_io_tensor(mgb.SymbolVar.sum)
  305. """
  306. Sum up the given tensors.
  307. """
  308. max = wrap_io_tensor(mgb.SymbolVar.max)
  309. """
  310. Return the maximum value of given tensor.
  311. """
  312. min = wrap_io_tensor(mgb.SymbolVar.min)
  313. """
  314. Return the minimum value of given tensor.
  315. """
  316. prod = wrap_io_tensor(mgb.SymbolVar.prod)
  317. """
  318. Return the product value of the given tensor.
  319. """
  320. mean = wrap_io_tensor(mgb.SymbolVar.mean)
  321. """
  322. Return the mean value of the given tensor.
  323. """
  324. dimshuffle = wrap_io_tensor(mgb.SymbolVar.dimshuffle)
  325. """
  326. See more details in :func:`~.functional.tensor.dimshuffle`.
  327. """
  328. astype = wrap_io_tensor(mgb.SymbolVar.astype)
  329. """
  330. Cast the tensor to a specified type.
  331. """
  332. def reshape(self, *target_shape):
  333. r"""Return a tensor which has given target shape
  334. Examples:
  335. .. testcode::
  336. import numpy as np
  337. from megengine import tensor
  338. inp = tensor(np.arange(1, 17, dtype=np.int32).reshape(4,4))
  339. out = tensor(np.arange(100, 116, dtype=np.int32).reshape(1,16))
  340. out = out.reshape(inp.shape)
  341. print(out.numpy())
  342. .. testoutput::
  343. [[100 101 102 103]
  344. [104 105 106 107]
  345. [108 109 110 111]
  346. [112 113 114 115]]
  347. """
  348. if isinstance(target_shape[0], tuple):
  349. if len(target_shape) > 1:
  350. raise ValueError("Only single tuple is accepted in reshape")
  351. target_shape = target_shape[0]
  352. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  353. return Tensor(mgb.SymbolVar.reshape(self._symvar, *target_shape))
  354. def broadcast(self, *target_shape):
  355. r"""Return a tesnor broadcasted by current tensor to given target shape
  356. Examples:
  357. .. testcode::
  358. import numpy as np
  359. from megengine import tensor
  360. data = tensor(np.arange(100, 104, dtype=np.int32).reshape(1,4))
  361. data = data.broadcast((4,4))
  362. print(data.numpy())
  363. .. testoutput::
  364. [[100 101 102 103]
  365. [100 101 102 103]
  366. [100 101 102 103]
  367. [100 101 102 103]]
  368. """
  369. if isinstance(target_shape[0], tuple):
  370. if len(target_shape) > 1:
  371. raise ValueError("Only single tuple is accepted in broadcast")
  372. target_shape = target_shape[0]
  373. target_shape = (t._symvar if isinstance(t, Tensor) else t for t in target_shape)
  374. return Tensor(mgb.SymbolVar.broadcast(self._symvar, *target_shape))
  375. # Prefer operators on Tensor instead of convert to numpy
  376. __array_priority__ = 1000
  377. # mgb indexing family
  378. def __getitem__(self, idx):
  379. return wrap_io_tensor(self._symvar.__getitem__)(_wrap_idx(idx))
  380. def set_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
  381. r"""
  382. Return a object which supports using ``__getitem__`` to set subtensor.
  383. ``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``.
  384. """
  385. return _MGBIndexWrapper(self, mgb.opr.set_subtensor, val)
  386. def incr_subtensor(self, val: "Tensor") -> _MGBIndexWrapper:
  387. r"""
  388. Return a object which supports using ``__getitem__`` to increase subtensor.
  389. ``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``.
  390. """
  391. return _MGBIndexWrapper(self, mgb.opr.incr_subtensor, val)
  392. @property
  393. def ai(self) -> _MGBIndexWrapper:
  394. r"""
  395. Return a object which supports complex index method to get subtensor.
  396. Examples:
  397. .. testcode::
  398. from megengine import tensor
  399. a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
  400. print(a.ai[:, [2, 3]])
  401. Outputs:
  402. .. testoutput::
  403. Tensor([[ 2. 3.]
  404. [ 6. 7.]
  405. [10. 11.]
  406. [14. 15.]])
  407. """
  408. return _MGBIndexWrapper(self, mgb.opr.advanced_indexing)
  409. def set_ai(self, val: "Tensor") -> _MGBIndexWrapper:
  410. r"""
  411. Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
  412. """
  413. return _MGBIndexWrapper(self, mgb.opr.set_advanced_indexing, val)
  414. def incr_ai(self, val: "Tensor") -> _MGBIndexWrapper:
  415. r"""
  416. Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
  417. """
  418. return _MGBIndexWrapper(self, mgb.opr.incr_advanced_indexing, val)
  419. @property
  420. def mi(self) -> _MGBIndexWrapper:
  421. r"""
  422. Return a object which supports getting subtensor by
  423. the coordinates which is Cartesian product of given index.
  424. Examples:
  425. .. testcode::
  426. from megengine import tensor
  427. a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
  428. print(a.mi[[1, 2], [2, 3]])
  429. # is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]]
  430. # a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11
  431. Outputs:
  432. .. testoutput::
  433. Tensor([[ 6. 7.]
  434. [10. 11.]])
  435. """
  436. return _MGBIndexWrapper(self, mgb.opr.mesh_indexing)
  437. def set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  438. r"""
  439. Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
  440. """
  441. return _MGBIndexWrapper(self, mgb.opr.set_mesh_indexing, val)
  442. def incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  443. r"""
  444. Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
  445. """
  446. return _MGBIndexWrapper(self, mgb.opr.incr_mesh_indexing, val)
  447. @property
  448. def batched_mi(self) -> _MGBIndexWrapper:
  449. r"""
  450. Return a object which supports getting subtensor by
  451. batched mesh indexing.
  452. For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice.
  453. Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``.
  454. Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated.
  455. And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below.
  456. Examples:
  457. .. testcode::
  458. from megengine import tensor
  459. a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4)))
  460. print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]])
  461. # is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1])
  462. # and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1])
  463. # a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77
  464. print(a.batched_mi[:2, [[0],[1]], :2, :1])
  465. # is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]``
  466. Outputs:
  467. .. testoutput::
  468. Tensor([[[[ 0.]
  469. [ 4.]]]
  470. [[[73.]
  471. [77.]]]])
  472. Tensor([[[[ 0.]
  473. [ 4.]]]
  474. [[[64.]
  475. [68.]]]])
  476. """
  477. return _MGBIndexWrapper(self, mgb.opr.batched_mesh_indexing)
  478. def batched_set_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  479. r"""
  480. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
  481. """
  482. return _MGBIndexWrapper(self, mgb.opr.batched_set_mesh_indexing, val)
  483. def batched_incr_mi(self, val: "Tensor") -> _MGBIndexWrapper:
  484. r"""
  485. Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
  486. """
  487. return _MGBIndexWrapper(self, mgb.opr.batched_incr_mesh_indexing, val)
  488. def __array__(self, dtype=None):
  489. if dtype is None:
  490. return self.numpy()
  491. else:
  492. return self.numpy().astype(dtype, copy=False)
  493. def __int__(self):
  494. return int(self.item())
  495. def __index__(self):
  496. return int(self.item())
  497. def __round__(self, ndigits=0):
  498. if ndigits != 0:
  499. raise ValueError("ndigits must be 0 for Tensor.round")
  500. return Tensor(mgb.opr.elemwise([self._symvar], mode="ROUND"))
  501. round = __round__
  502. def sqrt(self):
  503. r"""Return a tensor that each element is the square root of its
  504. original value.
  505. """
  506. return Tensor(mgb.opr.sqrt(self._symvar))
  507. def shapeof(self, axis=None):
  508. r"""Return a Tensor that represent the shape of the tensor.
  509. """
  510. return Tensor(mgb.opr.get_var_shape(self._symvar, axis=axis))
  511. @property
  512. def ndim(self):
  513. r"""Return the number of dimensions of the tensor.
  514. """
  515. return len(self._symvar.imm_shape)
  516. def __repr__(self):
  517. piece = "Tensor("
  518. with np.printoptions(precision=4, suppress=True):
  519. piece += "{}".format(str(self.numpy()))
  520. if self.dtype != np.float32:
  521. piece += ", dtype={}".format(np.dtype(self.dtype).name)
  522. if self._comp_node.locator_logical != ("XPU", -1, 0):
  523. piece += ", device={}".format(self.device)
  524. piece += ")"
  525. return piece
  526. def __bool__(self):
  527. raise RuntimeError(
  528. "Tensor object should not be converted to bool or used in a if statement. Use .numpy(), int() or float() if you want to use its value in if statement, be aware that this may lead to incorrect result in non-eager mode."
  529. )
  530. def __getstate__(self):
  531. r""" __getstate__ will be called for pickle serialization or deep copy
  532. """
  533. assert (self.__val is not None) and (
  534. self.__sym is None
  535. ), "Only SharedND initialized Tensor can be serialized or deep copied"
  536. metadata = {"requires_grad": self.requires_grad}
  537. state = {
  538. "data": self.numpy(),
  539. "device": self.device,
  540. "dtype": self.dtype,
  541. "metadata": metadata,
  542. }
  543. return state
  544. def __setstate__(self, state):
  545. data = state.pop("data")
  546. device = state.pop("device")
  547. dtype = state.pop("dtype")
  548. metadata = state.pop("metadata", {})
  549. requires_grad = metadata.pop("requires_grad", None)
  550. snd = mgb.make_shared(device, value=data, dtype=dtype)
  551. self._reset(snd, requires_grad=requires_grad)
  552. def __deepcopy__(self, memo):
  553. """
  554. Since Tensor have __getstate__ and __setstate__ method,
  555. deepcopy only process the that and ignore the attribute of Parameter.
  556. So we need to add __deepcopy__ method to deepcopy correct attribute.
  557. """
  558. assert (self.__val is not None) and (
  559. self.__sym is None
  560. ), "Only SharedND initialized Tensor can be serialized or deep copied"
  561. cls = self.__class__
  562. result = cls.__new__(cls)
  563. memo[id(self)] = result
  564. for k, v in self.__dict__.items():
  565. setattr(result, k, copy.deepcopy(v, memo))
  566. return result
  567. def tensor(
  568. data: Union[list, np.ndarray] = None,
  569. *,
  570. dtype: str = None,
  571. device: mgb.CompNode = None,
  572. requires_grad: bool = None
  573. ):
  574. r"""A helper function to create a :class:`~.Tensor` using existing data.
  575. :param data: an existing data array, must be Python list, NumPy array or None.
  576. :param dtype: target Tensor data type, one of ``("uint8", "int8", "int16", "int32", "float32", "float16")``.
  577. :param device: target device for Tensor storing.
  578. :param requires_grad: whether its gradiant will be calculated during :meth:`~.Optimizer.backward`
  579. """
  580. supported_dtypes = ("uint8", "int8", "int16", "int32", "float32", "float16")
  581. if isinstance(data, Tensor):
  582. raise NotImplementedError
  583. if dtype is not None and np.dtype(dtype).name not in supported_dtypes:
  584. raise TypeError("unsupported dtype {}".format(dtype))
  585. if data is not None:
  586. if not isinstance(data, np.ndarray):
  587. data = np.array(data, dtype=dtype)
  588. # In order to accept tensor([1]),
  589. # Automaticlly convert to 32-bit number instead of numpy's default 64-bit when input data is not nparray.
  590. dtype = mgb.to_mgb_supported_dtype(data.dtype)
  591. if dtype is None:
  592. if data.dtype.name not in supported_dtypes:
  593. raise TypeError("unsupported dtype {}".format(data.dtype))
  594. device, _ = _use_default_if_none(device, None)
  595. shared_nd = mgb.make_shared(device, value=data, dtype=dtype)
  596. return Tensor(shared_nd, requires_grad=requires_grad)
  597. class TensorDict(collections.MutableMapping):
  598. r"""
  599. A helper class to maintain dict with Tensor key.
  600. """
  601. def __init__(self, *args, **kwargs):
  602. self.data = {}
  603. for i in args:
  604. self.update(i)
  605. self.update(**kwargs)
  606. class keyfn:
  607. def __new__(cls, x: Tensor):
  608. if not isinstance(x, Tensor):
  609. return x
  610. return super().__new__(cls)
  611. def __init__(self, x: Tensor):
  612. self._data = x # do not save id directly to make pickle work
  613. def __hash__(self):
  614. return id(self._data)
  615. def __eq__(self, other):
  616. return isinstance(other, type(self)) and id(self._data) == id(other._data)
  617. def __getitem__(self, key):
  618. _, v = self.data[self.keyfn(key)]
  619. return v
  620. def __setitem__(self, key, value):
  621. self.data[self.keyfn(key)] = key, value
  622. def __delitem__(self, key):
  623. del self.data[self.keyfn(key)]
  624. def __iter__(self):
  625. for _, (k, _) in self.data.items():
  626. yield k
  627. def __len__(self):
  628. return len(self.data)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台