You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 28 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import math
  11. from itertools import accumulate
  12. from typing import Iterable, List, Optional, Sequence, Tuple, Union
  13. import numpy as np
  14. from ..core._imperative_rt import CompNode
  15. from ..core._wrap import device as as_device
  16. from ..core.ops import builtin
  17. from ..core.ops._internal import param_defs as P
  18. from ..core.ops.special import Const
  19. from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
  20. from ..core.tensor.utils import (
  21. astensor1d,
  22. convert_inputs,
  23. convert_single_value,
  24. dtype_promotion,
  25. get_device,
  26. )
  27. from ..device import get_default_device
  28. from ..tensor import Tensor
  29. from .elemwise import ceil
  30. __all__ = [
  31. "add_axis",
  32. "arange",
  33. "broadcast",
  34. "concat",
  35. "cond_take",
  36. "dimshuffle",
  37. "expand_dims",
  38. "eye",
  39. "full",
  40. "full_like",
  41. "gather",
  42. "linspace",
  43. "ones",
  44. "ones_like",
  45. "param_pack_concat",
  46. "param_pack_split",
  47. "reshape",
  48. "remove_axis",
  49. "split",
  50. "squeeze",
  51. "stack",
  52. "scatter",
  53. "transpose",
  54. "where",
  55. "zeros",
  56. "zeros_like",
  57. ]
  58. def eye(n: int, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
  59. """
  60. Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
  61. :param n: The number of rows
  62. :param m: The number of columns. Default: None
  63. :param dtype: The data type. Default: None
  64. :param device: Compute node of the matrix. Default: None
  65. :param comp_graph: Compute graph of the matrix. Default: None
  66. :return: The eye matrix
  67. Examples:
  68. .. testcode::
  69. import numpy as np
  70. import megengine.functional as F
  71. data_shape = (4, 6)
  72. n, m = data_shape
  73. out = F.eye([n, m], dtype=np.float32)
  74. print(out.numpy())
  75. Outputs:
  76. .. testoutput::
  77. [[1. 0. 0. 0. 0. 0.]
  78. [0. 1. 0. 0. 0. 0.]
  79. [0. 0. 1. 0. 0. 0.]
  80. [0. 0. 0. 1. 0. 0.]]
  81. """
  82. op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
  83. (result,) = apply(op, Tensor(n, dtype="int32", device=device))
  84. return result
  85. def full(shape, value, dtype="float32", device=None):
  86. if isinstance(shape, int):
  87. shape = (shape,)
  88. if device is None:
  89. device = get_default_device()
  90. (x,) = Const(value, dtype=dtype, device=device)(
  91. Tensor(value, dtype=dtype, device=device)
  92. )
  93. return broadcast(x, shape)
  94. def ones(shape, dtype="float32", device=None):
  95. return full(shape, 1.0, dtype=dtype, device=device)
  96. def zeros(shape, dtype="float32", device=None):
  97. return full(shape, 0.0, dtype=dtype, device=device)
  98. def zeros_like(inp: Tensor) -> Tensor:
  99. r"""
  100. Returns a zero tensor with the same shape as input tensor
  101. :param inp: input tensor
  102. Examples:
  103. .. testcode::
  104. import numpy as np
  105. from megengine import tensor
  106. import megengine.functional as F
  107. inp = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  108. out = F.zeros_like(inp)
  109. print(out.numpy())
  110. Outputs:
  111. .. testoutput::
  112. [[0 0 0]
  113. [0 0 0]]
  114. """
  115. return zeros(inp.shape, dtype=inp.dtype, device=inp.device)
  116. def ones_like(inp: Tensor) -> Tensor:
  117. r"""
  118. Returns a identity tensor with the same shape as input tensor
  119. """
  120. return ones(inp.shape, dtype=inp.dtype, device=inp.device)
  121. def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
  122. r"""
  123. Returns a tensor filled with value val with the same shape as input tensor
  124. """
  125. return full(inp.shape, value, dtype=inp.dtype, device=inp.device)
  126. def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
  127. """
  128. Broadcast a tensor to ``shape``
  129. :param inp: The input tensor
  130. :param shape: The target shape
  131. :return: The output tensor
  132. Examples:
  133. .. testcode::
  134. import numpy as np
  135. from megengine import tensor
  136. import megengine.functional as F
  137. data = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  138. out = F.broadcast(data, (4, 2, 3))
  139. print(out.numpy())
  140. Outputs:
  141. .. testoutput::
  142. [[[0. 1. 2.]
  143. [3. 4. 5.]]
  144. [[0. 1. 2.]
  145. [3. 4. 5.]]
  146. [[0. 1. 2.]
  147. [3. 4. 5.]]
  148. [[0. 1. 2.]
  149. [3. 4. 5.]]]
  150. """
  151. shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
  152. (result,) = apply(builtin.Broadcast(), inp, shape)
  153. return result
  154. def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
  155. r"""
  156. Concat some tensors
  157. :param inps: Input tensors to concat
  158. :param axis: the dimension over which the tensors are concatenated. Default: 0
  159. :param device: The comp node output on. Default: None
  160. :return: The output tensor
  161. Examples:
  162. .. testcode::
  163. import numpy as np
  164. from megengine import tensor
  165. import megengine.functional as F
  166. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  167. data2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  168. out = F.concat([data1, data2])
  169. print(out.numpy())
  170. Outputs:
  171. .. testoutput::
  172. [[ 0. 1. 2.]
  173. [ 3. 4. 5.]
  174. [ 6. 7. 8.]
  175. [ 9. 10. 11.]]
  176. """
  177. if len(inps) == 1:
  178. return inps[0]
  179. dtype = dtype_promotion(inps)
  180. if device is None:
  181. device = get_device(inps)
  182. device = as_device(device)
  183. def convert(x):
  184. return convert_single_value(x, inps, dtype=dtype)
  185. inps = tuple(map(convert, inps))
  186. (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
  187. return result
  188. def stack(inps, axis=0, device=None):
  189. """Concats a sequence of tensors along a new axis.
  190. The input tensors must have the same shape.
  191. :param inps: The input tensors.
  192. :param axis: Which axis will be concatenated.
  193. :param device: The comp node output on. Default: None
  194. :return: The output concatenated tensor.
  195. Examples:
  196. .. testcode::
  197. import numpy as np
  198. from megengine import tensor
  199. import megengine.functional as F
  200. x1 = tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  201. x2 = tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  202. out = F.stack([x1, x2], axis=0)
  203. print(out.numpy())
  204. Outputs:
  205. .. testoutput::
  206. [[[ 0. 1. 2.]
  207. [ 3. 4. 5.]]
  208. [[ 6. 7. 8.]
  209. [ 9. 10. 11.]]]
  210. """
  211. if len(inps) > 0 and not isinstance(inps[0].shape, inps[0].__class__):
  212. shapes = {arr.shape for arr in inps}
  213. if len(shapes) != 1:
  214. raise ValueError("All input tensors must have the same shape")
  215. inps = [add_axis(inp, axis=axis) for inp in inps]
  216. return concat(inps, axis=axis, device=device)
  217. def split(inp, nsplits_or_sections, axis=0):
  218. """Splits the input tensor into several smaller tensors.
  219. When nsplits_or_sections is int, the last tensor may be smaller than others.
  220. :param inp: The input tensor.
  221. :param nsplits_or_sections: Number of sub tensors or section information list.
  222. :param axis: Which axis will be splited.
  223. :return: The output tensor list.
  224. Examples:
  225. .. testcode::
  226. import numpy as np
  227. from megengine import tensor
  228. import megengine.functional as F
  229. x = tensor(np.random.random((2,3,4,5)), dtype=np.float32)
  230. out = F.split(x, 2, axis=3)
  231. print(out[0].shape, out[1].shape)
  232. Outputs:
  233. .. testoutput::
  234. (2, 3, 4, 3) (2, 3, 4, 2)
  235. """
  236. sub_tensors = []
  237. sections = []
  238. def swapaxis(inp, src, dst):
  239. if src == dst:
  240. return inp
  241. shape = [i for i in range(inp.ndim)]
  242. shape[src] = dst
  243. shape[dst] = src
  244. return inp.transpose(shape)
  245. inp = swapaxis(inp, 0, axis)
  246. if isinstance(nsplits_or_sections, int):
  247. incr_step = ceil(inp.shape[0] / nsplits_or_sections)
  248. nsplits = nsplits_or_sections
  249. while nsplits > 0:
  250. nsplits -= 1
  251. sections.append(incr_step.astype("int32"))
  252. incr_step += nsplits_or_sections
  253. else:
  254. sections = nsplits_or_sections
  255. st = 0
  256. for se in sections:
  257. sub_tensors.append(swapaxis(inp[st:se], axis, 0))
  258. st = se
  259. if st < inp.shape[0]:
  260. sub_tensors.append(swapaxis(inp[st:], axis, 0))
  261. return sub_tensors
  262. def _get_idx(index, axis):
  263. index_dims = len(index.shape)
  264. idx = []
  265. for i in range(index_dims):
  266. if i != axis:
  267. shape = [1] * index_dims
  268. shape[i] = index.shape[i]
  269. arange = linspace(
  270. 0, index.shape[i] - 1, index.shape[i], device=index.device,
  271. )
  272. arange = (
  273. arange.reshape(*shape)
  274. .broadcast(index.shape)
  275. .reshape(-1)
  276. .astype(np.int32)
  277. )
  278. idx.append(arange)
  279. else:
  280. idx.append(index.reshape(-1))
  281. return tuple(idx)
  282. def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
  283. r"""
  284. Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`.
  285. For a 3-D tensor, the output is specified by::
  286. out[i][j][k] = inp[index[i][j][k]][j][k] # if axis == 0
  287. out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1
  288. out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2
  289. if :attr:`inp` is an n-dimensional tensor with size
  290. :math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i,
  291. then :attr:`index` must be an n-dimensional tensor with size
  292. :math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and
  293. output will have the same size as :attr:`index`.
  294. :param inp: the source tensor
  295. :param axis: the axis along which to index
  296. :param index: the indices of elements to gather
  297. Examples:
  298. .. testcode::
  299. import megengine.functional as F
  300. from megengine import tensor
  301. inp = tensor([
  302. [1,2], [3,4], [5,6],
  303. ])
  304. index = tensor([[0,2], [1,0]])
  305. oup = F.gather(inp, 0, index)
  306. print(oup.numpy())
  307. Outputs:
  308. .. testoutput::
  309. [[1 6]
  310. [3 2]]
  311. """
  312. input_shape = inp.shape
  313. index_shape = index.shape
  314. input_dims = len(input_shape)
  315. index_dims = len(index_shape)
  316. if input_dims != index_dims:
  317. raise ValueError(
  318. "The index tensor must have same dimensions as input tensor, "
  319. "But the input dims:{}, the index dims:{}".format(input_dims, index_dims)
  320. )
  321. if axis < 0 or axis >= input_dims:
  322. raise ValueError(
  323. "Index axis {} is output of bounds, should in range [0 {})".format(
  324. axis, input_dims
  325. )
  326. )
  327. for i in range(input_dims):
  328. if i != axis and input_shape[i] != index_shape[i]:
  329. raise ValueError(
  330. "The input {} and index {} must have the same size apart from axis {}".format(
  331. input_shape, index_shape, axis
  332. )
  333. )
  334. idx = _get_idx(index, axis)
  335. return inp[idx].reshape(index.shape) # pylint: disable=no-member
  336. def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
  337. r"""
  338. Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor.
  339. For each value in :attr:`source`, its output index is specified by its index
  340. in :attr:`source` for ``axis != dimension`` and by the corresponding value in
  341. :attr:`index` for ``axis = dimension``.
  342. For a 3-D tensor, :attr:`inp` is updated as::
  343. inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0
  344. inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1
  345. inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2
  346. :attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions.
  347. It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)``
  348. for all dimensions ``d``.
  349. Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive.
  350. .. note::
  351. Please notice that, due to performance issues, the result is uncertain on the GPU device
  352. if scatter difference positions from source to the same destination position
  353. regard to index tensor.
  354. Show the case using the following examples, the oup[0][2] is maybe
  355. from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339
  356. if set the index[1][2] from 1 to 0.
  357. :param inp: the inp tensor which to be scattered
  358. :param axis: the axis along which to index
  359. :param index: the indices of elements to scatter
  360. :param source: the source element(s) to scatter
  361. Examples:
  362. .. testcode::
  363. import numpy as np
  364. import megengine.functional as F
  365. from megengine import tensor
  366. inp = tensor(np.zeros(shape=(3,5),dtype=np.float32))
  367. source = tensor([[0.9935,0.9465,0.2256,0.8926,0.4396],[0.7723,0.0718,0.5939,0.357,0.4576]])
  368. index = tensor([[0,2,0,2,1],[2,0,1,1,2]])
  369. oup = F.scatter(inp, 0, index,source)
  370. print(oup.numpy())
  371. Outputs:
  372. .. testoutput::
  373. [[0.9935 0.0718 0.2256 0. 0. ]
  374. [0. 0. 0.5939 0.357 0.4396]
  375. [0.7723 0.9465 0. 0.8926 0.4576]]
  376. """
  377. input_shape = inp.shape
  378. index_shape = index.shape
  379. source_shape = source.shape
  380. input_dims = len(input_shape)
  381. index_dims = len(index_shape)
  382. source_dims = len(source_shape)
  383. if input_dims != index_dims or input_dims != source_dims:
  384. raise ValueError("The input, source and index tensor must have same dimensions")
  385. if axis < 0 or axis >= input_dims:
  386. raise ValueError(
  387. "Index axis {} is output of bounds, should in range [0 {})".format(
  388. axis, input_dims
  389. )
  390. )
  391. for i in range(source_dims):
  392. if source_shape[i] > input_shape[i]:
  393. raise ValueError(
  394. "The each shape size for source {} must be less than or equal to input {} ".format(
  395. source_shape, input_shape
  396. )
  397. )
  398. for i in range(index_dims):
  399. if index_shape[i] != source_shape[i]:
  400. raise ValueError(
  401. "The each shape size for index {} must be equal to source {} ".format(
  402. index_shape, source_shape
  403. )
  404. )
  405. for i in range(index_dims):
  406. if i != axis and index_shape[i] > input_shape[i]:
  407. raise ValueError(
  408. "The index {} must be less than or equal to input {} size apart from axis {}".format(
  409. index_shape, input_shape, axis
  410. )
  411. )
  412. idx = _get_idx(index, axis)
  413. inp[idx] = source.flatten()
  414. return inp
  415. def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
  416. r"""
  417. Select elements either from Tensor x or Tensor y, according to mask.
  418. .. math::
  419. \textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i
  420. :param mask: a mask used for choosing x or y
  421. :param x: the first choice
  422. :param y: the second choice
  423. Examples:
  424. .. testcode::
  425. from megengine import tensor
  426. import megengine.functional as F
  427. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
  428. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  429. dtype=np.float32))
  430. y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
  431. out = F.where(mask, x, y)
  432. print(out.numpy())
  433. Outputs:
  434. .. testoutput::
  435. [[1. 6.]
  436. [7. 4.]]
  437. """
  438. x, y = convert_inputs(x, y)
  439. if not isinstance(x, (TensorWrapperBase, TensorBase)):
  440. raise TypeError("input x must be a tensor")
  441. if not isinstance(y, (TensorWrapperBase, TensorBase)):
  442. raise TypeError("input y must be a tensor")
  443. if not isinstance(mask, (TensorWrapperBase, TensorBase)):
  444. raise TypeError("mask must be a tensor")
  445. if mask.dtype != np.bool_:
  446. raise ValueError("mask must be bool")
  447. if x.device != mask.device:
  448. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  449. v0, index0 = cond_take(mask, x)
  450. v1, index1 = cond_take(~mask, y)
  451. if v0.shape == (0,):
  452. out = v1
  453. elif v1.shape == (0,):
  454. out = v0
  455. else:
  456. out = concat([v0, v1])
  457. out[index0] = v0
  458. out[index1] = v1
  459. out = out.reshape(x.shape)
  460. return out
  461. def cond_take(mask: Tensor, x: Tensor) -> Tensor:
  462. r"""
  463. Take elements from data if specific condition is satisfied on mask.
  464. This operator has two outputs: the first is the elements taken,
  465. and the second is the indices corresponding to those elements;
  466. they are both 1-dimensional. High-dimension input would first be flattened.
  467. :param mask: condition param; must be the same shape with data
  468. :param x: input tensor from which to take elements
  469. Examples:
  470. .. testcode::
  471. import numpy as np
  472. from megengine import tensor
  473. import megengine.functional as F
  474. mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
  475. x = tensor(np.array([[1, np.inf], [np.nan, 4]],
  476. dtype=np.float32))
  477. v, index = F.cond_take(mask, x)
  478. print(v.numpy(), index.numpy())
  479. Outputs:
  480. .. testoutput::
  481. [1. 4.] [0 3]
  482. """
  483. if not isinstance(x, (TensorWrapperBase, TensorBase)):
  484. raise TypeError("input must be a tensor")
  485. if not isinstance(mask, (TensorWrapperBase, TensorBase)):
  486. raise TypeError("mask must be a tensor")
  487. if mask.dtype != np.bool_:
  488. raise ValueError("mask must be bool")
  489. if x.device != mask.device:
  490. raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
  491. op = builtin.CondTake()
  492. v, index = apply(op, x, mask)
  493. return v, index
  494. def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor:
  495. r"""
  496. Swap shapes and strides according to given pattern
  497. :param inp: Input tensor
  498. :param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
  499. * (``'x'``) -> make a 0d (scalar) into a 1d vector
  500. * (0, 1) -> identity for 2d vectors
  501. * (1, 0) -> inverts the first and second dimensions
  502. * (``'x'``, 0) -> make a row out of a 1d vector (N to 1xN)
  503. * (0, ``'x'``) -> make a column out of a 1d vector (N to Nx1)
  504. * (2, 0, 1) -> AxBxC to CxAxB
  505. * (0, ``'x'``, 1) -> AxB to Ax1xB
  506. * (1, ``'x'``, 0) -> AxB to Bx1xA
  507. * (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A)
  508. :return: The output tensor
  509. Examples:
  510. .. testcode::
  511. import numpy as np
  512. from megengine import tensor
  513. import megengine.functional as F
  514. x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32))
  515. out = F.dimshuffle(x, (1, 0))
  516. print(out.numpy())
  517. Outputs:
  518. .. testoutput::
  519. [[1 0]
  520. [1 0]]
  521. """
  522. op = builtin.Dimshuffle(pattern)
  523. (inp,) = convert_inputs(inp)
  524. (result,) = apply(op, inp)
  525. return result
  526. transpose = dimshuffle
  527. def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
  528. r"""
  529. Reshape a tensor to given target shape; total number of logical elements must
  530. remain unchanged
  531. :param inp: Input tensor
  532. :param target_shape: target shape, the components would be concatenated to form the
  533. target shape, and it can contain an element of -1 representing unspec_axis.
  534. Examples:
  535. .. testcode::
  536. import numpy as np
  537. from megengine import tensor
  538. import megengine.functional as F
  539. x = tensor(np.arange(12, dtype=np.int32))
  540. out = F.reshape(x, (3, 2, 2))
  541. print(out.numpy())
  542. Outputs:
  543. .. testoutput::
  544. [[[ 0 1]
  545. [ 2 3]]
  546. [[ 4 5]
  547. [ 6 7]]
  548. [[ 8 9]
  549. [10 11]]]
  550. """
  551. if isinstance(target_shape, (TensorBase, TensorWrapperBase)):
  552. target_shape = target_shape.numpy()
  553. target_shape = tuple(map(int, target_shape))
  554. unspec_axis = None
  555. for i, s in enumerate(target_shape):
  556. if s < 0:
  557. if s != -1:
  558. raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
  559. if unspec_axis is not None:
  560. raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
  561. unspec_axis = i
  562. # TODO: device should be None (cpu)
  563. (target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp)
  564. if unspec_axis is None:
  565. op = builtin.Reshape()
  566. else:
  567. op = builtin.Reshape(unspec_axis=unspec_axis)
  568. (x,) = apply(op, inp, target_shape)
  569. return x
  570. AxisAddRemove = builtin.AxisAddRemove
  571. AxisDesc = AxisAddRemove.AxisDesc
  572. def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  573. r"""
  574. Add dimension before given axis.
  575. :param inp: Input tensor
  576. :param axis: Place of new axes
  577. :return: The output tensor
  578. Examples:
  579. .. testcode::
  580. import numpy as np
  581. from megengine import tensor
  582. import megengine.functional as F
  583. x = tensor([1, 2])
  584. out = F.add_axis(x, 0)
  585. print(out.shape)
  586. Outputs:
  587. .. testoutput::
  588. (1, 2)
  589. """
  590. Param = AxisAddRemove.Param
  591. def get_axes():
  592. try:
  593. return [int(axis)]
  594. except (TypeError, ValueError):
  595. pass
  596. return list(map(int, axis))
  597. axis = get_axes()
  598. ndim = inp.ndim + len(axis)
  599. axis = sorted(i + ndim if i < 0 else i for i in axis)
  600. param = Param(*map(AxisDesc.make_add, axis))
  601. op = AxisAddRemove(param=param)
  602. (result,) = apply(op, inp)
  603. return result
  604. expand_dims = add_axis
  605. def remove_axis(
  606. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None
  607. ) -> Tensor:
  608. r"""
  609. Remove dimension of shape 1.
  610. :param inp: Input tensor
  611. :param axis: Place of axis to be removed, if None, all axis=1 will be removed. Default: None
  612. :return: The output tensor
  613. Examples:
  614. .. testcode::
  615. import numpy as np
  616. from megengine import tensor
  617. import megengine.functional as F
  618. x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  619. out = F.remove_axis(x, 3)
  620. print(out.shape)
  621. Outputs:
  622. .. testoutput::
  623. (1, 1, 2)
  624. """
  625. Param = AxisAddRemove.Param
  626. def get_axes():
  627. if axis is None:
  628. return [i for i, s in enumerate(inp.shape) if s == 1]
  629. try:
  630. return [int(axis)]
  631. except (TypeError, ValueError):
  632. pass
  633. return list(map(int, axis))
  634. axis = get_axes()
  635. axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
  636. axis = [a - i for i, a in enumerate(axis)]
  637. param = Param(*map(AxisDesc.make_remove, axis))
  638. op = AxisAddRemove(param=param)
  639. (result,) = apply(op, inp)
  640. return result
  641. squeeze = remove_axis
  642. def linspace(
  643. start: Union[int, float, Tensor],
  644. stop: Union[int, float, Tensor],
  645. num: Union[int, Tensor],
  646. dtype="float32",
  647. device: Optional[CompNode] = None,
  648. ) -> Tensor:
  649. r"""
  650. Return equally spaced numbers over a specified interval
  651. :param start: Starting value of the squence, shoule be scalar
  652. :param stop: The last value of the squence, shoule be scalar
  653. :param num: number of values to generate
  654. :param dtype: result data type
  655. :return: The generated tensor
  656. Examples:
  657. .. testcode::
  658. import numpy as np
  659. import megengine.functional as F
  660. a = F.linspace(3,10,5)
  661. print(a.numpy())
  662. Outputs:
  663. .. testoutput::
  664. [ 3. 4.75 6.5 8.25 10. ]
  665. """
  666. start = Tensor(start, device=device)
  667. stop = Tensor(stop, device=device)
  668. num = Tensor(num, device=device)
  669. device = device if device is None else device.to_c()
  670. op = builtin.Linspace(comp_node=device)
  671. (result,) = apply(op, start, stop, num)
  672. if np.dtype(dtype) == np.int32:
  673. return result.astype(dtype)
  674. return result
  675. def arange(
  676. start: Union[int, float, Tensor] = 0,
  677. end: Optional[Union[int, float, Tensor]] = None,
  678. step: Union[int, float, Tensor] = 1,
  679. dtype="float32",
  680. device: Optional[CompNode] = None,
  681. ) -> Tensor:
  682. r"""
  683. Returns a Tensor with values from `start` to `end` with adjacent interval `step`
  684. :param start: starting value of the squence, shoule be scalar
  685. :param end: ending value of the squence, shoule be scalar
  686. :param step: the gap between each pair of adjacent values. Default 1
  687. :param dtype: result data type
  688. :return: The generated tensor
  689. Examples:
  690. .. testcode::
  691. import numpy as np
  692. import megengine.functional as F
  693. a = F.arange(5)
  694. print(a.numpy())
  695. Outputs:
  696. .. testoutput::
  697. [1. 2. 3. 4.]
  698. """
  699. if end is None:
  700. start, end = 0, start
  701. if isinstance(start, Tensor):
  702. start = start.astype("float32")
  703. if isinstance(end, Tensor):
  704. end = end.astype("float32")
  705. if isinstance(step, Tensor):
  706. step = step.astype("float32")
  707. num = ceil(Tensor((end - start) / step, device=device))
  708. stop = start + step * (num - 1)
  709. result = linspace(start, stop, num, device=device)
  710. if np.dtype(dtype) == np.int32:
  711. return result.astype(dtype)
  712. return result
  713. def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
  714. r"""
  715. Returns split Tensor to Tensor list as offsets and shapes described,
  716. only used for parampack.
  717. :param inp: Input tensor
  718. :param offsets: offsets of outputs, length of 2 * n,
  719. while n is tensor nums you want to split,
  720. format [begin0, end0, begin1, end1].
  721. :param shapes: tensor shapes of outputs
  722. :return: split tensors
  723. Examples:
  724. .. testcode::
  725. import numpy as np
  726. import megengine.functional as F
  727. from megengine import tensor
  728. a = tensor(np.ones((10,), np.int32))
  729. b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
  730. print(b.numpy())
  731. print(c.numpy())
  732. Outputs:
  733. .. testoutput::
  734. [1]
  735. [[1 1 1]
  736. [1 1 1]
  737. [1 1 1]]
  738. """
  739. op = builtin.ParamPackSplit()
  740. op.offsets = offsets
  741. op.shapes = shapes
  742. return apply(op, inp)
  743. def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
  744. r"""
  745. Returns concat Tensor, only used for parampack.
  746. :param inps: Input tensors
  747. :param offsets: device value of offsets
  748. :param offsets_val: offsets of inputs, length of 2 * n,
  749. format [begin0, end0, begin1, end1].
  750. :return: concat tensors
  751. Examples:
  752. .. testcode::
  753. import numpy as np
  754. import megengine.functional as F
  755. from megengine import tensor
  756. a = tensor(np.ones((1,), np.int32))
  757. b = tensor(np.ones((3, 3), np.int32))
  758. offsets_val = [0, 1, 1, 10]
  759. offsets = tensor(offsets_val, np.int32)
  760. c = F.param_pack_concat([a, b], offsets, offsets_val)
  761. print(c.numpy())
  762. Outputs:
  763. .. testoutput::
  764. [1 1 1 1 1 1 1 1 1 1]
  765. """
  766. op = builtin.ParamPackConcat()
  767. op.offsets = offsets_val
  768. return apply(op, *inps, offsets)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台