You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernels.py 44 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Graph kernels. They are composites of basic primitives and can be compiled into
  17. a fused kernel automaticly when context.set_context(enable_graph_kernel=True).
  18. """
  19. from ...common import dtype as mstype
  20. from ...ops import operations as P
  21. from ...ops.primitive import PrimitiveWithInfer, prim_attr_register
  22. from ...ops.composite import multitype_ops as C
  23. from ...ops.operations import _grad_ops as G
  24. from ..._checkparam import Validator
  25. from ..cell import Cell, GraphKernel
  26. class InplaceAssign(PrimitiveWithInfer):
  27. """
  28. Inplace assign `Parameter` with a value.
  29. This primitive can only use in graph kernel.
  30. Inputs:
  31. - **variable** (Parameter) - The `Parameter`.
  32. - **value** (Tensor) - The value to be assigned.
  33. - **depend** (Tensor) - The dependent tensor to keep this op connected in graph.
  34. Outputs:
  35. Tensor, has the same type as original `variable`.
  36. Examples:
  37. >>> def construct(self, x):
  38. >>> val = x - 1.0
  39. >>> ret = x + 2.0
  40. >>> return InplaceAssign()(x, val, ret)
  41. >>> x = Tensor([2.0], mindspore.float32)
  42. >>> net = Net()
  43. >>> net(x)
  44. """
  45. @prim_attr_register
  46. def __init__(self):
  47. self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output'])
  48. def infer_shape(self, x, y, z):
  49. return z
  50. def infer_dtype(self, x, y, z):
  51. return z
  52. def get_bprop(self):
  53. def bprop(x, y, z, out, dout):
  54. return (x, C.zeros_like(y), dout)
  55. return bprop
  56. class MaximumGrad(GraphKernel):
  57. """
  58. Backprop function for Maximum operator.
  59. Inputs:
  60. - **x** (Tensor) - The first input tensor of maximum.
  61. - **y** (Tensor) - The second input tensor of maximum.
  62. - **dout** (Tensor) - has the same shape as x and y, next operator's backprop output.
  63. Outputs:
  64. dx (Tensor): has the same shape as x and y, returns dout element if
  65. `x >= y` returns true at the same position, or returns zero at that
  66. position
  67. dy (Tensor): has the same shape as x and y, dy = dout - dx
  68. Examples:
  69. >>> layer = MaximumGrad()
  70. >>> output = layer(Tensor([1,2,3], [3, 2, 1], [4, 5, 6]))
  71. """
  72. def __init__(self, grad_x=True, grad_y=True):
  73. super(MaximumGrad, self).__init__()
  74. self.grad_x = grad_x
  75. self.grad_y = grad_y
  76. self.select = P.Select()
  77. self.greater_equal = P.GreaterEqual()
  78. self.zeros_like = P.ZerosLike()
  79. self.sub = P.Sub()
  80. def construct(self, x, y, dout):
  81. cmp_result = self.greater_equal(x, y)
  82. dx = self.select(cmp_result, dout, self.zeros_like(dout))
  83. dy = dout - dx
  84. return dx, dy
  85. class MinimumGrad(GraphKernel):
  86. """
  87. Backprop function for Minimum operator.
  88. Compares x and y elementwise, dout must has the same shape with x and y.
  89. Inputs:
  90. - **x** (Tensor) - The first input
  91. - **y** (Tensor) - x and y must have same shape
  92. - **dout** (Tensor) - Has the same shape as x and y, next operator's backprop output
  93. Outputs:
  94. - dx (Tensor) - Has the same shape as x and y, returns dout element if
  95. `x <= y` returns true at the same position, or returns zero at that
  96. position
  97. - dy (Tensor) - Has the same shape as x and y, dy = dout - dx
  98. Examples:
  99. >>> layer = MinimumGrad()
  100. >>> output = layer(Tensor([1,2,3], [3, 2, 1], [4, 5, 6]))
  101. """
  102. def __init__(self, grad_x=True, grad_y=True):
  103. super(MinimumGrad, self).__init__()
  104. self.grad_x = grad_x
  105. self.grad_y = grad_y
  106. self.select = P.Select()
  107. self.less_equal = P.LessEqual()
  108. self.zeros_like = P.ZerosLike()
  109. self.sub = P.Sub()
  110. def construct(self, x, y, dout):
  111. cmp_result = self.less_equal(x, y)
  112. dx = self.select(cmp_result, dout, self.zeros_like(dout))
  113. dy = dout - dx
  114. return dx, dy
  115. class AbsGrad(GraphKernel):
  116. """
  117. Abs's backprop function.
  118. Inputs:
  119. **input_x** (Tensor) - input data of this operator.
  120. **dout** (Tensor) - output of the next operator's backprop function.
  121. Outputs:
  122. Tensor, has the same shape as input_x.
  123. Examples:
  124. >>> back = AbsGrad()
  125. >>> output = back(Tensor([1, 2, 3]), Tensor([4, 5, 6]))
  126. """
  127. def __init__(self):
  128. super(AbsGrad, self).__init__()
  129. self.mul = P.Mul()
  130. self.abs = P.Abs()
  131. self.add = P.TensorAdd()
  132. self.div = P.RealDiv()
  133. self.round = P.Round()
  134. def construct(self, input_x, dout):
  135. NUM_MAX = 32768
  136. mul_max = self.mul(input_x, P.Fill()(P.DType()(input_x), (1,), NUM_MAX))
  137. res_abs = self.abs(mul_max)
  138. res_div = self.div(mul_max, res_abs)
  139. res_round = self.round(res_div)
  140. res = self.mul(res_round, dout)
  141. return res
  142. class ApplyMomentum(GraphKernel):
  143. """
  144. Update parameter according to the ApplyMomentum algorithm.
  145. Inputs:
  146. variable (Tensor): mutable tensor var
  147. accumulation (Tensor): mutable tensor accum
  148. learning_rate (float32): learning rate
  149. gradient (float32): The gradient
  150. momentum (float32): Momentum
  151. Outputs: updated accumulation and variable
  152. """
  153. def __init__(self,
  154. use_nesterov=False,
  155. use_locking=False,
  156. gradient_scale=1.0):
  157. super(ApplyMomentum, self).__init__()
  158. self.gradient_scale = Validator.check_type('gradient_scale', gradient_scale, [float])
  159. self.fake_output_assign_1 = InplaceAssign()
  160. self.fake_output_assign_1.add_prim_attr("fake_output", True)
  161. self.fake_output_assign_2 = InplaceAssign()
  162. self.fake_output_assign_2.add_prim_attr("fake_output", True)
  163. def construct(self, variable, accumulation, learning_rate, gradient, momentum):
  164. gradient = gradient * self.gradient_scale
  165. momt_accumulation = accumulation * momentum
  166. accumulation_inplace = momt_accumulation + gradient
  167. sum_gradient = accumulation_inplace * learning_rate
  168. variable_inplace = variable - sum_gradient
  169. accumulation_inplace = self.fake_output_assign_1(accumulation, accumulation_inplace, accumulation_inplace)
  170. variable_inplace = self.fake_output_assign_2(variable, variable_inplace, variable_inplace)
  171. return accumulation_inplace, variable_inplace
  172. class BiasAdd(GraphKernel):
  173. """
  174. Return the sum of x and bias.
  175. Inputs:
  176. x (Tensor): Tensor of input data.
  177. bias (Tensor): The bias tensor.
  178. Output:
  179. Tensor, the sum of x and bias.
  180. Example:
  181. >>> layer = BiasGrad()
  182. >>> output = BiasAdd(Tensor([1, 2, 3]), Tensor([1,]))
  183. """
  184. def __init__(self):
  185. super(BiasAdd, self).__init__()
  186. def construct(self, x, bias):
  187. shape = P.Shape()(x)
  188. if len(shape) == 4:
  189. bias_shape = (1, P.Shape()(bias)[0], 1, 1) # NCHW
  190. else:
  191. bias_shape = (1, P.Shape()(bias)[0])
  192. res = x + P.Reshape()(bias, bias_shape)
  193. return res
  194. class BiasAddGrad(GraphKernel):
  195. """
  196. Computes gradients of BiasAdd.
  197. Inputs:
  198. x (Tensor): the gradients of bias add output.
  199. Output:
  200. Tensor, the gradients of bias add input.
  201. Examples:
  202. >>> dout = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
  203. >>> bias_add_grad = BiasAddGrad()
  204. >>> dx = bias_add_grad(dout)
  205. """
  206. def __init__(self):
  207. super(BiasAddGrad, self).__init__()
  208. def construct(self, x):
  209. shape_x = P.Shape()(x)
  210. reduce_axis = [0]
  211. for i in range(2, len(shape_x)):
  212. reduce_axis.append(i)
  213. res = P.ReduceSum()(x, reduce_axis)
  214. return res
  215. class EqualCount(GraphKernel):
  216. """
  217. Computes the number of the same elements of two tensors.
  218. The two input tensors must have the same shape and data type.
  219. Inputs:
  220. x (Tensor): the first input tensor.
  221. y (Tensor): the second input tensor.
  222. Outputs:
  223. Tensor, the type is same as input tensor and size as (1,).
  224. Examples:
  225. >>> x = Tensor(np.array([1, 2, 3]), mindspore.int32)
  226. >>> y = Tensor(np.array([1, 2, 4]), mindspore.int32)
  227. >>> equal_count = EqualCount()
  228. >>> equal_count(x, y)
  229. """
  230. def __init__(self):
  231. super(EqualCount, self).__init__()
  232. def construct(self, x, y):
  233. equal_bool = P.Equal()(P.Cast()(x, mstype.float32), P.Cast()(y, mstype.float32))
  234. equal_count = P.Cast()(equal_bool, mstype.float16)
  235. axes = (0,)
  236. res = P.ReduceSum()(equal_count, axes)
  237. res = P.Cast()(res, P.DType()(x))
  238. return res
  239. class ReduceMean(GraphKernel):
  240. """
  241. Reduce a dimension of a tensor by averaging all elements in the dimension.
  242. The dtype of the tensor to be reduced is number.
  243. Args:
  244. keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
  245. If false, don't keep these dimensions. Default: False.
  246. Inputs:
  247. - **input_x** (Tensor[Number]) - The input tensor.
  248. - **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
  249. Only constant value is allowed.
  250. Outputs:
  251. Tensor, has the same dtype as the `input_x`.
  252. - If axis is (), and keep_dims is False,
  253. the output is a 0-D tensor representing the sum of all elements in the input tensor.
  254. - If axis is int, set as 2, and keep_dims is False,
  255. the shape of output is :math:`(x_1, x_3, ..., x_R)`.
  256. - If axis is tuple(int), set as (2, 3), and keep_dims is False,
  257. the shape of output is :math:`(x_1, x_4, ..., x_R)`.
  258. Examples:
  259. >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  260. >>> op = ReduceMean(keep_dims=True)
  261. >>> output = op(input_x, 1)
  262. """
  263. def __init__(self, keep_dims=True):
  264. super(ReduceMean, self).__init__()
  265. self.keep_dims = Validator.check_type('keep_dims', keep_dims, [bool])
  266. self.sum = P.ReduceSum(self.keep_dims)
  267. def construct(self, x, axis):
  268. shape = P.Shape()(x)
  269. value_num = 1
  270. for i in axis:
  271. value_num *= shape[i]
  272. data_sum = self.sum(x, axis)
  273. avg = 1.0 / P.Fill()(P.DType()(x), (1,), value_num)
  274. res = data_sum * avg
  275. return res
  276. class ReLU(GraphKernel):
  277. r"""
  278. Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
  279. It returns :math:`\max(x,\ 0)` element-wise.
  280. Inputs:
  281. - **input_x** (Tensor) - The input tensor.
  282. Outputs:
  283. Tensor, with the same type and shape as the `input_x`.
  284. Examples:
  285. >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
  286. >>> relu = ReLU()
  287. >>> result = relu(input_x)
  288. [[0, 4.0, 0.0], [2.0, 0.0, 9.0]]
  289. """
  290. def __init__(self):
  291. super(ReLU, self).__init__()
  292. self.max = P.Maximum()
  293. def construct(self, x):
  294. return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x)
  295. class SoftmaxCrossEntropyWithLogits(GraphKernel):
  296. r"""
  297. Gets the softmax cross-entropy value between logits and labels which shoule be one-hot encoding.
  298. Note:
  299. Sets input logits as `X`, input label as `Y`, output as `loss`. Then,
  300. .. math::
  301. p_{ij} = softmax(X_{ij}) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)}
  302. .. math::
  303. loss_{ij} = -\sum_j{Y_{ij} * ln(p_{ij})}
  304. Inputs:
  305. - **logits** (Tensor) - Input logits, with shape :math:`(N, C)`.
  306. - **labels** (Tensor) - Ground truth labels, with shape :math:`(N, C)`.
  307. Outputs:
  308. Tuple of 2 Tensors, the loss shape is `(N,)`, and the dlogits with the same shape as `logits`.
  309. Examples:
  310. >>> logits = Tensor([[2, 4, 1, 4, 5], [2, 1, 2, 4, 3]], mindspore.float32)
  311. >>> labels = Tensor([[0, 0, 0, 0, 1], [0, 0, 0, 1, 0]], mindspore.float32)
  312. >>> softmax_cross = SoftmaxCrossEntropyWithLogits()
  313. >>> loss, backprop = softmax_cross(logits, labels)
  314. """
  315. def __init__(self):
  316. super(SoftmaxCrossEntropyWithLogits, self).__init__()
  317. self.max = P.ReduceMax(keep_dims=True)
  318. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  319. def construct(self, features, labels):
  320. data_max = self.max(features, (1,))
  321. data_sub = features - data_max
  322. data_exp = P.Exp()(data_sub)
  323. data_sum = self.sum_keep_dims(data_exp, (1,))
  324. data_div = data_exp / data_sum
  325. data_log_tmp = P.Log()(data_sum)
  326. data_log = data_sub - data_log_tmp
  327. data_mul = labels * data_log
  328. data_muls = P.Neg()(data_mul)
  329. loss = P.ReduceSum()(data_muls, (1,))
  330. backprop = data_div - labels
  331. return loss, backprop
  332. def bprop(self, features, labels, out, dout):
  333. grad = out[1]
  334. grad = grad * P.ExpandDims()(dout[0], -1)
  335. return grad, P.ZerosLike()(labels)
  336. class LayerNormForward(GraphKernel):
  337. """ Forward function of the LayerNorm operator. """
  338. def __init__(self, begin_norm_axis=1, begin_params_axis=1):
  339. super(LayerNormForward, self).__init__()
  340. self.begin_norm_axis = Validator.check_type('begin_norm_axis', begin_norm_axis, [int])
  341. self.begin_params_axis = Validator.check_type('begin_params_axis', begin_params_axis, [int])
  342. self.mul = P.Mul()
  343. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  344. self.sub = P.Sub()
  345. self.add = P.TensorAdd()
  346. self.log = P.Log()
  347. self.exp = P.Exp()
  348. self.eps = P.Eps()
  349. def construct(self, input_x, input_gamma, input_beta):
  350. shape_x = P.Shape()(input_x)
  351. # Calculate the scaling ratio of the average
  352. begin_norm_axis = self.begin_norm_axis
  353. if begin_norm_axis < 0:
  354. begin_norm_axis += len(shape_x)
  355. reduce_axis = ()
  356. for i in range(len(shape_x)):
  357. if i > begin_norm_axis or i == begin_norm_axis:
  358. reduce_axis = reduce_axis + (i,)
  359. reduce_elts = 1.0
  360. for i in reduce_axis:
  361. reduce_elts *= shape_x[i]
  362. mean_cof = 1.0 / reduce_elts
  363. # Calculate mean
  364. mean_muls = self.mul(input_x, mean_cof)
  365. mean = self.sum_keep_dims(mean_muls, reduce_axis)
  366. # Calculate variance
  367. variance_sub = self.sub(input_x, mean)
  368. variance_mul = self.mul(variance_sub, variance_sub)
  369. variance_muls = self.mul(variance_mul, mean_cof)
  370. variance = self.sum_keep_dims(variance_muls, reduce_axis)
  371. # Calculate normalize
  372. normalize_sub = self.sub(input_x, mean)
  373. epsilon = self.eps(input_x)
  374. normalize_add = self.add(variance, epsilon)
  375. normalize_log = self.log(normalize_add)
  376. normalize_log_mul = self.mul(normalize_log, -0.5)
  377. normalize_exp = self.exp(normalize_log_mul)
  378. normalize_mul = self.mul(normalize_sub, normalize_exp)
  379. # Calculate scale and translate
  380. if self.begin_params_axis == 0:
  381. scale_mul = self.mul(input_gamma, normalize_mul)
  382. res = self.add(scale_mul, input_beta)
  383. else:
  384. scale_mul = self.mul(input_gamma, normalize_mul)
  385. res = self.add(scale_mul, input_beta)
  386. return res, mean, variance
  387. class LayerNormXBackprop(GraphKernel):
  388. r"""
  389. Together with LayerNormBetaGammaBackprop, to supply the backprop
  390. functionality for LayerNorm.
  391. Note:
  392. Sets input_x as :math:`x_i`, variance as :math:`\sigma^2`, mean as :math:`\mu`,
  393. input_gamma as :math:`\gamma`. Then,
  394. .. math::
  395. \begin{array}{ll} \\
  396. \hat{x_i} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
  397. \frac {\partial L} {\partial x_i} =
  398. \frac{\gamma}{\sqrt{\sigma^2+\epsilon}}
  399. ( \frac{\partial L}{\partial y_i}
  400. - \frac{1}{m} \cdot \frac{\partial L}{\partial \beta}
  401. - \frac{\hat{x_i}}{m} \cdot \frac{\partial L}{\partial \gamma})
  402. \end{array}
  403. Inputs:
  404. - **dy**(Tensor) - The first item of the next operator's backprop's output.
  405. - **input_x**(Tensor) - The first input of the forward function of LayerNorm.
  406. - **variance**(Tensor) - The second input of the forward function of LayerNorm.
  407. - **mean**(Tensor) - The third input of the forward function of LayerNorm.
  408. - **input_gamma**(Tensor) - The fourth input of the forward function of LayerNorm.
  409. Outputs:
  410. Tensor, the output of this operator, will be used as the first item of the result of
  411. LayerNorm's backprop function, has the same shape and data type as 'input_x'.
  412. Examples:
  413. >>> dy = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  414. >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  415. >>> variance = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  416. >>> mean = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  417. >>> input_gamma = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  418. >>> op = LayerNormXBackprop(keep_dims=False)
  419. >>> output = op(dy, input_x, variance, mean, input_gamma)
  420. """
  421. def __init__(self):
  422. super(LayerNormXBackprop, self).__init__()
  423. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  424. self.log = P.Log()
  425. self.exp = P.Exp()
  426. self.eps = P.Eps()
  427. def construct(self, dy, input_x, variance, mean, input_gamma):
  428. shape_x = P.Shape()(input_x)
  429. shape_mean = P.Shape()(mean)
  430. reduce_axis = ()
  431. flag = -1
  432. min_l = 0
  433. if len(shape_x) > len(shape_mean):
  434. min_l = len(shape_x)
  435. else:
  436. min_l = len(shape_mean)
  437. for i in range(min_l):
  438. if (shape_x[i] != shape_mean[i]) and (flag == -1):
  439. flag = i
  440. if flag != -1:
  441. for i in range(flag, len(shape_x)):
  442. reduce_axis = reduce_axis + (i,)
  443. else:
  444. reduce_axis = reduce_axis + (len(shape_x) - 1,)
  445. mean_num = 1.0
  446. for i in reduce_axis:
  447. mean_num *= shape_x[i]
  448. pd_xl = input_gamma * dy
  449. epsilon = self.eps(input_x)
  450. var_elta = variance + epsilon
  451. var_elta_log = self.log(var_elta)
  452. var_elta_mul = var_elta_log * -0.5
  453. var_elta_2 = P.Exp()(var_elta_mul)
  454. pdvar1_mul = var_elta_2 * var_elta_2
  455. pd_var_1 = pdvar1_mul * var_elta_2
  456. sub_x_mean = input_x - mean
  457. pdvar_mul1 = pd_xl * sub_x_mean
  458. pdvar_sum = self.sum_keep_dims(pdvar_mul1, reduce_axis)
  459. pdvar_mul3 = pdvar_sum * pd_var_1
  460. pd_var = pdvar_mul3 * -0.5
  461. pdmean1_sum = self.sum_keep_dims(pd_xl, reduce_axis)
  462. pdmean1_mul = pdmean1_sum * var_elta_2
  463. pd_mean_1 = pdmean1_mul * -1.0
  464. pdmean2_mul1 = sub_x_mean * -2.0
  465. pdmean2_sum = self.sum_keep_dims(pdmean2_mul1, reduce_axis)
  466. pdmean2_mul3 = pdmean2_sum * (1.0 / mean_num)
  467. pd_mean_2 = pd_var * pdmean2_mul3
  468. pd_mean = pd_mean_2 + pd_mean_1
  469. pd_x_1 = var_elta_2 * pd_xl
  470. pdx2_mul = pd_var * sub_x_mean
  471. pd_x_2 = pdx2_mul * (2.0 * (1.0 / mean_num))
  472. pd_x_3 = pd_mean * (1.0 / mean_num)
  473. pdx_add = pd_x_1 + pd_x_2
  474. pd_x = pdx_add + pd_x_3
  475. return pd_x
  476. class LayerNormBetaGammaBackprop(GraphKernel):
  477. r"""
  478. Together with LayerNormXBackprop, to supply the backprop functionality for
  479. LayerNorm.
  480. Note:
  481. Sets input_x as :math:`x_i`, variance as :math:`\sigma^2`, mean as :math:`\mu`,
  482. input_gamma as :math:`\gamma`. Then,
  483. .. math::
  484. \begin{array}{ll} \\
  485. \hat{x_i} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
  486. \frac {\partial L} {\partial \beta} =
  487. \sum_{i=1}^m \\frac{\\partial L}{\partial y_i} \\
  488. \frac {\partial L} {\partial \gamma} =
  489. \sum_{i=1}^m \\frac{\partial L}{\partial y_i} \cdot \hat{x_i}
  490. \end{array}
  491. Inputs:
  492. - **dy**(Tensor) - The first item of the next operator's backprop's output.
  493. - **input_x**(Tensor) - The first input of the forward function of LayerNorm.
  494. - **variance**(Tensor) - The second input of the forward function of LayerNorm.
  495. - **mean**(Tensor) - The third input of the forward function of LayerNorm.
  496. - **input_gamma**(Tensor) - The fourth input of the forward function of LayerNorm.
  497. Outputs:
  498. Tuple of 2 Tensors, the backprop outputs.
  499. - **pd_beta**(Tensor) - The first item of return value of this operator, will be used as
  500. the second item of the LayerNorm's backprop function.
  501. - **pd_gamma**(Tensor) - The second item of return value of this operator, will be used as
  502. the third item of the LayerNorm's backprop function.
  503. Examples:
  504. >>> dy = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  505. >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  506. >>> variance = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  507. >>> mean = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  508. >>> input_gamma = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
  509. >>> op = LayerNormBetaGammaBackprop(keep_dims=False)
  510. >>> pd_beta, pd_gamma = op(dy, input_x, variance, mean, input_gamma)
  511. """
  512. def __init__(self):
  513. super(LayerNormBetaGammaBackprop, self).__init__()
  514. self.sum_not_keep_dims = P.ReduceSum(keep_dims=False)
  515. self.log = P.Log()
  516. self.exp = P.Exp()
  517. self.eps = P.Eps()
  518. def construct(self, dy, input_x, variance, mean, shape_gamma):
  519. shape_x = P.Shape()(input_x)
  520. params_axis = ()
  521. if len(shape_x) != len(shape_gamma):
  522. sub = len(shape_x) - len(shape_gamma)
  523. for i in range(sub):
  524. params_axis = params_axis + (i,)
  525. pd_beta = self.sum_not_keep_dims(dy, params_axis)
  526. epsilon = self.eps(input_x)
  527. var_elta = variance + epsilon
  528. var_elta_log = self.log(var_elta)
  529. var_elta_mul = var_elta_log * -0.5
  530. var_elta_2 = P.Exp()(var_elta_mul)
  531. sub_x_mean = input_x - mean
  532. var_elta_2_cast = var_elta_2
  533. xl_mul = var_elta_2_cast * sub_x_mean
  534. pdga_mul = dy * xl_mul
  535. pd_gamma = self.sum_not_keep_dims(pdga_mul, params_axis)
  536. return pd_beta, pd_gamma
  537. class LogSoftmax(GraphKernel):
  538. r"""
  539. Log Softmax activation function.
  540. Applies the Log Softmax function to the input tensor on the specified axis.
  541. Suppose a slice in the given aixs :math:`x` then for each element :math:`x_i`
  542. the Log Softmax function is shown as follows:
  543. .. math::
  544. \text{output}(x_i) = \log \left(\frac{exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right),
  545. where :math:`N` is the length of the Tensor.
  546. Args:
  547. axis (int): The axis to do the Log softmax operation. Default: -1.
  548. Inputs:
  549. logits (Tensor): The input of Log Softmax.
  550. Outputs:
  551. Tensor, with the same type and shape as the logits.
  552. Examples:
  553. >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
  554. >>> log_softmax = LogSoftmax()
  555. >>> log_softmax(input_x)
  556. [-4.4519143, -3.4519143, -2.4519143, -1.4519144, -0.4519144]
  557. """
  558. def __init__(self, axis=-1):
  559. super(LogSoftmax, self).__init__()
  560. self.axis = Validator.check_type('axis', axis, [int])
  561. self.max_keep_dims = P.ReduceMax(keep_dims=True)
  562. self.sub = P.Sub()
  563. self.exp = P.Exp()
  564. self.sum_keep_dims = P.ReduceSum(keep_dims=True)
  565. self.log = P.Log()
  566. self.mul = P.Mul()
  567. def construct(self, input_x):
  568. data_max = self.max_keep_dims(input_x, (self.axis,))
  569. data_sub = self.sub(input_x, data_max)
  570. data_exp = self.exp(data_sub)
  571. data_sum = self.sum_keep_dims(data_exp, (self.axis,))
  572. data_log = self.log(data_sum)
  573. res = self.sub(data_sub, data_log)
  574. return res
  575. def bprop(self, input_x, out, dout):
  576. input_x = out
  577. input_dy = dout
  578. data_exp = self.exp(input_x)
  579. data_sum = self.sum_keep_dims(input_dy, (self.axis,))
  580. data_softmax = self.mul(data_exp, data_sum)
  581. res = self.sub(input_dy, data_softmax)
  582. return (res,)
  583. class Tanh(GraphKernel):
  584. r"""
  585. Tanh activation function.
  586. Computes hyperbolic tangent of input element-wise. The Tanh function is defined as:
  587. .. math::
  588. tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},
  589. where :math:`x_i` is an element of the input Tensor.
  590. Inputs:
  591. - **input_x** (Tensor) - The input of Tanh.
  592. Outputs:
  593. Tensor, with the same type and shape as the input_x.
  594. Examples:
  595. >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
  596. >>> tanh = Tanh()
  597. >>> tanh(input_x)
  598. [0.7615941, 0.9640276, 0.9950548, 0.9993293, 0.99990916]
  599. """
  600. def __init__(self):
  601. super(Tanh, self).__init__()
  602. self.abs = P.Abs()
  603. self.add = P.TensorAdd()
  604. self.div = P.RealDiv()
  605. self.mul = P.Mul()
  606. self.mul_fp16 = P.Mul()
  607. self.mul_fp16.add_prim_attr("output_precision", "float16")
  608. self.exp = P.Exp()
  609. def construct(self, input_x):
  610. input_abs = self.abs(input_x)
  611. sign_flag = self.div(input_x, input_abs)
  612. sign_flag_neg = self.mul(sign_flag, -1.0)
  613. power_val = self.mul(input_abs, -2.0)
  614. exp_val = self.exp(power_val)
  615. up_val = self.add(exp_val, -1.0)
  616. down_val = self.add(exp_val, 1.0)
  617. div_val = self.div(up_val, down_val)
  618. res = self.mul(sign_flag_neg, div_val)
  619. return res
  620. def bprop(self, input_x, out, dout):
  621. input_y = out
  622. input_dy = dout
  623. data_square = self.mul(input_y, input_y)
  624. data_mul = self.mul(data_square, -1.0)
  625. anuminate = self.add(data_mul, 1.0)
  626. res = self.mul_fp16(anuminate, input_dy)
  627. return (res,)
  628. class TanhGrad(GraphKernel):
  629. """
  630. Backprop function of Tanh
  631. Mathematical calculating:
  632. result = Tanh(out)
  633. result = 1 - result * result
  634. result = result * dout
  635. Inputs:
  636. out (Tensor): Tanh's output
  637. dout (Tensor): next layer's backward function's output, has same shape as out
  638. Outputs:
  639. result (Tensor): result of (1 - tanh(out)^2) * dout
  640. Examples:
  641. >>> x_np = np.random.randn(5, 3, 6).astype(np.float16)
  642. >>> dy_np = np.random.randn(5, 3, 6).astype(np.float16)
  643. >>> x_ms = Tensor(x_np)
  644. >>> dy_ms = Tensor(dy_np)
  645. >>> tanh_grad = TanhGrad()
  646. >>> out = tanh_grad(x_np, dy_np)
  647. """
  648. def __init__(self):
  649. super(TanhGrad, self).__init__()
  650. self.add = P.TensorAdd()
  651. self.mul = P.Mul()
  652. self.mul_fp16 = P.Mul()
  653. self.mul_fp16.add_prim_attr("output_precision", "float16")
  654. def construct(self, out, dout):
  655. input_y = out
  656. input_dy = dout
  657. data_square = self.mul(input_y, input_y)
  658. data_mul = self.mul(data_square, -1.0)
  659. anuminate = self.add(data_mul, 1.0)
  660. res = self.mul_fp16(anuminate, input_dy)
  661. return res
  662. class Gelu(GraphKernel):
  663. r"""
  664. Gaussian Error Linear Units activation function.
  665. GeLU is described in the paper `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
  666. And also please refer to `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
  667. <https://arxiv.org/abs/1810.04805>`_.
  668. Defined as follows:
  669. .. math::
  670. \text{output} = 0.5 * x * (1 + erf(x / \sqrt{2})),
  671. where :math:`erf` is the "Gauss error function" .
  672. Inputs:
  673. - **input_x** (Tensor) - Input to compute the Gelu.
  674. Outputs:
  675. Tensor, with the same type and shape as input.
  676. Examples:
  677. >>> tensor = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
  678. >>> gelu = Gelu()
  679. >>> result = gelu(tensor)
  680. """
  681. def __init__(self):
  682. super(Gelu, self).__init__()
  683. self.add = P.TensorAdd()
  684. self.abs = P.Abs()
  685. self.exp = P.Exp()
  686. self.neg = P.Neg()
  687. self.minimum = P.Minimum()
  688. self.div = P.RealDiv()
  689. self.mul = P.Mul()
  690. self.CSVALUE = 0.044715
  691. self.CSVALUE_A = 1.59576912
  692. self.CSVALUE_5 = 0.3989422804
  693. self.CSVALUE_3B = 0.2140644488
  694. def construct(self, input_x):
  695. def _tanh_parameter_compute(data_x):
  696. """
  697. compute the parameter of tanh:
  698. return: result equal (x+0.044715*tf.pow(x,3))
  699. """
  700. mul_0 = self.mul(data_x, data_x)
  701. pow_0 = self.mul(mul_0, data_x)
  702. mul_1 = self.mul(pow_0, self.CSVALUE)
  703. result = self.add(data_x, mul_1)
  704. return result
  705. tanh_parameter = _tanh_parameter_compute(input_x)
  706. mul_0 = self.mul(tanh_parameter, 1.5957691)
  707. mul_0_min = self.minimum(mul_0, 0.0)
  708. right_mul = self.exp(mul_0_min)
  709. mul_0_abs = self.abs(mul_0)
  710. mul_0_abs_neg = self.mul(mul_0_abs, -1.0)
  711. mul_0_abs_neg_exp = self.exp(mul_0_abs_neg)
  712. mul_0_abs_neg_exp_add = self.add(mul_0_abs_neg_exp, 1.0)
  713. left_mul = self.div(input_x, mul_0_abs_neg_exp_add)
  714. result = self.mul(left_mul, right_mul)
  715. return result
  716. def bprop(self, input_x, out, dout):
  717. """ register backprop function for Gelu """
  718. data_x = input_x
  719. data_gelu = out
  720. data_dy = dout
  721. def _math_four_compute(data_x):
  722. """
  723. return: math_four equal 2*(np(sqrt(2 / np.pi)*(x + 0.044715*tf.pow(x, 3)))
  724. """
  725. datax_pow = data_x * data_x * data_x
  726. datax_muls_c = self.mul(datax_pow, self.CSVALUE)
  727. datax_addx = self.add(datax_muls_c, data_x)
  728. datax_muls_s = self.mul(datax_addx, self.CSVALUE_A)
  729. return datax_muls_s
  730. # common part
  731. math_four = _math_four_compute(data_x)
  732. math_four_abs = self.abs(math_four)
  733. math_four_abs_neg = self.mul(math_four_abs, -1.0)
  734. math_four_abs_neg_exp = self.exp(math_four_abs_neg)
  735. math_four_min = self.minimum(math_four, 0.0)
  736. # dividend part
  737. datax_pow = self.mul(data_x, data_x)
  738. datax_pow_mul = self.mul(datax_pow, self.CSVALUE_3B)
  739. datax_pow_mul_add = self.add(datax_pow_mul, self.CSVALUE_A)
  740. data_gelu_mul = self.mul(data_gelu, datax_pow_mul_add)
  741. math_four_min_2 = self.mul(math_four_min, 2.0)
  742. div_right = self.mul(data_gelu_mul, math_four_abs_neg_exp)
  743. div_left = self.exp(math_four_min_2)
  744. dividend = self.add(div_left, div_right)
  745. # divisor part
  746. div_0 = self.add(math_four_abs_neg_exp, 1.0)
  747. div_1 = self.exp(math_four_min)
  748. divisor = self.mul(div_1, div_0)
  749. res_grad = self.div(dividend, divisor)
  750. result = self.mul(res_grad, data_dy)
  751. return (result,)
  752. class Softmax(GraphKernel):
  753. """
  754. Operator Softmax
  755. .. math: `exp(x-max(x)) / sum(exp(x-max(x)))`
  756. Args:
  757. axis (int, tuple): Axis along which the softmax normalization is applied
  758. Inputs:
  759. x (Tensor): input data for softmax
  760. Outputs:
  761. output (Tensor): a tensor with the same shape of the input
  762. Examples:
  763. >>> layer = Softmax(1)
  764. >>> x = Tensor(np.array([1.2, 2.1], [2.2, 3.2]), mindspore.float32)
  765. >>> output = layer(x)
  766. """
  767. def __init__(self, axis):
  768. super(Softmax, self).__init__()
  769. Validator.check_type("axis", axis, [int, tuple])
  770. if isinstance(axis, int):
  771. self.axis = (axis,)
  772. else:
  773. self.axis = axis
  774. for item in self.axis:
  775. Validator.check_type("item of axis", item, [int])
  776. self.max = P.ReduceMax(keep_dims=True)
  777. self.sub = P.Sub()
  778. self.exp = P.Exp()
  779. self.sum = P.ReduceSum(keep_dims=True)
  780. self.mul = P.Mul()
  781. def construct(self, x):
  782. max_x = self.max(x, self.axis)
  783. data_sub = self.sub(x, max_x)
  784. data_exp = self.exp(data_sub)
  785. data_expsum = self.sum(data_exp, self.axis)
  786. output = data_exp / data_expsum
  787. return output
  788. def bprop(self, x, out, dout):
  789. mul_res = self.mul(dout, out)
  790. sum_res = self.sum(mul_res, self.axis)
  791. sub_res = self.sub(dout, sum_res)
  792. res = self.mul(sub_res, out)
  793. return (res,)
  794. class LayerNorm(Cell):
  795. r"""
  796. Applies Layer Normalization over a mini-batch of inputs.
  797. Layer normalization is widely used in recurrent neural networks. It applies
  798. normalization on a mini-batch of inputs for each single training case as described
  799. in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
  800. normalization, layer normalization performs exactly the same computation at training and
  801. testing time. It can be described using the following formula. It is applied across all channels
  802. and pixel but only one batch size.
  803. .. math::
  804. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  805. Args:
  806. normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
  807. `begin_norm_axis ... R - 1`.
  808. begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
  809. `begin_norm_axis: rank(inputs)`, the value must be in [-1, rank(input)). Default: -1.
  810. begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
  811. will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
  812. the normalized inputs accordingly, the value must be in [-1, rank(input)). Default: -1.
  813. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
  814. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  815. 'he_uniform', etc. Default: 'ones'.
  816. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
  817. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
  818. 'he_uniform', etc. Default: 'zeros'.
  819. Inputs:
  820. - **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
  821. and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
  822. Outputs:
  823. Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
  824. Examples:
  825. >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
  826. >>> shape1 = x.shape[1:]
  827. >>> m = G.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
  828. >>> m(x)
  829. """
  830. def __init__(self,
  831. begin_norm_axis=-1,
  832. begin_params_axis=-1
  833. ):
  834. super(LayerNorm, self).__init__()
  835. self.begin_norm_axis = begin_norm_axis
  836. self.begin_params_axis = begin_params_axis
  837. self.layer_norm = LayerNormForward(begin_norm_axis, begin_params_axis)
  838. self.layer_norm_x_grad = LayerNormXBackprop()
  839. self.layer_norm_beta_gamma = LayerNormBetaGammaBackprop()
  840. self.layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
  841. def construct(self, input_x, input_gamma, input_beta):
  842. return self.layer_norm(input_x, input_gamma, input_beta)
  843. # case 1
  844. def bprop(self, input_x, input_gamma, input_beta, out, dout):
  845. dx, d_gamma, d_beta = self.layer_norm_grad(input_x, dout[0], out[2], dout[1], input_gamma)
  846. return dx, d_gamma, d_beta
  847. class LambUpdateWithLR(GraphKernel):
  848. r"""
  849. Part of Lamb optimizer.
  850. .. math::
  851. s_1 = select(i_1 \gt y_g, select(i_0 \gt y_g, \frac{i_1}{i_2}, se), se)
  852. i_5 = i_5 - max(min(s_1, y_m), y_g) \times i_3 \times i_4
  853. Inputs:
  854. - **input0** (Tensor) - The first tensor to be computed.
  855. - **input1** (Tensor) - The second tensor to be computed.
  856. - **input2** (Tensor) - The third tensor to be computed.
  857. - **input3** (Tensor) - The fourth tensor to be computed.
  858. - **input4** (Tensor) - The fifth tensor to be computed.
  859. - **input5** (Tensor) - The sixth tensor to be computed. It will be updated by result.
  860. - **greater_y** (Tensor) - The seventh tensor to be computed.
  861. - **select_e** (Tensor) - The eighth tensor to be computed.
  862. - **minimum_y** (Tensor) - The ninth tensor to be computed.
  863. Outputs:
  864. A fake output tensor.
  865. Examples:
  866. >>> lamb_update = LambUpdateWithLR()
  867. >>> i0 = np.random.normal(0, 1, [1, 16]).astype(np.float32)
  868. >>> i1 = np.random.normal(0, 1, [1]).astype(np.float32)
  869. >>> i2 = np.random.normal(0, 1, [1]).astype(np.float32)
  870. >>> i3 = np.random.normal(0, 1, [1]).astype(np.float32)
  871. >>> i4 = np.random.normal(0, 1, [1, 16]).astype(np.float32)
  872. >>> i5 = np.random.normal(0, 1, [1, 16]).astype(np.float32)
  873. >>> yg = np.random.normal(0, 1, [1]).astype(np.float32)
  874. >>> se = np.random.normal(0, 1, [1]).astype(np.float32)
  875. >>> ym = np.random.normal(0, 1, [1]).astype(np.float32)
  876. >>> lamb_update(i0, i1, i2, i3, i4, i5, yg, se, ym)
  877. """
  878. def __init__(self):
  879. super(LambUpdateWithLR, self).__init__()
  880. self.greater = P.Greater()
  881. self.select = P.Select()
  882. self.div = P.RealDiv()
  883. self.min = P.Minimum()
  884. self.max = P.Maximum()
  885. self.mul = P.Mul()
  886. self.sub = P.Sub()
  887. self.fake_output_assign = InplaceAssign()
  888. self.fake_output_assign.add_prim_attr("fake_output", True)
  889. def construct(self, input0, input1, input2, input3, input4, input5, greater_y, select_e, minimum_y):
  890. greater0 = self.greater(input0, greater_y)
  891. greater1 = self.greater(input1, greater_y)
  892. real_div0 = self.div(input1, input2)
  893. select0 = self.select(greater0, real_div0, select_e)
  894. select1 = self.select(greater1, select0, select_e)
  895. min0 = self.min(select1, minimum_y)
  896. max0 = self.max(min0, greater_y)
  897. mul0 = self.mul(max0, input3)
  898. mul1 = self.mul(mul0, input4)
  899. sub0 = self.sub(input5, mul1)
  900. sub0 = self.fake_output_assign(input5, sub0, sub0)
  901. return sub0
  902. class LambNextMV(GraphKernel):
  903. r"""
  904. Part of Lamb optimizer.
  905. .. math::
  906. rd_0 = \frac{i_8 \times i_5 + i_9 \times i_4}{i6}
  907. rd_1 = \frac{x_0 \times i_2 + x_1 \times i_1}{i3}
  908. y_2 = \frac{rd_0}{\sqrt{rd_1 + x3}} + x_2 \times i_7
  909. y_3 = \frac{rd_0}{\sqrt{rd_1} + x3}
  910. i5 = i_8 \times i_5 + i_9 \times i_4
  911. i2 = x_0 \times i_2 + x_1 \times i_1
  912. Inputs:
  913. - **inputs1** (Tensor) - The first input tensor to be computed.
  914. - **inputs2** (Tensor) - The second input tensor to be computed. It will be updated by result.
  915. - **inputs3** (Tensor) - The third input tensor to be computed.
  916. - **inputs4** (Tensor) - The fourth input tensor to be computed.
  917. - **inputs5** (Tensor) - The fifth input tensor to be computed. It will be updated by result.
  918. - **inputs6** (Tensor) - The sixth input tensor to be computed.
  919. - **inputs7** (Tensor) - The seventh input tensor to be computed.
  920. - **inputs8** (Tensor) - The eighth input tensor to be computed.
  921. - **inputs9** (Tensor) - The ninth input tensor to be computed.
  922. - **inputsx0** (Tensor) - The tenth input tensor to be computed.
  923. - **inputsx1** (Tensor) - The eleventh input tensor to be computed.
  924. - **inputsx2** (Tensor) - The twelfth input tensor to be computed.
  925. - **inputsx3** (Tensor) - The thirteenth input tensor to be computed.
  926. Outputs:
  927. Tuple of 2 Tensors.
  928. - **add3** (Tensor) - the shape is the same as the one after broadcasting, and the data type is
  929. the one with higher precision or higher digits among the inputs.
  930. - **realdiv4** (Tensor) - the shape is the same as the one after broadcasting, and the data type is
  931. the one with higher precision or higher digits among the inputs.
  932. Examples:
  933. >>> lamb_next_mv = LambNextMV()
  934. >>> i1 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  935. >>> i2 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  936. >>> i3 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  937. >>> i4 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  938. >>> i5 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  939. >>> i6 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  940. >>> i7 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  941. >>> i8 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  942. >>> i9 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  943. >>> x0 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  944. >>> x1 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  945. >>> x2 = Tensor(np.random.normal(0, 1, [1, 16]).astype(np.float32))
  946. >>> x3 = Tensor(np.ones([1, 16]).astype(np.float32) * 1e-6)
  947. >>> lamb_next_mv(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3)
  948. """
  949. def __init__(self):
  950. super(LambNextMV, self).__init__()
  951. self.mul = P.Mul()
  952. self.add = P.TensorAdd()
  953. self.div = P.RealDiv()
  954. self.sqrt = P.Sqrt()
  955. self.rsqrt = P.Rsqrt()
  956. self.fake_output_assign_1 = InplaceAssign()
  957. self.fake_output_assign_1.add_prim_attr("fake_output", False)
  958. self.fake_output_assign_2 = InplaceAssign()
  959. self.fake_output_assign_2.add_prim_attr("fake_output", False)
  960. def construct(self, input1, input2, input3, input4, input5, input6, input7,
  961. input8, input9, inputx0, inputx1, inputx2, inputx3):
  962. mul3 = self.mul(inputx1, input1)
  963. mul2 = self.mul(inputx0, input2)
  964. add1 = self.add(mul2, mul3)
  965. realdiv1 = self.div(add1, input3)
  966. add2 = self.add(realdiv1, inputx3)
  967. sqrt0 = self.rsqrt(add2)
  968. sqrt1 = self.sqrt(realdiv1)
  969. add4 = self.add(sqrt1, inputx3)
  970. mul1 = self.mul(input9, input4)
  971. mul0 = self.mul(input8, input5)
  972. add0 = self.add(mul0, mul1)
  973. realdiv0 = self.div(add0, input6)
  974. realdiv2 = self.mul(realdiv0, sqrt0)
  975. realdiv4 = self.div(realdiv0, add4)
  976. mul4 = self.mul(inputx2, input7)
  977. add3 = self.add(realdiv2, mul4)
  978. add3 = self.fake_output_assign_1(input5, add0, add3)
  979. add3 = self.fake_output_assign_2(input2, add1, add3)
  980. return add3, realdiv4