You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernels.py 44 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Graph kernels. They are composites of basic primitives and can be compiled into
  17. a fused kernel automaticly when context.set_context(enable_graph_kernel=True).
  18. """
  19. from ...common import dtype as mstype
  20. from ...ops import operations as P
  21. from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
  22. from ...ops.composite import multitype_ops as C
  23. from ...ops.operations import _grad_ops as G
  24. from ..._checkparam import ParamValidator as validator
  25. from ..cell import Cell, GraphKernel
  26. class InplaceAssign(PrimitiveWithInfer):
  27. """
  28. Inplace assign `Parameter` with a value.
  29. This primitive can only use in graph kernel.
  30. Inputs:
  31. - **variable** (Parameter) - The `Parameter`.
  32. - **value** (Tensor) - The value to assign.
  33. - **depend** (Tensor) - The dependent tensor to keep this op connected in graph.
  34. Outputs:
  35. Tensor, has the same type as original `variable`.
  36. Examples:
  37. >>> def construct(self, x):
  38. >>> val = x - 1.0
  39. >>> ret = x + 2.0
  40. >>> return InplaceAssign()(x, val, ret)
  41. >>> x = Tensor([2.0], mindspore.float32)
  42. >>> net = Net()
  43. >>> net(x)
  44. """
  45. @prim_attr_register
  46. def __init__(self):
  47. self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
  48. def infer_shape(self, x, y, z):
  49. return z
  50. def infer_dtype(self, x, y, z):
  51. return z
  52. def get_bprop(self):
  53. def bprop(x, y, z, out, dout):
  54. return (x, C.zeros_like(y), dout)
  55. return bprop
  56. class MaximumGrad(GraphKernel):
  57. """
  58. Backprop function for Maximum operator.
  59. Inputs:
  60. - **x** (Tensor) - The first input tensor of maximum.
  61. - **y** (Tensor) - The second input tensor of maximum.
  62. - **dout** (Tensor) - has the same shape as x and y, next operator's backprop output.
  63. Outputs:
  64. dx (Tensor): has the same shape as x and y, returns dout element if
  65. `x >= y` returns true at the same position, or returns zero at that
  66. position
  67. dy (Tensor): has the same shape as x and y, dy = dout - dx
  68. Examples:
  69. >>> layer = MaximumGrad()
  70. >>> output = layer(Tensor([1,2,3], [3, 2, 1], [4, 5, 6]))
  71. """
  72. def __init__(self, grad_x=True, grad_y=True):
  73. super(MaximumGrad, self).__init__()
  74. self.grad_x = grad_x
  75. self.grad_y = grad_y
  76. self.select = P.Select()
  77. self.greater_equal = P.GreaterEqual()
  78. self.zeros_like = P.ZerosLike()
  79. self.sub = P.Sub()
  80. def construct(self, x, y, dout):
  81. cmp_result = self.greater_equal(x, y)
  82. dx = self.select(cmp_result, dout, self.zeros_like(dout))
  83. dy = dout - dx
  84. return dx, dy
  85. class MinimumGrad(GraphKernel):
  86. """
  87. Backprop function for Minimum operator.
  88. Compares x and y elementwise, dout should has the same shape with x and y.
  89. Inputs:
  90. - **x** (Tensor) - The first input
  91. - **y** (Tensor) - x and y should have same shape
  92. - **dout** (Tensor) - Has the same shape as x and y, next operator's backprop output
  93. Outputs:
  94. - dx (Tensor) - Has the same shape as x and y, returns dout element if
  95. `x <= y` returns true at the same position, or returns zero at that
  96. position
  97. - dy (Tensor) - Has the same shape as x and y, dy = dout - dx
  98. Examples:
  99. >>> layer = MinimumGrad()
  100. >>> output = layer(Tensor([1,2,3], [3, 2, 1], [4, 5, 6]))
  101. """
  102. def __init__(self, grad_x=True, grad_y=True):
  103. super(MinimumGrad, self).__init__()
  104. self.grad_x = grad_x
  105. self.grad_y = grad_y
  106. self.select = P.Select()
  107. self.less_equal = P.LessEqual()
  108. self.zeros_like = P.ZerosLike()
  109. self.sub = P.Sub()
  110. def construct(self, x, y, dout):
  111. cmp_result = self.less_equal(x, y)
  112. dx = self.select(cmp_result, dout, self.zeros_like(dout))
  113. # dy = self.select(cmp_result, self.zeros_like(dout), dout)
  114. dy = dout - dx
  115. return dx, dy
  116. class AbsGrad(GraphKernel):
  117. """
  118. Abs's backprop function.
  119. Inputs:
  120. **input_x** (Tensor) - input data of this operator.
  121. **dout** (Tensor) - output of the next operator's backprop function.
  122. Outputs:
  123. Tensor, has the same shape as input_x.
  124. Examples:
  125. >>> back = AbsGrad()
  126. >>> output = back(Tensor([1, 2, 3]), Tensor([4, 5, 6]))
  127. """
  128. def __init__(self):
  129. super(AbsGrad, self).__init__()
  130. self.mul = P.Mul()
  131. self.abs = P.Abs()
  132. self.add = P.TensorAdd()
  133. self.div = P.RealDiv()
  134. self.round = P.Round()
  135. def construct(self, input_x, dout):
  136. NUM_MAX = 32768
  137. mul_max = self.mul(input_x, P.Fill()(P.DType()(input_x), (1,), NUM_MAX))
  138. res_abs = self.abs(mul_max)
  139. res_div = self.div(mul_max, res_abs)
  140. res_round = self.round(res_div)
  141. res = self.mul(res_round, dout)
  142. return res
  143. class ApplyMomentum(GraphKernel):
  144. """
  145. Update parameter according to the ApplyMomentum algorithm.
  146. Inputs:
  147. variable (Tensor): mutable tensor var
  148. accumulation (Tensor): mutable tensor accum
  149. learning_rate (float32): learning rate
  150. gradient (float32): The gradient
  151. momentum (float32): Momentum
  152. Outputs: updated accumulation and variable
  153. """
  154. def __init__(self,
  155. use_nesterov=False,
  156. use_locking=False,
  157. gradient_scale=1.0):
  158. super(ApplyMomentum, self).__init__()
  159. self.gradient_scale = validator.check_type('gradient_scale', gradient_scale, [float])
  160. self.fake_output_assign_1 = InplaceAssign()
  161. self.fake_output_assign_1.add_prim_attr("fake_output", True)
  162. self.fake_output_assign_2 = InplaceAssign()
  163. self.fake_output_assign_2.add_prim_attr("fake_output", True)
  164. def construct(self, variable, accumulation, learning_rate, gradient, momentum):
  165. gradient = gradient * self.gradient_scale
  166. momt_accumulation = accumulation * momentum
  167. accumulation_inplace = momt_accumulation + gradient
  168. sum_gradient = accumulation_inplace * learning_rate
  169. variable_inplace = variable - sum_gradient
  170. accumulation_inplace = self.fake_output_assign_1(accumulation, accumulation_inplace, accumulation_inplace)
  171. variable_inplace = self.fake_output_assign_2(variable, variable_inplace, variable_inplace)
  172. return accumulation_inplace, variable_inplace
  173. class BiasAdd(GraphKernel):
  174. """
  175. Return the sum of x and bias.
  176. Inputs:
  177. x (Tensor): Tensor of input data.
  178. bias (Tensor): The bias tensor.
  179. Output:
  180. Tensor, the sum of x and bias.
  181. Example:
  182. >>> layer = BiasGrad()
  183. >>> output = BiasAdd(Tensor([1, 2, 3]), Tensor([1,]))
  184. """
  185. def __init__(self):
  186. super(BiasAdd, self).__init__()
  187. def construct(self, x, bias):
  188. shape = P.Shape()(x)
  189. if len(shape) == 4:
  190. bias_shape = (1, P.Shape()(bias)[0], 1, 1) # NCHW
  191. else:
  192. bias_shape = (1, P.Shape()(bias)[0])
  193. res = x + P.Reshape()(bias, bias_shape)
  194. return res
  195. class BiasAddGrad(GraphKernel):
  196. """
  197. Computes gradients of BiasAdd.
  198. Inputs:
  199. x (Tensor): the gradients of bias add output.
  200. Output:
  201. Tensor, the gradients of bias add input.
  202. Examples:
  203. >>> dout = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
  204. >>> bias_add_grad = BiasAddGrad()
  205. >>> dx = bias_add_grad(dout)
  206. """
  207. def __init__(self):
  208. super(BiasAddGrad, self).__init__()
  209. def construct(self, x):
  210. shape_x = P.Shape()(x)
  211. reduce_axis = [0]
  212. for i in range(2, len(shape_x)):
  213. reduce_axis.append(i)
  214. res = P.ReduceSum()(x, reduce_axis)
  215. return res
  216. class EqualCount(GraphKernel):
  217. """
  218. Computes the number of the same elements of two tensors.
  219. The two input tensors should have same shape and data type.
  220. Inputs:
  221. x (Tensor): the first input tensor.
  222. y (Tensor): the second input tensor.
  223. Outputs:
  224. Tensor, the type is same as input tensor and size as (1,).
  225. Examples:
  226. >>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
  227. >>> y = Tensor(np.array([1, 2, 4]), mindspore.int32)
  228. >>> equal_count = EqualCount()
  229. >>> equal_count(x, y)
  230. """
  231. def __init__(self):
  232. super(EqualCount, self).__init__()
  233. def construct(self, x, y):
  234. equal_bool = P.Equal()(P.Cast()(x, mstype.float32), P.Cast()(y, mstype.float32))
  235. equal_count = P.Cast()(equal_bool, mstype.float16)
  236. axes = (0,)
  237. res = P.ReduceSum()(equal_count, axes)
  238. res = P.Cast()(res, P.DType()(x))
  239. return res
  240. class ReduceMean(GraphKernel):
  241. """
  242. Reduce a dimension of a tensor by averaging all elements in the dimension.
  243. The dtype of the tensor to be reduced is number.
  244. Args:
  245. keep_dims (bool): If True, keep these reduced dimensions and the length is 1.
  246. If False, don't keep these dimensions. Default : False.
  247. Inputs:
  248. - **input_x** (Tensor[Number]) - The input tensor.
  249. - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
  250. Only constant value is allowed.
  251. Outputs:
  252. Tensor, has the same dtype as the 'input_x'.
  253. - If axis is (), and keep_dims is false,
  254. the output is a 0-D tensor representing the sum of all elements in the input tensor.
  255. - If axis is int, set as 2, and keep_dims is false,
  256. the shape of output is :math:`(x_1, x_3, ..., x_R)`.
  257. - If axis is tuple(int), set as (2, 3), and keep_dims is false,
  258. the shape of output is :math:`(x_1, x_4, ..., x_R)`.
  259. Examples:
  260. >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  261. >>> op = ReduceMean(keep_dims=True)
  262. >>> output = op(input_x, 1)
  263. """
  264. def __init__(self, keep_dims=True):
  265. super(ReduceMean, self).__init__()
  266. self.keep_dims = validator.check_type('keep_dims', keep_dims, [bool])
  267. self.sum = P.ReduceSum(self.keep_dims)
  268. def construct(self, x, axis):
  269. shape = P.Shape()(x)
  270. value_num = 1
  271. for i in axis:
  272. value_num *= shape[i]
  273. data_sum = self.sum(x, axis)
  274. avg = 1.0 / P.Fill()(P.DType()(x), (1,), value_num)
  275. res = data_sum * avg
  276. return res
  277. class ReLU(GraphKernel):
  278. r"""
  279. Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
  280. It returns :math:`\max(x,\ 0)` element-wise.
  281. Inputs:
  282. - **input_x** (Tensor) - The input tensor.
  283. Outputs:
  284. Tensor, with the same type and shape as the `input_x`.
  285. Examples:
  286. >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
  287. >>> relu = ReLU()
  288. >>> result = relu(input_x)
  289. [[0, 4.0, 0.0], [2.0, 0.0, 9.0]]
  290. """
  291. def __init__(self):
  292. super(ReLU, self).__init__()
  293. self.max = P.Maximum()
  294. def construct(self, x):
  295. return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
  296. class SoftmaxCrossEntropyWithLogits(GraphKernel):
  297. r"""
  298. Gets the softmax cross-entropy value between logits and labels which shoule be one-hot encoding.
  299. Note:
  300. Sets input logits as `X`, input label as `Y`, output as `loss`. Then,
  301. .. math::
  302. p_{ij} = softmax(X_{ij}) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)}
  303. .. math::
  304. loss_{ij} = -\sum_j{Y_{ij} * ln(p_{ij})}
  305. Inputs:
  306. - **logits** (Tensor) - Input logits, with shape :math:`(N, C)`.
  307. - **labels** (Tensor) - Ground truth labels, with shape :math:`(N, C)`.
  308. Outputs:
  309. Tuple of 2 Tensor, the loss shape is `(N,)`, and the dlogits with the same shape as `logits`.
  310. Examples:
  311. >>> logits = Tensor([[2, 4, 1, 4, 5], [2, 1, 2, 4, 3]], mindspore.float32)
  312. >>> labels = Tensor([[0, 0, 0, 0, 1], [0, 0, 0, 1, 0]], mindspore.float32)
  313. >>> softmax_cross = SoftmaxCrossEntropyWithLogits()
  314. >>> loss, backprop = softmax_cross(logits, labels)
  315. """
  316. def __init__(self):
  317. super(SoftmaxCrossEntropyWithLogits, self).__init__()
  318. self.max = P.ReduceMax(keep_dims=True)
  319. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  320. def construct(self, features, labels):
  321. data_max = self.max(features, (1,))
  322. data_sub = features - data_max
  323. data_exp = P.Exp()(data_sub)
  324. data_sum = self.sum_keep_dims(data_exp, (1,))
  325. data_div = data_exp / data_sum
  326. data_log_tmp = P.Log()(data_sum)
  327. data_log = data_sub - data_log_tmp
  328. data_mul = labels * data_log
  329. data_muls = P.Neg()(data_mul)
  330. loss = P.ReduceSum()(data_muls, (1,))
  331. backprop = data_div - labels
  332. return loss, backprop
  333. def bprop(self, features, labels, out, dout):
  334. grad = out[1]
  335. grad = grad * P.ExpandDims()(dout[0], -1)
  336. return grad, P.ZerosLike()(labels)
  337. class LayerNormForward(GraphKernel):
  338. """ Forward function of the LayerNorm operator. """
  339. def __init__(self, begin_norm_axis=1, begin_params_axis=1):
  340. super(LayerNormForward, self).__init__()
  341. self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int])
  342. self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int])
  343. self.mul = P.Mul()
  344. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  345. self.sub = P.Sub()
  346. self.add = P.TensorAdd()
  347. self.log = P.Log()
  348. self.exp = P.Exp()
  349. self.eps = P.Eps()
  350. def construct(self, input_x, input_gamma, input_beta):
  351. shape_x = P.Shape()(input_x)
  352. # Calculate the scaling ratio of the average
  353. begin_norm_axis = self.begin_norm_axis
  354. if begin_norm_axis < 0:
  355. begin_norm_axis += len(shape_x)
  356. reduce_axis = ()
  357. for i in range(len(shape_x)):
  358. if i > begin_norm_axis or i == begin_norm_axis:
  359. reduce_axis = reduce_axis + (i,)
  360. reduce_elts = 1.0
  361. for i in reduce_axis:
  362. reduce_elts *= shape_x[i]
  363. mean_cof = 1.0 / reduce_elts
  364. # Calculate mean
  365. mean_muls = self.mul(input_x, mean_cof)
  366. mean = self.sum_keep_dims(mean_muls, reduce_axis)
  367. # Calculate variance
  368. variance_sub = self.sub(input_x, mean)
  369. variance_mul = self.mul(variance_sub, variance_sub)
  370. variance_muls = self.mul(variance_mul, mean_cof)
  371. variance = self.sum_keep_dims(variance_muls, reduce_axis)
  372. # Calculate normalize
  373. normalize_sub = self.sub(input_x, mean)
  374. epsilon = self.eps(input_x)
  375. normalize_add = self.add(variance, epsilon)
  376. normalize_log = self.log(normalize_add)
  377. normalize_log_mul = self.mul(normalize_log, -0.5)
  378. normalize_exp = self.exp(normalize_log_mul)
  379. normalize_mul = self.mul(normalize_sub, normalize_exp)
  380. # Calculate scale and translate
  381. if self.begin_params_axis == 0:
  382. scale_mul = self.mul(input_gamma, normalize_mul)
  383. res = self.add(scale_mul, input_beta)
  384. else:
  385. scale_mul = self.mul(input_gamma, normalize_mul)
  386. res = self.add(scale_mul, input_beta)
  387. return res, mean, variance
  388. class LayerNormXBackprop(GraphKernel):
  389. r"""
  390. Together with LayerNormBetaGammaBackprop, to supply the backprop
  391. functionality for LayerNorm.
  392. Note:
  393. Sets input_x as :math:`x_i`, variance as :math:`\sigma^2`, mean as :math:`\mu`,
  394. input_gamma as :math:`\gamma`. Then,
  395. .. math::
  396. \begin{array}{ll} \\
  397. \hat{x_i} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
  398. \frac {\partial L} {\partial x_i} =
  399. \frac{\gamma}{\sqrt{\sigma^2+\epsilon}}
  400. ( \frac{\partial L}{\partial y_i}
  401. - \frac{1}{m} \cdot \frac{\partial L}{\partial \beta}
  402. - \frac{\hat{x_i}}{m} \cdot \frac{\partial L}{\partial \gamma})
  403. \end{array}
  404. Inputs:
  405. - **dy**(Tensor) - The first item of the next operator's backprop's output.
  406. - **input_x**(Tensor) - The first input of the forward function of LayerNorm.
  407. - **variance**(Tensor) - The second input of the forward function of LayerNorm.
  408. - **mean**(Tensor) - The third input of the forward function of LayerNorm.
  409. - **input_gamma**(Tensor) - The fourth input of the forward function of LayerNorm.
  410. Outputs:
  411. Tensor, the output of this operator, will be used as the first item of the result of
  412. LayerNorm's backprop function, has the same shape and data type as 'input_x'.
  413. Examples:
  414. >>> dy = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  415. >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  416. >>> variance = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  417. >>> mean = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  418. >>> input_gamma = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  419. >>> op = LayerNormXBackprop(keep_dims=False)
  420. >>> output = op(dy, input_x, variance, mean, input_gamma)
  421. """
  422. def __init__(self):
  423. super(LayerNormXBackprop, self).__init__()
  424. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  425. self.log = P.Log()
  426. self.exp = P.Exp()
  427. self.eps = P.Eps()
  428. def construct(self, dy, input_x, variance, mean, input_gamma):
  429. shape_x = P.Shape()(input_x)
  430. shape_mean = P.Shape()(mean)
  431. reduce_axis = ()
  432. flag = -1
  433. min_l = 0
  434. if len(shape_x) > len(shape_mean):
  435. min_l = len(shape_x)
  436. else:
  437. min_l = len(shape_mean)
  438. for i in range(min_l):
  439. if (shape_x[i] != shape_mean[i]) and (flag == -1):
  440. flag = i
  441. if flag != -1:
  442. for i in range(flag, len(shape_x)):
  443. reduce_axis = reduce_axis + (i,)
  444. else:
  445. reduce_axis = reduce_axis + (len(shape_x) - 1,)
  446. mean_num = 1.0
  447. for i in reduce_axis:
  448. mean_num *= shape_x[i]
  449. pd_xl = input_gamma * dy
  450. epsilon = self.eps(input_x)
  451. var_elta = variance + epsilon
  452. var_elta_log = self.log(var_elta)
  453. var_elta_mul = var_elta_log * -0.5
  454. var_elta_2 = P.Exp()(var_elta_mul)
  455. pdvar1_mul = var_elta_2 * var_elta_2
  456. pd_var_1 = pdvar1_mul * var_elta_2
  457. sub_x_mean = input_x - mean
  458. pdvar_mul1 = pd_xl * sub_x_mean
  459. pdvar_sum = self.sum_keep_dims(pdvar_mul1, reduce_axis)
  460. pdvar_mul3 = pdvar_sum * pd_var_1
  461. pd_var = pdvar_mul3 * -0.5
  462. pdmean1_sum = self.sum_keep_dims(pd_xl, reduce_axis)
  463. pdmean1_mul = pdmean1_sum * var_elta_2
  464. pd_mean_1 = pdmean1_mul * -1.0
  465. pdmean2_mul1 = sub_x_mean * -2.0
  466. pdmean2_sum = self.sum_keep_dims(pdmean2_mul1, reduce_axis)
  467. pdmean2_mul3 = pdmean2_sum * (1.0 / mean_num)
  468. pd_mean_2 = pd_var * pdmean2_mul3
  469. pd_mean = pd_mean_2 + pd_mean_1
  470. pd_x_1 = var_elta_2 * pd_xl
  471. pdx2_mul = pd_var * sub_x_mean
  472. pd_x_2 = pdx2_mul * (2.0 * (1.0 / mean_num))
  473. pd_x_3 = pd_mean * (1.0 / mean_num)
  474. pdx_add = pd_x_1 + pd_x_2
  475. pd_x = pdx_add + pd_x_3
  476. return pd_x
  477. class LayerNormBetaGammaBackprop(GraphKernel):
  478. r"""
  479. Together with LayerNormXBackprop, to supply the backprop functionality for
  480. LayerNorm.
  481. Note:
  482. Sets input_x as :math:`x_i`, variance as :math:`\sigma^2`, mean as :math:`\mu`,
  483. input_gamma as :math:`\gamma`. Then,
  484. .. math::
  485. \begin{array}{ll} \\
  486. \hat{x_i} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
  487. \frac {\partial L} {\partial \beta} =
  488. \sum_{i=1}^m \\frac{\\partial L}{\partial y_i} \\
  489. \frac {\partial L} {\partial \gamma} =
  490. \sum_{i=1}^m \\frac{\partial L}{\partial y_i} \cdot \hat{x_i}
  491. \end{array}
  492. Inputs:
  493. - **dy**(Tensor) - The first item of the next operator's backprop's output.
  494. - **input_x**(Tensor) - The first input of the forward function of LayerNorm.
  495. - **variance**(Tensor) - The second input of the forward function of LayerNorm.
  496. - **mean**(Tensor) - The third input of the forward function of LayerNorm.
  497. - **input_gamma**(Tensor) - The fourth input of the forward function of LayerNorm.
  498. Outputs:
  499. Tuple of 2 Tensor, the backprop outputs.
  500. - **pd_beta**(Tensor) - The first item of return value of this operator, will be used as
  501. the second item of the LayerNorm's backprop function.
  502. - **pd_gamma**(Tensor) - The second item of return value of this operator, will be used as
  503. the third item of the LayerNorm's backprop function.
  504. Examples:
  505. >>> dy = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  506. >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  507. >>> variance = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  508. >>> mean = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  509. >>> input_gamma = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  510. >>> op = LayerNormBetaGammaBackprop(keep_dims=False)
  511. >>> pd_beta, pd_gamma = op(dy, input_x, variance, mean, input_gamma)
  512. """
  513. def __init__(self):
  514. super(LayerNormBetaGammaBackprop, self).__init__()
  515. self.sum_not_keep_dims = P.ReduceSum(keep_dims=False)
  516. self.log = P.Log()
  517. self.exp = P.Exp()
  518. self.eps = P.Eps()
  519. def construct(self, dy, input_x, variance, mean, shape_gamma):
  520. shape_x = P.Shape()(input_x)
  521. params_axis = ()
  522. if len(shape_x) != len(shape_gamma):
  523. sub = len(shape_x) - len(shape_gamma)
  524. for i in range(sub):
  525. params_axis = params_axis + (i,)
  526. pd_beta = self.sum_not_keep_dims(dy, params_axis)
  527. epsilon = self.eps(input_x)
  528. var_elta = variance + epsilon
  529. var_elta_log = self.log(var_elta)
  530. var_elta_mul = var_elta_log * -0.5
  531. var_elta_2 = P.Exp()(var_elta_mul)
  532. sub_x_mean = input_x - mean
  533. var_elta_2_cast = var_elta_2
  534. xl_mul = var_elta_2_cast * sub_x_mean
  535. pdga_mul = dy * xl_mul
  536. pd_gamma = self.sum_not_keep_dims(pdga_mul, params_axis)
  537. return pd_beta, pd_gamma
  538. class LogSoftmax(GraphKernel):
  539. r"""
  540. Log Softmax activation function.
  541. Applies the Log Softmax function to the input tensor on the specified axis.
  542. Suppose a slice along the given aixs :math:`x` then for each element :math:`x_i`
  543. the Log Softmax function is shown as follows:
  544. .. math::
  545. \text{output}(x_i) = \log \left(\frac{exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right),
  546. where :math:`N` is the length of the Tensor.
  547. Args:
  548. axis (int): The axis to do the Log softmax operation. Default: -1.
  549. Inputs:
  550. logits (Tensor): The input of Log Softmax.
  551. Outputs:
  552. Tensor, with the same type and shape as the logits.
  553. Examples:
  554. >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
  555. >>> log_softmax = LogSoftmax()
  556. >>> log_softmax(input_x)
  557. [-4.4519143, -3.4519143, -2.4519143, -1.4519144, -0.4519144]
  558. """
  559. def __init__(self, axis=-1):
  560. super(LogSoftmax, self).__init__()
  561. self.axis = validator.check_type('axis', axis, [int])
  562. self.max_keep_dims = P.ReduceMax(keep_dims=True)
  563. self.sub = P.Sub()
  564. self.exp = P.Exp()
  565. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  566. self.log = P.Log()
  567. self.mul = P.Mul()
  568. def construct(self, input_x):
  569. data_max = self.max_keep_dims(input_x, (self.axis,))
  570. data_sub = self.sub(input_x, data_max)
  571. data_exp = self.exp(data_sub)
  572. data_sum = self.sum_keep_dims(data_exp, (self.axis,))
  573. data_log = self.log(data_sum)
  574. res = self.sub(data_sub, data_log)
  575. return res
  576. def bprop(self, input_x, out, dout):
  577. input_x = out
  578. input_dy = dout
  579. data_exp = self.exp(input_x)
  580. data_sum = self.sum_keep_dims(input_dy, (self.axis,))
  581. data_softmax = self.mul(data_exp, data_sum)
  582. res = self.sub(input_dy, data_softmax)
  583. return (res,)
  584. class Tanh(GraphKernel):
  585. r"""
  586. Tanh activation function.
  587. Computes hyperbolic tangent of input element-wise. The Tanh function is defined as:
  588. .. math::
  589. tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},
  590. where :math:`x_i` is an element of the input Tensor.
  591. Inputs:
  592. - **input_x** (Tensor) - The input of Tanh.
  593. Outputs:
  594. Tensor, with the same type and shape as the input_x.
  595. Examples:
  596. >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
  597. >>> tanh = Tanh()
  598. >>> tanh(input_x)
  599. [0.7615941, 0.9640276, 0.9950548, 0.9993293, 0.99990916]
  600. """
  601. def __init__(self):
  602. super(Tanh, self).__init__()
  603. self.abs = P.Abs()
  604. self.add = P.TensorAdd()
  605. self.div = P.RealDiv()
  606. self.mul = P.Mul()
  607. self.mul_fp16 = P.Mul()
  608. self.mul_fp16.add_prim_attr("output_precision", "float16")
  609. self.exp = P.Exp()
  610. def construct(self, input_x):
  611. input_abs = self.abs(input_x)
  612. sign_flag = self.div(input_x, input_abs)
  613. sign_flag_neg = self.mul(sign_flag, -1.0)
  614. power_val = self.mul(input_abs, -2.0)
  615. exp_val = self.exp(power_val)
  616. up_val = self.add(exp_val, -1.0)
  617. down_val = self.add(exp_val, 1.0)
  618. div_val = self.div(up_val, down_val)
  619. res = self.mul(sign_flag_neg, div_val)
  620. return res
  621. def bprop(self, input_x, out, dout):
  622. input_y = out
  623. input_dy = dout
  624. data_square = self.mul(input_y, input_y)
  625. data_mul = self.mul(data_square, -1.0)
  626. anuminate = self.add(data_mul, 1.0)
  627. res = self.mul_fp16(anuminate, input_dy)
  628. return (res,)
  629. class TanhGrad(GraphKernel):
  630. """
  631. Backprop function of Tanh
  632. Mathematical calculating:
  633. result = Tanh(out)
  634. result = 1 - result * result
  635. result = result * dout
  636. Inputs:
  637. out (Tensor): Tanh's output
  638. dout (Tensor): next layer's backward function's output, has same shape as out
  639. Outputs:
  640. result (Tensor): result of (1 - tanh(out)^2) * dout
  641. Examples:
  642. >>> x_np = np.random.randn(5, 3, 6).astype(np.float16)
  643. >>> dy_np = np.random.randn(5, 3, 6).astype(np.float16)
  644. >>> x_ms = Tensor(x_np)
  645. >>> dy_ms = Tensor(dy_np)
  646. >>> tanh_grad = TanhGrad()
  647. >>> out = tanh_grad(x_np, dy_np)
  648. """
  649. def __init__(self):
  650. super(TanhGrad, self).__init__()
  651. self.add = P.TensorAdd()
  652. self.mul = P.Mul()
  653. self.mul_fp16 = P.Mul()
  654. self.mul_fp16.add_prim_attr("output_precision", "float16")
  655. def construct(self, out, dout):
  656. input_y = out
  657. input_dy = dout
  658. data_square = self.mul(input_y, input_y)
  659. data_mul = self.mul(data_square, -1.0)
  660. anuminate = self.add(data_mul, 1.0)
  661. res = self.mul_fp16(anuminate, input_dy)
  662. return res
  663. class Gelu(GraphKernel):
  664. r"""
  665. Gaussian Error Linear Units activation function.
  666. GeLU is described in the paper `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
  667. And also please refer to `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
  668. <https://arxiv.org/abs/1810.04805>`_.
  669. Defined as follows:
  670. .. math::
  671. \text{output} = 0.5 * x * (1 + erf(x / \sqrt{2})),
  672. where :math:`erf` is the "Gauss error function" .
  673. Inputs:
  674. - **input_x** (Tensor) - Input to compute the Gelu.
  675. Outputs:
  676. Tensor, with the same type and shape as input.
  677. Examples:
  678. >>> tensor = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
  679. >>> gelu = Gelu()
  680. >>> result = gelu(tensor)
  681. """
  682. def __init__(self):
  683. super(Gelu, self).__init__()
  684. self.add = P.TensorAdd()
  685. self.abs = P.Abs()
  686. self.exp = P.Exp()
  687. self.neg = P.Neg()
  688. self.minimum = P.Minimum()
  689. self.div = P.RealDiv()
  690. self.mul = P.Mul()
  691. self.CSVALUE = 0.044715
  692. self.CSVALUE_A = 1.59576912
  693. self.CSVALUE_5 = 0.3989422804
  694. self.CSVALUE_3B = 0.2140644488
  695. def construct(self, input_x):
  696. def _tanh_parameter_compute(data_x):
  697. """
  698. compute the parameter of tanh:
  699. return: result equal (x+0.044715*tf.pow(x,3))
  700. """
  701. mul_0 = self.mul(data_x, data_x)
  702. pow_0 = self.mul(mul_0, data_x)
  703. mul_1 = self.mul(pow_0, self.CSVALUE)
  704. result = self.add(data_x, mul_1)
  705. return result
  706. tanh_parameter = _tanh_parameter_compute(input_x)
  707. mul_0 = self.mul(tanh_parameter, 1.5957691)
  708. mul_0_min = self.minimum(mul_0, 0.0)
  709. right_mul = self.exp(mul_0_min)
  710. mul_0_abs = self.abs(mul_0)
  711. mul_0_abs_neg = self.mul(mul_0_abs, -1.0)
  712. mul_0_abs_neg_exp = self.exp(mul_0_abs_neg)
  713. mul_0_abs_neg_exp_add = self.add(mul_0_abs_neg_exp, 1.0)
  714. left_mul = self.div(input_x, mul_0_abs_neg_exp_add)
  715. result = self.mul(left_mul, right_mul)
  716. return result
  717. def bprop(self, input_x, out, dout):
  718. """ register backprop function for Gelu """
  719. data_x = input_x
  720. data_gelu = out
  721. data_dy = dout
  722. def _math_four_compute(data_x):
  723. """
  724. return: math_four equal 2*(np(sqrt(2 / np.pi)*(x + 0.044715*tf.pow(x, 3)))
  725. """
  726. datax_pow = data_x * data_x * data_x
  727. datax_muls_c = self.mul(datax_pow, self.CSVALUE)
  728. datax_addx = self.add(datax_muls_c, data_x)
  729. datax_muls_s = self.mul(datax_addx, self.CSVALUE_A)
  730. return datax_muls_s
  731. # common part
  732. math_four = _math_four_compute(data_x)
  733. math_four_abs = self.abs(math_four)
  734. math_four_abs_neg = self.mul(math_four_abs, -1.0)
  735. math_four_abs_neg_exp = self.exp(math_four_abs_neg)
  736. math_four_min = self.minimum(math_four, 0.0)
  737. # dividend part
  738. datax_pow = self.mul(data_x, data_x)
  739. datax_pow_mul = self.mul(datax_pow, self.CSVALUE_3B)
  740. datax_pow_mul_add = self.add(datax_pow_mul, self.CSVALUE_A)
  741. data_gelu_mul = self.mul(data_gelu, datax_pow_mul_add)
  742. math_four_min_2 = self.mul(math_four_min, 2.0)
  743. div_right = self.mul(data_gelu_mul, math_four_abs_neg_exp)
  744. div_left = self.exp(math_four_min_2)
  745. dividend = self.add(div_left, div_right)
  746. # divisor part
  747. div_0 = self.add(math_four_abs_neg_exp, 1.0)
  748. div_1 = self.exp(math_four_min)
  749. divisor = self.mul(div_1, div_0)
  750. res_grad = self.div(dividend, divisor)
  751. result = self.mul(res_grad, data_dy)
  752. return (result,)
  753. class Softmax(GraphKernel):
  754. """
  755. Operator Softmax
  756. .. math: `exp(x-max(x)) / sum(exp(x-max(x)))`
  757. Args:
  758. axis (int, tuple): Axis along which the softmax normalization is applied
  759. Inputs:
  760. x (Tensor): input data for softmax
  761. Outputs:
  762. output (Tensor): a tensor with the same shape of the input
  763. Examples:
  764. >>> layer = Softmax(1)
  765. >>> x = Tensor(np.array([1.2, 2.1], [2.2, 3.2]), mindspore.float32)
  766. >>> output = layer(x)
  767. """
  768. def __init__(self, axis):
  769. super(Softmax, self).__init__()
  770. validator.check_type("axis", axis, [int, tuple])
  771. if isinstance(axis, int):
  772. self.axis = (axis,)
  773. else:
  774. self.axis = axis
  775. for item in self.axis:
  776. validator.check_type("item of axis", item, [int])
  777. self.max = P.ReduceMax(keep_dims=True)
  778. self.sub = P.Sub()
  779. self.exp = P.Exp()
  780. self.sum = P.ReduceSum(keep_dims=True)
  781. self.mul = P.Mul()
  782. def construct(self, x):
  783. max_x = self.max(x, self.axis)
  784. data_sub = self.sub(x, max_x)
  785. data_exp = self.exp(data_sub)
  786. data_expsum = self.sum(data_exp, self.axis)
  787. output = data_exp / data_expsum
  788. return output
  789. def bprop(self, x, out, dout):
  790. mul_res = self.mul(dout, out)
  791. sum_res = self.sum(mul_res, self.axis)
  792. sub_res = self.sub(dout, sum_res)
  793. res = self.mul(sub_res, out)
  794. return (res,)
  795. class LayerNorm(Cell):
  796. r"""
  797. Applies Layer Normalization over a mini-batch of inputs.
  798. Layer normalization is widely used in recurrent neural networks. It applies
  799. normalization over a mini-batch of inputs for each single training case as described
  800. in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
  801. normalization, layer normalization performs exactly the same computation at training and
  802. testing times. It can be described using the following formula. It is applied across all channels
  803. and pixel but only one batch size.
  804. .. math::
  805. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  806. Args:
  807. normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
  808. `begin_norm_axis ... R - 1`.
  809. begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
  810. `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
  811. begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
  812. will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
  813. the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
  814. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  815. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  816. 'he_uniform', etc. Default: 'ones'.
  817. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  818. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  819. 'he_uniform', etc. Default: 'zeros'.
  820. Inputs:
  821. - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
  822. and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
  823. Outputs:
  824. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  825. Examples:
  826. >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
  827. >>> shape1 = x.shape()[1:]
  828. >>> m = G.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
  829. >>> m(x)
  830. """
  831. def __init__(self,
  832. begin_norm_axis=-1,
  833. begin_params_axis=-1
  834. ):
  835. super(LayerNorm, self).__init__()
  836. self.begin_norm_axis = begin_norm_axis
  837. self.begin_params_axis = begin_params_axis
  838. self.layer_norm = LayerNormForward(begin_norm_axis, begin_params_axis)
  839. self.layer_norm_x_grad = LayerNormXBackprop()
  840. self.layer_norm_beta_gamma = LayerNormBetaGammaBackprop()
  841. self.layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
  842. def construct(self, input_x, input_gamma, input_beta):
  843. return self.layer_norm(input_x, input_gamma, input_beta)
  844. # case 1
  845. def bprop(self, input_x, input_gamma, input_beta, out, dout):
  846. dx, d_gamma, d_beta = self.layer_norm_grad(input_x, dout[0], out[2], dout[1], input_gamma)
  847. return dx, d_gamma, d_beta
  848. class LambUpdateWithLR(GraphKernel):
  849. r"""
  850. Part of Lamb optimizer.
  851. .. math::
  852. s_1 = select(i_1 \gt y_g, select(i_0 \gt y_g, \frac{i_1}{i_2}, se), se)
  853. i_5 = i_5 - max(min(s_1, y_m), y_g) \times i_3 \times i_4
  854. Inputs:
  855. - **input0** (Tensor) - The first tensor to be computed.
  856. - **input1** (Tensor) - The second tensor to be computed.
  857. - **input2** (Tensor) - The third tensor to be computed.
  858. - **input3** (Tensor) - The fourth tensor to be computed.
  859. - **input4** (Tensor) - The fifth tensor to be computed.
  860. - **input5** (Tensor) - The sixth tensor to be computed. It will be updated by result.
  861. - **greater_y** (Tensor) - The seventh tensor to be computed.
  862. - **select_e** (Tensor) - The eighth tensor to be computed.
  863. - **minimum_y** (Tensor) - The ninth tensor to be computed.
  864. Outputs:
  865. A fake output tensor.
  866. Examples:
  867. >>> lamb_update = LambUpdateWithLR()
  868. >>> i0 = np.random.normal(0, 1, [1, 16]).astype(np.float32)
  869. >>> i1 = np.random.normal(0, 1, [1]).astype(np.float32)
  870. >>> i2 = np.random.normal(0, 1, [1]).astype(np.float32)
  871. >>> i3 = np.random.normal(0, 1, [1]).astype(np.float32)
  872. >>> i4 = np.random.normal(0, 1, [1, 16]).astype(np.float32)
  873. >>> i5 = np.random.normal(0, 1, [1, 16]).astype(np.float32)
  874. >>> yg = np.random.normal(0, 1, [1]).astype(np.float32)
  875. >>> se = np.random.normal(0, 1, [1]).astype(np.float32)
  876. >>> ym = np.random.normal(0, 1, [1]).astype(np.float32)
  877. >>> lamb_update(i0, i1, i2, i3, i4, i5, yg, se, ym)
  878. """
  879. def __init__(self):
  880. super(LambUpdateWithLR, self).__init__()
  881. self.greater = P.Greater()
  882. self.select = P.Select()
  883. self.div = P.RealDiv()
  884. self.min = P.Minimum()
  885. self.max = P.Maximum()
  886. self.mul = P.Mul()
  887. self.sub = P.Sub()
  888. self.fake_output_assign = InplaceAssign()
  889. self.fake_output_assign.add_prim_attr("fake_output", True)
  890. def construct(self, input0, input1, input2, input3, input4, input5, greater_y, select_e, minimum_y):
  891. greater0 = self.greater(input0, greater_y)
  892. greater1 = self.greater(input1, greater_y)
  893. real_div0 = self.div(input1, input2)
  894. select0 = self.select(greater0, real_div0, select_e)
  895. select1 = self.select(greater1, select0, select_e)
  896. min0 = self.min(select1, minimum_y)
  897. max0 = self.max(min0, greater_y)
  898. mul0 = self.mul(max0, input3)
  899. mul1 = self.mul(mul0, input4)
  900. sub0 = self.sub(input5, mul1)
  901. sub0 = self.fake_output_assign(input5, sub0, sub0)
  902. return sub0
  903. class LambNextMV(GraphKernel):
  904. r"""
  905. Part of Lamb optimizer.
  906. .. math::
  907. rd_0 = \frac{i_8 \times i_5 + i_9 \times i_4}{i6}
  908. rd_1 = \frac{x_0 \times i_2 + x_1 \times i_1}{i3}
  909. y_2 = \frac{rd_0}{\sqrt{rd_1 + x3}} + x_2 \times i_7
  910. y_3 = \frac{rd_0}{\sqrt{rd_1} + x3}
  911. i5 = i_8 \times i_5 + i_9 \times i_4
  912. i2 = x_0 \times i_2 + x_1 \times i_1
  913. Inputs:
  914. - **inputs1** (Tensor) - The first input tensor to be computed.
  915. - **inputs2** (Tensor) - The second input tensor to be computed. It will be updated by result.
  916. - **inputs3** (Tensor) - The third input tensor to be computed.
  917. - **inputs4** (Tensor) - The fourth input tensor to be computed.
  918. - **inputs5** (Tensor) - The fifth input tensor to be computed. It will be updated by result.
  919. - **inputs6** (Tensor) - The sixth input tensor to be computed.
  920. - **inputs7** (Tensor) - The seventh input tensor to be computed.
  921. - **inputs8** (Tensor) - The eighth input tensor to be computed.
  922. - **inputs9** (Tensor) - The ninth input tensor to be computed.
  923. - **inputsx0** (Tensor) - The tenth input tensor to be computed.
  924. - **inputsx1** (Tensor) - The eleventh input tensor to be computed.
  925. - **inputsx2** (Tensor) - The twelfth input tensor to be computed.
  926. - **inputsx3** (Tensor) - The thirteenth input tensor to be computed.
  927. Outputs:
  928. Tuple of 2 Tensor.
  929. - **add3** (Tensor) - The shape is same as the shape after broadcasting, and the data type is
  930. the one with high precision or high digits among the inputs.
  931. - **realdiv4** (Tensor) - The shape is same as the shape after broadcasting, and the data type is
  932. the one with high precision or high digits among the inputs.
  933. Examples:
  934. >>> lamb_next_mv = LambNextMV()
  935. >>> i1 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  936. >>> i2 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  937. >>> i3 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  938. >>> i4 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  939. >>> i5 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  940. >>> i6 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  941. >>> i7 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  942. >>> i8 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  943. >>> i9 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  944. >>> x0 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  945. >>> x1 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  946. >>> x2 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  947. >>> x3 = Tensor(np.ones([1, 16]).astype(np.float32) * 1e-6)
  948. >>> lamb_next_mv(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3)
  949. """
  950. def __init__(self):
  951. super(LambNextMV, self).__init__()
  952. self.mul = P.Mul()
  953. self.add = P.TensorAdd()
  954. self.div = P.RealDiv()
  955. self.sqrt = P.Sqrt()
  956. self.rsqrt = P.Rsqrt()
  957. self.fake_output_assign_1 = InplaceAssign()
  958. self.fake_output_assign_1.add_prim_attr("fake_output", False)
  959. self.fake_output_assign_2 = InplaceAssign()
  960. self.fake_output_assign_2.add_prim_attr("fake_output", False)
  961. def construct(self, input1, input2, input3, input4, input5, input6, input7,
  962. input8, input9, inputx0, inputx1, inputx2, inputx3):
  963. mul3 = self.mul(inputx1, input1)
  964. mul2 = self.mul(inputx0, input2)
  965. add1 = self.add(mul2, mul3)
  966. realdiv1 = self.div(add1, input3)
  967. add2 = self.add(realdiv1, inputx3)
  968. sqrt0 = self.rsqrt(add2)
  969. sqrt1 = self.sqrt(realdiv1)
  970. add4 = self.add(sqrt1, inputx3)
  971. mul1 = self.mul(input9, input4)
  972. mul0 = self.mul(input8, input5)
  973. add0 = self.add(mul0, mul1)
  974. realdiv0 = self.div(add0, input6)
  975. realdiv2 = self.mul(realdiv0, sqrt0)
  976. realdiv4 = self.div(realdiv0, add4)
  977. mul4 = self.mul(inputx2, input7)
  978. add3 = self.add(realdiv2, mul4)
  979. add3 = self.fake_output_assign_1(input5, add0, add3)
  980. add3 = self.fake_output_assign_2(input2, add1, add3)
  981. return add3, realdiv4