You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_ops.py 123 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Operators for nn."""
  16. import math
  17. import operator
  18. from functools import reduce
  19. import numpy as np
  20. from ... import context
  21. from ..._c_expression import signature_rw as sig_rw
  22. from ..._c_expression import signature_kind as sig_kind
  23. from ..._checkparam import Validator as validator
  24. from ..._checkparam import Rel
  25. from ...common import dtype as mstype
  26. from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
  27. from ..operations.math_ops import _infer_shape_reduce
  28. def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False):
  29. """
  30. Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements.
  31. """
  32. def _raise_message():
  33. raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
  34. f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}")
  35. def _get_return_value():
  36. if isinstance(arg_value, int):
  37. ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value)
  38. elif len(arg_value) == 2:
  39. ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value
  40. elif len(arg_value) == 4:
  41. if not allow_four:
  42. _raise_message()
  43. ret = arg_value if ret_four else (arg_value[2], arg_value[3])
  44. else:
  45. _raise_message()
  46. return ret
  47. validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
  48. ret_value = _get_return_value()
  49. for item in ret_value:
  50. if isinstance(item, int) and item > 0:
  51. continue
  52. _raise_message()
  53. return ret_value
  54. class Flatten(PrimitiveWithInfer):
  55. r"""
  56. Flattens a tensor without changing its batch size on the 0-th axis.
  57. Inputs:
  58. - **input_x** (Tensor) - Tensor of shape :math:`(N, \ldots)` to be flattened.
  59. Outputs:
  60. Tensor, the shape of the output tensor is :math:`(N, X)`, where :math:`X` is
  61. the product of the remaining dimension.
  62. Examples:
  63. >>> input_tensor = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
  64. >>> flatten = P.Flatten()
  65. >>> output = flatten(input_tensor)
  66. >>> assert output.shape() == (1, 24)
  67. """
  68. @prim_attr_register
  69. def __init__(self):
  70. pass
  71. def infer_shape(self, input_x):
  72. validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
  73. prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:])
  74. return input_x[0], prod
  75. def infer_dtype(self, input_x):
  76. validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
  77. return input_x
  78. class Softmax(PrimitiveWithInfer):
  79. r"""
  80. Softmax operation.
  81. Applies the Softmax operation to the input tensor on the specified axis.
  82. Suppose a slice along the given aixs :math:`x` then for each element :math:`x_i`
  83. the Softmax function is shown as follows:
  84. .. math::
  85. \text{output}(x_i) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)},
  86. where :math:`N` is the length of the tensor.
  87. Args:
  88. axis (Union[int, tuple]): The axis to do the Softmax operation. Default: -1.
  89. Inputs:
  90. - **logits** (Tensor) - The input of Softmax.
  91. Outputs:
  92. Tensor, with the same type and shape as the logits.
  93. """
  94. @prim_attr_register
  95. def __init__(self, axis=-1):
  96. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  97. validator.check_value_type("axis", axis, [int, tuple], self.name)
  98. if isinstance(axis, int):
  99. self.add_prim_attr('axis', (axis,))
  100. for item in self.axis:
  101. validator.check_value_type("item of axis", item, [int], self.name)
  102. def infer_shape(self, logits):
  103. validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name)
  104. rank = len(logits)
  105. for axis_v in self.axis:
  106. validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
  107. return logits
  108. def infer_dtype(self, logits):
  109. validator.check_subclass("logits", logits, mstype.tensor, self.name)
  110. return logits
  111. class LogSoftmax(PrimitiveWithInfer):
  112. r"""
  113. Log Softmax activation function.
  114. Applies the Log Softmax function to the input tensor on the specified axis.
  115. Suppose a slice along the given aixs :math:`x` then for each element :math:`x_i`
  116. the Log Softmax function is shown as follows:
  117. .. math::
  118. \text{output}(x_i) = \log \left(\frac{exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right),
  119. where :math:`N` is the length of the Tensor.
  120. Args:
  121. axis (int): The axis to do the Log softmax operation. Default: -1.
  122. Inputs:
  123. - **logits** (Tensor) - The input of Log Softmax.
  124. Outputs:
  125. Tensor, with the same type and shape as the logits.
  126. """
  127. @prim_attr_register
  128. def __init__(self, axis=-1):
  129. validator.check_value_type("axis", axis, [int], self.name)
  130. def infer_shape(self, logits):
  131. rank = len(logits)
  132. validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name)
  133. return logits
  134. def infer_dtype(self, logits):
  135. validator.check_subclass("logits", logits, mstype.tensor, self.name)
  136. return logits
  137. class ReLU(PrimitiveWithInfer):
  138. r"""
  139. Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
  140. It returns :math:`\max(x,\ 0)` element-wise.
  141. Inputs:
  142. - **input_x** (Tensor) - The input tensor.
  143. Outputs:
  144. Tensor, with the same type and shape as the `input_x`.
  145. Examples:
  146. >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
  147. >>> relu = P.ReLU()
  148. >>> result = relu(input_x)
  149. [[0, 4.0, 0.0], [2.0, 0.0, 9.0]]
  150. """
  151. @prim_attr_register
  152. def __init__(self):
  153. """init ReLU"""
  154. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  155. def infer_shape(self, input_x):
  156. return input_x
  157. def infer_dtype(self, input_x):
  158. validator.check_tensor_type_same({'input_x': input_x}, mstype.number_type, self.name)
  159. return input_x
  160. class ReLU6(PrimitiveWithInfer):
  161. r"""
  162. Computes ReLU(Rectified Linear Unit) upper bounded by 6 of input tensor element-wise.
  163. It returns :math:`\min(\max(0,x), 6)` element-wise.
  164. Inputs:
  165. - **input_x** (Tensor) - The input tensor.
  166. Outputs:
  167. Tensor, with the same type and shape as the `input_x`.
  168. Examples:
  169. >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
  170. >>> relu6 = P.ReLU6()
  171. >>> result = relu6(input_x)
  172. """
  173. @prim_attr_register
  174. def __init__(self):
  175. """init ReLU6"""
  176. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  177. def infer_shape(self, input_x):
  178. return input_x
  179. def infer_dtype(self, input_x):
  180. validator.check_tensor_type_same({'input_x': input_x}, (mstype.float16, mstype.float32), self.name)
  181. return input_x
  182. class ReLUV2(PrimitiveWithInfer):
  183. r"""
  184. Computes ReLU(Rectified Linear Unit) of input tensor element-wise.
  185. It returns :math:`\max(x,\ 0)` element-wise.
  186. Inputs:
  187. - **input_x** (Tensor) - The input tensor should be a 4-D tensor.
  188. Outputs:
  189. - **output** (Tensor) - Has the same type and shape as the `input_x`.
  190. - **mask** (Tensor) - A tensor whose data type must be uint8.
  191. Examples:
  192. >>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32)
  193. >>> relu_v2 = P.ReLUV2()
  194. >>> output = relu_v2(input_x)
  195. ([[[[1., 0.], [0., 4.]], [[0., 6.], [7., 0.]]]],
  196. [[[[1, 0], [2, 0]], [[2, 0], [1, 0]]]])
  197. """
  198. @prim_attr_register
  199. def __init__(self):
  200. """init ReLUV2"""
  201. self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
  202. def __infer__(self, input_x):
  203. input_shape = list(input_x['shape'])
  204. input_dtype = input_x['dtype']
  205. mask_shape = []
  206. if len(input_shape) != 4:
  207. raise ValueError("The `input_x` should be a 4-D tensor, "
  208. f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}")
  209. for i in enumerate(input_shape):
  210. if i[0] == 1:
  211. if input_dtype == mstype.uint8 and input_dtype == mstype.int8:
  212. mask_shape.append((input_shape[1] + 31) // 32)
  213. else:
  214. mask_shape.append((input_shape[1] + 15) // 16)
  215. else:
  216. mask_shape.append(i[1])
  217. if input_dtype == mstype.uint8 and input_dtype == mstype.int8:
  218. mask_shape.append(4)
  219. else:
  220. mask_shape.append(2)
  221. output_shape = (input_x['shape'], mask_shape)
  222. validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name)
  223. validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name)
  224. mask_dtype = mstype.uint8
  225. output_dtype = (input_dtype, mask_dtype)
  226. return {'shape': output_shape,
  227. 'dtype': output_dtype,
  228. 'value': None}
  229. class Elu(PrimitiveWithInfer):
  230. r"""
  231. Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise.
  232. The data type of input tensor should be float.
  233. Args:
  234. alpha (float): The coefficient of negative factor whose type is float. Default: 1.0.
  235. Inputs:
  236. - **input_x** (Tensor) - The input tensor whose data type should be float.
  237. Outputs:
  238. Tensor, has the same shape and data type as `input_x`.
  239. Examples:
  240. >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
  241. >>> elu = P.Elu()
  242. >>> result = elu(input_x)
  243. Tensor([[-0.632 4.0 -0.999]
  244. [2.0 -0.993 9.0 ]], shape=(2, 3), dtype=mindspore.float32)
  245. """
  246. @prim_attr_register
  247. def __init__(self, alpha=1.0):
  248. """Init Elu"""
  249. validator.check_value_type("alpha", alpha, [float], self.name)
  250. def infer_shape(self, input_x):
  251. return input_x
  252. def infer_dtype(self, input_x):
  253. validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name)
  254. return input_x
  255. class HSwish(PrimitiveWithInfer):
  256. r"""
  257. Hard swish activation function.
  258. Applies hswish-type activation element-wise. The input is a Tensor with any valid shape.
  259. Hard swish is defined as:
  260. .. math::
  261. \text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
  262. where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
  263. Inputs:
  264. - **input_data** (Tensor) - The input of HSwish.
  265. Outputs:
  266. Tensor, with the same type and shape as the `input_data`.
  267. """
  268. @prim_attr_register
  269. def __init__(self):
  270. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  271. def infer_shape(self, xshape):
  272. return xshape
  273. def infer_dtype(self, x_dtype):
  274. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  275. return x_dtype
  276. class Sigmoid(PrimitiveWithInfer):
  277. r"""
  278. Sigmoid activation function.
  279. Computes Sigmoid of input element-wise. The Sigmoid function is defined as:
  280. .. math::
  281. \text{sigmoid}(x_i) = \frac{1}{1 + exp(-x_i)},
  282. where :math:`x_i` is the element of the input.
  283. Inputs:
  284. - **input_x** (Tensor) - The input of Sigmoid.
  285. Outputs:
  286. Tensor, with the same type and shape as the input_x.
  287. """
  288. @prim_attr_register
  289. def __init__(self):
  290. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  291. def infer_shape(self, input_x):
  292. return input_x
  293. def infer_dtype(self, input_x):
  294. validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name)
  295. return input_x
  296. class HSigmoid(PrimitiveWithInfer):
  297. r"""
  298. Hard sigmoid activation function.
  299. Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape.
  300. Hard sigmoid is defined as:
  301. .. math::
  302. \text{hsigmoid}(x_{i}) = max(0, min(1, \frac{2 * x_{i} + 5}{10})),
  303. where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor.
  304. Inputs:
  305. - **input_data** (Tensor) - The input of HSigmoid.
  306. Outputs:
  307. Tensor, with the same type and shape as the `input_data`.
  308. """
  309. @prim_attr_register
  310. def __init__(self):
  311. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  312. def infer_shape(self, x_shape):
  313. return x_shape
  314. def infer_dtype(self, x_dtype):
  315. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  316. return x_dtype
  317. class Tanh(PrimitiveWithInfer):
  318. r"""
  319. Tanh activation function.
  320. Computes hyperbolic tangent of input element-wise. The Tanh function is defined as:
  321. .. math::
  322. tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},
  323. where :math:`x_i` is an element of the input Tensor.
  324. Inputs:
  325. - **input_x** (Tensor) - The input of Tanh.
  326. Outputs:
  327. Tensor, with the same type and shape as the input_x.
  328. """
  329. @prim_attr_register
  330. def __init__(self):
  331. pass
  332. def infer_shape(self, input_x):
  333. return input_x
  334. def infer_dtype(self, input_x):
  335. validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
  336. return input_x
  337. class FusedBatchNorm(Primitive):
  338. r"""
  339. FusedBatchNorm is a BatchNorm that moving mean and moving variance will be computed instead of being loaded.
  340. Batch Normalization is widely used in convolutional networks. This operation applies
  341. Batch Normalization over input to avoid internal covariate shift as described in the
  342. paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
  343. Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
  344. feature using a mini-batch of data and the learned parameters which can be described
  345. in the following formula.
  346. .. math::
  347. y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
  348. where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
  349. Args:
  350. mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
  351. epsilon (float): A small value added for numerical stability. Default: 1e-5.
  352. momentum (float): The hyper parameter to compute moving average for running_mean and running_var
  353. (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
  354. Momentum value should be [0, 1]. Default: 0.9.
  355. Inputs:
  356. - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
  357. - **scale** (Tensor) - Tensor of shape :math:`(C,)`.
  358. - **bias** (Tensor) - Tensor of shape :math:`(C,)`.
  359. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  360. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  361. Outputs:
  362. Tuple of 5 Tensor, the normalized input and the updated parameters.
  363. - **output_x** (Tensor) - The same type and shape as the `input_x`.
  364. - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
  365. - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
  366. - **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
  367. - **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
  368. Examples:
  369. >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
  370. >>> scale = Tensor(np.ones([64]), mindspore.float32)
  371. >>> bias = Tensor(np.ones([64]), mindspore.float32)
  372. >>> mean = Tensor(np.ones([64]), mindspore.float32)
  373. >>> variance = Tensor(np.ones([64]), mindspore.float32)
  374. >>> op = P.FusedBatchNorm()
  375. >>> output = op(input_x, scale, bias, mean, variance)
  376. """
  377. @prim_attr_register
  378. def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
  379. self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
  380. outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
  381. self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name)
  382. self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
  383. self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
  384. class BatchNorm(PrimitiveWithInfer):
  385. r"""
  386. Batch Normalization for input data and updated parameters.
  387. Batch Normalization is widely used in convolutional neural networks. This operation
  388. applies Batch Normalization over input to avoid internal covariate shift as described
  389. in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
  390. Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
  391. features using a mini-batch of data and the learned parameters which can be described
  392. in the following formula,
  393. .. math::
  394. y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
  395. where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
  396. Args:
  397. is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training.
  398. If `is_training` is False, they're loaded from checkpoint during inference. Default: False.
  399. epsilon (float): A small value added for numerical stability. Default: 1e-5.
  400. Inputs:
  401. - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
  402. - **scale** (Tensor) - Tensor of shape :math:`(C,)`.
  403. - **bias** (Tensor) - Tensor of shape :math:`(C,)`.
  404. - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
  405. - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
  406. Outputs:
  407. Tuple of 5 Tensor, the normalized inputs and the updated parameters.
  408. - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
  409. - **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
  410. - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
  411. - **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`.
  412. - **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`.
  413. - **reserve_space_3** (Tensor) - Tensor of shape :math:`(C,)`.
  414. """
  415. @prim_attr_register
  416. def __init__(self, is_training=False, epsilon=1e-5):
  417. validator.check_value_type('is_training', is_training, (bool,), self.name)
  418. validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
  419. self.add_prim_attr('data_format', "NCHW")
  420. self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
  421. outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2',
  422. 'reserve_space_3'])
  423. def infer_shape(self, input_x, scale, bias, mean, variance):
  424. validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name)
  425. validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
  426. validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
  427. if not self.is_training:
  428. validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name)
  429. validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
  430. validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
  431. return (input_x, scale, scale, scale, scale, scale)
  432. def infer_dtype(self, input_x, scale, bias, mean, variance):
  433. validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name)
  434. args = {"scale": scale, "bias": bias}
  435. validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
  436. args_moving = {"mean": mean, "variance": variance}
  437. if self.is_training:
  438. valid_types = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None]
  439. validator.check_type_same(args_moving, valid_types, self.name)
  440. else:
  441. args_moving = {"mean": mean, "variance": variance}
  442. validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name)
  443. return (input_x, scale, bias, input_x, input_x, input_x)
  444. class Conv2D(PrimitiveWithInfer):
  445. r"""
  446. 2D convolution layer.
  447. Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
  448. where :math:`N` is batch size and :math:`C_{in}` is channel number. For each batch of shape
  449. :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
  450. .. math::
  451. out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
  452. where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
  453. from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
  454. filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
  455. of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
  456. :math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
  457. :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
  458. to split the input in the channel dimension.
  459. If the 'pad_mode' is set to be "valid", the output height and width will be
  460. :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
  461. (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
  462. :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
  463. (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
  464. The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
  465. <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_. More detailed introduction can be found here:
  466. http://cs231n.github.io/convolutional-networks/.
  467. Args:
  468. out_channel (int): The dimension of the output.
  469. kernel_size (Union[int, tuple[int]]): The kernel size of the 2D convolution.
  470. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
  471. 2 deconvolution, 3 depthwise convolution. Default: 1.
  472. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
  473. pad (int): The pad value to fill. Default: 0.
  474. stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1.
  475. dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1.
  476. group (int): Split input into groups. Default: 1.
  477. Returns:
  478. Tensor, the value that applied 2D convolution.
  479. Inputs:
  480. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  481. - **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2)`, then the shape is
  482. :math:`(C_{out}, C_{in}, K_1, K_2)`.
  483. Outputs:
  484. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  485. Examples:
  486. >>> input = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32)
  487. >>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
  488. >>> conv2d = P.Conv2D(out_channel=32, kernel_size=3)
  489. >>> conv2d(input, weight)
  490. """
  491. @prim_attr_register
  492. def __init__(self,
  493. out_channel,
  494. kernel_size,
  495. mode=1,
  496. pad_mode="valid",
  497. pad=0,
  498. stride=1,
  499. dilation=1,
  500. group=1):
  501. """init Conv2D"""
  502. self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
  503. self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
  504. self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
  505. self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
  506. self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
  507. self.add_prim_attr('dilation', self.dilation)
  508. validator.check_value_type('pad', pad, (int,), self.name)
  509. self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
  510. self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
  511. if self.pad_mode == 'pad':
  512. validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
  513. self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
  514. self.add_prim_attr('data_format', "NCHW")
  515. self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name)
  516. self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
  517. def infer_shape(self, x_shape, w_shape):
  518. validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
  519. validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
  520. validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
  521. validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
  522. validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
  523. kernel_size_h = w_shape[2]
  524. kernel_size_w = w_shape[3]
  525. stride_h = self.stride[2]
  526. stride_w = self.stride[3]
  527. dilation_h = self.dilation[2]
  528. dilation_w = self.dilation[3]
  529. if self.pad_mode == "valid":
  530. h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
  531. w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
  532. pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
  533. elif self.pad_mode == "same":
  534. h_out = math.ceil(x_shape[2] / stride_h)
  535. w_out = math.ceil(x_shape[3] / stride_w)
  536. pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
  537. pad_top = math.floor(pad_needed_h / 2)
  538. pad_bottom = pad_needed_h - pad_top
  539. pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
  540. pad_left = math.floor(pad_needed_w / 2)
  541. pad_right = pad_needed_w - pad_left
  542. elif self.pad_mode == 'pad':
  543. pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
  544. h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
  545. / stride_h
  546. w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
  547. / stride_w
  548. h_out = math.floor(h_out)
  549. w_out = math.floor(w_out)
  550. self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
  551. self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
  552. out_channel = self.out_channel
  553. out_shape = [x_shape[0], out_channel, h_out, w_out]
  554. return out_shape
  555. def infer_dtype(self, x_dtype, w_dtype):
  556. args = {'x': x_dtype, 'w': w_dtype}
  557. valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
  558. validator.check_tensor_type_same(args, valid_types, self.name)
  559. return x_dtype
  560. class DepthwiseConv2dNative(PrimitiveWithInfer):
  561. r"""
  562. Returns the depth-wise convolution value for the input.
  563. Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
  564. Given an input tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})` where :math:`N` is the batch size and a
  565. filter tensor with kernel size :math:`(ks_{h}, ks_{w})`, containing :math:`C_{in} * \text{channel_multiplier}`
  566. convolutional filters of depth 1; it applies different filters to each input channel (channel_multiplier channels
  567. for each with default value 1), then concatenates the results together. The output has
  568. :math:`\text{in_channels} * \text{channel_multiplier}` channels.
  569. Args:
  570. channel_multiplier (int): The multipiler for the original output conv.
  571. kernel_size (Union[int, tuple[int]]): The size of the conv kernel.
  572. mode (int): 0 Math convolution, 1 cross-correlation convolution ,
  573. 2 deconvolution, 3 depthwise convolution. Default: 3.
  574. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
  575. pad (int): The pad value to fill. Default: 0.
  576. stride (Union[int, tuple[int]]): The stride to apply conv filter. Default: 1.
  577. dilation (Union[int, tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
  578. group (int): Splits input into groups. Default: 1.
  579. Inputs:
  580. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  581. - **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2)`, then the shape is
  582. :math:`(K, C_{in}, K_1, K_2)`, `K` must be 1.
  583. Outputs:
  584. Tensor of shape :math:`(N, C_{in} * \text{channel_multiplier}, H_{out}, W_{out})`.
  585. Examples:
  586. >>> input = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32)
  587. >>> weight = Tensor(np.ones([1, 32, 3, 3]), mindspore.float32)
  588. >>> depthwise_conv2d = P.DepthwiseConv2dNative(channel_multiplier = 3, kernel_size = (3, 3))
  589. >>> output = depthwise_conv2d(input, weight)
  590. >>> assert output.shape() == (10, 96, 30, 30)
  591. """
  592. @prim_attr_register
  593. def __init__(self,
  594. channel_multiplier,
  595. kernel_size,
  596. mode=3,
  597. pad_mode="valid",
  598. pad=0,
  599. stride=1,
  600. dilation=1,
  601. group=1):
  602. """init DepthwiseConv2dNative"""
  603. self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
  604. self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
  605. self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
  606. if self.stride[0] != self.stride[1]:
  607. raise ValueError("The height and width of stride should be equal,"
  608. f"but got height:{self.stride[0]}, width:{self.stride[1]}")
  609. self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
  610. self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name)
  611. if self.dilation[0] != self.dilation[1]:
  612. raise ValueError("The height and width of dilation should be equal,"
  613. f"but got height:{self.dilation[0]}, width:{self.dilation[1]}")
  614. self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1]))
  615. validator.check_value_type('pad', pad, (int,), self.name)
  616. self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
  617. self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
  618. self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name)
  619. self.add_prim_attr('data_format', "NCHW")
  620. self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT,
  621. self.name)
  622. self.group = validator.check_integer("group", group, 0, Rel.GT, self.name)
  623. def infer_shape(self, x_shape, w_shape):
  624. validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
  625. validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
  626. validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
  627. validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
  628. kernel_size_n, _, kernel_size_h, kernel_size_w = w_shape
  629. _, _, stride_h, stride_w = self.stride
  630. _, _, dilation_h, dilation_w = self.dilation
  631. if kernel_size_n != 1:
  632. raise ValueError(f"The batch of input weight should be 1, but got {kernel_size_n}")
  633. if self.pad_mode == "valid":
  634. h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
  635. w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
  636. pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
  637. elif self.pad_mode == "same":
  638. h_out = math.ceil(x_shape[2] / stride_h)
  639. w_out = math.ceil(x_shape[3] / stride_w)
  640. pad_needed_h = max(0, (h_out - 1) * stride_h+ dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
  641. pad_top = math.floor(pad_needed_h / 2)
  642. pad_bottom = pad_needed_h - pad_top
  643. pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
  644. pad_left = math.floor(pad_needed_w / 2)
  645. pad_right = pad_needed_w - pad_left
  646. elif self.pad_mode == 'pad':
  647. pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
  648. h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
  649. / stride_h
  650. w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
  651. / stride_w
  652. h_out = math.floor(h_out)
  653. w_out = math.floor(w_out)
  654. self.pad_list = (pad_top, pad_bottom, pad_left, pad_right)
  655. self.add_prim_attr('pads', self.pad_list)
  656. out_channel = self.channel_multiplier * x_shape[1]
  657. out_shape = [x_shape[0], out_channel, h_out, w_out]
  658. return out_shape
  659. def infer_dtype(self, x_dtype, w_dtype):
  660. args = {'x': x_dtype, 'w': w_dtype}
  661. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  662. return x_dtype
  663. class _Pool(PrimitiveWithInfer):
  664. r"""
  665. Performs max/avg pooling operation.
  666. Args:
  667. ksize (Union[int, tuple[int]]): The size of the kernel, that should be a tuple
  668. of two `int` for height and width. Default: 1.
  669. strides (Union[int, tuple[int]]): The stride of the window, that should be
  670. a tuple of two `int` for height and width. Default: 1.
  671. padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
  672. Default: "valid".
  673. """
  674. @prim_attr_register
  675. def __init__(self, ksize=1, strides=1, padding="valid"):
  676. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  677. validator.check_value_type('ksize', ksize, [int, tuple], self.name)
  678. validator.check_value_type('strides', strides, [int, tuple], self.name)
  679. self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
  680. self.add_prim_attr("padding", self.padding)
  681. self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
  682. if not self.is_maxpoolwithargmax:
  683. self.add_prim_attr('data_format', "NCHW")
  684. self.ksize = _check_positive_int_or_tuple("ksize", ksize, self.name, allow_four=False, ret_four=True)
  685. if self.is_maxpoolwithargmax:
  686. self.ksize = (1, self.ksize[-2], self.ksize[-1], 1)
  687. self.add_prim_attr("ksize", self.ksize)
  688. self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True)
  689. if self.is_maxpoolwithargmax:
  690. self.strides = (1, self.strides[-2], self.strides[-1], 1)
  691. self.add_prim_attr("strides", self.strides)
  692. def infer_shape(self, x_shape):
  693. validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
  694. batch, channel, input_h, input_w = x_shape
  695. if self.is_maxpoolwithargmax:
  696. _, kernel_h, kernel_w, _ = self.ksize
  697. _, stride_h, stride_w, _ = self.strides
  698. else:
  699. _, _, kernel_h, kernel_w = self.ksize
  700. _, _, stride_h, stride_w = self.strides
  701. if self.padding == "VALID":
  702. out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h)
  703. out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w)
  704. elif self.padding == "SAME":
  705. out_h = math.ceil(input_h / stride_h)
  706. out_w = math.ceil(input_w / stride_w)
  707. out_shape = [batch, channel, out_h, out_w]
  708. for shape_value in out_shape:
  709. if shape_value <= 0:
  710. raise ValueError(f"For '{self.name}' The kernel size is not valid, "
  711. f"please check it if is larger than data's shape size.")
  712. return out_shape
  713. def infer_dtype(self, x_dtype):
  714. validator.check_subclass("input", x_dtype, mstype.tensor, self.name)
  715. return x_dtype
  716. class MaxPool(_Pool):
  717. r"""
  718. Max pooling operation.
  719. Applies a 2D max pooling over an input Tensor which can be regarded as a composition of 2D planes.
  720. Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs
  721. regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
  722. :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
  723. .. math::
  724. \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
  725. \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
  726. Args:
  727. ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value,
  728. is an int number that represents height and width are both ksize, or a tuple
  729. of two int numbers that represent height and width respectively. Default: 1.
  730. strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
  731. the height and width of movement are both strides, or a tuple of two int numbers that
  732. represent height and width of movement respectively. Default: 1.
  733. padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
  734. Default: "valid".
  735. - same: Adopts the way of completion. Output height and width will be the same as
  736. the input. Total number of padding will be calculated for horizontal and vertical
  737. direction and evenly distributed to top and bottom, left and right if possible.
  738. Otherwise, the last extra padding will be done from the bottom and the right side.
  739. - valid: Adopts the way of discarding. The possibly largest height and width of output
  740. will be return without padding. Extra pixels will be discarded.
  741. Inputs:
  742. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  743. Outputs:
  744. Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  745. Examples:
  746. >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
  747. >>> maxpool_op = P.MaxPool(padding="VALID", ksize=2, strides=1)
  748. >>> output_tensor = maxpool_op(input_tensor)
  749. """
  750. @prim_attr_register
  751. def __init__(self, ksize=1, strides=1, padding="valid"):
  752. super(MaxPool, self).__init__(ksize, strides, padding)
  753. class MaxPoolWithArgmax(_Pool):
  754. r"""
  755. Performs max pooling on the input Tensor and return both max values and indices.
  756. Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs
  757. regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
  758. :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
  759. .. math::
  760. \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
  761. \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
  762. Args:
  763. ksize (Union[int, tuple[int]]): The size of kernel used to take the maximum value and arg value,
  764. is an int number that represents height and width are both ksize, or a tuple of
  765. two int numbers that represent height and width respectively. Default: 1.
  766. strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
  767. the height and width of movement are both strides, or a tuple of two int numbers that
  768. represent height and width of movement respectively. Default: 1.
  769. padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
  770. Default: "valid".
  771. - same: Adopts the way of completion. Output height and width will be the same as
  772. the input. Total number of padding will be calculated for horizontal and vertical
  773. direction and evenly distributed to top and bottom, left and right if possible.
  774. Otherwise, the last extra padding will be done from the bottom and the right side.
  775. - valid: Adopts the way of discarding. The possibly largest height and width of output
  776. will be return without padding. Extra pixels will be discarded.
  777. Inputs:
  778. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  779. Outputs:
  780. Tuple of 2 Tensor, the maxpool result and where max values from.
  781. - **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  782. - **mask** (Tensor) - Max values' index represented by the mask.
  783. Examples:
  784. >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32)
  785. >>> maxpool_arg_op = P.MaxPoolWithArgmax(padding="VALID", ksize=2, strides=1)
  786. >>> output_tensor, argmax = maxpool_arg_op(input_tensor)
  787. """
  788. def __init__(self, ksize=1, strides=1, padding="valid"):
  789. super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding)
  790. self.is_tbe = context.get_context("device_target") == "Ascend"
  791. def infer_shape(self, x_shape):
  792. out_shape = _Pool.infer_shape(self, x_shape)
  793. _, _, out_h, out_w = out_shape
  794. _, kernel_h, kernel_w, _ = self.ksize
  795. argmax_shape = []
  796. if self.is_tbe:
  797. for i in range(4):
  798. if i == 2:
  799. dim = kernel_h * kernel_w
  800. argmax_shape.append(dim)
  801. elif i == 3:
  802. dim = math.ceil(out_h * out_w / 16) + 1
  803. argmax_shape.append(dim)
  804. else:
  805. argmax_shape.append(x_shape[i])
  806. else:
  807. argmax_shape = out_shape
  808. return out_shape, argmax_shape
  809. def infer_dtype(self, x_dtype):
  810. out_dtype = x_dtype
  811. validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
  812. argmax_dtype = mstype.uint16
  813. return out_dtype, argmax_dtype
  814. class AvgPool(_Pool):
  815. r"""
  816. Average pooling operation.
  817. Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D input planes.
  818. Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool2d outputs
  819. regional average in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
  820. :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows.
  821. .. math::
  822. \text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
  823. \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
  824. Args:
  825. ksize (Union[int, tuple[int]]): The size of kernel used to take the average value,
  826. is an int number that represents height and width are both ksize, or a tuple
  827. of two int numbers that represent height and width respectively. Default: 1.
  828. strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
  829. the height and width of movement are both strides, or a tuple of two int numbers that
  830. represent height and width of movement respectively. Default: 1.
  831. padding (str): The optional values for pad mode, is "same" or "valid", not case sensitive.
  832. Default: "valid".
  833. - same: Adopts the way of completion. Output height and width will be the same as
  834. the input. Total number of padding will be calculated for horizontal and vertical
  835. direction and evenly distributed to top and bottom, left and right if possible.
  836. Otherwise, the last extra padding will be done from the bottom and the right side.
  837. - valid: Adopts the way of discarding. The possibly largest height and width of output
  838. will be return without padding. Extra pixels will be discarded.
  839. Inputs:
  840. - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
  841. Outputs:
  842. Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`.
  843. """
  844. @prim_attr_register
  845. def __init__(self, ksize=1, strides=1, padding="valid"):
  846. if context.get_context("device_target") == "GPU":
  847. self.target = "GPU"
  848. else:
  849. self.target = "OTHER"
  850. super(AvgPool, self).__init__(ksize, strides, padding)
  851. class Conv2DBackpropInput(PrimitiveWithInfer):
  852. """
  853. Computes the gradients of convolution with respect to the input.
  854. Args:
  855. out_channel (int): The dimensionality of the output space.
  856. kernel_size (Union[int, tuple[int]]): The size of the convolution window.
  857. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
  858. pad (int): The pad value to fill. Default: 0.
  859. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
  860. 2 deconvolution, 3 depthwise convolution. Default: 1.
  861. stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1.
  862. dilation (Union[int. tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
  863. group (int): Splits input into groups. Default: 1.
  864. Returns:
  865. Tensor, the gradients of convolution.
  866. Examples:
  867. >>> dout = Tensor(np.ones([10, 32, 30, 30]), mindspore.float32)
  868. >>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32)
  869. >>> x = Tensor(np.ones([10, 32, 32, 32]))
  870. >>> conv2d_backprop_input = P.Conv2DBackpropInput(out_channel=32, kernel_size=3)
  871. >>> conv2d_backprop_input(dout, weight, F.shape(x))
  872. """
  873. @prim_attr_register
  874. def __init__(self,
  875. out_channel,
  876. kernel_size,
  877. pad_mode="valid",
  878. pad=0,
  879. pad_list=None,
  880. mode=1,
  881. stride=1,
  882. dilation=1,
  883. group=1):
  884. """init Conv2DBackpropInput"""
  885. self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
  886. self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name)
  887. self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
  888. self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False)
  889. self.add_prim_attr('stride', self.stride)
  890. self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
  891. self.add_prim_attr('dilation', self.dilation)
  892. validator.check_value_type('pad', pad, (int,), self.name)
  893. self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
  894. self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
  895. pad_mode = pad_mode.upper()
  896. self.add_prim_attr('pad_mode', pad_mode)
  897. self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
  898. self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
  899. self.add_prim_attr('data_format', "NCHW")
  900. if pad_list:
  901. for x in pad_list:
  902. validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name)
  903. self.pad_list = pad_list
  904. def __infer__(self, doutput, w, x_size):
  905. x_size_v = x_size['value']
  906. validator.check_value_type('x_size', x_size_v, [tuple], self.name)
  907. for i, dim_len in enumerate(x_size_v):
  908. validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name)
  909. args = {'doutput': doutput['dtype'], 'w': w['dtype']}
  910. valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
  911. validator.check_tensor_type_same(args, valid_types, self.name)
  912. # infer shape
  913. dout_shape = doutput['shape']
  914. kernel_h = self.kernel_size[0]
  915. kernel_w = self.kernel_size[1]
  916. stride_h = self.stride[0]
  917. stride_w = self.stride[1]
  918. # default pad mode is valid
  919. pad_list = (0, 0, 0, 0)
  920. if self.pad_list:
  921. pad_list = tuple(self.pad_list)
  922. elif self.pad_mode == "SAME":
  923. pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + kernel_h - x_size_v[2])
  924. pad_top = math.floor(pad_needed_h / 2)
  925. pad_bottom = pad_needed_h - pad_top
  926. pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + kernel_w - x_size_v[3])
  927. pad_left = math.floor(pad_needed_w / 2)
  928. pad_right = pad_needed_w - pad_left
  929. pad_list = (pad_top, pad_bottom, pad_left, pad_right)
  930. elif self.pad_mode == 'PAD':
  931. pad_list = (self.pad,) * 4
  932. self.add_prim_attr('pad_list', pad_list)
  933. out = {
  934. 'value': None,
  935. 'shape': x_size_v,
  936. 'dtype': doutput['dtype'],
  937. }
  938. return out
  939. class BiasAdd(PrimitiveWithInfer):
  940. r"""
  941. Returns sum of input and bias tensor.
  942. Adds the 1-D bias tensor to the input tensor, and boardcasts the shape on all axis
  943. except for the channel axis.
  944. Inputs:
  945. - **input_x** (Tensor) - Input value, with shape :math:`(N, C)` or :math:`(N, C, H, W)`.
  946. - **bias** (Tensor) - Bias value, with shape :math:`(C)`.
  947. Outputs:
  948. Tensor, with the same shape and type as `input_x`.
  949. """
  950. @prim_attr_register
  951. def __init__(self):
  952. self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
  953. self.add_prim_attr('data_format', 'NCHW')
  954. def infer_shape(self, x_shape, b_shape):
  955. validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name)
  956. validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name)
  957. validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name)
  958. return x_shape
  959. def infer_dtype(self, x_type, b_type):
  960. args = {"input_x": x_type, "bias": b_type}
  961. valid_types = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
  962. validator.check_tensor_type_same(args, valid_types, self.name)
  963. return x_type
  964. class TopK(PrimitiveWithInfer):
  965. """
  966. Finds values and indices of the `k` largest entries along the last dimension.
  967. Args:
  968. sorted (bool): If true, the resulting elements will
  969. be sorted by the values in descending order. Default: False.
  970. Inputs:
  971. - **input_x** (Tensor) - Input to be computed.
  972. - **k** (int) - Number of top elements to be computed along the last dimension, constant input is needed.
  973. Outputs:
  974. Tuple of 2 Tensor, the values and the indices.
  975. - **values** (Tensor) - The `k` largest elements along each last dimensional slice.
  976. - **indices** (Tensor) - The indices of values within the last dimension of input.
  977. Examples:
  978. >>> topk = P.TopK(sorted=True)
  979. >>> input_x = Tensor([1, 2, 3, 4, 5], mindspore.float16)
  980. >>> k = 3
  981. >>> values, indices = topk(input_x, k)
  982. >>> assert values == Tensor(np.array([5, 4, 3]), mstype.float16)
  983. >>> assert indices == Tensor(np.array([4, 3, 2]), mstype.int32)
  984. """
  985. @prim_attr_register
  986. def __init__(self, sorted=False):
  987. validator.check_value_type("sorted", sorted, [bool], self.name)
  988. self.init_prim_io_names(inputs=['input', 'k'],
  989. outputs=['values', 'indices'])
  990. def __infer__(self, input_x, k):
  991. x_dtype = input_x['dtype']
  992. valid_types = (mstype.int32, mstype.float16, mstype.float32)
  993. validator.check_tensor_type_same({'x': x_dtype}, valid_types, self.name)
  994. k_v = k['value']
  995. validator.check_value_type('k', k_v, (int,), self.name)
  996. x_shape = list(input_x['shape'])
  997. ndim = len(x_shape) - 1
  998. x_shape[ndim] = k_v
  999. return {'shape': (x_shape, x_shape),
  1000. 'dtype': (x_dtype, mstype.int32),
  1001. 'value': None}
  1002. class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
  1003. r"""
  1004. Gets the softmax cross-entropy value between logits and labels which shoule be one-hot encoding.
  1005. Note:
  1006. Sets input logits as `X`, input label as `Y`, output as `loss`. Then,
  1007. .. math::
  1008. p_{ij} = softmax(X_{ij}) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)}
  1009. .. math::
  1010. loss_{ij} = -\sum_j{Y_{ij} * ln(p_{ij})}
  1011. Inputs:
  1012. - **logits** (Tensor) - Input logits, with shape :math:`(N, C)`.
  1013. - **labels** (Tensor) - Ground truth labels, with shape :math:`(N, C)`.
  1014. Outputs:
  1015. Tuple of 2 Tensor, the loss shape is `(N,)`, and the dlogits with the same shape as `logits`.
  1016. Examples:
  1017. Please refer to the usage in nn.SoftmaxCrossEntropyWithLogits source code.
  1018. """
  1019. @prim_attr_register
  1020. def __init__(self):
  1021. pass
  1022. def infer_shape(self, logits_shape, labels_shape):
  1023. validator.check("logits_shape", logits_shape, "labels_shape", labels_shape, Rel.EQ, self.name)
  1024. loss_shape = [logits_shape[0]]
  1025. dlogits_shape = logits_shape
  1026. return (loss_shape, dlogits_shape)
  1027. def infer_dtype(self, logits_type, labels_type):
  1028. args = {"logits": logits_type, "labels": labels_type}
  1029. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  1030. return (logits_type, logits_type)
  1031. class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
  1032. r"""
  1033. Computes the softmax cross-entropy value between logits and sparse encoding labels.
  1034. Note:
  1035. Sets input logits as `X`, input label as `Y`, output as `loss`. Then,
  1036. .. math::
  1037. p_{ij} = softmax(X_{ij}) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)}
  1038. .. math::
  1039. loss_{ij} = \begin{cases} -ln(p_{ij}), &j = y_i \cr -ln(1 - p_{ij}), & j \neq y_i \end{cases}
  1040. .. math::
  1041. loss = \sum_{ij} loss_{ij}
  1042. Args:
  1043. is_grad (bool): If it's true, this operation returns the computed gradient. Default: False.
  1044. Inputs:
  1045. - **logits** (Tensor) - Input logits, with shape :math:`(N, C)`.
  1046. - **labels** (Tensor) - Ground truth labels, with shape :math:`(N)`.
  1047. Outputs:
  1048. Tensor, if `is_grad` is False, the output tensor is the value of loss which is a scalar tensor;
  1049. if `is_grad` is True, the output tensor is the gradient of input with the same shape as `logits`.
  1050. Examples:
  1051. Please refer to the usage in nn.SoftmaxCrossEntropyWithLogits source code.
  1052. """
  1053. @prim_attr_register
  1054. def __init__(self, is_grad=False):
  1055. self.init_prim_io_names(inputs=['features', 'labels'], outputs=['output'])
  1056. self.is_grad = is_grad
  1057. self.add_prim_attr('sens', 1.0)
  1058. def infer_shape(self, logits_shape, labels_shape):
  1059. validator.check("logits_shape[0]", logits_shape[0], "labels_shape[0]", labels_shape[0], Rel.EQ, self.name)
  1060. loss_shape = []
  1061. if self.is_grad:
  1062. return logits_shape
  1063. return loss_shape
  1064. def infer_dtype(self, logits_type, labels_type):
  1065. validator.check_tensor_type_same({"logits": logits_type}, (mstype.float16, mstype.float32), self.name)
  1066. validator.check_tensor_type_same({"labels": labels_type}, (mstype.int32, mstype.int64), self.name)
  1067. return logits_type
  1068. class ApplyMomentum(PrimitiveWithInfer):
  1069. """
  1070. Optimizer that implements the Momentum algorithm.
  1071. Refer to the paper `On the importance of initialization and momentum in deep
  1072. learning <https://dl.acm.org/doi/10.5555/3042817.3043064>`_ for more details.
  1073. Args:
  1074. use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False.
  1075. use_nesterov (bool): Enable Nesterov momentum. Default: False.
  1076. gradient_scale (float): The scale of the gradient. Default: 1.0.
  1077. Inputs:
  1078. - **variable** (Tensor) - Weights to be updated.
  1079. - **accumulation** (Tensor) - Accumulated gradient value by moment weight.
  1080. - **learning_rate** (float) - Learning rate.
  1081. - **gradient** (Tensor) - Gradients.
  1082. - **momentum** (float) - Momentum.
  1083. Outputs:
  1084. Tensor, parameters to be updated.
  1085. Examples:
  1086. Please refer to the usage in nn.ApplyMomentum.
  1087. """
  1088. __mindspore_signature__ = (
  1089. ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
  1090. ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
  1091. ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
  1092. ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
  1093. ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
  1094. )
  1095. @prim_attr_register
  1096. def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):
  1097. self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'],
  1098. outputs=['output'])
  1099. def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape):
  1100. return v_shape
  1101. def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype):
  1102. valid_types = [mstype.float16, mstype.float32, mstype.float64]
  1103. if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey:
  1104. validator.check_tensor_type_same({"v": v_dtype}, valid_types, self.name)
  1105. validator.check_tensor_type_same({"a": a_dtype}, valid_types, self.name)
  1106. validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name)
  1107. validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name)
  1108. validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name)
  1109. return g_dtype
  1110. class SmoothL1Loss(PrimitiveWithInfer):
  1111. r"""
  1112. Computes smooth L1 loss, a robust L1 loss.
  1113. SmoothL1Loss is a Loss similar to MSELoss but less sensitive to outliers as described in the
  1114. `Fast R-CNN <https://arxiv.org/abs/1504.08083>`_ by Ross Girshick.
  1115. Note:
  1116. Sets input prediction as `X`, input target as `Y`, output as `loss`. Then,
  1117. .. math::
  1118. \text{SmoothL1Loss} = \begin{cases}0.5x^{2}, &if \left |x \right |\leq \text{sigma} \cr
  1119. \left |x \right|-0.5, &\text{otherwise}\end{cases}
  1120. Args:
  1121. sigma (float): A parameter used to control the point where the function will change from
  1122. quadratic to linear. Default: 1.0.
  1123. Inputs:
  1124. - **prediction** (Tensor) - Predict data.
  1125. - **target** (Tensor) - Ground truth data, with the same type and shape as `prediction`.
  1126. Outputs:
  1127. Tensor, with the same type and shape as `prediction`.
  1128. """
  1129. @prim_attr_register
  1130. def __init__(self, sigma=1.0):
  1131. validator.check_value_type('sigma', sigma, [float], self.name)
  1132. validator.check('sigma', sigma, '', 0, Rel.GT, self.name)
  1133. self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
  1134. def infer_shape(self, prediction, target):
  1135. validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
  1136. return prediction
  1137. def infer_dtype(self, prediction, target):
  1138. args = {"prediction": prediction, "target": target}
  1139. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  1140. return prediction
  1141. class L2Loss(PrimitiveWithInfer):
  1142. """
  1143. Calculates half of the L2 norm of a tensor without using the `sqrt`.
  1144. Set `input_x` as x and output as loss.
  1145. .. math::
  1146. loss = sum(x ** 2) / 2
  1147. Inputs:
  1148. - **input_x** (Tensor) - A input Tensor.
  1149. Outputs:
  1150. Tensor. Has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor.
  1151. Examples
  1152. >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16)
  1153. >>> l2_loss = P.L2Loss()
  1154. >>> l2_loss(input_x)
  1155. 7.0
  1156. """
  1157. @prim_attr_register
  1158. def __init__(self):
  1159. """init L2Loss"""
  1160. def infer_shape(self, input_x):
  1161. loss_shape = []
  1162. return loss_shape
  1163. def infer_dtype(self, x_type):
  1164. validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
  1165. validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name)
  1166. return x_type
  1167. class SGD(PrimitiveWithInfer):
  1168. """
  1169. Computes stochastic gradient descent (optionally with momentum).
  1170. Nesterov momentum is based on the formula from On the importance of
  1171. initialization and momentum in deep learning.
  1172. Note:
  1173. For details, please refer to `nn.SGD` source code.
  1174. Args:
  1175. dampening (float): The dampening for momentum. Default: 0.0.
  1176. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
  1177. nesterov (bool): Enable Nesterov momentum. Default: False.
  1178. Inputs:
  1179. - **parameters** (Tensor) - Parameters to be updated. Their data type can be list or tuple.
  1180. - **gradient** (Tensor) - Gradients.
  1181. - **learning_rate** (Tensor) - Learning rate. Must be float value. e.g. Tensor(0.1, mindspore.float32).
  1182. - **accum** (Tensor) - Accum(velocity) to be updated.
  1183. - **momentum** (Tensor) - Momentum. e.g. Tensor(0.1, mindspore.float32).
  1184. - **stat** (Tensor) - States to be updated with the same shape as gradient.
  1185. Outputs:
  1186. Tensor, parameters to be updated.
  1187. """
  1188. @prim_attr_register
  1189. def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
  1190. validator.check_value_type("nesterov", nesterov, [bool], self.name)
  1191. self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'],
  1192. outputs=['output'])
  1193. def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
  1194. accum_shape, momentum_shape, stat_shape):
  1195. validator.check_integer(f'parameters rank', len(parameters_shape), 0, Rel.GT, self.name)
  1196. validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name)
  1197. validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name)
  1198. validator.check_integer(f'accumulation rank', len(accum_shape), 0, Rel.GT, self.name)
  1199. validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name)
  1200. validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name)
  1201. validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
  1202. return parameters_shape
  1203. def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype,
  1204. accum_dtype, momentum_dtype, stat_dtype):
  1205. valid_types = [mstype.float16, mstype.float32]
  1206. validator.check_tensor_type_same({"parameters": parameters_dtype}, valid_types, self.name)
  1207. validator.check_tensor_type_same({"gradient": gradient_dtype}, valid_types, self.name)
  1208. validator.check_tensor_type_same({"learning_rate": learning_rate_dtype}, valid_types, self.name)
  1209. validator.check_tensor_type_same({"accum": accum_dtype}, valid_types, self.name)
  1210. validator.check_tensor_type_same({"momentum": momentum_dtype}, valid_types, self.name)
  1211. validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name)
  1212. return parameters_dtype
  1213. class ApplyRMSProp(PrimitiveWithInfer):
  1214. """
  1215. Optimizer that implements the Root Mean Square prop(RMSProp) algorithm.
  1216. Please refer to the usage in source code of `nn.RMSProp`.
  1217. Note:
  1218. Update `var` according to the RMSProp algorithm.
  1219. .. math::
  1220. s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
  1221. .. math::
  1222. m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} + \\epsilon}} \\nabla Q_{i}(w)
  1223. .. math::
  1224. w = w - m_{t}
  1225. where, :math:`w` represents `var`, which will be updated.
  1226. :math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
  1227. :math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
  1228. :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
  1229. :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
  1230. :math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`.
  1231. Args:
  1232. use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
  1233. Inputs:
  1234. - **var** (Tensor) - Weights to be update.
  1235. - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
  1236. - **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
  1237. - **grad** (Tensor) - Gradients, must have the same type as `var`.
  1238. - **learning_rate** (Union[Number, Tensor]) - Learning rate.
  1239. - **decay** (float) - Decay rate.
  1240. - **momentum** (float) - Momentum.
  1241. - **epsilon** (float) - Ridge term.
  1242. Outputs:
  1243. Tensor, parameters to be update.
  1244. """
  1245. @prim_attr_register
  1246. def __init__(self, use_locking=False):
  1247. self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
  1248. def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape,
  1249. momentum_shape, epsilon_shape):
  1250. validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
  1251. validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
  1252. validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
  1253. return var_shape
  1254. def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype,
  1255. momentum_dtype, epsilon_dtype):
  1256. args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
  1257. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1258. valid_types = [mstype.float16, mstype.float32]
  1259. args_decay = {"decay": decay_dtype, 'momentum': momentum_dtype, "epsilon": epsilon_dtype}
  1260. validator.check_type_same(args_decay, valid_types, self.name)
  1261. args_lr = {"learning_rate": learning_rate_dtype, "decay": decay_dtype}
  1262. validator.check_scalar_or_tensor_type_same(args_lr, valid_types, self.name, allow_mix=True)
  1263. return var_dtype
  1264. class ApplyCenteredRMSProp(PrimitiveWithInfer):
  1265. """
  1266. Optimizer that implements the centered RMSProp algorithm.
  1267. Please refer to the usage in source code of `nn.RMSProp`.
  1268. Note:
  1269. Update `var` according to the centered RMSProp algorithm.
  1270. .. math::
  1271. g_{t} = \\rho g_{t-1} + (1 - \\rho)\\nabla Q_{i}(w)
  1272. .. math::
  1273. s_{t} = \\rho s_{t-1} + (1 - \\rho)(\\nabla Q_{i}(w))^2
  1274. .. math::
  1275. m_{t} = \\beta m_{t-1} + \\frac{\\eta} {\\sqrt{s_{t} - g_{t}^2 + \\epsilon}} \\nabla Q_{i}(w)
  1276. .. math::
  1277. w = w - m_{t}
  1278. where, :math:`w` represents `var`, which will be updated.
  1279. :math:`g_{t}` represents `mean_gradient`, :math:`g_{t-1}` is the last momentent of :math:`g_{t}`.
  1280. :math:`s_{t}` represents `mean_square`, :math:`s_{t-1}` is the last momentent of :math:`s_{t}`,
  1281. :math:`m_{t}` represents `moment`, :math:`m_{t-1}` is the last momentent of :math:`m_{t}`.
  1282. :math:`\\rho` represents `decay`. :math:`\\beta` is the momentum term, represents `momentum`.
  1283. :math:`\\epsilon` is a smoothing term to avoid division by zero, represents `epsilon`.
  1284. :math:`\\eta` represents `learning_rate`. :math:`\\nabla Q_{i}(w)` represents `grad`.
  1285. Args:
  1286. use_locking (bool): Enable a lock to protect the update of variable tensors. Default: False.
  1287. Inputs:
  1288. - **var** (Tensor) - Weights to be update.
  1289. - **mean_gradient** (Tensor) - Mean gradients, must have the same type as `var`.
  1290. - **mean_square** (Tensor) - Mean square gradients, must have the same type as `var`.
  1291. - **moment** (Tensor) - Delta of `var`, must have the same type as `var`.
  1292. - **grad** (Tensor) - Gradients, must have the same type as `var`.
  1293. - **learning_rate** (Union[Number, Tensor]) - Learning rate.
  1294. - **decay** (float) - Decay rate.
  1295. - **momentum** (float) - Momentum.
  1296. - **epsilon** (float) - Ridge term.
  1297. Outputs:
  1298. Tensor, parameters to be update.
  1299. """
  1300. @prim_attr_register
  1301. def __init__(self, use_locking=False):
  1302. self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
  1303. def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape,
  1304. learning_rate_shape, decay_shape, momentum_shape, epsilon_shape):
  1305. validator.check("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape, Rel.EQ, self.name)
  1306. validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name)
  1307. validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name)
  1308. validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
  1309. return var_shape
  1310. def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype,
  1311. learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype):
  1312. args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype,
  1313. "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype}
  1314. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1315. args = {"learning_rate": learning_rate_dtype, "rho": rho_dtype, 'momentum': momentum_dtype,
  1316. "epsilon": epsilon_dtype}
  1317. validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
  1318. return var_dtype
  1319. class LayerNorm(Primitive):
  1320. r"""
  1321. Applies the Layer Normalization to the input tensor.
  1322. This operator will normalize the input tensor on given axis. LayerNorm is described in the paper
  1323. `Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
  1324. .. math::
  1325. y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
  1326. where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
  1327. Args:
  1328. begin_norm_axis (int): The begin axis of the `input_x` to apply LayerNorm,
  1329. the value should be in [-1, rank(input)). Default: 1.
  1330. begin_params_axis (int): The begin axis of the parameter input (`gamma`, `beta`) to
  1331. apply LayerNorm, the value should be in [-1, rank(input)). Default: 1.
  1332. Inputs:
  1333. - **input_x** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
  1334. The input of LayerNorm.
  1335. - **gamma** (Tensor) - Tensor of shape :math:`(P_0, \ldots, P_\text{begin_params_axis})`.
  1336. The learnable parameter `gamma` as the scale on norm.
  1337. - **beta** (Tensor) - Tensor of shape :math:`(P_0, \ldots, P_\text{begin_params_axis})`.
  1338. The learnable parameter `beta` as the scale on norm.
  1339. Outputs:
  1340. tuple[Tensor], tuple of 3 tensors, the normalized input and the updated parameters.
  1341. - **output_x** (Tensor) - The normalized input, has the same type and shape as the `input_x`.
  1342. The shape is :math:`(N, C)`.
  1343. - **updated_gamma** (Tensor) - Tensor of shape :math:`(C,)`.
  1344. - **updated_beta** (Tensor) - Tensor of shape :math:`(C,)`.
  1345. Examples:
  1346. >>> input_x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]), mindspore.float32)
  1347. >>> gamma = Tensor(np.ones([3]), mindspore.float32)
  1348. >>> beta = Tensor(np.ones([3]), mindspore.float32)
  1349. >>> layer_norm = P.LayerNorm()
  1350. >>> output = layer_norm(input_x, gamma, beta)
  1351. ([[-0.22474492, 1., 2.2247488], [-0.22474492, 1., 2.2247488]],
  1352. [[2.], [2.]], [[0.6666667], [0.6666667]])
  1353. """
  1354. @prim_attr_register
  1355. def __init__(self, begin_norm_axis=1, begin_params_axis=1):
  1356. validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
  1357. validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
  1358. class L2Normalize(PrimitiveWithInfer):
  1359. r"""
  1360. L2 normalization Operator.
  1361. This operator will normalizes the input using the given axis. The function is shown as follows:
  1362. .. math::
  1363. \text{output} = \frac{x}{\sqrt{\text{max}(\text{sum} (\text{input_x}^2), \epsilon)}},
  1364. where :math:`\epsilon` is epsilon.
  1365. Args:
  1366. axis (int): The begin axis for the input to apply L2 normalize. Default: 0.
  1367. epsilon (float): A small value added for numerical stability. Default: 1e-4.
  1368. Inputs:
  1369. - **input_x** (Tensor) - Input to compute the normalization.
  1370. Outputs:
  1371. Tensor, with the same type and shape as the input.
  1372. """
  1373. @prim_attr_register
  1374. def __init__(self, axis=0, epsilon=1e-4):
  1375. validator.check_value_type('axis', axis, [int], self.name)
  1376. validator.check_value_type('epsilon', epsilon, [int, float], self.name)
  1377. def infer_shape(self, input_x):
  1378. dim = len(input_x)
  1379. validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
  1380. return input_x
  1381. def infer_dtype(self, input_x):
  1382. validator.check_subclass("x", input_x, mstype.tensor, self.name)
  1383. return input_x
  1384. class DropoutGenMask(Primitive):
  1385. """
  1386. Generates the mask value for the input shape.
  1387. Args:
  1388. Seed0 (int): Seed0 value for random generating. Default: 0.
  1389. Seed1 (int): Seed1 value for random generating. Default: 0.
  1390. Inputs:
  1391. - **shape** (tuple[int]) - The shape of target mask.
  1392. - **keep_prob** (Tensor) - The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
  1393. means dropping out 10% of input units.
  1394. Outputs:
  1395. Tensor, the value of generated mask for input shape.
  1396. Examples:
  1397. >>> dropout_gen_mask = P.DropoutGenMask()
  1398. >>> shape = (20, 16, 50)
  1399. >>> keep_prob = Tensor(0.5, mindspore.float32)
  1400. >>> mask = dropout_gen_mask(shape, keep_prob)
  1401. """
  1402. @prim_attr_register
  1403. def __init__(self, Seed0=0, Seed1=0):
  1404. self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
  1405. validator.check_value_type("Seed0", Seed0, [int], self.name)
  1406. validator.check_value_type("Seed1", Seed1, [int], self.name)
  1407. class DropoutDoMask(PrimitiveWithInfer):
  1408. """
  1409. Applies dropout mask on the input tensor.
  1410. Take the mask output of DropoutGenMask as input, and apply dropout on the input.
  1411. Inputs:
  1412. - **input_x** (Tensor) - The input tensor.
  1413. - **mask** (Tensor) - The mask to be applied on `input_x`, which is the output of `DropoutGenMask`. And the
  1414. shape of `input_x` must be same as the value of `DropoutGenMask`'s input `shape`. If input wrong `mask`,
  1415. the output of `DropoutDoMask` are unpredictable.
  1416. - **keep_prob** (Tensor) - The keep rate, between 0 and 1, e.g. keep_prob = 0.9,
  1417. means dropping out 10% of input units. The value of `keep_prob` is same as the input `keep_prob` of
  1418. `DropoutGenMask`.
  1419. Outputs:
  1420. Tensor, the value that applied dropout on.
  1421. Examples:
  1422. >>> x = Tensor(np.ones([20, 16, 50]), mindspore.float32)
  1423. >>> shape = (20, 16, 50)
  1424. >>> keep_prob = Tensor(0.5, mindspore.float32)
  1425. >>> dropout_gen_mask = P.DropoutGenMask()
  1426. >>> dropout_do_mask = P.DropoutDoMask()
  1427. >>> mask = dropout_gen_mask(shape, keep_prob)
  1428. >>> output = dropout_do_mask(x, mask, keep_prob)
  1429. >>> assert output.shape() == (20, 16, 50)
  1430. """
  1431. @prim_attr_register
  1432. def __init__(self):
  1433. pass
  1434. def __infer__(self, input_x, mask, keep_prob):
  1435. input_x_shape = input_x['shape']
  1436. mask_shape = mask['shape']
  1437. keep_prob_shape = keep_prob['shape']
  1438. validator.check("keep_prob's dim", len(keep_prob_shape), '0(scalar)', 0, Rel.EQ, self.name)
  1439. size_x = reduce(lambda x, y: x * y, input_x_shape)
  1440. if len(mask_shape) != 1:
  1441. raise ValueError("DropoutDoMask mask shape should be 1-dimension.")
  1442. size_y = mask_shape[0] * 8
  1443. if size_x > size_y:
  1444. raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:"
  1445. "{input_x_shape}, mask shape: {mask_shape}.")
  1446. validator.check_tensor_type_same({"input_x": input_x['dtype']}, [mstype.float32, mstype.float16, mstype.int32],
  1447. self.name)
  1448. validator.check_tensor_type_same({"input_mask": mask['dtype']}, [mstype.uint8], self.name)
  1449. keep_prob_v = keep_prob['value']
  1450. if keep_prob_v is not None:
  1451. validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name)
  1452. out = {'shape': input_x_shape,
  1453. 'dtype': input_x['dtype'],
  1454. 'value': None}
  1455. return out
  1456. class ResizeBilinear(PrimitiveWithInfer):
  1457. r"""
  1458. Resizes the image to certain size using bilinear interpolation.
  1459. The resizing only affects the lower two dimensions which represent the height and width. The input images
  1460. can be represented by different data types, but the data types of output images are always float32.
  1461. Args:
  1462. size (tuple[int]): A tuple of 2 int elements `(new_height, new_width)`, the new size for the images.
  1463. align_corners (bool): If it's true, rescale input by `(new_height - 1) / (height - 1)`,
  1464. which exactly aligns the 4 corners of images and resized images. If it's false,
  1465. rescale by `new_height / height`. Default: False.
  1466. Inputs:
  1467. - **input** (Tensor) - Image to be resized. Tensor of shape `(N_i, ..., N_n, height, width)`.
  1468. Outputs:
  1469. Tensor, resized image. Tensor of shape `(N_i, ..., N_n, new_height, new_width)` in `float32`.
  1470. Examples:
  1471. >>> tensor = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.int32)
  1472. >>> resize_bilinear = P.ResizeBilinear((5, 5))
  1473. >>> result = resize_bilinear(tensor)
  1474. >>> assert result.shape() == (5, 5)
  1475. """
  1476. @prim_attr_register
  1477. def __init__(self, size, align_corners=False):
  1478. pass
  1479. def infer_shape(self, input_shape):
  1480. input_shape = list(input_shape)
  1481. batch, channel, _, _ = input_shape
  1482. out_shape = [batch, channel]
  1483. for i in self.size:
  1484. out_shape.append(int(i))
  1485. return out_shape
  1486. def infer_dtype(self, input_dtype):
  1487. return mstype.tensor_type(mstype.float32)
  1488. class OneHot(PrimitiveWithInfer):
  1489. r"""
  1490. Computes a one-hot tensor.
  1491. Makes a new tensor, whose locations represented by indices in `indices` take value `on_value`, while all
  1492. other locations take value `off_value`.
  1493. Note:
  1494. If the input indices is rank `N`, the output will have rank `N+1`. The new axis is created at dimension `axis`.
  1495. Args:
  1496. axis (int): Position to insert the value. e.g. If `indices` shape is [n, c], and `axis` is `-1` the output shape
  1497. will be [n, c, depth], If `axis` is `0` the output shape will be [depth, n, c]. Default: -1.
  1498. Inputs:
  1499. - **indices** (Tensor) - A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
  1500. - **depth** (int) - A scalar defining the depth of the one hot dimension.
  1501. - **on_value** (Tensor) - A value to fill in output when `indices[j] = i`.
  1502. - **off_value** (Tensor) - A value to fill in output when `indices[j] != i`.
  1503. Outputs:
  1504. Tensor, one_hot tensor. Tensor of shape :math:`(X_0, \ldots, X_{axis}, \text{depth} ,X_{axis+1}, \ldots, X_n)`.
  1505. Examples:
  1506. >>> indices = Tensor(np.array([0, 1, 2]), mindspore.int32)
  1507. >>> depth, on_value, off_value = 3, Tensor(1.0, mindspore.float32), Tensor(0.0, mindspore.float32)
  1508. >>> onehot = P.OneHot()
  1509. >>> result = onehot(indices, depth, on_value, off_value)
  1510. [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
  1511. """
  1512. @prim_attr_register
  1513. def __init__(self, axis=-1):
  1514. self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output'])
  1515. validator.check_value_type("axis", axis, [int], self.name)
  1516. def __infer__(self, indices, depth, on_value, off_value):
  1517. # check type
  1518. validator.check_tensor_type_same({"indices": indices['dtype']}, (mstype.int32,), self.name)
  1519. validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name)
  1520. args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']}
  1521. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  1522. # check shape
  1523. indices_shp = indices['shape']
  1524. validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name)
  1525. depth_val = depth['value']
  1526. validator.check_integer("depth", depth_val, 0, Rel.GE, self.name)
  1527. # create new dimension at end if self.axis is -1
  1528. indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val)
  1529. return {'shape': indices_shp,
  1530. 'dtype': on_value['dtype'],
  1531. 'value': None}
  1532. class Gelu(PrimitiveWithInfer):
  1533. r"""
  1534. Gaussian Error Linear Units activation function.
  1535. GeLU is described in the paper `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_.
  1536. And also please refer to `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
  1537. <https://arxiv.org/abs/1810.04805>`_.
  1538. Defined as follows:
  1539. .. math::
  1540. \text{output} = 0.5 * x * (1 + erf(x / \sqrt{2})),
  1541. where :math:`erf` is the "Gauss error function" .
  1542. Inputs:
  1543. - **input_x** (Tensor) - Input to compute the Gelu.
  1544. Outputs:
  1545. Tensor, with the same type and shape as input.
  1546. Examples:
  1547. >>> tensor = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
  1548. >>> gelu = P.Gelu()
  1549. >>> result = gelu(tensor)
  1550. """
  1551. @prim_attr_register
  1552. def __init__(self):
  1553. """init GeLU"""
  1554. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  1555. def infer_shape(self, input_x):
  1556. return input_x
  1557. def infer_dtype(self, input_x):
  1558. validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name)
  1559. return input_x
  1560. class GetNext(PrimitiveWithInfer):
  1561. """
  1562. Returns the next element in the dataset queue.
  1563. Note:
  1564. GetNext op needs to be associated with network and also depends on the init_dataset interface,
  1565. it can't be used directly as a single op.
  1566. For details, please refer to `nn.cell_wrapper.DataWrapper` source code.
  1567. Args:
  1568. types (list[:class:`mindspore.dtype`]): The type of the outputs.
  1569. shapes (list[tuple[int]]): The dimensionality of the outputs.
  1570. output_num (int): The output number, length of `types` and `shapes`.
  1571. shared_name (str): The queue name of `init_dataset` interface.
  1572. Inputs:
  1573. No inputs.
  1574. Outputs:
  1575. tuple[Tensor], the output of Dataset. The shape is described in `shapes`
  1576. and the type is described is `types`.
  1577. Examples:
  1578. >>> get_next = P.GetNext([mindspore.float32, mindspore.int32], [[32, 1, 28, 28], [10]], 2, 'shared_name')
  1579. >>> feature, label = get_next()
  1580. """
  1581. @prim_attr_register
  1582. def __init__(self, types, shapes, output_num, shared_name):
  1583. validator.check_value_type("types", types, [list, tuple], self.name)
  1584. validator.check_value_type("shapes", shapes, [list, tuple], self.name)
  1585. validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name)
  1586. validator.check_value_type("output_num", output_num, [int], self.name)
  1587. def infer_shape(self):
  1588. return tuple(self.shapes)
  1589. def infer_dtype(self):
  1590. return tuple(self.types)
  1591. class PReLU(PrimitiveWithInfer):
  1592. r"""
  1593. Parametric Rectified Linear Unit activation function.
  1594. PReLU is described in the paper `Delving Deep into Rectifiers: Surpassing Human-Level Performance on
  1595. ImageNet Classification <https://arxiv.org/abs/1502.01852>`_. Defined as follows:
  1596. .. math::
  1597. prelu(x_i)= \max(0, x_i) + \min(0, w * x_i),
  1598. where :math:`x_i` is an element of an channel of the input.
  1599. Inputs:
  1600. - **input_x** (Tensor) - Float tensor, representing the output of the preview layer.
  1601. - **weight** (Tensor) - Float Tensor, w > 0, there is only two shapes are legitimate,
  1602. 1 or the number of channels at input.
  1603. Outputs:
  1604. Tensor, with the same type as `input_x`.
  1605. Detailed information, please refer to `nn.PReLU`.
  1606. """
  1607. @prim_attr_register
  1608. def __init__(self):
  1609. pass
  1610. def infer_shape(self, input_x_shape, weight_shape):
  1611. input_x_dim = len(input_x_shape)
  1612. weight_dim = len(weight_shape)
  1613. if weight_dim != 1:
  1614. raise ValueError(f'For \'{self.name}\' weight_dim must be 1, while weight_dim is {weight_dim}.')
  1615. if input_x_dim == 1 and weight_shape[0] != 1:
  1616. raise ValueError(f'For \'{self.name}\' when input_x_dim is 1, weight_shape[0] must be 1, '
  1617. f'while weight_shape[0] is {weight_shape[0]}.')
  1618. if input_x_dim != 1 and weight_shape[0] != input_x_shape[1] and weight_shape[0] != 1:
  1619. raise ValueError(f'For \'{self.name}\' channel of input_x and weight must be matched,'
  1620. f' while channel of input_x is {input_x_shape[1]},'
  1621. f' weight_shape[0] is {weight_shape[0]}.')
  1622. return input_x_shape
  1623. def infer_dtype(self, input_x_dtype, weight_dtype):
  1624. args = {"input_x": input_x_dtype, "weight": weight_dtype}
  1625. validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
  1626. return input_x_dtype
  1627. class LSTM(PrimitiveWithInfer):
  1628. """
  1629. Performs the long short term memory(LSTM) on the input.
  1630. Detailed information, please refer to `nn.LSTM`.
  1631. """
  1632. @prim_attr_register
  1633. def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  1634. self.input_size = validator.check_integer("input_size", input_size, 0, Rel.GT, self.name)
  1635. self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.name)
  1636. self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.name)
  1637. self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name)
  1638. self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name)
  1639. self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
  1640. self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
  1641. if bidirectional:
  1642. self.num_directions = 2
  1643. else:
  1644. self.num_directions = 1
  1645. def infer_shape(self, x_shape, h_shape, c_shape, w_shape):
  1646. # (batch, seq, feature)
  1647. validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name)
  1648. # h and c should be same shape
  1649. validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name)
  1650. validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name)
  1651. # (num_layers * num_directions, batch, hidden_size)
  1652. validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
  1653. validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name)
  1654. validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name)
  1655. y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions)
  1656. # set arbitrary shape for reserved space
  1657. reserved_shape = (1, 1)
  1658. state_shape = (1, 1)
  1659. return (y_shape, h_shape, c_shape, reserved_shape, state_shape)
  1660. def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype):
  1661. args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype}
  1662. validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
  1663. return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
  1664. class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer):
  1665. r"""
  1666. Uses the given logits to compute sigmoid cross entropy.
  1667. Note:
  1668. Sets input logits as `X`, input label as `Y`, output as `loss`. Then,
  1669. .. math::
  1670. p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}
  1671. .. math::
  1672. loss_{ij} = -[Y_{ij} * ln(p_{ij}) + (1 - Y_{ij})ln(1 - p_{ij})]
  1673. Inputs:
  1674. - **logits** (Tensor) - Input logits.
  1675. - **label** (Tensor) - Ground truth label.
  1676. Outputs:
  1677. Tensor, with the same shape and type as input `logits`.
  1678. Examples:
  1679. >>> logits = Tensor(np.random.randn(2, 3).astype(np.float16))
  1680. >>> labels = Tensor(np.random.randn(2, 3).astype(np.float16))
  1681. >>> sigmoid = P.SigmoidCrossEntropyWithLogits()
  1682. >>> sigmoid(logits, labels)
  1683. """
  1684. @prim_attr_register
  1685. def __init__(self):
  1686. """Init SigmoidCrossEntropyWithLogits"""
  1687. self.init_prim_io_names(inputs=['predict', 'target'], outputs=['loss'])
  1688. def infer_shape(self, x_shape, y_shape):
  1689. validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
  1690. return x_shape
  1691. def infer_dtype(self, x_dtype, y_dtype):
  1692. args = {"x_dtype": x_dtype, "y_dtype": y_dtype}
  1693. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1694. return x_dtype
  1695. class Pad(PrimitiveWithInfer):
  1696. """
  1697. Pads input tensor according to the paddings.
  1698. Args:
  1699. paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of
  1700. paddings are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be
  1701. extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how many sizes to
  1702. be extended behind of the `D` th dimension of the input tensor.
  1703. Inputs:
  1704. - **input_x** (Tensor) - The input tensor.
  1705. Outputs:
  1706. Tensor, the tensor after padding.
  1707. Examples:
  1708. >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  1709. >>> pad_op = P.Pad(((1, 2), (2, 1)))
  1710. >>> output_tensor = pad_op(input_tensor)
  1711. >>> assert output_tensor == Tensor(np.array([[ 0. , 0. , 0. , 0. , 0. , 0. ],
  1712. >>> [ 0. , 0. , -0.1, 0.3, 3.6, 0. ],
  1713. >>> [ 0. , 0. , 0.4, 0.5, -3.2, 0. ],
  1714. >>> [ 0. , 0. , 0. , 0. , 0. , 0. ],
  1715. >>> [ 0. , 0. , 0. , 0. , 0. , 0. ]]), mindspore.float32)
  1716. """
  1717. @prim_attr_register
  1718. def __init__(self, paddings):
  1719. """Init Pad"""
  1720. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  1721. if not isinstance(paddings, tuple):
  1722. raise TypeError('Paddings must be tuple type.')
  1723. for item in paddings:
  1724. if len(item) != 2:
  1725. raise ValueError('The shape of paddings must be (n, 2).')
  1726. self.paddings = paddings
  1727. def infer_shape(self, x):
  1728. paddings = np.array(self.paddings)
  1729. validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name)
  1730. if not np.all(paddings >= 0):
  1731. raise ValueError('All elements of paddings must be >= 0.')
  1732. y_shape = ()
  1733. for i in range(int(paddings.size / 2)):
  1734. y_shape += ((x[i] + paddings[i, 0] + paddings[i, 1]),)
  1735. return y_shape
  1736. def infer_dtype(self, x):
  1737. validator.check_subclass("input_x", x, mstype.tensor, self.name)
  1738. return x
  1739. class MirrorPad(PrimitiveWithInfer):
  1740. """
  1741. Pads the input tensor according to the paddings and mode.
  1742. Args:
  1743. mode (string): Specifies padding mode. The optional values are "REFLECT", "SYMMETRIC".
  1744. Default: "REFLECT".
  1745. Inputs:
  1746. - **input_x** (Tensor) - The input tensor.
  1747. - **paddings** (Tensor) - The paddings tensor. The value of `paddings` is a matrix(list),
  1748. and its shape is (N, 2). N is the rank of input data. All elements of paddings
  1749. are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be
  1750. extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates
  1751. how many sizes to be extended behind of the `D` th dimension of the input tensor.
  1752. Outputs:
  1753. Tensor, the tensor after padding.
  1754. - If 'mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in,
  1755. symmetry. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the
  1756. Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]].
  1757. - If 'mode' is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied
  1758. according to the symmetry axis, except that it includes the symmetry axis. If the `input_x`
  1759. is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is
  1760. [[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]].
  1761. Examples:
  1762. >>> from mindspore import Tensor
  1763. >>> from mindspore.ops import operations as P
  1764. >>> import mindspore.nn as nn
  1765. >>> import numpy as np
  1766. >>> class Net(nn.Cell):
  1767. >>> def __init__(self):
  1768. >>> super(Net, self).__init__()
  1769. >>> self.pad = P.MirrorPad(mode="REFLECT")
  1770. >>> def construct(self, x, paddings):
  1771. >>> return self.pad(x, paddings)
  1772. >>> x = np.random.random(size=(2, 3)).astype(np.float32)
  1773. >>> paddings = Tensor([[1,1],[2,2]])
  1774. >>> pad = Net()
  1775. >>> ms_output = pad(Tensor(x), paddings)
  1776. """
  1777. @prim_attr_register
  1778. def __init__(self, mode='REFLECT'):
  1779. """Init Pad"""
  1780. validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
  1781. self.mode = mode
  1782. def __infer__(self, input_x, paddings):
  1783. validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name)
  1784. validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
  1785. x_shape = list(input_x['shape'])
  1786. paddings_value = paddings['value'].asnumpy()
  1787. paddings_size = paddings_value.size
  1788. validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name)
  1789. if not np.all(paddings_size >= 0):
  1790. raise ValueError('All elements of paddings must be >= 0.')
  1791. y_shape = ()
  1792. for i in range(0, int(paddings_size / 2)):
  1793. y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),)
  1794. return {'shape': y_shape,
  1795. 'dtype': input_x['dtype'],
  1796. 'value': None}
  1797. class ROIAlign(PrimitiveWithInfer):
  1798. """
  1799. Computes Region of Interest (RoI) Align operator.
  1800. The operator computes the value of each sampling point by bilinear interpolation from the nearby grid points on the
  1801. feature map. No quantization is performed on any coordinates involved in the RoI, its bins, or the sampling
  1802. points. The details of (RoI) Align operator are described in `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_.
  1803. Args:
  1804. pooled_height (int): The output features' height.
  1805. pooled_width (int): The output features' width.
  1806. spatial_scale (float): A scaling factor that maps the raw image coordinates to the input
  1807. feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the
  1808. input feature map, the `spatial_scale` should be `fea_h / ori_h`.
  1809. sample_num (int): Number of sampling points. Default: 2.
  1810. Inputs:
  1811. - **features** (Tensor) - The input features, whose shape should be `(N, C, H, W)`.
  1812. - **rois** (Tensor) - The shape is `(rois_n, 5)`. `rois_n` represents the number of RoI. The size of
  1813. the second dimension should be `5` and the `5` colunms are
  1814. `(image_index, top_left_x, top_left_y, bottom_right_x, bottom_right_y)`. `image_index` represents the
  1815. index of image. `top_left_x` and `top_left_y` represent the `x, y` coordinates of the top left corner
  1816. of corresponding RoI, respectively. `bottom_right_x` and `bottom_right_y` represent the `x, y`
  1817. coordinates of the bottom right corner of corresponding RoI, respectively.
  1818. Outputs:
  1819. Tensor, the shape is `(rois_n, C, pooled_height, pooled_width)`.
  1820. Examples:
  1821. >>> input_tensor = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32)
  1822. >>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32)
  1823. >>> roi_align = P.ROIAlign(1, 1, 0.5, 2)
  1824. >>> output_tensor = roi_align(input_tensor, rois)
  1825. >>> assert output_tensor == Tensor(np.array([[[[2.15]]]]), mindspore.float32)
  1826. """
  1827. @prim_attr_register
  1828. def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2):
  1829. """init ROIAlign"""
  1830. validator.check_value_type("pooled_height", pooled_height, [int], self.name)
  1831. validator.check_value_type("pooled_width", pooled_width, [int], self.name)
  1832. validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
  1833. validator.check_value_type("sample_num", sample_num, [int], self.name)
  1834. self.pooled_height = pooled_height
  1835. self.pooled_width = pooled_width
  1836. self.spatial_scale = spatial_scale
  1837. self.sample_num = sample_num
  1838. def infer_shape(self, inputs_shape, rois_shape):
  1839. return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width]
  1840. def infer_dtype(self, inputs_type, rois_type):
  1841. return inputs_type
  1842. class Adam(PrimitiveWithInfer):
  1843. r"""
  1844. Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
  1845. The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
  1846. The updating formulas are as follows,
  1847. .. math::
  1848. \begin{array}{ll} \\
  1849. m = \beta_1 * m + (1 - \beta_1) * g \\
  1850. v = \beta_2 * v + (1 - \beta_2) * g * g \\
  1851. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  1852. w = w - l * \frac{m}{\sqrt{v} + \epsilon}
  1853. \end{array}
  1854. :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents
  1855. `gradient`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`,
  1856. :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and
  1857. `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `var`, :math:`\epsilon` represents
  1858. `epsilon`.
  1859. Args:
  1860. use_locking (bool): Whether to enable a lock to protect updating variable tensors.
  1861. If True, updating of the var, m, and v tensors will be protected by a lock.
  1862. If False, the result is unpredictable. Default: False.
  1863. use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
  1864. If True, updates the gradients using NAG.
  1865. If False, updates the gradients without using NAG. Default: False.
  1866. Inputs:
  1867. - **var** (Tensor) - Weights to be updated.
  1868. - **m** (Tensor) - The 1st moment vector in the updating formula. Has the same type as `var`.
  1869. - **v** (Tensor) - the 2nd moment vector in the updating formula.
  1870. Mean square gradients, has the same type as `var`.
  1871. - **beta1_power** (float) - :math:`beta_1^t` in the updating formula.
  1872. - **beta2_power** (float) - :math:`beta_2^t` in the updating formula.
  1873. - **lr** (Union[float, Tensor, Iterable]) - :math:`l` in the updating formula.
  1874. Iterable type is used for the dynamic learning rate.
  1875. - **beta1** (float) - The exponential decay rate for the 1st moment estimates.
  1876. - **beta2** (float) - The exponential decay rate for the 2nd moment estimates.
  1877. - **epsilon** (float) - Term added to the denominator to improve numerical stability.
  1878. - **gradient** (Tensor) - Gradients.
  1879. Outputs:
  1880. Tuple of 3 Tensor, the updated parameters.
  1881. - **var** (Tensor) - The same shape and data type as `var`.
  1882. - **m** (Tensor) - The same shape and data type as `m`.
  1883. - **v** (Tensor) - The same shape and data type as `v`.
  1884. Examples:
  1885. Please refer to the usage in nn.Adam.
  1886. """
  1887. @prim_attr_register
  1888. def __init__(self, use_locking=False, use_nesterov=False):
  1889. validator.check_value_type("use_locking", use_locking, [bool], self.name)
  1890. validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
  1891. def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape,
  1892. beta1_shape, beta2_shape, epsilon_shape, grad_shape):
  1893. validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name)
  1894. validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name)
  1895. validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name)
  1896. return var_shape, m_shape, v_shape
  1897. def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype,
  1898. beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype):
  1899. args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype}
  1900. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1901. args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype,
  1902. "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype}
  1903. validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True)
  1904. return var_dtype, m_dtype, v_dtype
  1905. class BinaryCrossEntropy(PrimitiveWithInfer):
  1906. r"""
  1907. Computes the Binary Cross Entropy between the target and the output.
  1908. Note:
  1909. Sets input as :math:`x`, input label as :math:`y`, output as :math:`\ell(x, y)`.
  1910. Let,
  1911. .. math::
  1912. L = \{l_1,\dots,l_N\}^\top, \quad
  1913. l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
  1914. Then,
  1915. .. math::
  1916. \ell(x, y) = \begin{cases}
  1917. L, & \text{if reduction} = \text{'none';}\\
  1918. \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
  1919. \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
  1920. \end{cases}
  1921. Args:
  1922. reduction (str): Specifies the reduction to apply to the output.
  1923. Its value should be one of 'none', 'mean', 'sum'. Default: 'mean'.
  1924. Inputs:
  1925. - **input_x** (Tensor) - The input Tensor.
  1926. - **input_y** (Tensor) - The label Tensor which has same shape as `input_x`.
  1927. - **weight** (Tensor, optional) - A rescaling weight applied to the loss of each batch element.
  1928. And it should have same shape as `input_x`. Default: None.
  1929. Outputs:
  1930. Tensor or Scalar, if `reduction` is 'none', then output is a tensor and same shape as `input_x`.
  1931. Otherwise it is a scalar.
  1932. """
  1933. @prim_attr_register
  1934. def __init__(self, reduction='mean'):
  1935. self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
  1936. def infer_shape(self, x_shape, y_shape, weight_shape):
  1937. validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
  1938. if weight_shape:
  1939. validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
  1940. if self.reduction in ('mean', 'sum'):
  1941. shape = []
  1942. else:
  1943. shape = x_shape
  1944. return shape
  1945. def infer_dtype(self, x_type, y_type, weight_type):
  1946. args = {'x': x_type, 'y': y_type}
  1947. valid_types = (mstype.float16, mstype.float32)
  1948. validator.check_tensor_type_same(args, valid_types, self.name)
  1949. if weight_type:
  1950. validator.check_tensor_type_same({'x': x_type, 'weight': weight_type}, valid_types, self.name)
  1951. return x_type
  1952. class SparseApplyAdagrad(PrimitiveWithInfer):
  1953. r"""
  1954. Update relevant entries according to the adagrad scheme.
  1955. .. math::
  1956. accum += grad * grad
  1957. .. math::
  1958. var -= lr * grad * (1 / sqrt(accum))
  1959. Args:
  1960. lr (float): Learning rate.
  1961. use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
  1962. Inputs:
  1963. - **var** (Tensor) - Variable to be updated. The type must be float32.
  1964. - **accum** (Tensor) - Accum to be updated. The shape must be the same as `var`'s shape,
  1965. the type must be float32.
  1966. - **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape
  1967. except first dimension, the type must be float32.
  1968. - **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
  1969. The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
  1970. Outputs:
  1971. Tensor, has the same shape and type as `var`.
  1972. Examples:
  1973. var = Tensor(np.random.random((3, 3)), mindspore.float32)
  1974. accum = Tensor(np.random.random((3, 3)), mindspore.float32)
  1975. grad = Tensor(np.random.random((3, 3)), mindspore.float32)
  1976. indices = Tensor(np.ones((3,), np.int32))
  1977. sparse_apply_ada_grad = P.SparseApplyAdagrad(0.5)
  1978. sparse_apply_ada_grad(var, accum, grad, indices)
  1979. """
  1980. @prim_attr_register
  1981. def __init__(self, lr, use_locking=False):
  1982. self.lr = validator.check_value_type("lr", lr, [float], self.name)
  1983. self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
  1984. def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape):
  1985. validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
  1986. validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name)
  1987. if len(var_shape) > 1:
  1988. validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name)
  1989. validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name)
  1990. validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name)
  1991. return var_shape
  1992. def infer_dtype(self, var_type, accum_type, grad_type, indices_type):
  1993. args = {'var': var_type, 'accum': accum_type, 'grad': grad_type}
  1994. validator.check_tensor_type_same(args, (mstype.float32,), self.name)
  1995. validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name)
  1996. return var_type
  1997. class LARSUpdate(PrimitiveWithInfer):
  1998. """
  1999. Conduct lars (layer-wise adaptive rate scaling) update on the square sum of gradient.
  2000. Args:
  2001. epsilon (float): Term added to the denominator to improve numerical stability. Default: 1e-05.
  2002. hyperpara (float): Trust coefficient for calculating the local learning rate. Default: 0.001.
  2003. use_clip (bool): Whether to use clip operation for calculating the local learning rate. Default: False.
  2004. Inputs:
  2005. - **weight** (Tensor) - The weight to be updated.
  2006. - **gradient** (Tensor) - The gradient of weight, which has the same shape and dtype with weight.
  2007. - **norm_weight** (Tensor) - A scalar tensor, representing the square sum of weight.
  2008. - **norm_gradient** (Tensor) - A scalar tensor, representing the square sum of gradient.
  2009. - **weight_decay** (Union[Number, Tensor]) - Weight decay. It should be a scalar tensor or number.
  2010. - **learning_rate** (Union[Number, Tensor]) - Learning rate. It should be a scalar tensor or number.
  2011. Outputs:
  2012. Tensor, representing the new gradient.
  2013. Examples:
  2014. >>> from mindspore import Tensor
  2015. >>> from mindspore.ops import operations as P
  2016. >>> from mindspore.ops import functional as F
  2017. >>> import mindspore.nn as nn
  2018. >>> import numpy as np
  2019. >>> class Net(nn.Cell):
  2020. >>> def __init__(self):
  2021. >>> super(Net, self).__init__()
  2022. >>> self.lars = P.LARSUpdate()
  2023. >>> self.reduce = P.ReduceSum()
  2024. >>> def construct(self, weight, gradient):
  2025. >>> w_square_sum = self.reduce(F.square(weight))
  2026. >>> grad_square_sum = self.reduce(F.square(gradient))
  2027. >>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
  2028. >>> return grad_t
  2029. >>> weight = np.random.random(size=(2, 3)).astype(np.float32)
  2030. >>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
  2031. >>> net = Net()
  2032. >>> ms_output = net(Tensor(weight), Tensor(gradient))
  2033. """
  2034. @prim_attr_register
  2035. def __init__(self, epsilon=1e-05, hyperpara=0.001, use_clip=False):
  2036. """init"""
  2037. validator.check_value_type("epsilon", epsilon, [float], self.name)
  2038. validator.check_value_type("hyperpara", hyperpara, [float], self.name)
  2039. validator.check_value_type("use_clip", use_clip, [bool], self.name)
  2040. def infer_shape(self, weight_shape, gradient_shape, norm_weight_shape, norm_gradient_shape, weight_decay_shape,
  2041. learning_rate_shape):
  2042. validator.check("weight shape", weight_shape, "gradient shape", gradient_shape, Rel.EQ, self.name)
  2043. validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ,
  2044. self.name)
  2045. shp_len = len(weight_decay_shape)
  2046. validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name)
  2047. if shp_len == 1:
  2048. validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name)
  2049. shp_len = len(learning_rate_shape)
  2050. validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name)
  2051. if shp_len == 1:
  2052. validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name)
  2053. return weight_shape
  2054. def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype,
  2055. weight_decay_dtype, learning_rate_dtype):
  2056. args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype,
  2057. "norm gradient dtype": norm_gradient_dtype}
  2058. validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32], self.name)
  2059. validator.check_scalar_or_tensor_type_same({"weight_decay": weight_decay_dtype},
  2060. [mstype.float16, mstype.float32, mstype.float64], self.name)
  2061. validator.check_scalar_or_tensor_type_same({"learning_rate": learning_rate_dtype},
  2062. [mstype.float16, mstype.float32, mstype.float64], self.name)
  2063. return weight_dtype
  2064. class ApplyFtrl(PrimitiveWithInfer):
  2065. """
  2066. Update relevant entries according to the FTRL scheme.
  2067. Args:
  2068. use_locking (bool): Use locks for update operation if True . Default: False.
  2069. Inputs:
  2070. - **var** (Tensor): The variable to be updated.
  2071. - **accum** (Tensor): The accum to be updated, must be same type and shape as `var`.
  2072. - **linear** (Tensor): The linear to be updated, must be same type and shape as `var`.
  2073. - **grad** (Tensor): Gradient.
  2074. - **lr** (Union[Number, Tensor]): The learning rate value, must be positive. Default: 0.001.
  2075. - **l1** (Union[Number, Tensor]): l1 regularization strength, must be greater than or equal to zero.
  2076. Default: 0.0.
  2077. - **l2** (Union[Number, Tensor]): l2 regularization strength, must be greater than or equal to zero.
  2078. Default: 0.0.
  2079. - **lr_power** (Union[Number, Tensor]): Learning rate power controls how the learning rate decreases
  2080. during training, must be less than or equal to zero. Use fixed learning rate if lr_power is zero.
  2081. Default: -0.5.
  2082. Outputs:
  2083. Tensor, representing the updated var.
  2084. """
  2085. @prim_attr_register
  2086. def __init__(self, use_locking=False):
  2087. self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'],
  2088. outputs=['output'])
  2089. self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
  2090. def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
  2091. lr_power_shape):
  2092. validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
  2093. validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
  2094. return var_shape
  2095. def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type):
  2096. valid_types = [mstype.float16, mstype.float32]
  2097. args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type}
  2098. validator.check_tensor_type_same(args, valid_types, self.name)
  2099. validator.check_scalar_or_tensor_type_same({"lr": lr_type}, valid_types, self.name)
  2100. validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name)
  2101. validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name)
  2102. validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name)
  2103. return var_type
  2104. class ExtractImagePatches(PrimitiveWithInfer):
  2105. """
  2106. Extract patches from images.
  2107. The input tensor must be a 4-D tensor and the data format is NHWC.
  2108. Args:
  2109. ksizes (Union[tuple[int], list[int]]): The size of sliding window, should be a tuple or list of int,
  2110. and the format is [1, ksize_row, ksize_col, 1].
  2111. strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
  2112. should be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
  2113. rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dim
  2114. pixel positions, should be a tuple or list of int, and the format is [1, rate_row, rate_col, 1].
  2115. padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
  2116. not case sensitive. Default: "valid".
  2117. - same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
  2118. - valid: Means that the patch area taken must be completely contained in the original image.
  2119. Inputs:
  2120. - **input_x** (Tensor) - A 4-D tensor whose shape is [in_batch, in_row, in_col, in_depth] and
  2121. data type is int8, float16, uint8.
  2122. Outputs:
  2123. Tensor, a 4-D tensor whose data type is same as 'input_x',
  2124. and the shape is [out_batch, out_row, out_col, out_depth], the out_batch is same as the in_batch.
  2125. """
  2126. @prim_attr_register
  2127. def __init__(self, ksizes, strides, rates, padding="valid"):
  2128. """init"""
  2129. def _check_tuple_or_list(arg_name, arg_val, prim_name):
  2130. validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
  2131. if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
  2132. raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
  2133. f"{arg_name}_col, 1], but got {arg_val}.")
  2134. if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
  2135. raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
  2136. f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
  2137. f"is {arg_val[2]}")
  2138. _check_tuple_or_list("ksize", ksizes, self.name)
  2139. _check_tuple_or_list("stride", strides, self.name)
  2140. _check_tuple_or_list("rate", rates, self.name)
  2141. self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
  2142. self.add_prim_attr("padding", self.padding)
  2143. def infer_shape(self, input_x):
  2144. in_batch, in_row, in_col, in_depth = input_x
  2145. _, ksize_row, ksize_col, _ = self.ksizes
  2146. _, stride_row, stride_col, _ = self.strides
  2147. _, rate_row, rate_col, _ = self.rates
  2148. if len(input_x) != 4:
  2149. raise ValueError("The `input_x` should be a 4-D tensor, "
  2150. f"but got a {len(input_x)}-D tensor whose shape is {input_x}")
  2151. out_batch = in_batch
  2152. out_depth = ksize_row * ksize_col * in_depth
  2153. if self.padding == "VALID":
  2154. out_row = \
  2155. (in_row - (ksize_row + (ksize_row - 1) * (rate_row - 1))) // stride_row + 1
  2156. out_col = \
  2157. (in_col - (ksize_col + (ksize_col - 1) * (rate_col - 1))) // stride_col + 1
  2158. else:
  2159. out_row = (in_row - 1) // stride_row + 1
  2160. out_col = (in_col - 1) // stride_col + 1
  2161. out_shape = [out_batch, out_row, out_col, out_depth]
  2162. return out_shape
  2163. def infer_dtype(self, input_x):
  2164. validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name)
  2165. return input_x
  2166. class ConfusionMulGrad(PrimitiveWithInfer):
  2167. """
  2168. `output0` is the result of which input0 dot multily input1.
  2169. `output1` is the result of which input0 dot multily input1, then reducesum it.
  2170. Args:
  2171. axis (Union[int, tuple[int], list[int]]): The dimensions to reduce.
  2172. Default:(), reduce all dimensions. Only constant value is allowed.
  2173. keep_dims (bool):
  2174. - If true, keep these reduced dimensions and the length is 1.
  2175. - If false, don't keep these dimensions. Default:False.
  2176. Inputs:
  2177. - **input_0** (Tensor) - The input Tensor.
  2178. - **input_1** (Tensor) - The input Tensor.
  2179. - **input_2** (Tensor) - The input Tensor.
  2180. outputs:
  2181. - **output_0** (Tensor) - The same shape with `input0`.
  2182. - **output_1** (Tensor)
  2183. - If axis is (), and keep_dims is false, the output is a 0-D array representing
  2184. the sum of all elements in the input array.
  2185. - If axis is int, set as 2, and keep_dims is false,
  2186. the shape of output is :math:`(x_1,x_3,...,x_R)`.
  2187. - If axis is tuple(int), set as (2,3), and keep_dims is false,
  2188. the shape of output is :math:`(x_1,x_4,...x_R)`.
  2189. """
  2190. @prim_attr_register
  2191. def __init__(self, axis=(), keep_dims=False):
  2192. self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"])
  2193. self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name)
  2194. self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
  2195. def infer_shape(self, input0_shape, input1_shape, input2_shape):
  2196. outshape0 = input0_shape
  2197. outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name)
  2198. return outshape0, outshape1
  2199. def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype):
  2200. validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name)
  2201. validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name)
  2202. validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name)
  2203. return input0_dtype, input1_dtype