You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 35 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. from functools import lru_cache
  12. from typing import Iterable, Optional, Sequence, Tuple, Union
  13. from ..core import _config
  14. from ..core._imperative_rt.core2 import apply, dtype_promotion
  15. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  16. from ..core._trace_option import use_symbolic_shape
  17. from ..core.ops import builtin
  18. from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
  19. from ..core.ops.special import Const
  20. from ..core.tensor import amp
  21. from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
  22. from ..jit import exclude_from_trace
  23. from ..tensor import Tensor
  24. from .debug_param import get_execution_strategy
  25. from .elemwise import clip, minimum
  26. from .tensor import broadcast_to, concat, expand_dims, squeeze
  27. __all__ = [
  28. "argmax",
  29. "argmin",
  30. "argsort",
  31. "dot",
  32. "isinf",
  33. "isnan",
  34. "matinv",
  35. "matmul",
  36. "max",
  37. "mean",
  38. "min",
  39. "norm",
  40. "normalize",
  41. "prod",
  42. "sign",
  43. "sort",
  44. "std",
  45. "sum",
  46. "svd",
  47. "topk",
  48. "var",
  49. ]
  50. def isnan(inp: Tensor) -> Tensor:
  51. r"""Returns a new tensor representing if each element is ``NaN`` or not.
  52. Args:
  53. inp: input tensor.
  54. Returns:
  55. result tensor.
  56. Examples:
  57. .. testcode::
  58. from megengine import tensor
  59. import megengine.functional as F
  60. x = tensor([1, float("nan"), 0])
  61. print(F.isnan(x).numpy())
  62. Outputs:
  63. .. testoutput::
  64. [False True False]
  65. """
  66. return inp != inp
  67. def isinf(inp: Tensor) -> Tensor:
  68. r"""Returns a new tensor representing if each element is ``Inf`` or not.
  69. Args:
  70. inp: input tensor.
  71. Returns:
  72. result tensor.
  73. Examples:
  74. .. testcode::
  75. from megengine import tensor
  76. import megengine.functional as F
  77. x = tensor([1, float("inf"), 0])
  78. print(F.isinf(x).numpy())
  79. Outputs:
  80. .. testoutput::
  81. [False True False]
  82. """
  83. return abs(inp).astype("float32") == float("inf")
  84. def sign(inp: Tensor):
  85. r"""Returns a new tensor representing the sign of each element in input tensor.
  86. Args:
  87. inp: Tensor:
  88. Returns:
  89. the sign of input tensor.
  90. Examples:
  91. .. testcode::
  92. from megengine import tensor
  93. import megengine.functional as F
  94. x = tensor([1, -1, 0])
  95. print(F.sign(x).numpy())
  96. Outputs:
  97. .. testoutput::
  98. [ 1 -1 0]
  99. """
  100. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  101. def sum(
  102. inp: Tensor,
  103. axis: Optional[Union[int, Sequence[int]]] = None,
  104. keepdims: bool = False,
  105. ) -> Tensor:
  106. r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  107. reduce over all of them.
  108. Args:
  109. inp: input tensor.
  110. axis: dimension to reduce. If None, all dimensions will be reduced.
  111. Default: None
  112. keepdims: whether the output tensor has axis retained or not.
  113. Default: False
  114. Returns:
  115. output tensor.
  116. Examples:
  117. .. testcode::
  118. import numpy as np
  119. from megengine import tensor
  120. import megengine.functional as F
  121. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  122. out = F.sum(x)
  123. print(out.numpy())
  124. Outputs:
  125. .. testoutput::
  126. 21
  127. """
  128. return inp.sum(axis=axis, keepdims=keepdims)
  129. def prod(
  130. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  131. ) -> Tensor:
  132. r"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
  133. reduce over all of them.
  134. Args:
  135. inp: input tensor.
  136. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  137. keepdims: whether the output tensor has axis retained or not. Default: False
  138. Returns:
  139. output tensor.
  140. Examples:
  141. .. testcode::
  142. import numpy as np
  143. from megengine import tensor
  144. import megengine.functional as F
  145. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  146. out = F.prod(x)
  147. print(out.numpy())
  148. Outputs:
  149. .. testoutput::
  150. 720
  151. """
  152. return inp.prod(axis=axis, keepdims=keepdims)
  153. def mean(
  154. inp: Tensor,
  155. axis: Optional[Union[int, Sequence[int]]] = None,
  156. keepdims: bool = False,
  157. ) -> Tensor:
  158. r"""Returns the mean value of input tensor along
  159. given axis. If axis is a list of dimensions,
  160. reduce over all of them.
  161. Args:
  162. inp: input tensor.
  163. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  164. keepdims: whether the output tensor has axis retained or not. Default: False
  165. Returns:
  166. output tensor.
  167. Examples:
  168. .. testcode::
  169. import numpy as np
  170. from megengine import tensor
  171. import megengine.functional as F
  172. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  173. out = F.mean(x)
  174. print(out.numpy())
  175. Outputs:
  176. .. testoutput::
  177. 3.5
  178. """
  179. return inp.mean(axis=axis, keepdims=keepdims)
  180. def var(
  181. inp: Tensor,
  182. axis: Optional[Union[int, Sequence[int]]] = None,
  183. keepdims: bool = False,
  184. ) -> Tensor:
  185. r"""Returns the variance value of input tensor along
  186. given axis. If axis is a list of dimensions,
  187. reduce over all of them.
  188. Args:
  189. inp: input tensor.
  190. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  191. keepdims: whether the output tensor has axis retained or not. Default: False
  192. Returns:
  193. output tensor.
  194. Examples:
  195. .. testcode::
  196. import numpy as np
  197. from megengine import tensor
  198. import megengine.functional as F
  199. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  200. out = F.var(data)
  201. print(out.numpy().round(decimals=4))
  202. Outputs:
  203. .. testoutput::
  204. 2.9167
  205. """
  206. if axis is None:
  207. m = mean(inp, axis=axis, keepdims=False)
  208. else:
  209. m = mean(inp, axis=axis, keepdims=True)
  210. v = inp - m
  211. return mean(v ** 2, axis=axis, keepdims=keepdims)
  212. def std(
  213. inp: Tensor,
  214. axis: Optional[Union[int, Sequence[int]]] = None,
  215. keepdims: bool = False,
  216. ) -> Tensor:
  217. r"""Returns the standard deviation of input tensor along
  218. given axis. If axis is a list of dimensions,
  219. reduce over all of them.
  220. Args:
  221. inp: input tensor.
  222. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  223. keepdims: whether the output tensor has axis retained or not. Default: False
  224. Returns:
  225. output tensor.
  226. Examples:
  227. .. testcode::
  228. import numpy as np
  229. from megengine import tensor
  230. import megengine.functional as F
  231. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  232. out = F.std(data, axis=1)
  233. print(out.numpy().round(decimals=4))
  234. Outputs:
  235. .. testoutput::
  236. [0.8165 0.8165]
  237. """
  238. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  239. def min(
  240. inp: Tensor,
  241. axis: Optional[Union[int, Sequence[int]]] = None,
  242. keepdims: bool = False,
  243. ) -> Tensor:
  244. r"""Returns the min value of input tensor along
  245. given axis. If axis is a list of dimensions,
  246. reduce over all of them.
  247. Args:
  248. inp: input tensor.
  249. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  250. keepdims: whether the output tensor has axis retained or not. Default: False
  251. Returns:
  252. output tensor.
  253. Examples:
  254. .. testcode::
  255. import numpy as np
  256. from megengine import tensor
  257. import megengine.functional as F
  258. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  259. out = F.min(x)
  260. print(out.numpy())
  261. Outputs:
  262. .. testoutput::
  263. 1
  264. """
  265. return inp.min(axis=axis, keepdims=keepdims)
  266. def max(
  267. inp: Tensor,
  268. axis: Optional[Union[int, Sequence[int]]] = None,
  269. keepdims: bool = False,
  270. ) -> Tensor:
  271. r"""Returns the max value of the input tensor along
  272. given axis. If axis is a list of dimensions,
  273. reduce over all of them.
  274. Args:
  275. inp: input tensor.
  276. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  277. keepdims: whether the output tensor has axis retained or not. Default: False
  278. Returns:
  279. output tensor.
  280. Examples:
  281. .. testcode::
  282. import numpy as np
  283. from megengine import tensor
  284. import megengine.functional as F
  285. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  286. out = F.max(x)
  287. print(out.numpy())
  288. Outputs:
  289. .. testoutput::
  290. 6
  291. """
  292. return inp.max(axis=axis, keepdims=keepdims)
  293. def norm(
  294. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  295. ):
  296. r"""Calculates ``p``-norm of input tensor along
  297. given axis.
  298. Args:
  299. inp: input tensor.
  300. ord: power of value applied to inp. Default: 2
  301. axis: dimension to reduce. If None, input must be a vector. Default: None
  302. keepdims: whether the output tensor has axis retained or not. Default: False
  303. Returns:
  304. output tensor.
  305. Examples:
  306. .. testcode::
  307. import numpy as np
  308. from megengine import tensor
  309. import megengine.functional as F
  310. x = tensor(np.arange(-3, 3, dtype=np.float32))
  311. out = F.norm(x)
  312. print(out.numpy().round(decimals=4))
  313. Outputs:
  314. .. testoutput::
  315. 4.3589
  316. """
  317. if axis is None:
  318. if inp.ndim != 1:
  319. raise TypeError("axis is required unless input is a vector")
  320. if ord is None:
  321. ord = 2
  322. if ord == 0:
  323. return sum(inp != 0, axis=axis, keepdims=keepdims)
  324. if ord == math.inf:
  325. return max(abs(inp))
  326. if ord == -math.inf:
  327. return min(abs(inp))
  328. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  329. def argmin(
  330. inp: Tensor,
  331. axis: Optional[Union[int, Sequence[int]]] = None,
  332. keepdims: bool = False,
  333. ) -> Tensor:
  334. r"""Returns the indices of the minimum values along
  335. given axis. If axis is a list of dimensions,
  336. reduce over all of them.
  337. Args:
  338. inp: input tensor.
  339. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  340. keepdims: whether the output tensor has axis retained or not. Default: False
  341. Returns:
  342. output tensor.
  343. Examples:
  344. .. testcode::
  345. import numpy as np
  346. from megengine import tensor
  347. import megengine.functional as F
  348. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  349. out = F.argmin(x)
  350. print(out.numpy())
  351. Outputs:
  352. .. testoutput::
  353. 0
  354. """
  355. if axis is None:
  356. assert not keepdims, "can not set axis=None and keepdims=True"
  357. inp = inp.flatten()
  358. axis = 0
  359. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  360. if isinstance(axis, collections.abc.Iterable):
  361. for ai in axis:
  362. op = builtin.Argmin(axis=ai)
  363. (inp,) = apply(op, inp)
  364. if not keepdims:
  365. inp = squeeze(inp, ai)
  366. return inp
  367. op = builtin.Argmin(axis=axis)
  368. (result,) = apply(op, inp)
  369. if not keepdims:
  370. result = squeeze(result, axis)
  371. return result
  372. def argmax(
  373. inp: Tensor,
  374. axis: Optional[Union[int, Sequence[int]]] = None,
  375. keepdims: bool = False,
  376. ) -> Tensor:
  377. r"""Returns the indices of the maximum values along
  378. given axis. If axis is a list of dimensions,
  379. reduce over all of them.
  380. Args:
  381. inp: input tensor.
  382. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  383. keepdims: whether the output tensor has axis retained or not. Default: False
  384. Returns:
  385. output tensor.
  386. Examples:
  387. .. testcode::
  388. import numpy as np
  389. from megengine import tensor
  390. import megengine.functional as F
  391. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  392. out = F.argmax(x)
  393. print(out.numpy())
  394. Outputs:
  395. .. testoutput::
  396. 5
  397. """
  398. if axis is None:
  399. assert not keepdims, "can not set axis=None and keepdims=True"
  400. inp = inp.flatten()
  401. axis = 0
  402. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  403. if isinstance(axis, collections.abc.Iterable):
  404. for ai in axis:
  405. op = builtin.Argmax(axis=ai)
  406. (inp,) = apply(op, inp)
  407. if not keepdims:
  408. inp = squeeze(inp, ai)
  409. return inp
  410. op = builtin.Argmax(axis=axis)
  411. (result,) = apply(op, inp)
  412. if not keepdims:
  413. result = squeeze(result, axis)
  414. return result
  415. def normalize(
  416. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  417. ) -> Tensor:
  418. r"""Performs :math:`L_p` normalization of input tensor along
  419. given axis.
  420. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  421. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  422. .. math::
  423. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  424. Args:
  425. inp: input tensor.
  426. ord: power of value applied to input tensor. Default: 2
  427. axis: dimension to reduce.If None, input must be a vector. Default: None
  428. eps: a small value to avoid division by zero. Default: 1e-12
  429. Returns:
  430. normalized output tensor.
  431. """
  432. if axis is None:
  433. return inp / clip(norm(inp, ord, axis), lower=eps)
  434. else:
  435. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  436. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  437. r"""Returns the indices that would sort the input tensor.
  438. Args:
  439. inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  440. descending: sort in descending order, where the largest comes first. Default: False
  441. inp: Tensor:
  442. descending: bool:
  443. Returns:
  444. indices of int32 indicates how to sort the input.
  445. Examples:
  446. .. testcode::
  447. import numpy as np
  448. from megengine import tensor
  449. import megengine.functional as F
  450. x = tensor(np.array([1,2], dtype=np.float32))
  451. indices = F.argsort(x)
  452. print(indices.numpy())
  453. Outputs:
  454. .. testoutput::
  455. [0 1]
  456. """
  457. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  458. if descending:
  459. order = "descending"
  460. else:
  461. order = "ascending"
  462. op = builtin.Argsort(order=order)
  463. if len(inp.shape) == 1:
  464. inp = inp.reshape(1, -1)
  465. _, result = apply(op, inp)
  466. return result[0]
  467. _, result = apply(op, inp)
  468. return result
  469. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  470. r"""Returns sorted tensor and the indices would sort the input tensor.
  471. Args:
  472. inp: input tensor. If it's 2d, the result would be sorted by row.
  473. descending: sort in descending order, where the largest comes first. Default: False
  474. Returns:
  475. tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  476. Examples:
  477. .. testcode::
  478. import numpy as np
  479. from megengine import tensor
  480. import megengine.functional as F
  481. x = tensor(np.array([1,2], dtype=np.float32))
  482. out, indices = F.sort(x)
  483. print(out.numpy())
  484. Outputs:
  485. .. testoutput::
  486. [1. 2.]
  487. """
  488. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  489. if descending:
  490. order = "descending"
  491. else:
  492. order = "ascending"
  493. op = builtin.Argsort(order=order)
  494. if len(inp.shape) == 1:
  495. inp = inp.reshape(1, -1)
  496. tns, ind = apply(op, inp)
  497. return tns[0], ind[0]
  498. tns, ind = apply(op, inp)
  499. return tns, ind
  500. def topk(
  501. inp: Tensor,
  502. k: int,
  503. descending: bool = False,
  504. kth_only: bool = False,
  505. no_sort: bool = False,
  506. ) -> Tuple[Tensor, Tensor]:
  507. r"""Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  508. Args:
  509. inp: input tensor. If input tensor is 2d, each row will be sorted.
  510. k: number of elements needed.
  511. descending: if True, return the largest elements instead. Default: False
  512. kth_only: if True, only the k-th element will be returned. Default: False
  513. no_sort: if True, the returned elements can be unordered. Default: False
  514. Returns:
  515. tuple of two tensors ``(topk_tensor, indices_of_int32)``
  516. Examples:
  517. .. testcode::
  518. import numpy as np
  519. from megengine import tensor
  520. import megengine.functional as F
  521. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  522. top, indices = F.topk(x, 5)
  523. print(top.numpy(), indices.numpy())
  524. Outputs:
  525. .. testoutput::
  526. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  527. """
  528. if descending:
  529. k = -k
  530. if kth_only:
  531. mode = "kth_only"
  532. elif no_sort:
  533. mode = "value_idx_nosort"
  534. else:
  535. mode = "value_idx_sorted"
  536. op = builtin.TopK(mode=mode)
  537. if not isinstance(k, Tensor):
  538. (k,) = Const(k, dtype="int32", device=inp.device)()
  539. if len(inp.shape) == 1:
  540. if kth_only:
  541. (tns,) = apply(op, expand_dims(inp, 0), k)
  542. # FIXME:
  543. # could use a dedicated kernel
  544. # gradient may be routed to other indices if k-th value is not unique
  545. ind = argmax((tns == inp).astype("int8"))
  546. tns = squeeze(tns, 0)
  547. else:
  548. tns, ind = apply(op, expand_dims(inp, 0), k)
  549. tns = squeeze(tns, 0)
  550. ind = squeeze(ind, 0)
  551. else:
  552. if kth_only:
  553. (tns,) = apply(op, inp, k)
  554. # FIXME: same as above
  555. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  556. else:
  557. tns, ind = apply(op, inp, k)
  558. return tns, ind
  559. def matinv(inp: Tensor) -> Tensor:
  560. r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  561. Args:
  562. inp: input tensor.
  563. Returns:
  564. output tensor.
  565. Examples:
  566. .. testcode::
  567. import numpy as np
  568. from megengine import tensor
  569. import megengine.functional as F
  570. data = tensor([[1.0, 0.0], [1.0, 1.0]])
  571. out = F.matinv(data)
  572. print(out.numpy())
  573. Outputs:
  574. .. testoutput::
  575. [[ 1. 0.]
  576. [-1. 1.]]
  577. """
  578. (result,) = apply(builtin.MatrixInverse(), inp)
  579. return result
  580. class _Hashable:
  581. def __init__(self, value) -> None:
  582. self.value = value
  583. def __hash__(self) -> int:
  584. return hash(str(self.value))
  585. def __eq__(self, o: object) -> bool:
  586. if not isinstance(o, _Hashable):
  587. return False
  588. return self.value == o.value
  589. @lru_cache(maxsize=None)
  590. def _get_extentedMatrixMulOp(
  591. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  592. ):
  593. @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
  594. def extentedMatrixMulOp(inputs, f, c):
  595. assert len(inputs) == 2
  596. inp1, inp2 = inputs
  597. _dim1, _dim2 = dim1, dim2
  598. def build_shape_head(shape, idx=-1):
  599. # shape[:idx]
  600. return f(
  601. builtin.Subtensor(items=[[0, False, True, False, False]]),
  602. shape,
  603. c(idx, "int32"),
  604. )
  605. def build_shape_tail(shape, idx=-1):
  606. # shape[idx:]
  607. return f(
  608. builtin.Subtensor(items=[[0, True, False, False, False]]),
  609. shape,
  610. c(idx, "int32"),
  611. )
  612. remove_row, remove_col = False, False
  613. if _dim1 == 1:
  614. _dim1 = 2
  615. remove_row = True
  616. if _dim2 == 1:
  617. _dim2 = 2
  618. remove_col = True
  619. if remove_row:
  620. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  621. if remove_col:
  622. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  623. shape1 = f(GetVarShape(), inp1)
  624. shape2 = f(GetVarShape(), inp2)
  625. if _dim1 > 2:
  626. inp1 = f(
  627. builtin.Reshape(),
  628. inp1,
  629. f(
  630. builtin.Concat(axis=0, comp_node=device),
  631. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
  632. build_shape_tail(shape1),
  633. ),
  634. )
  635. if _dim2 > 2:
  636. inp2 = f(
  637. builtin.Reshape(),
  638. inp2,
  639. f(
  640. builtin.Concat(axis=0, comp_node=device),
  641. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
  642. build_shape_tail(shape2),
  643. ),
  644. )
  645. op = builtin.MatrixMul(
  646. transposeA=transpose_a,
  647. transposeB=transpose_b,
  648. compute_mode=compute_mode,
  649. format=format,
  650. strategy=strategy.value,
  651. )
  652. result = f(op, inp1, inp2)
  653. result_shape = f(GetVarShape(), result)
  654. if _dim1 > 2:
  655. result = f(
  656. builtin.Reshape(),
  657. result,
  658. f(
  659. builtin.Concat(axis=0, comp_node=device),
  660. build_shape_head(shape1),
  661. build_shape_tail(result_shape),
  662. ),
  663. )
  664. if _dim2 > 2:
  665. result = f(
  666. builtin.Reshape(),
  667. result,
  668. f(
  669. builtin.Concat(axis=0, comp_node=device),
  670. build_shape_head(shape2),
  671. build_shape_tail(result_shape),
  672. ),
  673. )
  674. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  675. if remove_row:
  676. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  677. if remove_col:
  678. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  679. return (result,), (True,)
  680. return extentedMatrixMulOp
  681. @lru_cache(maxsize=None)
  682. def _get_extentedBatchedMatrixMulOp(
  683. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  684. ):
  685. @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
  686. def extentedBatchedMatrixMulOp(inputs, f, c):
  687. assert len(inputs) == 2
  688. inp1, inp2 = inputs
  689. _dim1, _dim2 = dim1, dim2
  690. def build_shape_head(shape, idx=-2):
  691. # shape[:idx]
  692. return f(
  693. builtin.Subtensor(items=[[0, False, True, False, False]]),
  694. shape,
  695. c(idx, "int32"),
  696. )
  697. def build_shape_tail(shape, idx=-2):
  698. # shape[idx:]
  699. return f(
  700. builtin.Subtensor(items=[[0, True, False, False, False]]),
  701. shape,
  702. c(idx, "int32"),
  703. )
  704. remove_row, remove_col = False, False
  705. if _dim1 == 1:
  706. _dim1 = 2
  707. remove_row = True
  708. if _dim2 == 1:
  709. _dim2 = 2
  710. remove_col = True
  711. if remove_row:
  712. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  713. if remove_col:
  714. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  715. shape1 = f(GetVarShape(), inp1)
  716. shape2 = f(GetVarShape(), inp2)
  717. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  718. if _dim1 > _dim2:
  719. # broadcast
  720. shape2 = f(
  721. builtin.Concat(axis=0, comp_node=device),
  722. build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
  723. shape2,
  724. )
  725. inp2 = f(builtin.Broadcast(), inp2, shape2)
  726. batch_shape = build_shape_head(shape1)
  727. if _dim2 > _dim1:
  728. # broadcast
  729. shape1 = f(
  730. builtin.Concat(axis=0, comp_node=device),
  731. build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
  732. shape1,
  733. )
  734. inp1 = f(builtin.Broadcast(), inp1, shape1)
  735. batch_shape = build_shape_head(shape2)
  736. if _dim1 == _dim2:
  737. batch_shape = build_shape_head(shape1)
  738. # compress inputs to 3d
  739. if maxdim > 3:
  740. inp1 = f(
  741. builtin.Reshape(),
  742. inp1,
  743. f(
  744. builtin.Concat(axis=0, comp_node=device),
  745. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  746. build_shape_tail(shape1),
  747. ),
  748. )
  749. inp2 = f(
  750. builtin.Reshape(),
  751. inp2,
  752. f(
  753. builtin.Concat(axis=0, comp_node=device),
  754. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  755. build_shape_tail(shape2),
  756. ),
  757. )
  758. op = builtin.BatchedMatrixMul(
  759. transposeA=transpose_a,
  760. transposeB=transpose_b,
  761. compute_mode=compute_mode,
  762. format=format,
  763. strategy=strategy.value,
  764. )
  765. result = f(op, inp1, inp2)
  766. if maxdim > 3:
  767. result = f(
  768. builtin.Reshape(),
  769. result,
  770. f(
  771. builtin.Concat(axis=0, comp_node=device),
  772. batch_shape,
  773. build_shape_tail(f(GetVarShape(), result)),
  774. ),
  775. )
  776. if remove_row:
  777. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  778. if remove_col:
  779. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  780. return (result,), (True,)
  781. return extentedBatchedMatrixMulOp
  782. def matmul(
  783. inp1: Tensor,
  784. inp2: Tensor,
  785. transpose_a=False,
  786. transpose_b=False,
  787. compute_mode="default",
  788. format="default",
  789. ) -> Tensor:
  790. r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  791. With different inputs dim, this function behaves differently:
  792. * Both 1-D tensor, simply forward to ``dot``.
  793. * Both 2-D tensor, normal matrix multiplication.
  794. * If one input tensor is 1-D, matrix vector multiplication.
  795. * If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  796. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  797. For example:
  798. * inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  799. * inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  800. * inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  801. Args:
  802. inp1: first matrix to be multiplied.
  803. inp2: second matrix to be multiplied.
  804. Returns:
  805. output tensor.
  806. Examples:
  807. .. testcode::
  808. import numpy as np
  809. from megengine import tensor
  810. import megengine.functional as F
  811. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  812. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  813. out = F.matmul(data1, data2)
  814. print(out.numpy())
  815. Outputs:
  816. .. testoutput::
  817. [[10. 13.]
  818. [28. 40.]]
  819. """
  820. if amp._enabled:
  821. compute_mode = "float32"
  822. inp1, inp2 = cast_tensors(inp1, inp2)
  823. else:
  824. dtype = dtype_promotion(inp1, inp2)
  825. if inp1.dtype != dtype:
  826. inp1 = inp1.astype(dtype)
  827. if inp2.dtype != dtype:
  828. inp2 = inp2.astype(dtype)
  829. dim1, dim2 = inp1.ndim, inp2.ndim
  830. assert dim1 > 0 and dim2 > 0
  831. maxdim = dim1 if dim1 > dim2 else dim2
  832. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  833. if dim1 == 1 and dim2 == 1: # dispatch to Dot
  834. return dot(inp1, inp2)
  835. elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
  836. extentedMatrixMulOp = _get_extentedMatrixMulOp(
  837. inp1.device,
  838. inp1.dtype,
  839. dim1,
  840. dim2,
  841. transpose_a,
  842. transpose_b,
  843. compute_mode,
  844. format,
  845. strategy=_Hashable(get_execution_strategy()),
  846. )
  847. (result,) = apply(extentedMatrixMulOp(), inp1, inp2)
  848. return result
  849. else: # dispath to BatchedMatrixMul
  850. extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
  851. inp1.device,
  852. inp1.dtype,
  853. dim1,
  854. dim2,
  855. transpose_a,
  856. transpose_b,
  857. compute_mode,
  858. format,
  859. strategy=_Hashable(get_execution_strategy()),
  860. )
  861. (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
  862. return result
  863. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  864. r"""Computes dot-product of two vectors ``inp1`` and ``inp2``.
  865. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  866. Refer to :func:`~.matmul` for more general usage.
  867. Args:
  868. inp1: first vector.
  869. inp2: second vector.
  870. Returns:
  871. output value.
  872. Examples:
  873. .. testcode::
  874. import numpy as np
  875. from megengine import tensor
  876. import megengine.functional as F
  877. data1 = tensor(np.arange(0, 6, dtype=np.float32))
  878. data2 = tensor(np.arange(0, 6, dtype=np.float32))
  879. out = F.dot(data1, data2)
  880. print(out.numpy())
  881. Outputs:
  882. .. testoutput::
  883. 55.
  884. """
  885. op = builtin.Dot()
  886. assert (
  887. inp1.ndim <= 1 and inp2.ndim <= 1
  888. ), "Input tensors for dot must be 1-dimensional or scalar"
  889. (result,) = apply(op, inp1, inp2)
  890. setscalar(result)
  891. return result
  892. def svd(x: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  893. r"""Returns a singular value decomposition ``A = USVh`` of a matrix (or a stack of matrices) ``x`` , where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows.
  894. Args:
  895. x (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``x.ndim >= 2`` .
  896. full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)`` and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` .
  897. compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` .
  898. Returns:
  899. Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``x``. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True).
  900. ``U`` contains matrices orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` .
  901. ``S`` contains the vector(s) of singular values of length ``K`` , where ``K = min(M, N)`` . For each vector, the singular values must be sorted in descending order by magnitude, such that ``s[..., 0]`` is the largest value, ``s[..., 1]`` is the second largest value, etc. The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x`` .
  902. ``Vh`` contains orthonormal rows (i.e., the rows are the right singular vectors and the array is the adjoint). If ``full_matrices`` is ``True`` , the array must have shape ``(..., N, N)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., K, N)`` where ``K = min(M, N)`` . The first ``x.ndim-2`` dimensions must have the same shape as those of the input ``x`` .
  903. Each returned array must have the same floating-point data type as ``x`` .
  904. Examples:
  905. >>> import numpy as np
  906. >>> x = Tensor(np.random.randn(9, 6))
  907. >>> y = Tensor(np.random.randn(2, 7, 8, 3))
  908. Reconstruction based on full SVD, 2D case:
  909. >>> U, S, Vh = F.svd(x, full_matrices=True)
  910. >>> U.shape, S.shape, Vh.shape
  911. ((9, 9), (6,), (6, 6))
  912. Reconstruction based on reduced SVD, 2D case:
  913. >>> U, S, Vh = F.svd(x, full_matrices=False)
  914. >>> U.shape, S.shape, Vh.shape
  915. ((9, 6), (6,), (6, 6))
  916. Reconsturction based on full SVD, 4D case:
  917. >>> u, s, vh = F.svd(y, full_matrices=True)
  918. >>> u.shape, s.shape, vh.shape
  919. ((2, 7, 8, 8), (2, 7, 3), (2, 7, 3, 3))
  920. Reconsturction based on reduced SVD, 4D case:
  921. >>> u, s, vh = F.svd(y, full_matrices=False)
  922. >>> u.shape, s.shape, vh.shape
  923. ((2, 7, 8, 3), (2, 7, 3), (2, 7, 3, 3))
  924. """
  925. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  926. U, S, Vh = apply(op, x)
  927. return U, S, Vh
  928. def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
  929. r"""Check whether input contains infinite or nan value.
  930. Args:
  931. inp: a tensor to be checked.
  932. Returns:
  933. a int32 scalar tensor, 0 for False and 1 for True.
  934. """
  935. op = builtin.CheckNonFinite(scale=scale)
  936. oups = apply(op, *inps)
  937. out = oups[-1]
  938. for i in range(len(inps)):
  939. inps[i]._reset(oups[i])
  940. out._setscalar()
  941. return out