You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 41 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """math"""
  16. import numpy as np
  17. from mindspore.ops import operations as P
  18. from mindspore.common.tensor import Tensor
  19. from mindspore.common._decorator import deprecated
  20. from mindspore.ops.primitive import constexpr
  21. from mindspore.ops import functional as F
  22. from ..cell import Cell
  23. from ...common import dtype as mstype
  24. from ..._checkparam import Validator as validator
  25. __all__ = ['ReduceLogSumExp',
  26. 'Range',
  27. 'LGamma',
  28. 'DiGamma',
  29. 'IGamma',
  30. 'LBeta',
  31. 'MatMul',
  32. 'Moments',
  33. 'MatInverse',
  34. 'MatDet',
  35. ]
  36. _BASE_LANCZOS_COEFF = 0.99999999999980993227684700473478
  37. _LANCZOS_COEFFICIENTS = [676.520368121885098567009190444019,
  38. -1259.13921672240287047156078755283,
  39. 771.3234287776530788486528258894,
  40. -176.61502916214059906584551354,
  41. 12.507343278686904814458936853,
  42. -0.13857109526572011689554707,
  43. 9.984369578019570859563e-6,
  44. 1.50563273514931155834e-7]
  45. @constexpr
  46. def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
  47. validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
  48. class ReduceLogSumExp(Cell):
  49. r"""
  50. Reduces a dimension of a tensor by calculating exponential for all elements in the dimension,
  51. then calculate logarithm of the sum.
  52. .. math::
  53. ReduceLogSumExp(x) = \log(\sum(e^x))
  54. Args:
  55. axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
  56. Only constant value is allowed.
  57. keep_dims (bool): If True, keep these reduced dimensions and the length is 1.
  58. If False, don't keep these dimensions.
  59. Default : False.
  60. Inputs:
  61. - **x** (Tensor) - The input tensor. With float16 or float32 data type.
  62. Outputs:
  63. Tensor, has the same dtype as the `x`.
  64. - If axis is (), and keep_dims is False,
  65. the output is a 0-D tensor representing the sum of all elements in the input tensor.
  66. - If axis is int, set as 2, and keep_dims is False,
  67. the shape of output is :math:`(x_1, x_3, ..., x_R)`.
  68. - If axis is tuple(int), set as (2, 3), and keep_dims is False,
  69. the shape of output is :math:`(x_1, x_4, ..., x_R)`.
  70. Raises:
  71. TypeError: If `axis` is not one of int, list, tuple.
  72. TypeError: If `keep_dims` is not bool.
  73. TypeError: If dtype of `x` is neither float16 nor float32.
  74. Supported Platforms:
  75. ``Ascend`` ``GPU`` ``CPU``
  76. Examples:
  77. >>> x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  78. >>> op = nn.ReduceLogSumExp(1, keep_dims=True)
  79. >>> output = op(x)
  80. >>> print(output.shape)
  81. (3, 1, 5, 6)
  82. """
  83. def __init__(self, axis, keep_dims=False):
  84. """Initialize ReduceLogSumExp."""
  85. super(ReduceLogSumExp, self).__init__()
  86. validator.check_value_type('axis', axis, [int, list, tuple], self.cls_name)
  87. validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
  88. self.axis = axis
  89. self.exp = P.Exp()
  90. self.sum = P.ReduceSum(keep_dims)
  91. self.log = P.Log()
  92. def construct(self, x):
  93. exp = self.exp(x)
  94. sumexp = self.sum(exp, self.axis)
  95. logsumexp = self.log(sumexp)
  96. return logsumexp
  97. class Range(Cell):
  98. r"""
  99. Creates a sequence of numbers in range [start, limit) with step size delta.
  100. The size of output is :math:`\left \lfloor \frac{limit-start}{delta} \right \rfloor + 1` and `delta` is the gap
  101. between two values in the tensor.
  102. .. math::
  103. out_{i+1} = out_{i} +delta
  104. Args:
  105. start (Union[int, float]): If `limit` is `None`, the value acts as limit in the range and first entry
  106. defaults to `0`. Otherwise, it acts as first entry in the range.
  107. limit (Union[int, float]): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
  108. while set the first entry of the range to `0`. It can not be equal to `start`. Default: None.
  109. delta (Union[int, float]): Increment of the range. It can not be equal to zero. Default: 1.
  110. Outputs:
  111. Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.
  112. Supported Platforms:
  113. ``Ascend`` ``GPU`` ``CPU``
  114. Examples:
  115. >>> net = nn.Range(1, 8, 2)
  116. >>> output = net()
  117. >>> print(output)
  118. [1 3 5 7]
  119. """
  120. def __init__(self, start, limit=None, delta=1):
  121. """Initialize Range."""
  122. super(Range, self).__init__()
  123. if delta == 0:
  124. raise ValueError(f"For '{self.cls_name}', the 'delta' can not be zero.")
  125. data = np.arange(start, limit, delta)
  126. if data.dtype == np.float:
  127. self.ms_dtype = mstype.float32
  128. else:
  129. self.ms_dtype = mstype.int32
  130. self.result_tensor = Tensor(data, dtype=self.ms_dtype)
  131. def construct(self):
  132. return self.result_tensor
  133. class LGamma(Cell):
  134. r"""
  135. Calculates LGamma using Lanczos' approximation referring to "A Precision Approximation of the Gamma Function".
  136. The algorithm is:
  137. .. math::
  138. \begin{array}{ll} \\
  139. lgamma(z + 1) = \frac{(\log(2) + \log(pi))}{2} + (z + 1/2) * log(t(z)) - t(z) + A(z) \\
  140. t(z) = z + kLanczosGamma + 1/2 \\
  141. A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k}
  142. \end{array}
  143. However, if the input is less than 0.5 use Euler's reflection formula:
  144. .. math::
  145. lgamma(x) = \log(pi) - lgamma(1-x) - \log(abs(sin(pi * x)))
  146. And please note that
  147. .. math::
  148. lgamma(+/-inf) = +inf
  149. Thus, the behaviour of LGamma follows:
  150. - when x > 0.5, return log(Gamma(x))
  151. - when x < 0.5 and is not an integer, return the real part of Log(Gamma(x)) where Log is the complex logarithm
  152. - when x is an integer less or equal to 0, return +inf
  153. - when x = +/- inf, return +inf
  154. Inputs:
  155. - **x** (Tensor) - The input tensor. Only float16, float32 are supported.
  156. Outputs:
  157. Tensor, has the same shape and dtype as the `x`.
  158. Raises:
  159. TypeError: If dtype of `x` is neither float16 nor float32.
  160. Supported Platforms:
  161. ``Ascend`` ``GPU``
  162. Examples:
  163. >>> x = Tensor(np.array([2, 3, 4]).astype(np.float32))
  164. >>> op = nn.LGamma()
  165. >>> output = op(x)
  166. >>> print(output)
  167. [3.5762787e-07 6.9314754e-01 1.7917603e+00]
  168. """
  169. def __init__(self):
  170. """Initialize LGamma."""
  171. super(LGamma, self).__init__()
  172. # const numbers
  173. self.k_lanczos_gamma = 7
  174. self.k_base_lanczos_coeff = _BASE_LANCZOS_COEFF
  175. self.k_lanczos_coefficients = _LANCZOS_COEFFICIENTS
  176. self.one_half = 0.5
  177. self.one = 1
  178. self.two = 2
  179. self.inf = np.inf
  180. self.pi = np.pi
  181. self.log_2 = np.log(self.two)
  182. self.log_pi = np.log(np.pi)
  183. self.log_sqrt_two_pi = (self.log_2 + self.log_pi) / self.two
  184. self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5
  185. self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half)
  186. # operations
  187. self.log = P.Log()
  188. self.log1p = P.Log1p()
  189. self.abs = P.Abs()
  190. self.shape = P.Shape()
  191. self.dtype = P.DType()
  192. self.fill = P.Fill()
  193. self.floor = P.Floor()
  194. self.equal = P.Equal()
  195. self.greater = P.Greater()
  196. self.less = P.Less()
  197. self.lessequal = P.LessEqual()
  198. self.select = P.Select()
  199. self.sin = P.Sin()
  200. self.isfinite = P.IsFinite()
  201. def construct(self, x):
  202. input_dtype = self.dtype(x)
  203. _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  204. infinity = self.fill(input_dtype, self.shape(x), self.inf)
  205. need_to_reflect = self.less(x, 0.5)
  206. neg_input = -x
  207. z = self.select(need_to_reflect, neg_input, x - 1)
  208. @constexpr
  209. def _calculate_reflected_x(z, k_base_lanczos_coeff, k_lanczos_coefficients):
  210. reflex_x = k_base_lanczos_coeff
  211. for i in range(8):
  212. product_ = k_lanczos_coefficients[i] / (z + i + 1)
  213. reflex_x = product_ + reflex_x
  214. return reflex_x
  215. reflex_x = _calculate_reflected_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
  216. t = z + self.lanczos_gamma_plus_one_half
  217. log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half
  218. log_y = self.log(reflex_x) + (z + self.one_half - t / log_t) * log_t + self.log_sqrt_two_pi
  219. abs_input = self.abs(x)
  220. abs_frac_input = abs_input - self.floor(abs_input)
  221. x = self.select(self.lessequal(x, 0.0), self.select(self.equal(abs_frac_input, 0.0), infinity, x), x)
  222. reduced_frac_input = self.select(self.greater(abs_frac_input, 0.5),
  223. 1 - abs_frac_input, abs_frac_input)
  224. reflection_denom = self.log(self.sin(self.pi * reduced_frac_input))
  225. reflection = self.select(self.isfinite(reflection_denom),
  226. -reflection_denom - log_y + self.log_pi, # pylint: disable=invalid-unary-operand-type
  227. -reflection_denom) # pylint: disable=invalid-unary-operand-type
  228. result = self.select(need_to_reflect, reflection, log_y)
  229. return self.select(self.isfinite(x), result, infinity)
  230. class DiGamma(Cell):
  231. r"""
  232. Calculates Digamma using Lanczos' approximation referring to "A Precision Approximation of the Gamma Function".
  233. The algorithm is:
  234. .. math::
  235. \begin{array}{ll} \\
  236. digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) \\
  237. t(z) = z + kLanczosGamma + 1/2 \\
  238. A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k} \\
  239. A'(z) = \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{{z + k}^2}
  240. \end{array}
  241. However, if the input is less than 0.5 use Euler's reflection formula:
  242. .. math::
  243. digamma(x) = digamma(1 - x) - pi * cot(pi * x)
  244. Inputs:
  245. - **x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
  246. Outputs:
  247. Tensor, has the same shape and dtype as the `x`.
  248. Raises:
  249. TypeError: If dtype of `x` is neither float16 nor float32.
  250. Supported Platforms:
  251. ``Ascend`` ``GPU``
  252. Examples:
  253. >>> x = Tensor(np.array([2, 3, 4]).astype(np.float32))
  254. >>> op = nn.DiGamma()
  255. >>> output = op(x)
  256. >>> print(output)
  257. [0.42278463 0.92278427 1.2561178]
  258. """
  259. def __init__(self):
  260. """Initialize DiGamma."""
  261. super(DiGamma, self).__init__()
  262. # const numbers
  263. self.k_lanczos_gamma = 7
  264. self.k_base_lanczos_coeff = _BASE_LANCZOS_COEFF
  265. self.k_lanczos_coefficients = _LANCZOS_COEFFICIENTS
  266. self.nan = np.nan
  267. self.pi = np.pi
  268. self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5
  269. self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half)
  270. # operations
  271. self.log1p = P.Log1p()
  272. self.abs = P.Abs()
  273. self.shape = P.Shape()
  274. self.dtype = P.DType()
  275. self.fill = P.Fill()
  276. self.floor = P.Floor()
  277. self.equal = P.Equal()
  278. self.less = P.Less()
  279. self.select = P.Select()
  280. self.sin = P.Sin()
  281. self.cos = P.Cos()
  282. self.logicaland = P.LogicalAnd()
  283. def construct(self, x):
  284. input_dtype = self.dtype(x)
  285. _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  286. need_to_reflect = self.less(x, 0.5)
  287. neg_input = -x
  288. z = self.select(need_to_reflect, neg_input, x - 1)
  289. @constexpr
  290. def _calculate_num_denom(z, k_base_lanczos_coeff, k_lanczos_coefficients):
  291. num = 0
  292. denom = k_base_lanczos_coeff
  293. for i in range(8):
  294. num = num - k_lanczos_coefficients[i] / ((z + i + 1) * (z + i + 1))
  295. denom = denom + k_lanczos_coefficients[i] / (z + i + 1)
  296. return num, denom
  297. num, denom = _calculate_num_denom(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
  298. t = z + self.lanczos_gamma_plus_one_half
  299. log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half
  300. y = log_t + num / denom - self.k_lanczos_gamma / t
  301. reduced_input = x + self.abs(self.floor(x + 0.5))
  302. reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input)
  303. real_result = self.select(need_to_reflect, reflection, y)
  304. nan = self.fill(self.dtype(x), self.shape(x), np.nan)
  305. return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))),
  306. nan, real_result)
  307. eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
  308. def _while_helper_func(cond, body, vals):
  309. while cond(vals).any():
  310. vals = body(vals)
  311. return vals
  312. def _igamma_series(ax, x, a, enabled):
  313. """Helper function for computing Igamma using a power series."""
  314. logicaland = P.LogicalAnd()
  315. greater = P.Greater()
  316. fill = P.Fill()
  317. shape = P.Shape()
  318. dtype = P.DType()
  319. select = P.Select()
  320. # If more data types are supported, this epsilon need to be selected.
  321. epsilon = eps_fp32
  322. def cond(vals):
  323. enabled = vals[0]
  324. return enabled
  325. def body(vals):
  326. enabled = vals[0]
  327. r = vals[1]
  328. c = vals[2]
  329. ans = vals[3]
  330. x = vals[4]
  331. dc_da = vals[5]
  332. dans_da = vals[6]
  333. r = r + 1
  334. dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r)
  335. dans_da = dans_da + dc_da
  336. c = c * (x / r)
  337. ans = ans + c
  338. conditional = logicaland(enabled, greater(c / ans, epsilon))
  339. return (conditional, select(enabled, r, vals[1]),
  340. select(enabled, c, vals[2]), select(enabled, ans, vals[3]),
  341. select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
  342. select(enabled, dans_da, vals[6]))
  343. ones = fill(dtype(a), shape(a), 1)
  344. zeros = fill(dtype(a), shape(a), 0)
  345. vals = (enabled, a, ones, ones, x, zeros, zeros)
  346. vals = _while_helper_func(cond, body, vals)
  347. ans = vals[3]
  348. return (ans * ax) / a
  349. def _igammac_continued_fraction(ax, x, a, enabled):
  350. """Helper function for computing Igammac using a continued fraction."""
  351. abs_x = P.Abs()
  352. logicaland = P.LogicalAnd()
  353. greater = P.Greater()
  354. less = P.Less()
  355. notequal = P.NotEqual()
  356. fill = P.Fill()
  357. shape = P.Shape()
  358. dtype = P.DType()
  359. select = P.Select()
  360. # If more data types are supported, this epsilon need to be selected.
  361. epsilon = eps_fp32
  362. def cond(vals):
  363. enabled = vals[0]
  364. c = vals[5]
  365. return logicaland(less(c, 2000), enabled)
  366. def body(vals):
  367. enabled = vals[0]
  368. ans = vals[1]
  369. t = vals[2]
  370. y = vals[3]
  371. z = vals[4]
  372. c = vals[5]
  373. pkm1 = vals[6]
  374. qkm1 = vals[7]
  375. pkm2 = vals[8]
  376. qkm2 = vals[9]
  377. dpkm2_da = vals[10]
  378. dqkm2_da = vals[11]
  379. dpkm1_da = vals[12]
  380. dqkm1_da = vals[13]
  381. dans_da = vals[14]
  382. c = c + 1
  383. y = y + 1
  384. z = z + 2
  385. yc = y * c
  386. pk = pkm1 * z - pkm2 * yc
  387. qk = qkm1 * z - qkm2 * yc
  388. qk_is_nonzero = notequal(qk, 0)
  389. r = pk / qk
  390. t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
  391. ans = select(qk_is_nonzero, r, ans)
  392. dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
  393. dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c
  394. dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
  395. grad_conditional = select(qk_is_nonzero,
  396. abs_x(dans_da_new - dans_da),
  397. fill(dtype(dans_da), shape(dans_da), 1))
  398. pkm2 = pkm1
  399. pkm1 = pk
  400. qkm2 = qkm1
  401. qkm1 = qk
  402. dpkm2_da = dpkm1_da
  403. dqkm2_da = dqkm1_da
  404. dpkm1_da = dpk_da
  405. dqkm1_da = dqk_da
  406. rescale = greater(abs_x(pk), 1 / epsilon)
  407. pkm2 = select(rescale, pkm2 * epsilon, pkm2)
  408. pkm1 = select(rescale, pkm1 * epsilon, pkm1)
  409. qkm2 = select(rescale, qkm2 * epsilon, qkm2)
  410. qkm1 = select(rescale, qkm1 * epsilon, qkm1)
  411. dpkm2_da = select(rescale, dpkm2_da * epsilon, dpkm2_da)
  412. dqkm2_da = select(rescale, dqkm2_da * epsilon, dqkm2_da)
  413. dpkm1_da = select(rescale, dpkm1_da * epsilon, dpkm1_da)
  414. dqkm1_da = select(rescale, dqkm1_da * epsilon, dqkm1_da)
  415. conditional = logicaland(enabled, greater(grad_conditional, epsilon))
  416. return (conditional, select(enabled, ans, vals[1]), select(enabled, t, vals[2]),
  417. select(enabled, y, vals[3]), select(enabled, z, vals[4]),
  418. c, select(enabled, pkm1, vals[6]),
  419. select(enabled, qkm1, vals[7]), select(enabled, pkm2, vals[8]),
  420. select(enabled, qkm2, vals[9]), select(enabled, dpkm2_da, vals[10]),
  421. select(enabled, dqkm2_da, vals[11]), select(enabled, dpkm1_da, vals[12]),
  422. select(enabled, dqkm1_da, vals[13]), select(enabled, dans_da_new, vals[14]))
  423. y = 1 - a
  424. z = x + y + 1
  425. c = fill(dtype(x), shape(x), 0)
  426. pkm2 = fill(dtype(x), shape(x), 1)
  427. qkm2 = x
  428. pkm1 = x + 1
  429. qkm1 = z * x
  430. ans = pkm1 / qkm1
  431. t = fill(dtype(x), shape(x), 1)
  432. dpkm2_da = fill(dtype(x), shape(x), 0)
  433. dqkm2_da = fill(dtype(x), shape(x), 0)
  434. dpkm1_da = fill(dtype(x), shape(x), 0)
  435. dqkm1_da = -x
  436. dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
  437. vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
  438. vals = _while_helper_func(cond, body, vals)
  439. ans = vals[1]
  440. return ans * ax
  441. class IGamma(Cell):
  442. r"""
  443. Calculates lower regularized incomplete Gamma function.
  444. The lower regularized incomplete Gamma function is defined as:
  445. .. math::
  446. P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
  447. where
  448. .. math::
  449. gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
  450. is the lower incomplete Gamma function.
  451. Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
  452. Inputs:
  453. - **a** (Tensor) - The input tensor. With float32 data type. `a` should have
  454. the same dtype with `x`.
  455. - **x** (Tensor) - The input tensor. With float32 data type. `x` should have
  456. the same dtype with `a`.
  457. Outputs:
  458. Tensor, has the same dtype as `a` and `x`.
  459. Raises:
  460. TypeError: If dtype of input x and a is not float16 nor float32,
  461. or if x has different dtype with a.
  462. Supported Platforms:
  463. ``Ascend`` ``GPU``
  464. Examples:
  465. >>> a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
  466. >>> x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
  467. >>> igamma = nn.IGamma()
  468. >>> output = igamma(a, x)
  469. >>> print (output)
  470. [0.593994 0.35276785 0.21486944 0.13337152]
  471. """
  472. def __init__(self):
  473. """Initialize IGamma."""
  474. super(IGamma, self).__init__()
  475. # const numbers
  476. # If more data types are supported, this float max value need to be selected.
  477. self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32)
  478. # operations
  479. self.logicaland = P.LogicalAnd()
  480. self.logicalor = P.LogicalOr()
  481. self.logicalnot = P.LogicalNot()
  482. self.equal = P.Equal()
  483. self.greater = P.Greater()
  484. self.less = P.Less()
  485. self.neg = P.Neg()
  486. self.log = P.Log()
  487. self.exp = P.Exp()
  488. self.select = P.Select()
  489. self.zeroslike = P.ZerosLike()
  490. self.fill = P.Fill()
  491. self.shape = P.Shape()
  492. self.dtype = P.DType()
  493. self.lgamma = LGamma()
  494. self.const = P.ScalarToArray()
  495. self.cast = P.Cast()
  496. def construct(self, a, x):
  497. a_dtype = self.dtype(a)
  498. x_dtype = self.dtype(x)
  499. _check_input_dtype("a", a_dtype, [mstype.float32], self.cls_name)
  500. _check_input_dtype("x", x_dtype, a_dtype, self.cls_name)
  501. domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
  502. use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
  503. ax = a * self.log(x) - x - self.lgamma(a)
  504. para_shape = self.shape(ax)
  505. if para_shape != ():
  506. broadcastto = P.BroadcastTo(para_shape)
  507. x = broadcastto(x)
  508. a = broadcastto(a)
  509. x_is_zero = self.equal(x, 0)
  510. log_maxfloat = self.log_maxfloat32
  511. underflow = self.less(ax, self.neg(log_maxfloat))
  512. ax = self.exp(ax)
  513. enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
  514. output = self.select(use_igammac,
  515. 1 - _igammac_continued_fraction(ax, x, a, self.logicaland(enabled, use_igammac)),
  516. _igamma_series(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
  517. output = self.select(x_is_zero, self.zeroslike(output), output)
  518. output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output)
  519. return output
  520. class LBeta(Cell):
  521. r"""
  522. This method avoids the numeric cancellation by explicitly
  523. decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling
  524. the large terms from the Striling analytically.
  525. This is semantically equal to
  526. .. math::
  527. P(x, y) = lgamma(x) + lgamma(y) - lgamma(x + y).
  528. The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation
  529. is catastrophic cancellation between the lgammas.
  530. Inputs:
  531. - **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
  532. the same dtype with `y`.
  533. - **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have
  534. the same dtype with `x`.
  535. Outputs:
  536. Tensor, has the same dtype as `x` and `y`.
  537. Raises:
  538. TypeError: If dtype of `x` or `y` is neither float16 nor float32,
  539. or if `x` has different dtype with `y`.
  540. Supported Platforms:
  541. ``Ascend`` ``GPU``
  542. Examples:
  543. >>> x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
  544. >>> y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32))
  545. >>> lbeta = nn.LBeta()
  546. >>> output = lbeta(y, x)
  547. >>> print(output)
  548. [-1.7917596 -4.094345 -12.000229 -14.754799]
  549. """
  550. def __init__(self):
  551. """Initialize LBeta."""
  552. super(LBeta, self).__init__()
  553. # const numbers
  554. self.log_2pi = np.log(2 * np.pi)
  555. self.minimax_coeff = [-0.165322962780713e-02,
  556. 0.837308034031215e-03,
  557. -0.595202931351870e-03,
  558. 0.793650666825390e-03,
  559. -0.277777777760991e-02,
  560. 0.833333333333333e-01]
  561. # operations
  562. self.log = P.Log()
  563. self.log1p = P.Log1p()
  564. self.less = P.Less()
  565. self.select = P.Select()
  566. self.shape = P.Shape()
  567. self.dtype = P.DType()
  568. self.lgamma = LGamma()
  569. self.const = P.ScalarToTensor()
  570. def construct(self, x, y):
  571. x_dtype = self.dtype(x)
  572. y_dtype = self.dtype(y)
  573. _check_input_dtype("x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
  574. _check_input_dtype("y", y_dtype, x_dtype, self.cls_name)
  575. x_plus_y = x + y
  576. para_shape = self.shape(x_plus_y)
  577. if para_shape != ():
  578. broadcastto = P.BroadcastTo(para_shape)
  579. x = broadcastto(x)
  580. y = broadcastto(y)
  581. comp_less = self.less(x, y)
  582. x_min = self.select(comp_less, x, y)
  583. y_max = self.select(comp_less, y, x)
  584. @constexpr
  585. def _log_gamma_correction(x, minimax_coeff):
  586. inverse_x = 1. / x
  587. inverse_x_squared = inverse_x * inverse_x
  588. accum = minimax_coeff[0]
  589. for i in range(1, 6):
  590. accum = accum * inverse_x_squared + minimax_coeff[i]
  591. return accum * inverse_x
  592. log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff)
  593. log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff)
  594. log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
  595. # Two large arguments case: y >= x >= 8.
  596. log_beta_two_large = self.const(0.5 * self.log_2pi, x_dtype) - 0.5 * self.log(y_max) \
  597. + log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
  598. + (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
  599. cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
  600. correction = log_gamma_correction_y - log_gamma_correction_x_y
  601. log_gamma_difference_big_y = correction + cancelled_stirling
  602. # One large argument case: x < 8, y >= 8.
  603. log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y
  604. # Small arguments case: x <= y < 8.
  605. log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max)
  606. comp_xless8 = self.less(x_min, 8)
  607. comp_yless8 = self.less(y_max, 8)
  608. temp = self.select(comp_yless8, log_beta_small, log_beta_one_large)
  609. return self.select(comp_xless8, temp, log_beta_two_large)
  610. @constexpr
  611. def get_broadcast_matmul_shape(x_shape, y_shape, prim_name=None):
  612. """get broadcast_matmul shape"""
  613. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  614. if (len(x_shape) < 2) or (len(y_shape) < 2):
  615. raise ValueError(f"{msg_prefix} length of 'x_shape' and 'y_shape' should be equal to or greater than 2, "
  616. f"but got the length of 'x_shape': {len(x_shape)} and the length of 'y_shape': "
  617. f"{len(y_shape)}.")
  618. x_shape_batch = x_shape[:-2]
  619. y_shape_batch = y_shape[:-2]
  620. if x_shape_batch == y_shape_batch:
  621. return x_shape, y_shape
  622. x_len = len(x_shape)
  623. y_len = len(y_shape)
  624. length = x_len if x_len < y_len else y_len
  625. broadcast_shape_back = []
  626. for i in range(-length, -2):
  627. if x_shape[i] == 1:
  628. broadcast_shape_back.append(y_shape[i])
  629. elif y_shape[i] == 1:
  630. broadcast_shape_back.append(x_shape[i])
  631. elif x_shape[i] == y_shape[i]:
  632. broadcast_shape_back.append(x_shape[i])
  633. else:
  634. raise ValueError(f"{msg_prefix} 'x_shape[{i}]' should be equal to 1, or the 'y_shape[{i}]' should be equal "
  635. f"to 1, or the 'x_shape[{i}]' should be equal to 'y_shape[{i}]', but got "
  636. f"'x_shape[{i}]': {x_shape[i]}, 'y_shape[{i}]': {y_shape[i]}.")
  637. broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
  638. x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
  639. y_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + y_shape[-2:]
  640. return x_broadcast_shape, y_broadcast_shape
  641. @constexpr
  642. def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2, prim_name=None):
  643. """check col and row equal"""
  644. msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
  645. if len(x1_shape) == 1:
  646. transpose_x1 = False
  647. x1_shape = (1,) + x1_shape
  648. if len(x2_shape) == 1:
  649. transpose_x2 = False
  650. x2_shape = x2_shape + (1,)
  651. x1_last = x1_shape[-2:]
  652. x2_last = x2_shape[-2:]
  653. x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
  654. x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
  655. if x1_col != x2_row:
  656. raise ValueError(f"{msg_prefix} column of matrix dimensions of 'x1' should be equal to "
  657. f"the row of matrix dimensions of 'x2', but got 'x1_col' {x1_col} and 'x2_row' {x2_row}.")
  658. def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
  659. """select matmul op"""
  660. x1_dim, x2_dim = len(x1_shape), len(x2_shape)
  661. if x1_dim == 1 and x2_dim == 1:
  662. matmul_op = P.Mul()
  663. elif x1_dim <= 2 and x2_dim <= 2:
  664. transpose_x1 = False if x1_dim == 1 else transpose_x1
  665. transpose_x2 = False if x2_dim == 1 else transpose_x2
  666. matmul_op = P.MatMul(transpose_x1, transpose_x2)
  667. elif x1_dim == 1 and x2_dim > 2:
  668. matmul_op = P.BatchMatMul(False, transpose_x2)
  669. elif x1_dim > 2 and x2_dim == 1:
  670. matmul_op = P.BatchMatMul(transpose_x1, False)
  671. else:
  672. matmul_op = P.BatchMatMul(transpose_x1, transpose_x2)
  673. return matmul_op
  674. class MatMul(Cell):
  675. r"""
  676. Multiplies matrix `x1` by matrix `x2`.
  677. nn.MatMul will be deprecated in future versions. Please use ops.matmul instead.
  678. - If both `x1` and `x2` are 1-dimensional, the dot product is returned.
  679. - If the dimensions of `x1` and `x2` are all not greater than 2, the matrix-matrix product will
  680. be returned. Note if one of 'x1' and 'x2' is 1-dimensional, the argument will first be
  681. expanded to 2 dimension. After the matrix multiply, the expanded dimension will be removed.
  682. - If at least one of `x1` and `x2` is N-dimensional (N>2), the none-matrix dimensions(batch) of inputs
  683. will be broadcasted and must be broadcastable. Note if one of 'x1' and 'x2' is 1-dimensional,
  684. the argument will first be expanded to 2 dimension and then the none-matrix dimensions will be broadcasted.
  685. after the matrix multiply, the expanded dimension will be removed. For example,
  686. if `x1` is a :math:`(j \times 1 \times n \times m)` tensor and
  687. `x2` is b :math:`(k \times m \times p)` tensor, the output will be a :math:`(j \times k \times n \times p)`
  688. tensor.
  689. Args:
  690. transpose_x1 (bool): If true, `a` is transposed before multiplication. Default: False.
  691. transpose_x2 (bool): If true, `b` is transposed before multiplication. Default: False.
  692. Inputs:
  693. - **x1** (Tensor) - The first tensor to be multiplied.
  694. - **x2** (Tensor) - The second tensor to be multiplied.
  695. Outputs:
  696. Tensor, the shape of the output tensor depends on the dimension of input tensors.
  697. Raises:
  698. TypeError: If `transpose_x1` or `transpose_x2` is not a bool.
  699. ValueError: If the column of matrix dimensions of `x1` is not equal to
  700. the row of matrix dimensions of `x2`.
  701. Supported Platforms:
  702. ``Ascend`` ``GPU`` ``CPU``
  703. Examples:
  704. >>> net = nn.MatMul()
  705. >>> x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
  706. >>> x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
  707. >>> output = net(x1, x2)
  708. >>> print(output.shape)
  709. (3, 2, 4)
  710. """
  711. @deprecated('1.2', 'ops.matmul', False)
  712. def __init__(self, transpose_x1=False, transpose_x2=False):
  713. """Initialize MatMul."""
  714. super(MatMul, self).__init__()
  715. validator.check_value_type('transpose_x1', transpose_x1, [bool], self.cls_name)
  716. validator.check_value_type('transpose_x2', transpose_x2, [bool], self.cls_name)
  717. self.transpose_x1 = transpose_x1
  718. self.transpose_x2 = transpose_x2
  719. self.shape_op = P.Shape()
  720. self.expand_op = P.ExpandDims()
  721. self.squeeze_left_op = P.Squeeze(-2)
  722. self.squeeze_right_op = P.Squeeze(-1)
  723. self.reduce_sum_op = P.ReduceSum(keep_dims=False)
  724. def construct(self, x1, x2):
  725. x1_shape = self.shape_op(x1)
  726. x2_shape = self.shape_op(x2)
  727. check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2, self.cls_name)
  728. matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
  729. x1_dim, x2_dim = len(x1_shape), len(x2_shape)
  730. if x1_dim == x2_dim and x2_dim == 1:
  731. return self.reduce_sum_op(matmul_op(x1, x2), -1)
  732. if x1_dim == 1:
  733. x1 = self.expand_op(x1, 0)
  734. x1_shape = self.shape_op(x1)
  735. if x2_dim == 1:
  736. x2 = self.expand_op(x2, 1)
  737. x2_shape = self.shape_op(x2)
  738. x1_broadcast_shape, x2_broadcast_shape = get_broadcast_matmul_shape(x1_shape, x2_shape)
  739. x1_broadcast_to = P.BroadcastTo(x1_broadcast_shape)
  740. x2_broadcast_to = P.BroadcastTo(x2_broadcast_shape)
  741. if x1_broadcast_shape != x1_shape:
  742. x1 = x1_broadcast_to(x1)
  743. if x2_broadcast_shape != x2_shape:
  744. x2 = x2_broadcast_to(x2)
  745. matmul_broadcast = matmul_op(x1, x2)
  746. if x1_dim == 1:
  747. matmul_broadcast = self.squeeze_left_op(matmul_broadcast)
  748. if x2_dim == 1:
  749. matmul_broadcast = self.squeeze_right_op(matmul_broadcast)
  750. return matmul_broadcast
  751. class Moments(Cell):
  752. """
  753. Calculates the mean and variance of `x`.
  754. The mean and variance are calculated by aggregating the contents of `input_x` across axes.
  755. If `input_x` is 1-D and axes = [0] this is just the mean and variance of a vector.
  756. Args:
  757. axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: None.
  758. keep_dims (bool): If true, The dimension of mean and variance are identical with input's.
  759. If false, don't keep these dimensions. Default: None.
  760. Inputs:
  761. - **x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported.
  762. :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
  763. Outputs:
  764. - **mean** (Tensor) - The mean of `x`, with the same data type as input `x`.
  765. - **variance** (Tensor) - The variance of `x`, with the same data type as input `x`.
  766. Raises:
  767. TypeError: If `axis` is not one of int, tuple, None.
  768. TypeError: If `keep_dims` is neither bool nor None.
  769. TypeError: If dtype of `x` is neither float16 nor float32.
  770. Supported Platforms:
  771. ``Ascend`` ``GPU`` ``CPU``
  772. Examples:
  773. >>> x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32)
  774. >>> net = nn.Moments(axis=0, keep_dims=True)
  775. >>> output = net(x)
  776. >>> print(output)
  777. (Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
  778. [[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
  779. [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]),
  780. Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
  781. [[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
  782. [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]]))
  783. >>> net = nn.Moments(axis=1, keep_dims=True)
  784. >>> output = net(x)
  785. >>> print(output)
  786. (Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
  787. [[[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
  788. [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00, 6.00000000e+00]]]]),
  789. Tensor(shape=[1, 1, 2, 4], dtype=Float32, value=
  790. [[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
  791. [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]]))
  792. >>> net = nn.Moments(axis=2, keep_dims=True)
  793. >>> output = net(x)
  794. >>> print(output)
  795. (Tensor(shape=[1, 1, 1, 4], dtype=Float32, value=
  796. [[[[ 2.00000000e+00, 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]]]),
  797. Tensor(shape=[1, 1, 1, 4], dtype=Float32, value=
  798. [[[[ 1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00]]]]))
  799. >>> net = nn.Moments(axis=3, keep_dims=True)
  800. >>> output = net(x)
  801. >>> print(output)
  802. (Tensor(shape=[1, 1, 2, 1], dtype=Float32, value=
  803. [[[[ 2.50000000e+00],
  804. [ 4.50000000e+00]]]]), Tensor(shape=[1, 1, 2, 1], dtype=Float32, value=
  805. [[[[ 1.25000000e+00],
  806. [ 1.25000000e+00]]]]))
  807. """
  808. def __init__(self, axis=None, keep_dims=None):
  809. """Initialize Moments."""
  810. super(Moments, self).__init__()
  811. if axis is None:
  812. axis = ()
  813. if isinstance(axis, tuple):
  814. for idx, item in enumerate(axis):
  815. validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
  816. self.axis = validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
  817. if keep_dims is None:
  818. keep_dims = False
  819. self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
  820. self.cast = P.Cast()
  821. self.reduce_mean = P.ReduceMean(keep_dims=True)
  822. self.square_diff = P.SquaredDifference()
  823. self.squeeze = P.Squeeze(self.axis)
  824. def construct(self, x):
  825. tensor_dtype = F.dtype(x)
  826. _check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name)
  827. if tensor_dtype == mstype.float16:
  828. x = self.cast(x, mstype.float32)
  829. mean = self.reduce_mean(x, self.axis)
  830. variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis)
  831. if not self.keep_dims:
  832. mean = self.squeeze(mean)
  833. variance = self.squeeze(variance)
  834. if tensor_dtype == mstype.float16:
  835. mean = self.cast(mean, mstype.float16)
  836. variance = self.cast(variance, mstype.float16)
  837. return mean, variance
  838. return mean, variance
  839. class MatInverse(Cell):
  840. """
  841. Calculates the inverse of Positive-Definite Hermitian matrix using Cholesky decomposition.
  842. Inputs:
  843. - **x** (Tensor[Number]) - The input tensor. It must be a positive-definite matrix.
  844. With float16 or float32 data type.
  845. Outputs:
  846. Tensor, has the same dtype as the `x`.
  847. Raises:
  848. TypeError: If dtype of `x` is neither float16 nor float32.
  849. Supported Platforms:
  850. ``GPU``
  851. Examples:
  852. >>> x = Tensor(np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32))
  853. >>> op = nn.MatInverse()
  854. >>> output = op(x)
  855. >>> print(output)
  856. [[49.36112 -13.555558 2.1111116]
  857. [-13.555558 3.7777784 -0.5555557]
  858. [2.1111116 -0.5555557 0.11111113]]
  859. """
  860. def __init__(self):
  861. """Initialize MatInverse."""
  862. super(MatInverse, self).__init__()
  863. self.dtype = P.DType()
  864. self.choleskytrsm = P.CholeskyTrsm()
  865. self.matmul = MatMul(transpose_x1=True)
  866. def construct(self, a):
  867. input_dtype = self.dtype(a)
  868. _check_input_dtype("input_a", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  869. l_inverse = self.choleskytrsm(a)
  870. a_inverse = self.matmul(l_inverse, l_inverse)
  871. return a_inverse
  872. class MatDet(Cell):
  873. """
  874. Calculates the determinant of Positive-Definite Hermitian matrix using Cholesky decomposition.
  875. Inputs:
  876. - **x** (Tensor[Number]) - The input tensor. It must be a positive-definite matrix.
  877. With float16 or float32 data type.
  878. Outputs:
  879. Tensor, has the same dtype as the `x`.
  880. Raises:
  881. TypeError: If dtype of `x` is neither float16 nor float32.
  882. Supported Platforms:
  883. ``GPU``
  884. Examples:
  885. >>> x = Tensor(np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32))
  886. >>> op = nn.MatDet()
  887. >>> output = op(x)
  888. >>> print(output)
  889. 35.999996
  890. """
  891. def __init__(self):
  892. """Initialize MatDet."""
  893. super(MatDet, self).__init__()
  894. self.dtype = P.DType()
  895. self.cholesky = P.Cholesky()
  896. self.det_triangle = P.DetTriangle()
  897. self.square = P.Square()
  898. def construct(self, a):
  899. input_dtype = self.dtype(a)
  900. _check_input_dtype("input_a", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  901. l = self.cholesky(a)
  902. l_det = self.det_triangle(l)
  903. a_det = self.square(l_det)
  904. return a_det