You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 40 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """math"""
  16. import math
  17. import numpy as np
  18. from mindspore.ops import operations as P
  19. from mindspore.ops.operations import _inner_ops as inner
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.ops.primitive import constexpr
  22. from mindspore.ops import functional as F
  23. from ..cell import Cell
  24. from ...common import dtype as mstype
  25. from ..._checkparam import Validator as validator
  26. __all__ = ['ReduceLogSumExp',
  27. 'Range',
  28. 'LGamma',
  29. 'DiGamma',
  30. 'IGamma',
  31. 'LBeta',
  32. 'MatMul',
  33. 'Moments',
  34. 'MatInverse',
  35. 'MatDet',
  36. ]
  37. @constexpr
  38. def _check_input_dtype(param_name, input_dtype, allow_dtypes, cls_name):
  39. validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
  40. class ReduceLogSumExp(Cell):
  41. r"""
  42. Reduces a dimension of a tensor by calculating exponential for all elements in the dimension,
  43. then calculate logarithm of the sum.
  44. The dtype of the tensor to be reduced is number.
  45. .. math::
  46. ReduceLogSumExp(x) = \log(\sum(e^x))
  47. Args:
  48. axis (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
  49. Only constant value is allowed.
  50. keep_dims (bool): If True, keep these reduced dimensions and the length is 1.
  51. If False, don't keep these dimensions.
  52. Default : False.
  53. Inputs:
  54. - **x** (Tensor) - The input tensor. With float16 or float32 data type.
  55. Outputs:
  56. Tensor, has the same dtype as the `x`.
  57. - If axis is (), and keep_dims is False,
  58. the output is a 0-D tensor representing the sum of all elements in the input tensor.
  59. - If axis is int, set as 2, and keep_dims is False,
  60. the shape of output is :math:`(x_1, x_3, ..., x_R)`.
  61. - If axis is tuple(int), set as (2, 3), and keep_dims is False,
  62. the shape of output is :math:`(x_1, x_4, ..., x_R)`.
  63. Supported Platforms:
  64. ``Ascend`` ``GPU``
  65. Examples:
  66. >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  67. >>> op = nn.ReduceLogSumExp(1, keep_dims=True)
  68. >>> output = op(input_x)
  69. >>> print(output.shape)
  70. (3, 1, 5, 6)
  71. """
  72. def __init__(self, axis, keep_dims=False):
  73. super(ReduceLogSumExp, self).__init__()
  74. validator.check_value_type('axis', axis, [int, list, tuple], self.cls_name)
  75. validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
  76. self.axis = axis
  77. self.exp = P.Exp()
  78. self.sum = P.ReduceSum(keep_dims)
  79. self.log = P.Log()
  80. def construct(self, x):
  81. exp = self.exp(x)
  82. sumexp = self.sum(exp, self.axis)
  83. logsumexp = self.log(sumexp)
  84. return logsumexp
  85. class Range(Cell):
  86. r"""
  87. Creates a sequence of numbers in range [start, limit) with step size delta.
  88. The size of output is :math:`\left \lfloor \frac{limit-start}{delta} \right \rfloor + 1` and `delta` is the gap
  89. between two values in the tensor.
  90. .. math::
  91. out_{i+1} = out_{i} +delta
  92. Args:
  93. start (Union[int, float]): If `limit` is `None`, the value acts as limit in the range and first entry
  94. defaults to `0`. Otherwise, it acts as first entry in the range.
  95. limit (Union[int, float]): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
  96. while set the first entry of the range to `0`. It can not be equal to `start`.
  97. delta (Union[int, float]): Increment of the range. It can not be equal to zero. Default: 1.
  98. Outputs:
  99. Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.
  100. Supported Platforms:
  101. ``Ascend`` ``GPU`` ``CPU``
  102. Examples:
  103. >>> net = nn.Range(1, 8, 2)
  104. >>> output = net()
  105. >>> print(output)
  106. [1 3 5 7]
  107. """
  108. def __init__(self, start, limit=None, delta=1):
  109. super(Range, self).__init__()
  110. validator.check_value_type("start", start, [int, float], self.cls_name)
  111. validator.check_value_type("delta", delta, [int, float], self.cls_name)
  112. if delta == 0:
  113. raise ValueError("The input of `delta` can not be equal to zero.")
  114. if limit is not None:
  115. validator.check_value_type("limit", limit, [int, float], self.cls_name)
  116. if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int):
  117. self.dtype = mstype.int32
  118. else:
  119. self.dtype = mstype.float32
  120. else:
  121. if isinstance(start, int) and isinstance(delta, int):
  122. self.dtype = mstype.int32
  123. else:
  124. self.dtype = mstype.float32
  125. if isinstance(start, int):
  126. start = float(start)
  127. if isinstance(limit, int):
  128. limit = float(limit)
  129. if isinstance(delta, int):
  130. delta = float(delta)
  131. self.range_x = inner.Range(start, limit, delta)
  132. if limit is None:
  133. length_input = math.ceil(start / delta)
  134. else:
  135. length_input = math.ceil((limit - start) / delta)
  136. self.input_tensor = Tensor(list(range(length_input)), self.dtype)
  137. def construct(self):
  138. range_out = self.range_x(self.input_tensor)
  139. return range_out
  140. class LGamma(Cell):
  141. r"""
  142. Calculates LGamma using Lanczos' approximation referring to "A Precision Approximation of the Gamma Function".
  143. The algorithm is:
  144. .. math::
  145. \begin{array}{ll} \\
  146. lgamma(z + 1) = \frac{(\log(2) + \log(pi))}{2} + (z + 1/2) * log(t(z)) - t(z) + A(z) \\
  147. t(z) = z + kLanczosGamma + 1/2 \\
  148. A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k}
  149. \end{array}
  150. However, if the input is less than 0.5 use Euler's reflection formula:
  151. .. math::
  152. lgamma(x) = \log(pi) - lgamma(1-x) - \log(abs(sin(pi * x)))
  153. And please note that
  154. .. math::
  155. lgamma(+/-inf) = +inf
  156. Thus, the behaviour of LGamma follows:
  157. when x > 0.5, return log(Gamma(x))
  158. when x < 0.5 and is not an integer, return the real part of Log(Gamma(x)) where Log is the complex logarithm
  159. when x is an integer less or equal to 0, return +inf
  160. when x = +/- inf, return +inf
  161. Inputs:
  162. - **x** (Tensor) - The input tensor. Only float16, float32 are supported.
  163. Outputs:
  164. Tensor, has the same shape and dtype as the `x`.
  165. Raises:
  166. TypeError: If dtype of input x is not float16 nor float32.
  167. Supported Platforms:
  168. ``Ascend`` ``GPU``
  169. Examples:
  170. >>> input_x = Tensor(np.array([2, 3, 4]).astype(np.float32))
  171. >>> op = nn.LGamma()
  172. >>> output = op(input_x)
  173. >>> print(output)
  174. [3.5762787e-07 6.9314754e-01 1.7917603e+00]
  175. """
  176. def __init__(self):
  177. super(LGamma, self).__init__()
  178. # const numbers
  179. self.k_lanczos_gamma = 7
  180. self.k_base_lanczos_coeff = 0.99999999999980993227684700473478
  181. self.k_lanczos_coefficients = [676.520368121885098567009190444019,
  182. -1259.13921672240287047156078755283,
  183. 771.3234287776530788486528258894,
  184. -176.61502916214059906584551354,
  185. 12.507343278686904814458936853,
  186. -0.13857109526572011689554707,
  187. 9.984369578019570859563e-6,
  188. 1.50563273514931155834e-7]
  189. self.one_half = 0.5
  190. self.one = 1
  191. self.two = 2
  192. self.inf = np.inf
  193. self.pi = np.pi
  194. self.log_2 = np.log(self.two)
  195. self.log_pi = np.log(np.pi)
  196. self.log_sqrt_two_pi = (self.log_2 + self.log_pi) / self.two
  197. self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5
  198. self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half)
  199. # operations
  200. self.log = P.Log()
  201. self.log1p = P.Log1p()
  202. self.abs = P.Abs()
  203. self.shape = P.Shape()
  204. self.dtype = P.DType()
  205. self.fill = P.Fill()
  206. self.floor = P.Floor()
  207. self.equal = P.Equal()
  208. self.greater = P.Greater()
  209. self.less = P.Less()
  210. self.lessequal = P.LessEqual()
  211. self.select = P.Select()
  212. self.sin = P.Sin()
  213. self.isfinite = P.IsFinite()
  214. def construct(self, x):
  215. input_dtype = self.dtype(x)
  216. _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  217. infinity = self.fill(input_dtype, self.shape(x), self.inf)
  218. need_to_reflect = self.less(x, 0.5)
  219. neg_input = -x
  220. z = self.select(need_to_reflect, neg_input, x - 1)
  221. @constexpr
  222. def _calculate_reflected_x(z, k_base_lanczos_coeff, k_lanczos_coefficients):
  223. reflex_x = k_base_lanczos_coeff
  224. for i in range(8):
  225. product_ = k_lanczos_coefficients[i] / (z + i + 1)
  226. reflex_x = product_ + reflex_x
  227. return reflex_x
  228. reflex_x = _calculate_reflected_x(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
  229. t = z + self.lanczos_gamma_plus_one_half
  230. log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half
  231. log_y = self.log(reflex_x) + (z + self.one_half - t / log_t) * log_t + self.log_sqrt_two_pi
  232. abs_input = self.abs(x)
  233. abs_frac_input = abs_input - self.floor(abs_input)
  234. x = self.select(self.lessequal(x, 0.0), self.select(self.equal(abs_frac_input, 0.0), infinity, x), x)
  235. reduced_frac_input = self.select(self.greater(abs_frac_input, 0.5),
  236. 1 - abs_frac_input, abs_frac_input)
  237. reflection_denom = self.log(self.sin(self.pi * reduced_frac_input))
  238. reflection = self.select(self.isfinite(reflection_denom),
  239. -reflection_denom - log_y + self.log_pi,
  240. -reflection_denom)
  241. result = self.select(need_to_reflect, reflection, log_y)
  242. return self.select(self.isfinite(x), result, infinity)
  243. class DiGamma(Cell):
  244. r"""
  245. Calculates Digamma using Lanczos' approximation referring to "A Precision Approximation of the Gamma Function".
  246. The algorithm is:
  247. .. math::
  248. \begin{array}{ll} \\
  249. digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) \\
  250. t(z) = z + kLanczosGamma + 1/2 \\
  251. A(z) = kBaseLanczosCoeff + \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{z + k} \\
  252. A'(z) = \sum_{k=1}^n \frac{kLanczosCoefficients[i]}{{z + k}^2}
  253. \end{array}
  254. However, if the input is less than 0.5 use Euler's reflection formula:
  255. .. math::
  256. digamma(x) = digamma(1 - x) - pi * cot(pi * x)
  257. Inputs:
  258. - **x** (Tensor[Number]) - The input tensor. Only float16, float32 are supported.
  259. Outputs:
  260. Tensor, has the same shape and dtype as the `x`.
  261. Raises:
  262. TypeError: If dtype of input x is not float16 nor float32.
  263. Supported Platforms:
  264. ``Ascend`` ``GPU``
  265. Examples:
  266. >>> input_x = Tensor(np.array([2, 3, 4]).astype(np.float32))
  267. >>> op = nn.DiGamma()
  268. >>> output = op(input_x)
  269. >>> print(output)
  270. [0.42278463 0.92278427 1.2561178]
  271. """
  272. def __init__(self):
  273. super(DiGamma, self).__init__()
  274. # const numbers
  275. self.k_lanczos_gamma = 7
  276. self.k_base_lanczos_coeff = 0.99999999999980993227684700473478
  277. self.k_lanczos_coefficients = [676.520368121885098567009190444019,
  278. -1259.13921672240287047156078755283,
  279. 771.3234287776530788486528258894,
  280. -176.61502916214059906584551354,
  281. 12.507343278686904814458936853,
  282. -0.13857109526572011689554707,
  283. 9.984369578019570859563e-6,
  284. 1.50563273514931155834e-7]
  285. self.nan = np.nan
  286. self.pi = np.pi
  287. self.lanczos_gamma_plus_one_half = self.k_lanczos_gamma + 0.5
  288. self.log_lanczos_gamma_plus_one_half = np.log(self.lanczos_gamma_plus_one_half)
  289. # operations
  290. self.log1p = P.Log1p()
  291. self.abs = P.Abs()
  292. self.shape = P.Shape()
  293. self.dtype = P.DType()
  294. self.fill = P.Fill()
  295. self.floor = P.Floor()
  296. self.equal = P.Equal()
  297. self.less = P.Less()
  298. self.select = P.Select()
  299. self.sin = P.Sin()
  300. self.cos = P.Cos()
  301. self.logicaland = P.LogicalAnd()
  302. def construct(self, x):
  303. input_dtype = self.dtype(x)
  304. _check_input_dtype("x", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  305. need_to_reflect = self.less(x, 0.5)
  306. neg_input = -x
  307. z = self.select(need_to_reflect, neg_input, x - 1)
  308. @constexpr
  309. def _calculate_num_denom(z, k_base_lanczos_coeff, k_lanczos_coefficients):
  310. num = 0
  311. denom = k_base_lanczos_coeff
  312. for i in range(8):
  313. num = num - k_lanczos_coefficients[i] / ((z + i + 1) * (z + i + 1))
  314. denom = denom + k_lanczos_coefficients[i] / (z + i + 1)
  315. return num, denom
  316. num, denom = _calculate_num_denom(z, self.k_base_lanczos_coeff, self.k_lanczos_coefficients)
  317. t = z + self.lanczos_gamma_plus_one_half
  318. log_t = self.log1p(z / self.lanczos_gamma_plus_one_half) + self.log_lanczos_gamma_plus_one_half
  319. y = log_t + num / denom - self.k_lanczos_gamma / t
  320. reduced_input = x + self.abs(self.floor(x + 0.5))
  321. reflection = y - self.pi * self.cos(self.pi * reduced_input) / self.sin(self.pi * reduced_input)
  322. real_result = self.select(need_to_reflect, reflection, y)
  323. nan = self.fill(self.dtype(x), self.shape(x), np.nan)
  324. return self.select(self.logicaland(self.less(x, 0), self.equal(x, self.floor(x))),
  325. nan, real_result)
  326. eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
  327. def _while_helper_func(cond, body, vals):
  328. while cond(vals).any():
  329. vals = body(vals)
  330. return vals
  331. def _IgammaSeries(ax, x, a, enabled):
  332. """Helper function for computing Igamma using a power series."""
  333. logicaland = P.LogicalAnd()
  334. greater = P.Greater()
  335. fill = P.Fill()
  336. shape = P.Shape()
  337. dtype = P.DType()
  338. select = P.Select()
  339. # If more data types are supported, this epsilon need to be selected.
  340. epsilon = eps_fp32
  341. def cond(vals):
  342. enabled = vals[0]
  343. return enabled
  344. def body(vals):
  345. enabled = vals[0]
  346. r = vals[1]
  347. c = vals[2]
  348. ans = vals[3]
  349. x = vals[4]
  350. dc_da = vals[5]
  351. dans_da = vals[6]
  352. r = r + 1
  353. dc_da = dc_da * (x / r) + (-1 * c * x) / (r * r)
  354. dans_da = dans_da + dc_da
  355. c = c * (x / r)
  356. ans = ans + c
  357. conditional = logicaland(enabled, greater(c / ans, epsilon))
  358. return (conditional, select(enabled, r, vals[1]),
  359. select(enabled, c, vals[2]), select(enabled, ans, vals[3]),
  360. select(enabled, x, vals[4]), select(enabled, dc_da, vals[5]),
  361. select(enabled, dans_da, vals[6]))
  362. ones = fill(dtype(a), shape(a), 1)
  363. zeros = fill(dtype(a), shape(a), 0)
  364. vals = (enabled, a, ones, ones, x, zeros, zeros)
  365. vals = _while_helper_func(cond, body, vals)
  366. ans = vals[3]
  367. return (ans * ax) / a
  368. def _IgammacContinuedFraction(ax, x, a, enabled):
  369. """Helper function for computing Igammac using a continued fraction."""
  370. abs_x = P.Abs()
  371. logicaland = P.LogicalAnd()
  372. greater = P.Greater()
  373. less = P.Less()
  374. notequal = P.NotEqual()
  375. fill = P.Fill()
  376. shape = P.Shape()
  377. dtype = P.DType()
  378. select = P.Select()
  379. # If more data types are supported, this epsilon need to be selected.
  380. epsilon = eps_fp32
  381. def cond(vals):
  382. enabled = vals[0]
  383. c = vals[5]
  384. return logicaland(less(c, 2000), enabled)
  385. def body(vals):
  386. enabled = vals[0]
  387. ans = vals[1]
  388. t = vals[2]
  389. y = vals[3]
  390. z = vals[4]
  391. c = vals[5]
  392. pkm1 = vals[6]
  393. qkm1 = vals[7]
  394. pkm2 = vals[8]
  395. qkm2 = vals[9]
  396. dpkm2_da = vals[10]
  397. dqkm2_da = vals[11]
  398. dpkm1_da = vals[12]
  399. dqkm1_da = vals[13]
  400. dans_da = vals[14]
  401. c = c + 1
  402. y = y + 1
  403. z = z + 2
  404. yc = y * c
  405. pk = pkm1 * z - pkm2 * yc
  406. qk = qkm1 * z - qkm2 * yc
  407. qk_is_nonzero = notequal(qk, 0)
  408. r = pk / qk
  409. t = select(qk_is_nonzero, abs_x((ans - r) / r), fill(dtype(t), shape(t), 1))
  410. ans = select(qk_is_nonzero, r, ans)
  411. dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c
  412. dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c
  413. dans_da_new = select(qk_is_nonzero, (dpk_da - ans * dqk_da) / qk, dans_da)
  414. grad_conditional = select(qk_is_nonzero,
  415. abs_x(dans_da_new - dans_da),
  416. fill(dtype(dans_da), shape(dans_da), 1))
  417. pkm2 = pkm1
  418. pkm1 = pk
  419. qkm2 = qkm1
  420. qkm1 = qk
  421. dpkm2_da = dpkm1_da
  422. dqkm2_da = dqkm1_da
  423. dpkm1_da = dpk_da
  424. dqkm1_da = dqk_da
  425. rescale = greater(abs_x(pk), 1 / epsilon)
  426. pkm2 = select(rescale, pkm2 * epsilon, pkm2)
  427. pkm1 = select(rescale, pkm1 * epsilon, pkm1)
  428. qkm2 = select(rescale, qkm2 * epsilon, qkm2)
  429. qkm1 = select(rescale, qkm1 * epsilon, qkm1)
  430. dpkm2_da = select(rescale, dpkm2_da * epsilon, dpkm2_da)
  431. dqkm2_da = select(rescale, dqkm2_da * epsilon, dqkm2_da)
  432. dpkm1_da = select(rescale, dpkm1_da * epsilon, dpkm1_da)
  433. dqkm1_da = select(rescale, dqkm1_da * epsilon, dqkm1_da)
  434. conditional = logicaland(enabled, greater(grad_conditional, epsilon))
  435. return (conditional, select(enabled, ans, vals[1]), select(enabled, t, vals[2]),
  436. select(enabled, y, vals[3]), select(enabled, z, vals[4]),
  437. c, select(enabled, pkm1, vals[6]),
  438. select(enabled, qkm1, vals[7]), select(enabled, pkm2, vals[8]),
  439. select(enabled, qkm2, vals[9]), select(enabled, dpkm2_da, vals[10]),
  440. select(enabled, dqkm2_da, vals[11]), select(enabled, dpkm1_da, vals[12]),
  441. select(enabled, dqkm1_da, vals[13]), select(enabled, dans_da_new, vals[14]))
  442. y = 1 - a
  443. z = x + y + 1
  444. c = fill(dtype(x), shape(x), 0)
  445. pkm2 = fill(dtype(x), shape(x), 1)
  446. qkm2 = x
  447. pkm1 = x + 1
  448. qkm1 = z * x
  449. ans = pkm1 / qkm1
  450. t = fill(dtype(x), shape(x), 1)
  451. dpkm2_da = fill(dtype(x), shape(x), 0)
  452. dqkm2_da = fill(dtype(x), shape(x), 0)
  453. dpkm1_da = fill(dtype(x), shape(x), 0)
  454. dqkm1_da = -x
  455. dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1
  456. vals = (enabled, ans, t, y, z, c, pkm1, qkm1, pkm2, qkm2, dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)
  457. vals = _while_helper_func(cond, body, vals)
  458. ans = vals[1]
  459. return ans * ax
  460. class IGamma(Cell):
  461. r"""
  462. Calculates lower regularized incomplete Gamma function.
  463. The lower regularized incomplete Gamma function is defined as:
  464. .. math::
  465. P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
  466. where
  467. .. math::
  468. gamma(a, x) = \int_0^x t^{a-1} \exp^{-t} dt
  469. is the lower incomplete Gamma function.
  470. Above :math:`Q(a, x)` is the upper regularized complete Gamma function.
  471. Inputs:
  472. - **a** (Tensor) - The input tensor. With float32 data type. `a` should have
  473. the same dtype with `x`.
  474. - **x** (Tensor) - The input tensor. With float32 data type. `x` should have
  475. the same dtype with `a`.
  476. Outputs:
  477. Tensor, has the same dtype as `a` and `x`.
  478. Raises:
  479. TypeError: If dtype of input x and a is not float16 nor float32,
  480. or if x has different dtype with a.
  481. Supported Platforms:
  482. ``Ascend`` ``GPU``
  483. Examples:
  484. >>> input_a = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
  485. >>> input_x = Tensor(np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32))
  486. >>> igamma = nn.IGamma()
  487. >>> output = igamma(input_a, input_x)
  488. >>> print (output)
  489. [0.593994 0.35276785 0.21486944 0.13337152]
  490. """
  491. def __init__(self):
  492. super(IGamma, self).__init__()
  493. # const numbers
  494. # If more data types are supported, this float max value need to be selected.
  495. self.log_maxfloat32 = Tensor(np.log(np.finfo(np.float32).max), mstype.float32)
  496. # operations
  497. self.logicaland = P.LogicalAnd()
  498. self.logicalor = P.LogicalOr()
  499. self.logicalnot = P.LogicalNot()
  500. self.equal = P.Equal()
  501. self.greater = P.Greater()
  502. self.less = P.Less()
  503. self.neg = P.Neg()
  504. self.log = P.Log()
  505. self.exp = P.Exp()
  506. self.select = P.Select()
  507. self.zeroslike = P.ZerosLike()
  508. self.fill = P.Fill()
  509. self.shape = P.Shape()
  510. self.dtype = P.DType()
  511. self.lgamma = LGamma()
  512. self.const = P.ScalarToArray()
  513. self.cast = P.Cast()
  514. def construct(self, a, x):
  515. a_dtype = self.dtype(a)
  516. x_dtype = self.dtype(x)
  517. _check_input_dtype("a", a_dtype, [mstype.float32], self.cls_name)
  518. _check_input_dtype("x", x_dtype, a_dtype, self.cls_name)
  519. domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
  520. use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
  521. ax = a * self.log(x) - x - self.lgamma(a)
  522. para_shape = self.shape(ax)
  523. if para_shape != ():
  524. broadcastto = P.BroadcastTo(para_shape)
  525. x = broadcastto(x)
  526. a = broadcastto(a)
  527. x_is_zero = self.equal(x, 0)
  528. log_maxfloat = self.log_maxfloat32
  529. underflow = self.less(ax, self.neg(log_maxfloat))
  530. ax = self.exp(ax)
  531. enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
  532. output = self.select(use_igammac,
  533. 1 - _IgammacContinuedFraction(ax, x, a, self.logicaland(enabled, use_igammac)),
  534. _IgammaSeries(ax, x, a, self.logicaland(enabled, self.logicalnot(use_igammac))))
  535. output = self.select(x_is_zero, self.zeroslike(output), output)
  536. output = self.select(domain_error, self.fill(self.dtype(a), self.shape(a), np.nan), output)
  537. return output
  538. class LBeta(Cell):
  539. r"""
  540. This is semantically equal to
  541. .. math::
  542. P(x, y) = lgamma(x) + lgamma(y) - lgamma(x + y).
  543. The method is more accurate for arguments above 8. The reason for accuracy loss in the naive computation
  544. is catastrophic cancellation between the lgammas. This method avoids the numeric cancellation by explicitly
  545. decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling
  546. the large terms from the Striling analytically.
  547. Inputs:
  548. - **x** (Tensor) - The input tensor. With float16 or float32 data type. `x` should have
  549. the same dtype with `y`.
  550. - **y** (Tensor) - The input tensor. With float16 or float32 data type. `y` should have
  551. the same dtype with `x`.
  552. Outputs:
  553. Tensor, has the same dtype as `x` and `y`.
  554. Raises:
  555. TypeError: If dtype of input x and a is not float16 nor float32,
  556. or if x has different dtype with a.
  557. Supported Platforms:
  558. ``Ascend`` ``GPU``
  559. Examples:
  560. >>> input_x = Tensor(np.array([2.0, 4.0, 6.0, 8.0]).astype(np.float32))
  561. >>> input_y = Tensor(np.array([2.0, 3.0, 14.0, 15.0]).astype(np.float32))
  562. >>> lbeta = nn.LBeta()
  563. >>> output = lbeta(input_y, input_x)
  564. >>> print(output)
  565. [-1.7917596 -4.094345 -12.000229 -14.754799]
  566. """
  567. def __init__(self):
  568. super(LBeta, self).__init__()
  569. # const numbers
  570. self.log_2pi = np.log(2 * np.pi)
  571. self.minimax_coeff = [-0.165322962780713e-02,
  572. 0.837308034031215e-03,
  573. -0.595202931351870e-03,
  574. 0.793650666825390e-03,
  575. -0.277777777760991e-02,
  576. 0.833333333333333e-01]
  577. # operations
  578. self.log = P.Log()
  579. self.log1p = P.Log1p()
  580. self.less = P.Less()
  581. self.select = P.Select()
  582. self.shape = P.Shape()
  583. self.dtype = P.DType()
  584. self.lgamma = LGamma()
  585. self.const = P.ScalarToTensor()
  586. def construct(self, x, y):
  587. x_dtype = self.dtype(x)
  588. y_dtype = self.dtype(y)
  589. _check_input_dtype("x", x_dtype, [mstype.float16, mstype.float32], self.cls_name)
  590. _check_input_dtype("y", y_dtype, x_dtype, self.cls_name)
  591. x_plus_y = x + y
  592. para_shape = self.shape(x_plus_y)
  593. if para_shape != ():
  594. broadcastto = P.BroadcastTo(para_shape)
  595. x = broadcastto(x)
  596. y = broadcastto(y)
  597. comp_less = self.less(x, y)
  598. x_min = self.select(comp_less, x, y)
  599. y_max = self.select(comp_less, y, x)
  600. @constexpr
  601. def _log_gamma_correction(x, minimax_coeff):
  602. inverse_x = 1. / x
  603. inverse_x_squared = inverse_x * inverse_x
  604. accum = minimax_coeff[0]
  605. for i in range(1, 6):
  606. accum = accum * inverse_x_squared + minimax_coeff[i]
  607. return accum * inverse_x
  608. log_gamma_correction_x = _log_gamma_correction(x_min, self.minimax_coeff)
  609. log_gamma_correction_y = _log_gamma_correction(y_max, self.minimax_coeff)
  610. log_gamma_correction_x_y = _log_gamma_correction(x_plus_y, self.minimax_coeff)
  611. # Two large arguments case: y >= x >= 8.
  612. log_beta_two_large = self.const(0.5 * self.log_2pi, x_dtype) - 0.5 * self.log(y_max) \
  613. + log_gamma_correction_x + log_gamma_correction_y - log_gamma_correction_x_y \
  614. + (x_min - 0.5) * self.log(x_min / (x_min + y_max)) - y_max * self.log1p(x_min / y_max)
  615. cancelled_stirling = -1 * (x_min + y_max - 0.5) * self.log1p(x_min / y_max) - x_min * self.log(y_max) + x_min
  616. correction = log_gamma_correction_y - log_gamma_correction_x_y
  617. log_gamma_difference_big_y = correction + cancelled_stirling
  618. # One large argument case: x < 8, y >= 8.
  619. log_beta_one_large = self.lgamma(x_min) + log_gamma_difference_big_y
  620. # Small arguments case: x <= y < 8.
  621. log_beta_small = self.lgamma(x_min) + self.lgamma(y_max) - self.lgamma(x_min + y_max)
  622. comp_xless8 = self.less(x_min, 8)
  623. comp_yless8 = self.less(y_max, 8)
  624. temp = self.select(comp_yless8, log_beta_small, log_beta_one_large)
  625. return self.select(comp_xless8, temp, log_beta_two_large)
  626. @constexpr
  627. def get_broadcast_matmul_shape(x_shape, y_shape):
  628. """get broadcast_matmul shape"""
  629. if (len(x_shape) < 2) or (len(y_shape) < 2):
  630. raise ValueError('For matmul, rank of x1 and x2 should be equal to or greater than 2, '
  631. + f'but got {x_shape} and {y_shape}.')
  632. x_shape_batch = x_shape[:-2]
  633. y_shape_batch = y_shape[:-2]
  634. if x_shape_batch == y_shape_batch:
  635. return x_shape, y_shape
  636. x_len = len(x_shape)
  637. y_len = len(y_shape)
  638. length = x_len if x_len < y_len else y_len
  639. broadcast_shape_back = []
  640. for i in range(-length, -2):
  641. if x_shape[i] == 1:
  642. broadcast_shape_back.append(y_shape[i])
  643. elif y_shape[i] == 1:
  644. broadcast_shape_back.append(x_shape[i])
  645. elif x_shape[i] == y_shape[i]:
  646. broadcast_shape_back.append(x_shape[i])
  647. else:
  648. raise ValueError(f"For MatMul, the x1_shape {x_shape} and x2_shape {y_shape} can not broadcast.")
  649. broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
  650. x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:]
  651. y_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + y_shape[-2:]
  652. return x_broadcast_shape, y_broadcast_shape
  653. @constexpr
  654. def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2):
  655. """check col and row equal"""
  656. if len(x1_shape) == 1:
  657. transpose_x1 = False
  658. x1_shape = (1,) + x1_shape
  659. if len(x2_shape) == 1:
  660. transpose_x2 = False
  661. x2_shape = x2_shape + (1,)
  662. x1_last = x1_shape[-2:]
  663. x2_last = x2_shape[-2:]
  664. x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
  665. x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
  666. if x1_col != x2_row:
  667. raise ValueError('The column of matrix dimensions of x1 should be equal to '
  668. + f'the row of matrix dimensions of x2, but got {x1_col} and {x2_row}.')
  669. def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
  670. """select matmul op"""
  671. x1_dim, x2_dim = len(x1_shape), len(x2_shape)
  672. if x1_dim == 1 and x2_dim == 1:
  673. matmul_op = P.Mul()
  674. elif x1_dim <= 2 and x2_dim <= 2:
  675. transpose_x1 = False if x1_dim == 1 else transpose_x1
  676. transpose_x2 = False if x2_dim == 1 else transpose_x2
  677. matmul_op = P.MatMul(transpose_x1, transpose_x2)
  678. elif x1_dim == 1 and x2_dim > 2:
  679. matmul_op = P.BatchMatMul(False, transpose_x2)
  680. elif x1_dim > 2 and x2_dim == 1:
  681. matmul_op = P.BatchMatMul(transpose_x1, False)
  682. else:
  683. matmul_op = P.BatchMatMul(transpose_x1, transpose_x2)
  684. return matmul_op
  685. class MatMul(Cell):
  686. r"""
  687. Multiplies matrix `x1` by matrix `x2`.
  688. - If both x1 and x2 are 1-dimensional, the dot product is returned.
  689. - If the dimensions of x1 and x2 are all not greater than 2, the matrix-matrix product will be returned. Note if
  690. one of 'x1' and 'x2' is 1-dimensional, the argument will first be expanded to 2 dimension. After the matrix
  691. multiply, the expanded dimension will be removed.
  692. - If at least one of x1 and x2 is N-dimensional (N>2), the none-matrix dimensions(batch) of inputs will be
  693. broadcasted and must be broadcastable. Note if one of 'x1' and 'x2' is 1-dimensional, the argument will first be
  694. expanded to 2 dimension and then the none-matrix dimensions will be broadcasted. After the matrix multiply, the
  695. expanded dimension will be removed. For example, if `x1` is a :math:`(j \times 1 \times n \times m)` tensor and
  696. `x2` is a :math:`(k \times m \times p)` tensor, the output will be a :math:`(j \times k \times n \times p)`
  697. tensor.
  698. Args:
  699. transpose_x1 (bool): If true, `a` is transposed before multiplication. Default: False.
  700. transpose_x2 (bool): If true, `b` is transposed before multiplication. Default: False.
  701. Inputs:
  702. - **input_x1** (Tensor) - The first tensor to be multiplied.
  703. - **input_x2** (Tensor) - The second tensor to be multiplied.
  704. Outputs:
  705. Tensor, the shape of the output tensor depends on the dimension of input tensors.
  706. Supported Platforms:
  707. ``Ascend`` ``GPU`` ``CPU``
  708. Examples:
  709. >>> net = nn.MatMul()
  710. >>> input_x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32)
  711. >>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
  712. >>> output = net(input_x1, input_x2)
  713. >>> print(output.shape)
  714. (3, 2, 4)
  715. """
  716. def __init__(self, transpose_x1=False, transpose_x2=False):
  717. super(MatMul, self).__init__()
  718. validator.check_value_type('transpose_x1', transpose_x1, [bool], self.cls_name)
  719. validator.check_value_type('transpose_x2', transpose_x2, [bool], self.cls_name)
  720. self.transpose_x1 = transpose_x1
  721. self.transpose_x2 = transpose_x2
  722. self.shape_op = P.Shape()
  723. self.expand_op = P.ExpandDims()
  724. self.squeeze_left_op = P.Squeeze(-2)
  725. self.squeeze_right_op = P.Squeeze(-1)
  726. self.reduce_sum_op = P.ReduceSum(keep_dims=False)
  727. def construct(self, x1, x2):
  728. x1_shape = self.shape_op(x1)
  729. x2_shape = self.shape_op(x2)
  730. check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
  731. matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
  732. x1_dim, x2_dim = len(x1_shape), len(x2_shape)
  733. if x1_dim == x2_dim and x2_dim == 1:
  734. return self.reduce_sum_op(matmul_op(x1, x2), -1)
  735. if x1_dim == 1:
  736. x1 = self.expand_op(x1, 0)
  737. x1_shape = self.shape_op(x1)
  738. if x2_dim == 1:
  739. x2 = self.expand_op(x2, 1)
  740. x2_shape = self.shape_op(x2)
  741. x1_broadcast_shape, x2_broadcast_shape = get_broadcast_matmul_shape(x1_shape, x2_shape)
  742. x1_broadcast_to = P.BroadcastTo(x1_broadcast_shape)
  743. x2_broadcast_to = P.BroadcastTo(x2_broadcast_shape)
  744. if x1_broadcast_shape != x1_shape:
  745. x1 = x1_broadcast_to(x1)
  746. if x2_broadcast_shape != x2_shape:
  747. x2 = x2_broadcast_to(x2)
  748. matmul_broadcast = matmul_op(x1, x2)
  749. if x1_dim == 1:
  750. matmul_broadcast = self.squeeze_left_op(matmul_broadcast)
  751. if x2_dim == 1:
  752. matmul_broadcast = self.squeeze_right_op(matmul_broadcast)
  753. return matmul_broadcast
  754. class Moments(Cell):
  755. """
  756. Calculates the mean and variance of `x`.
  757. Args:
  758. axis (Union[int, tuple(int)]): Calculates the mean and variance along the specified axis. Default: ().
  759. keep_dims (bool): If true, The dimension of mean and variance are identical with input's.
  760. If false, don't keep these dimensions. Default: False.
  761. Inputs:
  762. - **input_x** (Tensor) - The tensor to be calculated. Only float16 and float32 are supported.
  763. Outputs:
  764. - **mean** (Tensor) - The mean of input x, with the same date type as input x.
  765. - **variance** (Tensor) - The variance of input x, with the same date type as input x.
  766. Supported Platforms:
  767. ``Ascend``
  768. Examples:
  769. >>> net = nn.Moments(axis=3, keep_dims=True)
  770. >>> input_x = Tensor(np.array([[[[1, 2, 3, 4], [3, 4, 5, 6]]]]), mindspore.float32)
  771. >>> output = net(input_x)
  772. >>> print(output)
  773. (Tensor(shape=[1, 1, 2, 1], dtype=Float32, value=
  774. [[[[ 2.50000000e+00],
  775. [ 4.50000000e+00]]]]), Tensor(shape=[1, 1, 2, 1], dtype=Float32, value=
  776. [[[[ 1.25000000e+00],
  777. [ 1.25000000e+00]]]]))
  778. """
  779. def __init__(self, axis=None, keep_dims=None):
  780. super(Moments, self).__init__()
  781. if axis is None:
  782. axis = ()
  783. if isinstance(axis, tuple):
  784. for idx, item in enumerate(axis):
  785. validator.check_value_type("axis[%d]" % idx, item, [int], self.cls_name)
  786. self.axis = validator.check_value_type('axis', axis, [int, tuple], self.cls_name)
  787. if keep_dims is None:
  788. keep_dims = False
  789. self.keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], self.cls_name)
  790. self.cast = P.Cast()
  791. self.reduce_mean = P.ReduceMean(keep_dims=True)
  792. self.square_diff = P.SquaredDifference()
  793. self.squeeze = P.Squeeze(self.axis)
  794. def construct(self, x):
  795. tensor_dtype = F.dtype(x)
  796. _check_input_dtype("input x", tensor_dtype, [mstype.float16, mstype.float32], self.cls_name)
  797. if tensor_dtype == mstype.float16:
  798. x = self.cast(x, mstype.float32)
  799. mean = self.reduce_mean(x, self.axis)
  800. variance = self.reduce_mean(self.square_diff(x, F.stop_gradient(mean)), self.axis)
  801. if not self.keep_dims:
  802. mean = self.squeeze(mean)
  803. variance = self.squeeze(variance)
  804. if tensor_dtype == mstype.float16:
  805. mean = self.cast(mean, mstype.float16)
  806. variance = self.cast(variance, mstype.float16)
  807. return mean, variance
  808. return mean, variance
  809. class MatInverse(Cell):
  810. """
  811. Calculates the inverse of Positive-Definite Hermitian matrix using Cholesky decomposition.
  812. Inputs:
  813. - **a** (Tensor[Number]) - The input tensor. It must be a positive-definite matrix.
  814. With float16 or float32 data type.
  815. Outputs:
  816. Tensor, has the same dtype as the `a`.
  817. Raises:
  818. TypeError: If dtype of input x is not float16 nor float32.
  819. Supported Platforms:
  820. ``GPU``
  821. Examples:
  822. >>> input_a = Tensor(np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32))
  823. >>> op = nn.MatInverse()
  824. >>> output = op(input_a)
  825. >>> print(output)
  826. [[49.36112 -13.555558 2.1111116]
  827. [-13.555558 3.7777784 -0.5555557]
  828. [2.1111116 -0.5555557 0.11111111]]
  829. """
  830. def __init__(self):
  831. super(MatInverse, self).__init__()
  832. self.dtype = P.DType()
  833. self.choleskytrsm = P.CholeskyTrsm()
  834. self.matmul = MatMul(transpose_x1=True)
  835. def construct(self, a):
  836. input_dtype = self.dtype(a)
  837. _check_input_dtype("input_a", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  838. l_inverse = self.choleskytrsm(a)
  839. a_inverse = self.matmul(l_inverse, l_inverse)
  840. return a_inverse
  841. class MatDet(Cell):
  842. """
  843. Calculates the determinant of Positive-Definite Hermitian matrix using Cholesky decomposition.
  844. Inputs:
  845. - **a** (Tensor[Number]) - The input tensor. It must be a positive-definite matrix.
  846. With float16 or float32 data type.
  847. Outputs:
  848. Tensor, has the same dtype as the `a`.
  849. Raises:
  850. TypeError: If dtype of input x is not float16 nor float32.
  851. Supported Platforms:
  852. ``GPU``
  853. Examples:
  854. >>> input_a = Tensor(np.array([[4, 12, -16], [12, 37, -43], [-16, -43, 98]]).astype(np.float32))
  855. >>> op = nn.MatDet()
  856. >>> output = op(input_a)
  857. >>> print(output)
  858. 35.999996
  859. """
  860. def __init__(self):
  861. super(MatDet, self).__init__()
  862. self.dtype = P.DType()
  863. self.cholesky = P.Cholesky()
  864. self.det_triangle = P.DetTriangle()
  865. self.square = P.Square()
  866. def construct(self, a):
  867. input_dtype = self.dtype(a)
  868. _check_input_dtype("input_a", input_dtype, [mstype.float16, mstype.float32], self.cls_name)
  869. l = self.cholesky(a)
  870. l_det = self.det_triangle(l)
  871. a_det = self.square(l_det)
  872. return a_det