You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops.py 124 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095
  1. # coding: utf-8
  2. # Copyright 2020 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. """Operators for array."""
  17. import copy
  18. import functools
  19. import itertools
  20. import numbers
  21. import numpy as np
  22. from ..._checkparam import Validator as validator
  23. from ..._checkparam import Rel
  24. from ...common import dtype as mstype
  25. from ...common.tensor import Tensor
  26. from ...common.parameter import Parameter
  27. from ..operations.math_ops import _infer_shape_reduce
  28. from .._utils import get_concat_offset
  29. from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op
  30. from ..._c_expression import signature_rw as sig_rw
  31. from ..._c_expression import signature_kind as sig_kind
  32. from ..._c_expression import signature_dtype as sig_dtype
  33. from ..._c_expression import typing
  34. def _check_infer_attr_reduce(axis, keep_dims, prim_name):
  35. validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
  36. validator.check_value_type('axis', axis, [int, tuple], prim_name)
  37. if isinstance(axis, tuple):
  38. for index, value in enumerate(axis):
  39. validator.check_value_type('axis[%d]' % index, value, [int], prim_name)
  40. class ExpandDims(PrimitiveWithInfer):
  41. """
  42. Adds an additional dimension at the given axis.
  43. Note:
  44. If the specified axis is a negative number, the index is counted
  45. backward from the end and starts at 1.
  46. Raises:
  47. ValueError: If axis is not an integer or not in the valid range.
  48. Inputs:
  49. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  50. - **axis** (int) - Specifies the dimension index at which to expand
  51. the shape of `input_x`. The value of axis must be in the range
  52. `[-input_x.dim()-1, input_x.dim()]`. Only constant value is allowed.
  53. Outputs:
  54. Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the
  55. value of `axis` is 0.
  56. Examples:
  57. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  58. >>> expand_dims = P.ExpandDims()
  59. >>> output = expand_dims(input_tensor, 0)
  60. """
  61. @prim_attr_register
  62. def __init__(self):
  63. """init ExpandDims"""
  64. self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
  65. def __infer__(self, x, axis):
  66. validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
  67. x_shape = list(x['shape'])
  68. axis_v = axis['value']
  69. rank = len(x_shape)
  70. validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name)
  71. value = None
  72. if x['value'] is not None:
  73. value = x['value'].asnumpy()
  74. value = np.expand_dims(value, axis_v)
  75. value = Tensor(value)
  76. if axis_v < 0:
  77. axis_v = rank + 1 + axis_v
  78. x_shape.insert(axis_v, 1)
  79. out = {'shape': x_shape,
  80. 'dtype': x['dtype'],
  81. 'value': value}
  82. return out
  83. class DType(PrimitiveWithInfer):
  84. """
  85. Returns the data type of input tensor as mindspore.dtype.
  86. Inputs:
  87. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  88. Outputs:
  89. mindspore.dtype, the data type of a tensor.
  90. Examples:
  91. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  92. >>> type = P.DType()(input_tensor)
  93. """
  94. @prim_attr_register
  95. def __init__(self):
  96. """init DType"""
  97. def __infer__(self, x):
  98. validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
  99. out = {'shape': (),
  100. 'dtype': mstype.type_type,
  101. 'value': x['dtype'].element_type()}
  102. return out
  103. class SameTypeShape(PrimitiveWithInfer):
  104. """
  105. Checks whether data type and shape of two tensors are the same.
  106. Raises:
  107. TypeError: If data type not the same.
  108. ValueError: If shape of two tensors not the same.
  109. Inputs:
  110. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  111. - **input_y** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
  112. Outputs:
  113. Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
  114. if data type and shape of `input_x` and `input_y` are the same.
  115. Examples:
  116. >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  117. >>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  118. >>> out = P.SameTypeShape()(input_x, input_y)
  119. """
  120. @prim_attr_register
  121. def __init__(self):
  122. """init Same"""
  123. def __call__(self, x, y):
  124. """run in PyNative mode"""
  125. validator.check_value_type('x', x, Tensor, self.name)
  126. validator.check_value_type('y', y, Tensor, self.name)
  127. validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError)
  128. validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name)
  129. return x
  130. def __infer__(self, x, y):
  131. validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
  132. validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
  133. validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError)
  134. validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name)
  135. return x
  136. class Cast(PrimitiveWithInfer):
  137. """
  138. Returns a tensor with the new specified data type.
  139. Inputs:
  140. - **input_x** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  141. The tensor to be casted.
  142. - **type** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
  143. Outputs:
  144. Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`, same as `input_x`.
  145. Examples:
  146. >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  147. >>> input_x = Tensor(input_np)
  148. >>> type_dst = mindspore.float16
  149. >>> cast = P.Cast()
  150. >>> result = cast(input_x, type_dst)
  151. """
  152. @prim_attr_register
  153. def __init__(self):
  154. # if primitive need setattr in __infer__ need add this flag
  155. """init Cast"""
  156. self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
  157. def check_elim(self, x, dtype):
  158. if isinstance(x, (Tensor, numbers.Number, Parameter)):
  159. if isinstance(x, Tensor) and x.dtype == dtype:
  160. return (True, x)
  161. if isinstance(x, numbers.Number):
  162. return (True, Tensor(x, dtype=dtype))
  163. if isinstance(x, Parameter):
  164. data = x.default_input
  165. if data.dtype == dtype:
  166. return (True, x)
  167. return (False, None)
  168. def __infer__(self, x, t):
  169. src_type = x['dtype']
  170. dst_type = t['value']
  171. validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number], self.name)
  172. validator.check_subclass("type", dst_type, mstype.number, self.name)
  173. if isinstance(src_type, type(mstype.tensor)):
  174. src_type = x['dtype'].element_type()
  175. if isinstance(dst_type, type(mstype.tensor)):
  176. dst_type = dst_type.element_type()
  177. self.add_prim_attr('DstT', dst_type)
  178. self.add_prim_attr('SrcT', src_type)
  179. value = None
  180. if x['value'] is not None:
  181. np_dst_type = mstype.dtype_to_nptype(dst_type)
  182. if isinstance(x['value'], (int, float)):
  183. value = Tensor(np.array(x['value']).astype(np_dst_type))
  184. else:
  185. value = Tensor(x['value'].asnumpy().astype(np_dst_type))
  186. out = {'shape': x['shape'],
  187. 'dtype': mstype.tensor_type(t['value']),
  188. 'value': value}
  189. return out
  190. class IsSubClass(PrimitiveWithInfer):
  191. """
  192. Check whether one type is sub class of another type.
  193. Inputs:
  194. - **sub_type** (mindspore.dtype) - The type to be check. Only constant value is allowed.
  195. - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
  196. Outputs:
  197. bool, the check result.
  198. Examples:
  199. >>> result = P.IsSubClass()(mindspore.int32, mindspore.intc)
  200. """
  201. @prim_attr_register
  202. def __init__(self):
  203. pass
  204. def __infer__(self, sub_type, type_):
  205. sub_type_t = sub_type['value']
  206. type_v = type_['value']
  207. validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
  208. validator.check_value_type("type_", type_v, [mstype.Type], self.name)
  209. value = mstype.issubclass_(sub_type_t, type_v)
  210. out = {'shape': (),
  211. 'dtype': mstype.type_type,
  212. 'value': value}
  213. return out
  214. class IsInstance(PrimitiveWithInfer):
  215. """
  216. Check whether an object is an instance of a target type.
  217. Inputs:
  218. - **inst** (Any Object) - The instance to be check. Only constant value is allowed.
  219. - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
  220. Outputs:
  221. bool, the check result.
  222. Examples:
  223. >>> a = 1
  224. >>> result = P.IsInstance()(a, mindspore.int32)
  225. """
  226. @prim_attr_register
  227. def __init__(self):
  228. pass
  229. def __infer__(self, inst, type_):
  230. sub_type_t = inst['dtype']
  231. type_v = type_['value']
  232. validator.check_const_input("inst", inst['value'], self.name)
  233. validator.check_value_type("type_", type_v, [mstype.Type], self.name)
  234. value = mstype.issubclass_(sub_type_t, type_v)
  235. out = {'shape': (),
  236. 'dtype': mstype.type_type,
  237. 'value': value}
  238. return out
  239. class Reshape(PrimitiveWithInfer):
  240. """
  241. Reshapes input tensor with the same values based on a given shape tuple.
  242. Raises:
  243. ValueError: Given a shape tuple, if it has more than one -1; or if the product
  244. of its elements is less than or equal to 0 or cannot be divided by the product
  245. of the input tensor shape; or if it does not match the input's array size.
  246. Inputs:
  247. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  248. - **input_shape** (tuple[int]) - The input tuple is constructed by multiple
  249. integers, i.e., :math:`(y_1, y_2, ..., y_S)`. Only constant value is allowed.
  250. Outputs:
  251. Tensor, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  252. Examples:
  253. >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  254. >>> reshape = P.Reshape()
  255. >>> output = reshape(input_tensor, (3, 2))
  256. """
  257. @prim_attr_register
  258. def __init__(self):
  259. """init Reshape"""
  260. self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
  261. def __infer__(self, x, shape):
  262. shape_v = shape['value']
  263. x_shp = x['shape']
  264. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  265. validator.check_value_type("shape", shape_v, [tuple], self.name)
  266. shape_v = list(shape_v)
  267. neg_index = -1
  268. dim_prod = 1
  269. for i, shp_i in enumerate(shape_v):
  270. validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name)
  271. if shp_i == -1:
  272. if neg_index != -1:
  273. raise ValueError(f'The shape can only has one -1 at most, but {shape_v}.')
  274. neg_index = i
  275. else:
  276. dim_prod *= shp_i
  277. arr_prod = np.prod(x_shp)
  278. if dim_prod <= 0 or arr_prod % dim_prod != 0:
  279. raise ValueError(f'For \'{self.name}\' the product of shape should > 0 and'
  280. f' can be divided by prod of input {arr_prod},'
  281. f' but shape {shape}, product of shape {dim_prod}.')
  282. if neg_index != -1:
  283. shape_v[neg_index] = int(arr_prod / dim_prod)
  284. dim_prod *= shape_v[neg_index]
  285. if dim_prod != arr_prod:
  286. raise ValueError(f'For \'{self.name}\' The shape arg for reshape must match array''s size'
  287. f' input shape {arr_prod}, shape {dim_prod}.')
  288. value = None
  289. if x['value'] is not None:
  290. value = Tensor(x['value'].asnumpy().reshape(shape_v))
  291. out = {'shape': tuple(shape_v),
  292. 'dtype': x['dtype'],
  293. 'value': value}
  294. return out
  295. class Shape(Primitive):
  296. """
  297. Returns the shape of input tensor.
  298. Inputs:
  299. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  300. Outputs:
  301. tuple[int], the output tuple is constructed by multiple integers,
  302. :math:`(x_1, x_2, ..., x_R)`.
  303. Examples:
  304. >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
  305. >>> shape = P.Shape()
  306. >>> output = shape(input_tensor)
  307. """
  308. @prim_attr_register
  309. def __init__(self):
  310. """init Shape"""
  311. class Squeeze(PrimitiveWithInfer):
  312. """
  313. Returns a tensor with the same type but dimensions of 1 being removed based on axis.
  314. Note:
  315. The dimension index starts at 0 and must be in the range `[-input.dim(), input.dim())`.
  316. Raises:
  317. ValueError: If the corresponding dimension of the specified axis does not equal to 1.
  318. Args:
  319. axis (int): Specifies the dimension indexes of shape to be removed, which will remove
  320. all the dimensions that are equal to 1. If specified, it must be int32 or int64.
  321. Default: (), an empty tuple.
  322. Inputs:
  323. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  324. Outputs:
  325. Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
  326. Examples:
  327. >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
  328. >>> squeeze = P.Squeeze(2)
  329. >>> output = squeeze(input_tensor)
  330. """
  331. @prim_attr_register
  332. def __init__(self, axis=()):
  333. """init Squeeze"""
  334. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  335. validator.check_value_type('axis', axis, [int, tuple], self.name)
  336. if isinstance(axis, tuple):
  337. for idx, item in enumerate(axis):
  338. validator.check_value_type("axis[%d]" % idx, item, [int], self.name)
  339. else:
  340. self.axis = (axis,)
  341. self.add_prim_attr("axis", (axis,))
  342. def infer_shape(self, x_shape):
  343. axis = self.axis
  344. x_shape = list(x_shape)
  345. ndim = len(x_shape)
  346. if not axis:
  347. ret = [d for d in x_shape if d != 1]
  348. else:
  349. for a in axis:
  350. validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name)
  351. if x_shape[a] != 1:
  352. raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.')
  353. ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
  354. return ret
  355. def infer_dtype(self, x_dtype):
  356. validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
  357. return x_dtype
  358. class Transpose(PrimitiveWithInfer):
  359. """
  360. Permutes the dimensions of input tensor according to input perm.
  361. Inputs:
  362. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  363. - **input_perm** (tuple[int]) - The permutation to be converted. The input tuple is constructed by multiple
  364. indexes. The length of `input_perm` and the shape of `input_x` should be the same. Only constant value is
  365. allowed.
  366. Outputs:
  367. Tensor, the type of output tensor is same as `input_x` and the shape of output tensor is decided by the
  368. shape of `input_x` and the value of `input_perm`.
  369. Examples:
  370. >>> input_tensor = Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), mindspore.float32)
  371. >>> perm = (0, 2, 1)
  372. >>> transpose = P.Transpose()
  373. >>> output = transpose(input_tensor, perm)
  374. """
  375. @prim_attr_register
  376. def __init__(self):
  377. """init Transpose"""
  378. self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
  379. def __infer__(self, x, perm):
  380. x_shape = x['shape']
  381. p_value = perm['value']
  382. x_type = x['dtype']
  383. validator.check_value_type("p_value", p_value, [tuple], self.name)
  384. validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
  385. if len(x_shape) != len(p_value):
  386. raise ValueError('The dimension of x and perm must be equal.')
  387. tmp = list(p_value)
  388. for i, dim in enumerate(p_value):
  389. validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name)
  390. validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name)
  391. tmp.remove(dim)
  392. if dim in tmp:
  393. raise ValueError('The value of perm is wrong.')
  394. out_shapes = []
  395. for i in p_value:
  396. out_shapes.append(x_shape[i])
  397. out = {'shape': tuple(out_shapes),
  398. 'dtype': x['dtype'],
  399. 'value': None}
  400. return out
  401. class GatherV2(PrimitiveWithInfer):
  402. """
  403. Returns a slice of input tensor based on the specified indices and axis.
  404. Inputs:
  405. - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  406. The original Tensor.
  407. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  408. Specifies the indices of elements of the original Tensor. Must be in the range
  409. `[0, input_param.shape[axis])`.
  410. - **axis** (int) - Specifies the dimension index to gather indices.
  411. Outputs:
  412. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  413. Examples:
  414. >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
  415. >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
  416. >>> axis = 1
  417. >>> out = P.GatherV2()(input_params, input_indices, axis)
  418. """
  419. @prim_attr_register
  420. def __init__(self):
  421. """init index_select"""
  422. self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
  423. def __infer__(self, params, indices, axis):
  424. validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
  425. validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
  426. validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
  427. axis_v = axis['value']
  428. params_shp = params['shape']
  429. rank = len(params_shp)
  430. validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
  431. if axis_v < 0:
  432. axis_v += rank
  433. out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
  434. out = {'shape': out_shape,
  435. 'dtype': params['dtype'],
  436. 'value': None}
  437. return out
  438. class SparseGatherV2(GatherV2):
  439. """
  440. Returns a slice of input tensor based on the specified indices and axis.
  441. Inputs:
  442. - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  443. The original Tensor.
  444. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  445. Specifies the indices of elements of the original Tensor. Must be in the range
  446. `[0, input_param.shape[axis])`.
  447. - **axis** (int) - Specifies the dimension index to gather indices.
  448. Outputs:
  449. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  450. Examples:
  451. >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
  452. >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
  453. >>> axis = 1
  454. >>> out = P.SparseGatherV2()(input_params, input_indices, axis)
  455. """
  456. class Split(PrimitiveWithInfer):
  457. """
  458. Splits input tensor into output_num of tensors along the given axis and output numbers.
  459. Args:
  460. axis (int): Index of the split position. Default: 0.
  461. output_num (int): The number of output tensors. Default: 1.
  462. Raises:
  463. ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)),
  464. or if the output_num is less than or equal to 0, or if the
  465. dimension which to split cannot be evenly divided by output_num.
  466. Inputs:
  467. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  468. Outputs:
  469. tuple[Tensor], the shape of each output tensor is same, which is
  470. :math:`(y_1, y_2, ..., y_S)`.
  471. Examples:
  472. >>> split = P.Split(1, 2)
  473. >>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
  474. >>> output = split(x)
  475. """
  476. @prim_attr_register
  477. def __init__(self, axis=0, output_num=1):
  478. """init Split"""
  479. validator.check_value_type("axis", axis, [int], self.name)
  480. validator.check_value_type("output_num", output_num, [int], self.name)
  481. self.axis = axis
  482. self.output_num = output_num
  483. def __infer__(self, x):
  484. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  485. x_shape = list(x['shape'])
  486. dim = len(x_shape)
  487. validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
  488. validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name)
  489. output_valid_check = x_shape[self.axis] % self.output_num
  490. validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ,
  491. self.name)
  492. x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
  493. out_shapes = []
  494. out_dtypes = []
  495. for _ in range(self.output_num):
  496. out_shapes.append(tuple(x_shape))
  497. out_dtypes.append(x['dtype'])
  498. out_shapes = tuple(out_shapes)
  499. out_dtypes = tuple(out_dtypes)
  500. out = {'shape': out_shapes,
  501. 'dtype': out_dtypes,
  502. 'value': None}
  503. return out
  504. class Rank(PrimitiveWithInfer):
  505. """
  506. Returns the rank of a tensor.
  507. Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
  508. is the number of indices required to uniquely select each element of the tensor.
  509. Inputs:
  510. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  511. Outputs:
  512. Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`.
  513. Examples:
  514. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  515. >>> rank = P.Rank()
  516. >>> rank(input_tensor)
  517. """
  518. @prim_attr_register
  519. def __init__(self):
  520. """init Rank"""
  521. def __infer__(self, x):
  522. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  523. out = {'shape': None,
  524. 'dtype': None,
  525. 'value': len(x['shape'])}
  526. return out
  527. class TruncatedNormal(PrimitiveWithInfer):
  528. """
  529. Returns a tensor of the specified shape filled with truncated normal values.
  530. The generated values follow a normal distribution.
  531. Args:
  532. seed (int): A int number used to create random seed. Default: 0.
  533. dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
  534. Inputs:
  535. - **shape** (tuple[int]) - Shape of output tensor, is a tuple of positive int.
  536. Outputs:
  537. Tensor, type of output tensor is same as attribute `dtype`.
  538. Examples:
  539. >>> shape = (1, 2, 3)
  540. >>> truncated_normal = P.TruncatedNormal()
  541. >>> output = truncated_normal(shape)
  542. """
  543. @prim_attr_register
  544. def __init__(self, seed=0, dtype=mstype.float32):
  545. """init TruncatedNormal"""
  546. validator.check_value_type('seed', seed, [int], self.name)
  547. validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name)
  548. def __infer__(self, shape):
  549. shape_value = shape['value']
  550. validator.check_value_type("shape", shape_value, [tuple], self.name)
  551. for i, value in enumerate(shape_value):
  552. validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT, self.name)
  553. out = {'shape': shape_value,
  554. 'dtype': mstype.tensor_type(self.dtype),
  555. 'value': None}
  556. return out
  557. class Size(PrimitiveWithInfer):
  558. r"""
  559. Returns the elements count size of a tensor.
  560. Returns an int scalar representing the elements size of input, the total number of elements in the tensor.
  561. Inputs:
  562. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  563. Outputs:
  564. int, a scalar representing the elements size of `input_x`, tensor is the number of elements
  565. in a tensor, :math:`size=x_1*x_2*...x_R`.
  566. Examples:
  567. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  568. >>> size = P.Size()
  569. >>> output = size(input_tensor)
  570. """
  571. @prim_attr_register
  572. def __init__(self):
  573. """init Size"""
  574. def __infer__(self, x):
  575. size = 1
  576. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  577. shp = x['shape']
  578. if not shp:
  579. size = 0
  580. else:
  581. size = functools.reduce(lambda x, y: x * y, x['shape'])
  582. out = {'shape': None,
  583. 'dtype': mstype.int32,
  584. 'value': size}
  585. return out
  586. class Fill(PrimitiveWithInfer):
  587. """
  588. Creates a tensor filled with a scalar value.
  589. Creates a tensor with shape described by the first argument and fills it with values in the second argument.
  590. Inputs:
  591. - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
  592. - **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
  593. - **value** (scalar) - Value to fill the returned tensor. Only constant value is allowed.
  594. Outputs:
  595. Tensor, has the same type and shape as input value.
  596. Examples:
  597. >>> fill = P.Fill()
  598. >>> fill(mindspore.float32, (2, 2), 1)
  599. """
  600. @prim_attr_register
  601. def __init__(self):
  602. """init Fill"""
  603. def __infer__(self, dtype, dims, x):
  604. validator.check_value_type("shape", dims['value'], [tuple], self.name)
  605. validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
  606. for idx, item in enumerate(dims['value']):
  607. validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name)
  608. valid_types = [mstype.bool_, mstype.int8, mstype.int32, mstype.int64,
  609. mstype.uint8, mstype.uint32, mstype.uint64,
  610. mstype.float16, mstype.float32, mstype.float64]
  611. validator.check_type_same({"value": dtype['value']}, valid_types, self.name)
  612. x_nptype = mstype.dtype_to_nptype(dtype['value'])
  613. ret = np.full(dims['value'], x['value'], x_nptype)
  614. out = {
  615. 'value': Tensor(ret),
  616. 'shape': dims['value'],
  617. 'dtype': x['dtype'],
  618. }
  619. return out
  620. class OnesLike(PrimitiveWithInfer):
  621. """
  622. Creates a new tensor. All elements' value are 1.
  623. Returns a tensor of ones with the same shape and type as the input.
  624. Inputs:
  625. - **input_x** (Tensor) - Input tensor.
  626. Outputs:
  627. Tensor, has the same shape and type as `input_x` but filled with ones.
  628. Examples:
  629. >>> oneslike = P.OnesLike()
  630. >>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
  631. >>> output = oneslike(x)
  632. """
  633. @prim_attr_register
  634. def __init__(self):
  635. """Init OnesLike"""
  636. def infer_shape(self, x_shape):
  637. return x_shape
  638. def infer_dtype(self, x_dtype):
  639. validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
  640. return x_dtype
  641. class ZerosLike(PrimitiveWithInfer):
  642. """
  643. Creates a new tensor. All elements value are 0.
  644. Returns a tensor of zeros with the same shape and type as the input tensor.
  645. Inputs:
  646. - **input_x** (Tensor) - Input tensor.
  647. Outputs:
  648. Tensor, has the same shape and type as `input_x` but filled with zeros.
  649. Examples:
  650. >>> zeroslike = P.ZerosLike()
  651. >>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
  652. >>> output = zeroslike(x)
  653. """
  654. @prim_attr_register
  655. def __init__(self):
  656. """Init ZerosLike"""
  657. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  658. def infer_shape(self, x_shape):
  659. return x_shape
  660. def infer_dtype(self, x_dtype):
  661. validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
  662. return x_dtype
  663. class TupleToArray(PrimitiveWithInfer):
  664. """
  665. Converts a tuple to tensor.
  666. If the first number type of tuple is int, the output tensor type is int. Else, the output tensor type is float.
  667. Inputs:
  668. - **input_x** (tuple) - A tuple of numbers. These numbers have the same type. Only constant value is allowed.
  669. Outputs:
  670. Tensor, if the input tuple contain `N` numbers, then the output tensor shape is (N,).
  671. Examples:
  672. >>> type = P.TupleToArray()((1,2,3))
  673. """
  674. @prim_attr_register
  675. def __init__(self):
  676. """init TupleToArray"""
  677. def infer_value(self, x):
  678. validator.check_value_type("x", x, [tuple], self.name)
  679. validator.check("size of x", len(x), '', 0, Rel.GT, self.name)
  680. dtype = type(x[0])
  681. for i, item in enumerate(x):
  682. validator.check_value_type(f"x[{i}]", item, [numbers.Number], self.name)
  683. if not all(isinstance(item, dtype) for item in x):
  684. raise TypeError("For \'{self.name}\' all elements of input x must be have same type.")
  685. if isinstance(x[0], int):
  686. ret = np.array(x, np.int32)
  687. else:
  688. ret = np.array(x, np.float32)
  689. return Tensor(ret)
  690. def __call__(self, x):
  691. args = list()
  692. if isinstance(x, range):
  693. args.append(tuple(x))
  694. else:
  695. args.append(x)
  696. return _run_op(self, self.name, args)
  697. class ScalarToArray(PrimitiveWithInfer):
  698. """
  699. Converts scalar to `Tensor`.
  700. Inputs:
  701. - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
  702. Outputs:
  703. Tensor. 0-D Tensor and the content is the input.
  704. Examples:
  705. >>> op = P.ScalarToArray()
  706. >>> data = 1.0
  707. >>> output = op(data)
  708. """
  709. @prim_attr_register
  710. def __init__(self):
  711. pass
  712. def infer_value(self, x):
  713. validator.check_value_type("x", x, [int, float], self.name)
  714. if isinstance(x, int):
  715. ret = np.array(x, np.int32)
  716. else:
  717. ret = np.array(x, np.float32)
  718. return Tensor(ret)
  719. class ScalarToTensor(PrimitiveWithInfer):
  720. """
  721. Converts scalar to `Tensor`, and convert data type to specified type.
  722. Inputs:
  723. - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
  724. - **dtype** (mindspore.dtype) - The target data type. Default: mindspore.float32. Only
  725. constant value is allowed.
  726. Outputs:
  727. Tensor. 0-D Tensor and the content is the input.
  728. Examples:
  729. >>> op = P.ScalarToTensor()
  730. >>> data = 1
  731. >>> output = op(data, mindspore.float32)
  732. """
  733. @prim_attr_register
  734. def __init__(self):
  735. pass
  736. def infer_value(self, x, dtype=mstype.float32):
  737. validator.check_value_type("x", x, [int, float], self.name)
  738. validator.check_subclass("dtype", dtype, mstype.number, self.name)
  739. data_type = mstype.dtype_to_nptype(dtype)
  740. return Tensor(np.array(x, data_type))
  741. class InvertPermutation(PrimitiveWithInfer):
  742. r"""
  743. Computes the inverse of an index permutation.
  744. Given a tuple input, this operation inserts a dimension of 1 at the dimension
  745. This operation calculates the inverse of the index replacement. It requires a
  746. 1-dimensional tuple x, which represents the array starting at zero,
  747. and swaps each value with its index position. In other words, for the output
  748. tuple y and the input tuple x, this operation calculates the following:
  749. :math:`y[x[i]] = i, \quad i \in [0, 1, \ldots, \text{len}(x)-1]`.
  750. Note:
  751. These values must include 0. There must be no duplicate values and the
  752. values can not be negative.
  753. Inputs:
  754. - **input_x** (Union(tuple[int], Tensor[int])) - The input tuple is constructed by multiple
  755. integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices.
  756. The values must include 0. There can be no duplicate values or negative values.
  757. If the input is Tensor, it must be 1-d and the dtype is int. Only constant value is allowed.
  758. Outputs:
  759. tuple[int]. the lenth is same as input.
  760. Examples:
  761. >>> invert = P.InvertPermutation()
  762. >>> input_data = (3, 4, 0, 2, 1)
  763. >>> output = invert(input_data)
  764. >>> output == (2, 4, 3, 0, 1)
  765. """
  766. @prim_attr_register
  767. def __init__(self):
  768. """init InvertPermutation"""
  769. self.const_value = True
  770. def __infer__(self, x):
  771. x_shp = x['shape']
  772. x_value = x['value']
  773. if x_value is None:
  774. raise ValueError(f'For \'{self.name}\' the input value must be const.')
  775. validator.check_value_type("shape", x_shp, [tuple, list], self.name)
  776. if mstype.issubclass_(x['dtype'], mstype.tensor):
  777. validator.check('x dimension', len(x_shp), '', 1, Rel.EQ, self.name)
  778. validator.check_tensor_type_same({'x dtype': x['dtype']}, mstype.int_type, self.name)
  779. x_value = [int(i) for i in x_value.asnumpy()]
  780. z = [x_value[i] for i in range(len(x_value))]
  781. z.sort()
  782. for i in range(1, len(z)):
  783. if z[i - 1] == z[i]:
  784. raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.")
  785. validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
  786. validator.check(f'value max', max(x_value), '', len(x_value) - 1, Rel.EQ, self.name)
  787. y = [None] * len(x_value)
  788. for i, value in enumerate(x_value):
  789. validator.check_value_type("input[%d]" % i, value, [int], self.name)
  790. validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
  791. y[value] = i
  792. z.append(value)
  793. return {'shape': x_shp,
  794. 'dtype': x['dtype'],
  795. 'value': tuple(y)}
  796. class Argmax(PrimitiveWithInfer):
  797. """
  798. Returns the indices of the max value of a tensor across the axis.
  799. If the shape of input tensor is :math:`(x_1, ..., x_N)`, the output tensor shape is
  800. :math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  801. Args:
  802. axis (int): Axis on which Argmax operation applies. Default: -1.
  803. output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
  804. Default: `mindspore.dtype.int32`.
  805. Inputs:
  806. - **input_x** (Tensor) - Input tensor.
  807. Outputs:
  808. Tensor, indices of the max value of input tensor across the axis.
  809. Examples:
  810. >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
  811. >>> index = P.Argmax(output_type=mindspore.int32)(input_x)
  812. """
  813. @prim_attr_register
  814. def __init__(self, axis=-1, output_type=mstype.int32):
  815. """init Argmax"""
  816. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  817. validator.check_value_type("axis", axis, [int], self.name)
  818. validator.check_type_same({'output': output_type}, [mstype.int32, mstype.int64], self.name)
  819. self.axis = axis
  820. self.add_prim_attr('output_type', output_type)
  821. def infer_shape(self, x_shape):
  822. axis = self.axis
  823. if axis is None:
  824. axis = 0
  825. x_rank = len(x_shape)
  826. validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
  827. axis = axis + x_rank if axis < 0 else axis
  828. ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
  829. return ouput_shape
  830. def infer_dtype(self, x_dtype):
  831. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  832. return mstype.tensor_type(self.output_type)
  833. class Argmin(PrimitiveWithInfer):
  834. """
  835. Returns the indices of the min value of a tensor across the axis.
  836. If the shape of input tensor is :math:`(x_1, ..., x_N)`, the output tensor shape is
  837. :math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  838. Args:
  839. axis (int): Axis on which Argmin operation applies. Default: -1.
  840. output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
  841. Default: `mindspore.dtype.int32`.
  842. Inputs:
  843. - **input_x** (Tensor) - Input tensor.
  844. Outputs:
  845. Tensor, indices of the min value of input tensor across the axis.
  846. Examples:
  847. >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
  848. >>> index = P.Argmin()(input_x)
  849. >>> assert index == Tensor(2, mindspore.int64)
  850. """
  851. @prim_attr_register
  852. def __init__(self, axis=-1, output_type=mstype.int32):
  853. """init Argmin"""
  854. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  855. validator.check_value_type("axis", axis, [int], self.name)
  856. validator.check_type_name("output_type", output_type, [mstype.int32, mstype.int64], self.name)
  857. self.axis = axis
  858. self.add_prim_attr('output_type', output_type)
  859. def infer_shape(self, x_shape):
  860. axis = self.axis
  861. if axis is None:
  862. axis = 0
  863. x_rank = len(x_shape)
  864. validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
  865. axis = axis + x_rank if axis < 0 else axis
  866. ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
  867. return ouput_shape
  868. def infer_dtype(self, x_dtype):
  869. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  870. return mstype.tensor_type(self.output_type)
  871. class ArgMaxWithValue(PrimitiveWithInfer):
  872. """
  873. Calculates maximum value with corresponding index.
  874. Calculates maximum value along with given axis for the input tensor. Returns the maximum values and indices.
  875. Note:
  876. In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
  877. Args:
  878. axis (int): The dimension to reduce. Default: 0.
  879. keep_dims (bool): Whether to reduce dimension, if true the output will keep same dimension with the input,
  880. the output will reduce dimension if false. Default: False.
  881. Inputs:
  882. - **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
  883. :math:`(x_1, x_2, ..., x_N)`.
  884. Outputs:
  885. tuple(Tensor), tuple of 2 tensors, corresponding index and maximum value of input tensor.
  886. - index (Tensor) - The index for maximum value of input tensor. If `keep_dims` is true, the output tensors shape
  887. is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Else, the shape is
  888. :math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  889. - output_x (Tensor) - The maximum value of input tensor, the shape same as index.
  890. Examples:
  891. >>> input_x = Tensor(np.random.rand(5), mindspore.float32)
  892. >>> index, output = P.ArgMaxWithValue()(input_x)
  893. """
  894. @prim_attr_register
  895. def __init__(self, axis=0, keep_dims=False):
  896. """init ArgMaxWithValue"""
  897. self.axis = axis
  898. self.keep_dims = keep_dims
  899. _check_infer_attr_reduce(axis, keep_dims, self.name)
  900. def infer_shape(self, x_shape):
  901. axis = self.axis
  902. x_rank = len(x_shape)
  903. validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
  904. ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
  905. return ouput_shape, ouput_shape
  906. def infer_dtype(self, x_dtype):
  907. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  908. return mstype.tensor_type(mstype.int32), x_dtype
  909. class ArgMinWithValue(PrimitiveWithInfer):
  910. """
  911. Calculates minimum value with corresponding index, return indices and values.
  912. Calculates minimum value along with given axis for the input tensor. Returns the minimum values and indices.
  913. Note:
  914. In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
  915. Args:
  916. axis (int): The dimension to reduce. Default: 0.
  917. keep_dims (bool): Whether to reduce dimension, if true the output will keep same dimension as the input,
  918. the output will reduce dimension if false. Default: False.
  919. Inputs:
  920. - **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
  921. :math:`(x_1, x_2, ..., x_N)`.
  922. Outputs:
  923. Tensor, corresponding index and minimum value of input tensor. If `keep_dims` is true, the output tensors shape
  924. is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Else, the shape is
  925. :math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  926. Examples:
  927. >>> input_x = Tensor(np.random.rand(5))
  928. >>> index, output = P.ArgMinWithValue()(input_x)
  929. """
  930. @prim_attr_register
  931. def __init__(self, axis=0, keep_dims=False):
  932. """init ArgMinWithValue"""
  933. self.axis = axis
  934. self.keep_dims = keep_dims
  935. _check_infer_attr_reduce(axis, keep_dims, self.name)
  936. def infer_shape(self, x_shape):
  937. axis = self.axis
  938. x_rank = len(x_shape)
  939. validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
  940. ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
  941. return ouput_shape, ouput_shape
  942. def infer_dtype(self, x_dtype):
  943. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  944. return mstype.tensor_type(mstype.int32), x_dtype
  945. class Tile(PrimitiveWithInfer):
  946. r"""
  947. Replicates a tensor with given multiples times.
  948. Creates a new tensor by replicating input multiples times. The dimension of
  949. output tensor is the larger of the dimension length of input and the length of multiples.
  950. Inputs:
  951. - **input_x** (Tensor) - 1-D or higher Tensor. Set the shape of input tensor as
  952. :math:`(x_1, x_2, ..., x_S)`.
  953. - **multiples** (tuple[int]) - The input tuple is constructed by multiple
  954. integers, i.e., :math:`(y_1, y_2, ..., y_S)`. The length of `multiples`
  955. can't be smaller than the length of shape in `input_x`.
  956. Outputs:
  957. Tensor, has the same type as the `input_x`.
  958. - If the length of `multiples` is the same as the length of shape in `input_x`,
  959. then the shape of their corresponding positions can be multiplied, and
  960. the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_R)`.
  961. - If the length of `multiples` is larger than the length of shape in `input_x`,
  962. fill in multiple 1 in front of the shape in `input_x` until their lengths are consistent.
  963. Such as set the shape of `input_x` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
  964. then the shape of their corresponding positions can be multiplied, and
  965. the shape of Outputs is :math:`(1*y_1, ..., x_S*y_R)`.
  966. Examples:
  967. >>> tile = P.Tile()
  968. >>> input_x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
  969. >>> multiples = (2, 3)
  970. >>> result = tile(input_x, multiples)
  971. [[1. 2. 1. 2. 1. 2.]
  972. [3. 4. 3. 4. 3. 4.]
  973. [1. 2. 1. 2. 1. 2.]
  974. [3. 4. 3. 4. 3. 4.]]
  975. """
  976. @prim_attr_register
  977. def __init__(self):
  978. """init Tile"""
  979. self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output'])
  980. def check_elim(self, base_tensor, multiplier):
  981. if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
  982. raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
  983. if all(v == 1 for v in multiplier):
  984. return (True, base_tensor)
  985. return (False, None)
  986. def __infer__(self, x, multiples):
  987. multiples_v = multiples['value']
  988. x_shp = x['shape']
  989. validator.check_value_type("shape", multiples_v, [tuple], self.name)
  990. for i, multiple in enumerate(multiples_v):
  991. validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
  992. validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name)
  993. len_sub = len(multiples_v) - len(x_shp)
  994. multiples_w = None
  995. if len_sub == 0:
  996. multiples_w = multiples_v
  997. if len_sub > 0:
  998. for i in range(0, len_sub):
  999. x_shp.insert(0, 1)
  1000. multiples_w = multiples_v
  1001. elif len_sub < 0:
  1002. raise ValueError(f'For \'{self.name}\' the length of multiples can not be smaller than '
  1003. f'the length of dimension in input_x.')
  1004. for i, a in enumerate(multiples_w):
  1005. x_shp[i] *= a
  1006. value = None
  1007. if x['value'] is not None:
  1008. value = Tensor(np.tile(x['value'].asnumpy(), multiples_w))
  1009. return {'shape': x_shp,
  1010. 'dtype': x['dtype'],
  1011. 'value': value}
  1012. class UnsortedSegmentSum(PrimitiveWithInfer):
  1013. r"""
  1014. Computes the sum along segments of a tensor.
  1015. Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
  1016. :math:`j` is a tuple describing the index of element in data. `segment_ids` selects which elements in data to sum
  1017. up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
  1018. range.
  1019. If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
  1020. is negative, the value will be ignored. 'num_segments' should be equal to the number of different segment_ids.
  1021. Inputs:
  1022. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
  1023. - **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R. Type must be int.
  1024. - **num_segments** (int) - Set :math:`z` as num_segments.
  1025. Outputs:
  1026. Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
  1027. Examples:
  1028. >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
  1029. >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
  1030. >>> num_segments = 4
  1031. >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
  1032. [3, 3, 4, 0]
  1033. """
  1034. @prim_attr_register
  1035. def __init__(self):
  1036. """init UnsortedSegmentSum"""
  1037. self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
  1038. def __infer__(self, x, segment_ids, num_segments):
  1039. x_type = x['dtype']
  1040. x_shp = x['shape']
  1041. validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
  1042. validator.check_value_type("x_shape", x_shp, [list], self.name)
  1043. x_shp_len = len(x_shp)
  1044. validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT, self.name)
  1045. segment_ids_shp = segment_ids['shape']
  1046. segment_ids_type = segment_ids['dtype']
  1047. validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
  1048. validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
  1049. segment_ids_shp_len = len(segment_ids_shp)
  1050. validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT, self.name)
  1051. validator.check(f'rank of input_x', len(x_shp),
  1052. 'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
  1053. for i, value in enumerate(segment_ids_shp):
  1054. validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
  1055. num_segments_v = num_segments['value']
  1056. validator.check_value_type('num_segments', num_segments_v, [int], self.name)
  1057. validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
  1058. shp = [num_segments_v]
  1059. shp += x_shp[segment_ids_shp_len:]
  1060. out = {'shape': shp,
  1061. 'dtype': mstype.tensor_type(x_type.element_type()),
  1062. 'value': None}
  1063. return out
  1064. class UnsortedSegmentMin(PrimitiveWithInfer):
  1065. """
  1066. Computes the minimum along segments of a tensor.
  1067. If the given segment_ids is negative, the value will be ignored.
  1068. Inputs:
  1069. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
  1070. - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`.
  1071. - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
  1072. Outputs:
  1073. Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
  1074. Examples:
  1075. >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
  1076. >>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
  1077. >>> num_segments = 2
  1078. >>> unsorted_segment_min = P.UnsortedSegmentMin()
  1079. >>> unsorted_segment_min(input_x, segment_ids, num_segments)
  1080. [[1., 2., 3.], [4., 2., 1.]]
  1081. """
  1082. @prim_attr_register
  1083. def __init__(self):
  1084. """init UnsortedSegmentMin"""
  1085. self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
  1086. def __infer__(self, x, segment_ids, num_segments):
  1087. x_type = x['dtype']
  1088. x_shape = x['shape']
  1089. segment_ids_shape = segment_ids['shape']
  1090. valid_type = [mstype.float16, mstype.float32, mstype.int32]
  1091. validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
  1092. validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
  1093. validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name)
  1094. validator.check(f'first shape of input_x', x_shape[0],
  1095. 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
  1096. num_segments_v = num_segments['value']
  1097. validator.check_value_type('num_segments', num_segments_v, [int], self.name)
  1098. validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
  1099. segment_ids_shape_len = len(segment_ids_shape)
  1100. out_shape = [num_segments_v]
  1101. out_shape += x_shape[segment_ids_shape_len:]
  1102. out = {'shape': out_shape,
  1103. 'dtype': x_type,
  1104. 'value': None}
  1105. return out
  1106. class Concat(PrimitiveWithInfer):
  1107. r"""
  1108. Concat tensor in specified axis.
  1109. Concat input tensors along with the given axis.
  1110. Note:
  1111. The input data is a tuple of tensors. These tensors have the same rank `R`. Set the given axis as `m`, and
  1112. :math:`0 \le m < N`. Set the number of input tensors as `N`. For the :math:`i`-th tensor :math:`t_i` has
  1113. the shape :math:`(x_1, x_2, ..., x_{mi}, ..., x_R)`. :math:`x_{mi}` is the :math:`m`-th dimension of the
  1114. :math:`i`-th tensor. Then, the output tensor shape is
  1115. .. math::
  1116. (x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)
  1117. Args:
  1118. axis (int): The specified axis. Default: 0.
  1119. Inputs:
  1120. - **input_x** (tuple, list) - Tuple or list of input tensors.
  1121. Outputs:
  1122. Tensor, the shape is :math:`(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)`.
  1123. Examples:
  1124. >>> data1 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
  1125. >>> data2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
  1126. >>> op = P.Concat()
  1127. >>> output = op((data1, data2))
  1128. """
  1129. @prim_attr_register
  1130. def __init__(self, axis=0):
  1131. """init Tile"""
  1132. validator.check_value_type("axis", axis, [int], self.name)
  1133. def __infer__(self, input_x):
  1134. axis = self.axis
  1135. x_shp = input_x['shape']
  1136. x_type = input_x['dtype']
  1137. _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
  1138. self.add_prim_attr('T', x_type[0].element_type())
  1139. self.add_prim_attr('inputNums', len(x_shp))
  1140. ret_shp = x_shp[0].copy()
  1141. ret_shp[axis] = all_shp
  1142. out = {'shape': ret_shp,
  1143. 'dtype': x_type[0],
  1144. 'value': None}
  1145. return out
  1146. def _get_pack_shape(x_shape, x_type, axis, prim_name):
  1147. """for pack output shape"""
  1148. validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
  1149. validator.check_integer("len of input_x", len(x_shape), 1, Rel.GT, prim_name)
  1150. validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
  1151. rank_base = len(x_shape[0])
  1152. N = len(x_shape)
  1153. out_shape = x_shape[0]
  1154. validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
  1155. if axis < 0:
  1156. axis = axis + rank_base + 1
  1157. for i in range(1, N):
  1158. validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
  1159. if x_shape[i] != x_shape[0]:
  1160. raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
  1161. out_shape.insert(axis, N)
  1162. return out_shape
  1163. class Pack(PrimitiveWithInfer):
  1164. r"""
  1165. Packs a list of tensors in specified axis.
  1166. Packs the list of input tensors with the same rank `R`, output is a tensor of rank `(R+1)`.
  1167. Given input tensors of shape :math:`(x_1, x_2, ..., x_R)`. Set the number of input tensors as `N`.
  1168. If :math:`0 \le axis`, the output tensor shape is :math:`(x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)`.
  1169. Args:
  1170. axis (int): Dimension along which to pack. Default: 0.
  1171. Negative values wrap around. The range is [-(R+1), R+1).
  1172. Inputs:
  1173. - **input_x** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
  1174. Outputs:
  1175. Tensor. A packed Tensor with the same type as `input_x`.
  1176. Raises:
  1177. TypeError: If the data types of elements in input_x are not the same.
  1178. ValueError: If length of input_x is not greater than 1;
  1179. or if axis is out of the range [-(R+1), R+1);
  1180. or if the shapes of elements in input_x are not the same.
  1181. Examples:
  1182. >>> data1 = Tensor(np.array([0, 1]).astype(np.float32))
  1183. >>> data2 = Tensor(np.array([2, 3]).astype(np.float32))
  1184. >>> pack = P.Pack()
  1185. >>> output = pack([data1, data2])
  1186. [[0, 1], [2, 3]]
  1187. """
  1188. @prim_attr_register
  1189. def __init__(self, axis=0):
  1190. """init Pack"""
  1191. validator.check_value_type("axis", axis, [int], self.name)
  1192. self.axis = axis
  1193. def __infer__(self, value):
  1194. x_shape = value['shape']
  1195. x_type = value['dtype']
  1196. self.add_prim_attr('num', len(x_shape))
  1197. all_shape = _get_pack_shape(x_shape, x_type, self.axis, self.name)
  1198. out = {'shape': all_shape,
  1199. 'dtype': x_type[0],
  1200. 'value': None}
  1201. return out
  1202. class Unpack(PrimitiveWithInfer):
  1203. r"""
  1204. Unpacks tensor in specified axis.
  1205. Unpacks a tensor of rank `R` along axis dimension, output tensors will have rank `(R-1)`.
  1206. Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
  1207. the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
  1208. This is the opposite of pack.
  1209. Args:
  1210. axis (int): Dimension along which to pack. Default: 0.
  1211. Negative values wrap around. The range is [-R, R).
  1212. Inputs:
  1213. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
  1214. A rank R > 0 Tensor to be unpacked.
  1215. Outputs:
  1216. A tuple of Tensors, the shape of each objects is same.
  1217. Raises:
  1218. ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
  1219. Examples:
  1220. >>> unpack = P.Unpack()
  1221. >>> input_x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
  1222. >>> output = unpack(input_x)
  1223. ([1, 1, 1, 1], [2, 2, 2, 2])
  1224. """
  1225. @prim_attr_register
  1226. def __init__(self, axis=0):
  1227. """init Unpack"""
  1228. validator.check_value_type("axis", axis, [int], self.name)
  1229. self.axis = axis
  1230. def __infer__(self, x):
  1231. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  1232. x_shape = list(x['shape'])
  1233. dim = len(x_shape)
  1234. validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
  1235. if self.axis < 0:
  1236. self.axis = self.axis + dim
  1237. output_num = x_shape[self.axis]
  1238. validator.check_value_type("num", output_num, [int], self.name)
  1239. validator.check_integer("output_num", output_num, 0, Rel.GT, self.name)
  1240. self.add_prim_attr('num', output_num)
  1241. output_valid_check = x_shape[self.axis] - output_num
  1242. validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
  1243. self.name)
  1244. out_shapes = []
  1245. out_dtypes = []
  1246. out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
  1247. for _ in range(output_num):
  1248. out_shapes.append(tuple(out_shape))
  1249. out_dtypes.append(x['dtype'])
  1250. out_shapes = tuple(out_shapes)
  1251. out_dtypes = tuple(out_dtypes)
  1252. out = {'shape': out_shapes,
  1253. 'dtype': out_dtypes,
  1254. 'value': None}
  1255. return out
  1256. class Slice(PrimitiveWithInfer):
  1257. """
  1258. Slice a tensor in specified shape.
  1259. Args:
  1260. x (Tensor): The target tensor.
  1261. begin (tuple): The beginning of the slice. Only constant value is allowed.
  1262. size (tuple): The size of the slice. Only constant value is allowed.
  1263. Returns:
  1264. Tensor.
  1265. Examples:
  1266. >>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
  1267. >>> [[3, 3, 3], [4, 4, 4]],
  1268. >>> [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
  1269. >>> type = P.Slice()(data, (1, 0, 0), (1, 1, 3))
  1270. """
  1271. @prim_attr_register
  1272. def __init__(self):
  1273. """init slice"""
  1274. self.init_prim_io_names(inputs=['x', 'begin', 'size'], outputs=['output'])
  1275. def __infer__(self, x, begin, size):
  1276. x_shape = x['shape']
  1277. x_shp_len = len(x_shape)
  1278. validator.check_const_input('begin', begin['value'], self.name)
  1279. validator.check_const_input('size', size['value'], self.name)
  1280. begin_v, size_v = begin['value'], size['value']
  1281. if begin_v is None or size_v is None:
  1282. return {'shape': None,
  1283. 'dtype': x['dtype'],
  1284. 'value': None}
  1285. for key, value in zip(('begin', 'size'), (begin_v, size_v)):
  1286. validator.check(f'len of {key}', len(value),
  1287. 'len x\'s dim', x_shp_len)
  1288. for i in range(x_shp_len):
  1289. if x_shape[i] < begin_v[i] + size_v[i]:
  1290. y = begin_v[i] + size_v[i]
  1291. raise ValueError("For '%s' slice shape can not bigger than orign shape %d, %d." %
  1292. (self.name, x_shape[i], y))
  1293. return {'shape': size_v,
  1294. 'dtype': x['dtype'],
  1295. 'value': None}
  1296. class Select(PrimitiveWithInfer):
  1297. r"""
  1298. Return the selected elements, either from input :math:`x` or input :math:`y`, depending on the `condition`.
  1299. Given a tensor as input, this operation inserts a dimension of 1 at the dimension,
  1300. if both :math:`x` and :math:`y` are none, the operation returns the coordinates of the true
  1301. element in the condition, the coordinates are returned as a two-dimensional
  1302. tensor, where the first dimension (row) represents the number of true elements
  1303. and the second dimension (columns) represents the coordinates of the true
  1304. elements. Keep in mind that the shape of the output tensor can vary depending
  1305. on how much of the true value is in the input. Indexes are output in row-first
  1306. order.
  1307. If neither is None, :math:`x` and :math:`y` must have the same shape. If :math:`x` and :math:`y` are
  1308. scalars, the conditional tensor must be a scalar. If :math:`x` and :math:`y` are
  1309. higher-demensional vectors, the condition must be a vector whose size matches the
  1310. first dimension of :math:`x`, or must have the same shape as :math:`y`.
  1311. The conditional tensor acts as an optional compensation (mask), which
  1312. determines whether the corresponding element / row in the output should be
  1313. selected from :math:`x` (if true) or :math:`y` (if false) based on the value of each
  1314. element.
  1315. If condition is a vector, then :math:`x` and :math:`y` are higher-demensional matrices, then it
  1316. chooses to copy that row (external dimensions) from :math:`x` and :math:`y`. If condition has
  1317. the same shape as :math:`x` and :math:`y`, you can choose to copy these elements from :math:`x`
  1318. and :math:`y`.
  1319. Inputs:
  1320. - **input_x** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N)`.
  1321. The condition tensor, decides whose element is chosen.
  1322. - **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
  1323. The first input tensor.
  1324. - **input_z** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
  1325. The second input tensor.
  1326. Outputs:
  1327. Tensor, has the same shape as input_y. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
  1328. Examples:
  1329. >>> select = P.Select()
  1330. >>> input_x = Tensor([True, False])
  1331. >>> input_y = Tensor([2,3], mindspore.float32)
  1332. >>> input_z = Tensor([1,2], mindspore.float32)
  1333. >>> select(input_x, input_y, input_z)
  1334. """
  1335. @prim_attr_register
  1336. def __init__(self):
  1337. """init"""
  1338. self.init_prim_io_names(inputs=['condition', 'x', 'y'], outputs=['output'])
  1339. def infer_shape(self, cond_shape, x_shape, y_shape):
  1340. if cond_shape != x_shape or x_shape != y_shape:
  1341. raise ValueError('The x_shape and y_shape must be the same as cond_shape.')
  1342. return x_shape
  1343. def infer_dtype(self, cond_type, x_type, y_type):
  1344. self.add_prim_attr('T', x_type)
  1345. validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
  1346. validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
  1347. validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name)
  1348. if x_type != y_type:
  1349. raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
  1350. return x_type
  1351. def infer_value(self, cond, x, y):
  1352. if cond is not None and x is not None and y is not None:
  1353. cond = cond.asnumpy()
  1354. x = x.asnumpy()
  1355. y = y.asnumpy()
  1356. out = np.where(cond, x, y)
  1357. return Tensor(out)
  1358. return None
  1359. def _compute_slicing_length(begin, end, stride, x_shape, i):
  1360. """Compute the length of the slicing."""
  1361. if i >= len(x_shape):
  1362. raise ValueError(f"For 'StridedSlice', When their is no new axis, the index length must be less or "
  1363. f"equal than the dim of x.")
  1364. x_dim = x_shape[i]
  1365. if stride > 0:
  1366. # When slicing forward, convert begin and end to positive numbers.
  1367. if begin >= x_dim or end < -x_dim:
  1368. # When slicing forward, if begin >= x_dim or end < -x_dim, the length of the slicing is 0.
  1369. slicing_length = 0
  1370. else:
  1371. if -x_dim <= begin < 0:
  1372. begin += x_dim
  1373. if begin < -x_dim:
  1374. # When slicing forward, if begin < -x_dim, set begin = 0, which means start from the 0th element.
  1375. begin = 0
  1376. if -x_dim <= end < 0:
  1377. end += x_dim
  1378. if end > x_dim:
  1379. # When slicing forward, if end > x_dim, set end = x_dims, which means slice to the last element.
  1380. end = x_dim
  1381. if begin >= end:
  1382. # When slicing forward, if begin >= end, the length of the slicing is 0.
  1383. slicing_length = 0
  1384. else:
  1385. slicing_length = 1 + (end - 1 - begin) // stride
  1386. else:
  1387. # When slicing backward, convert begin and end to negative numbers.
  1388. if begin < -x_dim or end >= x_dim:
  1389. # When slicing backward, if begin < -x_dim or end >= x_dim, the length of the slicing is 0.
  1390. slicing_length = 0
  1391. else:
  1392. if 0 <= begin < x_dim:
  1393. begin += -x_dim
  1394. if begin >= x_dim:
  1395. # When slicing backward, if begin >= x_dim, set begin = -1, which means start from the last element.
  1396. begin = -1
  1397. if 0 < end < x_dim:
  1398. end += -x_dim
  1399. if end < -x_dim - 1:
  1400. # When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means
  1401. # slicing to the 0th element.
  1402. end = -x_dim - 1
  1403. if begin <= end:
  1404. # When slicing backward, if begin <= end, the length of the slicing is 0.
  1405. slicing_length = 0
  1406. else:
  1407. slicing_length = 1 + (end + 1 - begin) // stride
  1408. return slicing_length
  1409. class StridedSlice(PrimitiveWithInfer):
  1410. r"""
  1411. Extracts a strided slice of a tensor.
  1412. Given an input tensor, this operation inserts a dimension of length 1 at the dimension.
  1413. This operation extracts a fragment of size (end-begin)/stride from the given
  1414. 'input_tensor'. Starting from the position specified by the begin, the fragment
  1415. continues adding stride to the index until all dimensions are not less than end.
  1416. Note:
  1417. The stride may be negative value, which causes reverse slicing.
  1418. The shape of `begin`, `end` and `strides` should be the same.
  1419. Args:
  1420. begin_mask (int): Starting index of the slice. Default: 0.
  1421. end_mask (int): Ending index of the slice. Default: 0.
  1422. ellipsis_mask (int): An int mask. Default: 0.
  1423. new_axis_mask (int): An int mask. Default: 0.
  1424. shrink_axis_mask (int): An int mask. Default: 0.
  1425. Inputs:
  1426. - **input_x** (Tensor) - The input Tensor.
  1427. - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
  1428. constant value is allowed.
  1429. - **end** (tuple[int]) - A tuple or which represents the maximum location where to stop.
  1430. Only constant value is allowed.
  1431. - **strides** (tuple[int]) - A tuple which represents the stride continuously added
  1432. before reach the maximum location. Only constant value is allowed.
  1433. Outputs:
  1434. Tensor.
  1435. Explain with the following example.
  1436. - In the 0th dim, begin is 1, end is 2, and strides is 1,
  1437. because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
  1438. Thus, return the element with :math:`index = 1` in 0th dim, i.e., [[3, 3, 3], [4, 4, 4]].
  1439. - In the 1st dim, similarly, the interval is :math:`[0,1)`.
  1440. Based on the return value of the 0th dim, return the element with :math:`index = 0`,
  1441. i.e., [3, 3, 3].
  1442. - In the 2nd dim, similarly, the interval is :math:`[0,3)`.
  1443. Based on the return value of the 1st dim, return the element with :math:`index = 0,1,2`,
  1444. i.e., [3, 3, 3].
  1445. - Finally, the output is [3, 3, 3].
  1446. Examples
  1447. >>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
  1448. >>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
  1449. >>> slice = P.StridedSlice()
  1450. >>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
  1451. >>> output.shape
  1452. (1, 1, 3)
  1453. >>> output
  1454. [[[3, 3, 3]]]
  1455. """
  1456. @prim_attr_register
  1457. def __init__(self,
  1458. begin_mask=0,
  1459. end_mask=0,
  1460. ellipsis_mask=0,
  1461. new_axis_mask=0,
  1462. shrink_axis_mask=0):
  1463. """Init StrideSlice"""
  1464. self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
  1465. validator.check_integer('begin_mask', begin_mask, 0, Rel.GE, self.name)
  1466. validator.check_integer('end_mask', end_mask, 0, Rel.GE, self.name)
  1467. validator.check_integer('ellipsis_mask', ellipsis_mask, 0, Rel.GE, self.name)
  1468. if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1:
  1469. raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.")
  1470. validator.check_integer('new_axis_mask', new_axis_mask, 0, Rel.GE, self.name)
  1471. validator.check_integer('shrink_axis_mask', shrink_axis_mask, 0, Rel.GE, self.name)
  1472. def __infer__(self, x, begin, end, strides):
  1473. begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
  1474. validator.check_value_type("begin", begin_v, [tuple], self.name)
  1475. validator.check_value_type("end", end_v, [tuple], self.name)
  1476. validator.check_value_type("strides", strides_v, [tuple], self.name)
  1477. if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
  1478. raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
  1479. f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
  1480. if tuple(filter(lambda x: x == 0, strides_v)):
  1481. raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.")
  1482. if len(end_v) != len(begin_v) or len(strides_v) != len(begin_v):
  1483. raise ValueError(f"For '{self.name}' the length of begin index: {begin_v}, end index: {end_v} and "
  1484. f"strides: {strides_v} must be equal.")
  1485. ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
  1486. value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type())
  1487. return {'shape': ret_shape,
  1488. 'dtype': x['dtype'],
  1489. 'value': value}
  1490. def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
  1491. """Compute the shape of the slicing."""
  1492. x_rank = len(x_shape)
  1493. slice_len = len(begin_v)
  1494. # After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'.
  1495. begin_pos = bin(self.begin_mask)[-1:1:-1]
  1496. end_pos = bin(self.end_mask)[-1:1:-1]
  1497. ellipsis_pos = bin(self.ellipsis_mask)[-1:1:-1]
  1498. new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
  1499. shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
  1500. ret_shape = []
  1501. i, j = 0, 0
  1502. has_ellipsis = False
  1503. while i < x_rank or j < slice_len:
  1504. if j < slice_len:
  1505. begin, end, stride = begin_v[j], end_v[j], strides_v[j]
  1506. if j < len(ellipsis_pos) and ellipsis_pos[j] == '1':
  1507. # When there is ellipsis, the latter part of the ellipsis will be processed separately.
  1508. has_ellipsis = True
  1509. break
  1510. if j < len(begin_pos) and begin_pos[j] == '1':
  1511. begin = -1 if strides_v[j] < 0 else 0
  1512. if j < len(end_pos) and end_pos[j] == '1':
  1513. end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
  1514. if j < len(new_axis_pos) and new_axis_pos[j] == '1':
  1515. ret_shape.append(1)
  1516. j += 1
  1517. continue
  1518. if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
  1519. if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
  1520. raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
  1521. f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
  1522. f"but got stride: {stride}, begin: {begin}.")
  1523. j += 1
  1524. i += 1
  1525. continue
  1526. else:
  1527. begin, end, stride = 0, x_shape[i], 1
  1528. slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
  1529. ret_shape.append(slicing_length)
  1530. i += 1
  1531. j += 1
  1532. if has_ellipsis:
  1533. # When there is ellipsis, handle the second half of the ellipsis split.
  1534. ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
  1535. len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
  1536. ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
  1537. j += 1
  1538. i += ellipsis_occupied_dims
  1539. while i < x_rank or j < slice_len:
  1540. begin, end, stride = begin_v[j], end_v[j], strides_v[j]
  1541. if j < len(begin_pos) and begin_pos[j] == '1':
  1542. begin = -1 if strides_v[j] < 0 else 0
  1543. if j < len(end_pos) and end_pos[j] == '1':
  1544. end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
  1545. if j < len(new_axis_pos) and new_axis_pos[j] == '1':
  1546. ret_shape.append(1)
  1547. j += 1
  1548. continue
  1549. if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
  1550. if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
  1551. raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
  1552. f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
  1553. f"but got stride: {stride}, begin: {begin}.")
  1554. j += 1
  1555. i += 1
  1556. continue
  1557. slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
  1558. ret_shape.append(slicing_length)
  1559. i += 1
  1560. j += 1
  1561. return ret_shape
  1562. class Diag(PrimitiveWithInfer):
  1563. r"""
  1564. Construct a diagonal tensor with a given diagonal values.
  1565. Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of
  1566. rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where:
  1567. :math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else.
  1568. Inputs:
  1569. - **input_x** (Tensor) - The input tensor.
  1570. Outputs:
  1571. Tensor.
  1572. Examples:
  1573. >>> input_x = Tensor([1, 2, 3, 4])
  1574. >>> diag = P.Diag()
  1575. >>> diag(input_x)
  1576. [[1, 0, 0, 0],
  1577. [0, 2, 0, 0],
  1578. [0, 0, 3, 0],
  1579. [0, 0, 0, 4]]
  1580. """
  1581. @prim_attr_register
  1582. def __init__(self):
  1583. """init Diag"""
  1584. def infer_dtype(self, x_type):
  1585. validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
  1586. return x_type
  1587. def infer_shape(self, x_shape):
  1588. validator.check("x rank", len(x_shape), "", 1, Rel.GE)
  1589. ret_shape = copy.deepcopy(x_shape)
  1590. ret_shape = ret_shape + ret_shape
  1591. return ret_shape
  1592. def infer_value(self, x):
  1593. if x is None:
  1594. return None
  1595. # do constant-folding only when x rank is 1
  1596. if len(x.shape) != 1:
  1597. return None
  1598. ret = np.diag(x.asnumpy())
  1599. return Tensor(ret)
  1600. class DiagPart(PrimitiveWithInfer):
  1601. r"""
  1602. Extract the diagonal part from given tensor.
  1603. Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor
  1604. of rank k with dimensions :math:`[D_1,..., D_k]` where:
  1605. :math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
  1606. Inputs:
  1607. - **input_x** (Tensor) - The input Tensor.
  1608. Outputs:
  1609. Tensor.
  1610. Examples
  1611. >>> input_x = Tensor([[1, 0, 0, 0],
  1612. >>> [0, 2, 0, 0],
  1613. >>> [0, 0, 3, 0],
  1614. >>> [0, 0, 0, 4]])
  1615. >>> diag_part = P.DiagPart()
  1616. >>> diag_part(input_x)
  1617. [1, 2, 3, 4]
  1618. """
  1619. @prim_attr_register
  1620. def __init__(self):
  1621. """init DiagPart"""
  1622. def infer_dtype(self, x_type):
  1623. validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
  1624. return x_type
  1625. def infer_shape(self, x_shape):
  1626. if len(x_shape) % 2 != 0 or \
  1627. not x_shape:
  1628. raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
  1629. f"with shapes {x_shape}")
  1630. length = len(x_shape) // 2
  1631. ret_shape = x_shape[0:length]
  1632. return ret_shape
  1633. def infer_value(self, x):
  1634. if x is None:
  1635. return None
  1636. # do constant-folding only when x rank is 2
  1637. if len(x.shape) != 2:
  1638. return None
  1639. ret = np.diag(x.asnumpy())
  1640. return Tensor(ret)
  1641. class Eye(PrimitiveWithInfer):
  1642. """
  1643. Creates a tensor with ones on the diagonal and zeros elsewhere.
  1644. Inputs:
  1645. - **n** (int) - Number of rows of returned tensor
  1646. - **m** (int) - Number of columns of returned tensor
  1647. - **t** (mindspore.dtype) - MindSpore's dtype, The data type of the returned tensor.
  1648. Outputs:
  1649. Tensor, a tensor with ones on the diagonal and zeros elsewhere.
  1650. Examples:
  1651. >>> eye = P.Eye()
  1652. >>> out_tensor = eye(2, 2, mindspore.int32)
  1653. """
  1654. @prim_attr_register
  1655. def __init__(self):
  1656. """init Eye"""
  1657. def infer_value(self, n, m, t):
  1658. validator.check_integer("n", n, 0, Rel.GT, self.name)
  1659. validator.check_integer("m", m, 0, Rel.GT, self.name)
  1660. args = {"dtype": t}
  1661. validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
  1662. np_type = mstype.dtype_to_nptype(t)
  1663. ret = np.eye(n, m, dtype=np_type)
  1664. return Tensor(ret)
  1665. class ScatterNd(PrimitiveWithInfer):
  1666. """
  1667. Scatters a tensor into a new tensor depending on the specified indices.
  1668. Creates an empty tensor, and set values by scattering the update tensor depending on indices.
  1669. Inputs:
  1670. - **indices** (Tensor) - The index of scattering in the new tensor. With int32 data type.
  1671. - **update** (Tensor) - The source Tensor to be scattered.
  1672. - **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
  1673. Outputs:
  1674. Tensor, the new tensor, has the same type as `update` and the same shape as `shape`.
  1675. Examples:
  1676. >>> op = P.ScatterNd()
  1677. >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
  1678. >>> update = Tensor(np.array([3.2, 1.1]), mindspore.float32)
  1679. >>> shape = (3, 3)
  1680. >>> output = op(indices, update, shape)
  1681. """
  1682. @prim_attr_register
  1683. def __init__(self):
  1684. """Init ScatterNd"""
  1685. self.init_prim_io_names(inputs=['indices', 'update', 'shape'], outputs=['output'])
  1686. def __infer__(self, indices, update, shape):
  1687. shp = shape['value']
  1688. validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
  1689. validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
  1690. validator.check_value_type("shape", shp, [tuple], self.name)
  1691. for i, x in enumerate(shp):
  1692. validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
  1693. indices_shape, update_shape = indices["shape"], update["shape"]
  1694. if indices_shape[0] != update_shape[0]:
  1695. raise ValueError(f'For \'{self.name}\' The indices_shape[0] and update_shape[0] must be equal.')
  1696. return {'shape': shp,
  1697. 'dtype': update['dtype'],
  1698. 'value': None}
  1699. class ResizeNearestNeighbor(PrimitiveWithInfer):
  1700. r"""
  1701. Resize the input tensor by using nearest neighbor algorithm.
  1702. Resize input tensor to given size by using nearest neighbor algorithm. The nearest
  1703. neighbor algorithm selects the value of the nearest point and does not consider the
  1704. values of neighboring points at all, yielding a piecewise-constant interpolant.
  1705. Args:
  1706. size (Union[tuple, list]): The target size. The dimension of size must be 2.
  1707. align_corners (bool): Whether the centers of the 4 corner pixels of the input
  1708. and output tensors are aligned. Default: False.
  1709. Inputs:
  1710. - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
  1711. Outputs:
  1712. Tensor, the shape of the output tensor is :math:`(N, NEW\_C, NEW\_H, W)`.
  1713. Examples:
  1714. >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  1715. >>> resize = P.ResizeNearestNeighbor((2, 2))
  1716. >>> output = resize(input_tensor)
  1717. """
  1718. @prim_attr_register
  1719. def __init__(self, size, align_corners=False):
  1720. """Init ResizeNearestNeighbor"""
  1721. validator.check_value_type("size", size, [tuple, list], self.name)
  1722. validator.check_value_type("align_corners", align_corners, [bool], self.name)
  1723. validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name)
  1724. for i, value in enumerate(size):
  1725. validator.check_integer(f'{i}th value of size', value, 0, Rel.GE, self.name)
  1726. self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
  1727. def infer_shape(self, x):
  1728. validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name)
  1729. return tuple(x)[:-2] + tuple(self.size)
  1730. def infer_dtype(self, x):
  1731. return x
  1732. class GatherNd(PrimitiveWithInfer):
  1733. """
  1734. Gathers slices from a tensor by indices.
  1735. Using given indices to gather slices from a tensor with a specified shape.
  1736. Inputs:
  1737. - **input_x** (Tensor) - The target tensor to gather values.
  1738. - **indices** (Tensor) - The index tensor.
  1739. Outputs:
  1740. Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].
  1741. Examples:
  1742. >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  1743. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  1744. >>> op = P.GatherNd()
  1745. >>> output = op(input_x, indices)
  1746. """
  1747. @prim_attr_register
  1748. def __init__(self):
  1749. """Init GatherNd"""
  1750. self.init_prim_io_names(inputs=['input_x', 'indices'], outputs=['y'])
  1751. def infer_shape(self, x_shape, indices_shape):
  1752. validator.check('the dimension of x', len(x_shape),
  1753. 'the dimension of indices', indices_shape[-1], Rel.GE, self.name)
  1754. return indices_shape[:-1] + x_shape[indices_shape[-1]:]
  1755. def infer_dtype(self, x_dtype, indices_dtype):
  1756. validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
  1757. validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
  1758. return x_dtype
  1759. class TensorScatterUpdate(PrimitiveWithInfer):
  1760. """
  1761. Update tensor value by using input indices and value.
  1762. Using given values to update tensor value, along with the input indices.
  1763. Inputs:
  1764. - **input_x** (Tensor) - The target tensor.
  1765. - **indices** (Tensor) - The index of input tensor whose data type is int32.
  1766. - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
  1767. and update.shape = indices.shape + input_x.shape[1:].
  1768. Outputs:
  1769. Tensor, has the same shape and type as `input_x`.
  1770. Examples:
  1771. >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  1772. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  1773. >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
  1774. >>> op = P.TensorScatterUpdate()
  1775. >>> output = op(input_x, indices, update)
  1776. """
  1777. @prim_attr_register
  1778. def __init__(self):
  1779. """Init TensorScatterUpdate"""
  1780. self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
  1781. def infer_shape(self, x_shape, indices_shape, value_shape):
  1782. validator.check('the dimension of x', len(x_shape),
  1783. 'the dimension of indices', indices_shape[-1], Rel.GE)
  1784. if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
  1785. raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
  1786. return x_shape
  1787. def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
  1788. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  1789. args = {"x": x_dtype, "value": value_dtype}
  1790. validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
  1791. return x_dtype
  1792. class ScatterUpdate(PrimitiveWithInfer):
  1793. """
  1794. Update tensor value by using input indices and value.
  1795. Using given values to update tensor value, along with the input indices.
  1796. Args:
  1797. use_locking (bool): Whether protect the assignment by a lock. Default: True.
  1798. Inputs:
  1799. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
  1800. - **indices** (Tensor) - The index of input tensor. With int32 data type.
  1801. - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
  1802. and update.shape = indices.shape + input_x.shape[1:].
  1803. Outputs:
  1804. Tensor, has the same shape and type as `input_x`.
  1805. Examples:
  1806. >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
  1807. >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
  1808. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  1809. >>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
  1810. >>> update = Tensor(np_update, mindspore.float32)
  1811. >>> op = P.ScatterUpdate()
  1812. >>> output = op(input_x, indices, update)
  1813. """
  1814. __mindspore_signature__ = (
  1815. ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  1816. ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
  1817. ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
  1818. )
  1819. @prim_attr_register
  1820. def __init__(self, use_locking=True):
  1821. """Init ScatterUpdate"""
  1822. validator.check_value_type('use_locking', use_locking, [bool], self.name)
  1823. self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
  1824. def infer_shape(self, x_shape, indices_shape, value_shape):
  1825. if indices_shape + x_shape[1:] != value_shape:
  1826. raise ValueError("For 'ScatterUpdate', input value are not match with input indices.")
  1827. return x_shape
  1828. def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
  1829. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  1830. args = {"x": x_dtype, "value": value_dtype}
  1831. validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
  1832. return x_dtype
  1833. class ScatterNdUpdate(PrimitiveWithInfer):
  1834. """
  1835. Update tensor value by using input indices and value.
  1836. Using given values to update tensor value, along with the input indices.
  1837. Args:
  1838. use_locking (bool): Whether protect the assignment by a lock. Default: True.
  1839. Inputs:
  1840. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
  1841. - **indices** (Tensor) - The index of input tensor, with int32 data type.
  1842. - **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
  1843. Outputs:
  1844. Tensor, has the same shape and type as `input_x`.
  1845. Examples:
  1846. >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
  1847. >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
  1848. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  1849. >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
  1850. >>> op = P.ScatterNdUpdate()
  1851. >>> output = op(input_x, indices, update)
  1852. """
  1853. __mindspore_signature__ = (
  1854. ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
  1855. ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
  1856. ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
  1857. )
  1858. @prim_attr_register
  1859. def __init__(self, use_locking=True):
  1860. """Init ScatterNdUpdate"""
  1861. validator.check_value_type('use_locking', use_locking, [bool], self.name)
  1862. self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
  1863. def infer_shape(self, x_shape, indices_shape, value_shape):
  1864. validator.check('the dimension of x', len(x_shape),
  1865. 'the dimension of indices', indices_shape[-1], Rel.GE)
  1866. if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
  1867. raise ValueError("For 'ScatterNdUpdate', input value are not match with input indices.")
  1868. return x_shape
  1869. def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
  1870. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  1871. args = {"x": x_dtype, "value": value_dtype}
  1872. validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
  1873. return x_dtype
  1874. def _check_scatter_shape(x_shape, indices_shape, updates_shape, prim_name):
  1875. if updates_shape and updates_shape != indices_shape + x_shape[1:]:
  1876. raise ValueError(f"For '{prim_name}', the shape of updates should be [] or "
  1877. f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
  1878. f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
  1879. class ScatterMax(PrimitiveWithInfer):
  1880. """
  1881. Update the value of the input tensor through the max operation.
  1882. Using given values to update tensor value through the max operation, along with the input indices.
  1883. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  1884. Args:
  1885. use_locking (bool): Whether protect the assignment by a lock. Default: True.
  1886. Inputs:
  1887. - **input_x** (Parameter) - The target parameter.
  1888. - **indices** (Tensor) - The index to do max operation whose data type should be mindspore.int32.
  1889. - **updates** (Tensor) - The tensor doing the maximum operation with `input_x`,
  1890. the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  1891. Outputs:
  1892. Parameter, the updated `input_x`.
  1893. Examples:
  1894. >>> input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32), name="input_x")
  1895. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  1896. >>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
  1897. >>> scatter_max = P.ScatterMax()
  1898. >>> output = scatter_max(input_x, indices, update)
  1899. [[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]
  1900. """
  1901. @prim_attr_register
  1902. def __init__(self, use_locking=True):
  1903. """Init ScatterMax"""
  1904. self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
  1905. validator.check_value_type('use_locking', use_locking, (bool,), self.name)
  1906. def infer_shape(self, x_shape, indices_shape, updates_shape):
  1907. _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
  1908. return x_shape
  1909. def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
  1910. validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
  1911. args = {"x": x_dtype, "updates": updates_dtype}
  1912. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1913. return x_dtype
  1914. class ScatterMin(PrimitiveWithInfer):
  1915. """
  1916. Update the value of the input tensor through the min operation.
  1917. Using given values to update tensor value through the min operation, along with the input indices.
  1918. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  1919. Args:
  1920. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  1921. Inputs:
  1922. - **input_x** (Parameter) - The target parameter.
  1923. - **indices** (Tensor) - The index to do min operation whose data type should be mindspore.int32.
  1924. - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
  1925. the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  1926. Outputs:
  1927. Parameter, the updated `input_x`.
  1928. Examples:
  1929. >>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="input_x")
  1930. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  1931. >>> update = Tensor(np.ones([2, 2, 3]), mindspore.float32)
  1932. >>> scatter_min = P.ScatterMin()
  1933. >>> output = scatter_min(input_x, indices, update)
  1934. [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
  1935. """
  1936. @prim_attr_register
  1937. def __init__(self, use_locking=False):
  1938. """Init ScatterMin"""
  1939. self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
  1940. validator.check_value_type('use_locking', use_locking, (bool,), self.name)
  1941. def infer_shape(self, x_shape, indices_shape, updates_shape):
  1942. _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
  1943. return x_shape
  1944. def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
  1945. validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
  1946. args = {"x": x_dtype, "updates": updates_dtype}
  1947. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1948. return x_dtype
  1949. class ScatterAdd(PrimitiveWithInfer):
  1950. """
  1951. Update the value of the input tensor through the add operation.
  1952. Using given values to update tensor value through the add operation, along with the input indices.
  1953. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  1954. Args:
  1955. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  1956. Inputs:
  1957. - **input_x** (Parameter) - The target parameter.
  1958. - **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32.
  1959. - **updates** (Tensor) - The tensor doing the add operation with `input_x`,
  1960. the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  1961. Outputs:
  1962. Parameter, the updated `input_x`.
  1963. Examples:
  1964. >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
  1965. >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
  1966. >>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32)
  1967. >>> scatter_add = P.ScatterAdd()
  1968. >>> output = scatter_add(input_x, indices, updates)
  1969. [[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
  1970. """
  1971. @prim_attr_register
  1972. def __init__(self, use_locking=False):
  1973. """Init ScatterAdd"""
  1974. validator.check_value_type('use_locking', use_locking, (bool,), self.name)
  1975. def infer_shape(self, x_shape, indices_shape, updates_shape):
  1976. _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
  1977. return x_shape
  1978. def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
  1979. validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
  1980. args = {'x': x_dtype, 'updates': updates_dtype}
  1981. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  1982. return x_dtype
  1983. class ScatterSub(PrimitiveWithInfer):
  1984. """
  1985. Update the value of the input tensor through the sub operation.
  1986. Using given values to update tensor value through the sub operation, along with the input indices.
  1987. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  1988. Args:
  1989. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  1990. Inputs:
  1991. - **input_x** (Parameter) - The target parameter.
  1992. - **indices** (Tensor) - The index to do sub operation whose data type should be mindspore.int32.
  1993. - **updates** (Tensor) - The tensor doing the sub operation with `input_x`,
  1994. the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  1995. Outputs:
  1996. Parameter, the updated `input_x`.
  1997. Examples:
  1998. >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), mindspore.float32), name="x")
  1999. >>> indices = Tensor(np.array([[0, 1]]), mindspore.int32)
  2000. >>> updates = Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32)
  2001. >>> scatter_sub = P.ScatterSub()
  2002. >>> output = scatter_sub(input_x, indices, updates)
  2003. [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]
  2004. """
  2005. @prim_attr_register
  2006. def __init__(self, use_locking=False):
  2007. """Init ScatterSub"""
  2008. validator.check_value_type('use_locking', use_locking, (bool,), self.name)
  2009. def infer_shape(self, x_shape, indices_shape, updates_shape):
  2010. _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
  2011. return x_shape
  2012. def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
  2013. validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
  2014. args = {'x': x_dtype, 'updates': updates_dtype}
  2015. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  2016. return x_dtype
  2017. class SpaceToDepth(PrimitiveWithInfer):
  2018. r"""
  2019. Rearrange blocks of spatial data into depth.
  2020. The output tensor's `height` dimension is :math:`height / block\_size`.
  2021. The output tensor's `weight` dimension is :math:`weight / block\_size`.
  2022. The depth of output tensor is :math:`block\_size * block\_size * input\_depth`.
  2023. The input tensor's height and width must be divisible by `block_size`.
  2024. The data format is "NCHW".
  2025. Args:
  2026. block_size (int): The block size used to divide spatial data. It must be >= 2.
  2027. Inputs:
  2028. - **x** (Tensor) - The target tensor.
  2029. Outputs:
  2030. Tensor, the same type as `x`.
  2031. Examples:
  2032. >>> x = Tensor(np.random.rand(1,3,2,2), mindspore.float32)
  2033. >>> block_size = 2
  2034. >>> op = P.SpaceToDepth(block_size)
  2035. >>> output = op(x)
  2036. >>> output.asnumpy().shape == (1,12,1,1)
  2037. """
  2038. @prim_attr_register
  2039. def __init__(self, block_size):
  2040. """Init SpaceToDepth"""
  2041. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  2042. validator.check_value_type('block_size', block_size, [int], self.name)
  2043. validator.check('block_size', block_size, '', 2, Rel.GE)
  2044. self.block_size = block_size
  2045. def infer_shape(self, x_shape):
  2046. validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
  2047. out_shape = copy.deepcopy(x_shape)
  2048. for i in range(2):
  2049. if out_shape[i + 2] % self.block_size != 0:
  2050. raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be '
  2051. f'fully divided by block_size {self.block_size}')
  2052. out_shape[i + 2] //= self.block_size
  2053. out_shape[1] *= self.block_size * self.block_size
  2054. return out_shape
  2055. def infer_dtype(self, x_dtype):
  2056. validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
  2057. return x_dtype
  2058. class DepthToSpace(PrimitiveWithInfer):
  2059. r"""
  2060. Rearrange blocks of depth data into spatial dimensions.
  2061. This is the reverse operation of SpaceToDepth.
  2062. The output tensor's `height` dimension is :math:`height * block\_size`.
  2063. The output tensor's `weight` dimension is :math:`weight * block\_size`.
  2064. The depth of output tensor is :math:`input\_depth / (block\_size * block\_size)`.
  2065. The input tensor's depth must be divisible by `block_size * block_size`.
  2066. The data format is "NCHW".
  2067. Args:
  2068. block_size (int): The block size used to divide depth data. It must be >= 2.
  2069. Inputs:
  2070. - **x** (Tensor) - The target tensor.
  2071. Outputs:
  2072. Tensor, the same type as `x`.
  2073. Examples:
  2074. >>> x = Tensor(np.random.rand(1,12,1,1), mindspore.float32)
  2075. >>> block_size = 2
  2076. >>> op = P.DepthToSpace(block_size)
  2077. >>> output = op(x)
  2078. >>> output.asnumpy().shape == (1,3,2,2)
  2079. """
  2080. @prim_attr_register
  2081. def __init__(self, block_size):
  2082. """Init DepthToSpace"""
  2083. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  2084. validator.check_value_type('block_size', block_size, [int], self.name)
  2085. validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
  2086. self.block_size = block_size
  2087. def infer_shape(self, x_shape):
  2088. validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
  2089. out_shape = copy.deepcopy(x_shape)
  2090. for i in range(2):
  2091. out_shape[i + 2] *= self.block_size
  2092. validator.check_integer('x_shape[1] % (block_size*block_size)',
  2093. x_shape[1] % (self.block_size * self.block_size),
  2094. 0, Rel.EQ, self.name)
  2095. out_shape[1] //= self.block_size * self.block_size
  2096. return out_shape
  2097. def infer_dtype(self, x_dtype):
  2098. validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
  2099. return x_dtype
  2100. class SpaceToBatch(PrimitiveWithInfer):
  2101. r"""
  2102. Divide spatial dimensions into blocks and combine the block size with the original batch.
  2103. This operation will divide spatial dimensions (H, W) into blocks with block_size, the output tensor's H and W
  2104. dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
  2105. product of the original batch and the square of block_size. Prior to division into blocks, the spatial dimensions
  2106. of the input are zero padded according to paddings if necessary.
  2107. Args:
  2108. block_size (int): The block size of dividing block with value >= 2.
  2109. paddings (list): The padding value for H and W dimension, containing 2 sub list, each containing 2 int value.
  2110. All values must be >= 0. paddings[i] specifies the paddings for spatial dimension i, which corresponds to
  2111. input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible
  2112. by block_size.
  2113. Inputs:
  2114. - **input_x** (Tensor) - The input tensor.
  2115. Outputs:
  2116. Tensor, the output tensor with the same type as input. Assume input shape is :math:`(n, c, h, w)` with
  2117. :math:`block\_size` and :math:`padddings`. The output tensor shape will be :math:`(n', c', h', w')`, where
  2118. :math:`n' = n*(block\_size*block\_size)`
  2119. :math:`c' = c`
  2120. :math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_size`
  2121. :math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_size`
  2122. Examples:
  2123. >>> block_size = 2
  2124. >>> paddings = [[0, 0], [0, 0]]
  2125. >>> space_to_batch = P.SpaceToBatch(block_size, paddings)
  2126. >>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
  2127. >>> space_to_batch(input_x)
  2128. [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
  2129. """
  2130. @prim_attr_register
  2131. def __init__(self, block_size, paddings):
  2132. """Init SpaceToBatch"""
  2133. validator.check_value_type('block_size', block_size, [int], self.name)
  2134. validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
  2135. self.block_size = block_size
  2136. validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name)
  2137. for elem in itertools.chain(*paddings):
  2138. validator.check_integer('paddings element', elem, 0, Rel.GE, self.name)
  2139. validator.check_value_type('paddings element', elem, [int], self.name)
  2140. self.paddings = paddings
  2141. def infer_dtype(self, x_dtype):
  2142. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2143. return x_dtype
  2144. def infer_shape(self, x_shape):
  2145. validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
  2146. out_shape = copy.deepcopy(x_shape)
  2147. for i in range(2):
  2148. padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1]
  2149. if padded % self.block_size != 0:
  2150. raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
  2151. f'block_size {self.block_size}')
  2152. out_shape[i + 2] = padded // self.block_size
  2153. out_shape[0] *= self.block_size * self.block_size
  2154. return out_shape
  2155. class BatchToSpace(PrimitiveWithInfer):
  2156. r"""
  2157. Divide batch dimension with blocks and interleaves these blocks back into spatial dimensions.
  2158. This operation will divide batch dimension N into blocks with block_size, the output tensor's N dimension
  2159. is the corresponding number of blocks after division. The output tensor's H, W dimension is product of original H, W
  2160. dimension and block_size with given amount to crop from dimension, respectively.
  2161. Args:
  2162. block_size (int): The block size of dividing block with value >= 2.
  2163. crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
  2164. All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
  2165. input dimension i+2. It is required that input_shape[i+2]*block_size >= crops[i][0]+crops[i][1].
  2166. Inputs:
  2167. - **input_x** (Tensor) - The input tensor.
  2168. Outputs:
  2169. Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_size
  2170. and crops. The output shape will be (n', c', h', w'), where
  2171. :math:`n' = n//(block\_size*block\_size)`
  2172. :math:`c' = c`
  2173. :math:`h' = h*block\_size-crops[0][0]-crops[0][1]`
  2174. :math:`w' = w*block\_size-crops[1][0]-crops[1][1]`
  2175. Examples:
  2176. >>> block_size = 2
  2177. >>> crops = [[0, 0], [0, 0]]
  2178. >>> op = P.BatchToSpace(block_size, crops)
  2179. >>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
  2180. >>> output = op(input_x)
  2181. [[[[1., 2.], [3., 4.]]]]
  2182. """
  2183. @prim_attr_register
  2184. def __init__(self, block_size, crops):
  2185. """Init BatchToSpace"""
  2186. validator.check_value_type('block_size', block_size, [int], self.name)
  2187. validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
  2188. self.block_size = block_size
  2189. validator.check('crops shape', np.array(crops).shape, '', (2, 2))
  2190. for elem in itertools.chain(*crops):
  2191. validator.check_integer('crops element', elem, 0, Rel.GE, self.name)
  2192. validator.check_value_type('crops element', elem, [int], self.name)
  2193. self.crops = crops
  2194. def infer_dtype(self, x_dtype):
  2195. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2196. return x_dtype
  2197. def infer_shape(self, x_shape):
  2198. validator.check('rank of input_x', len(x_shape), '', 4)
  2199. out_shape = copy.deepcopy(x_shape)
  2200. for i in range(2):
  2201. x_block_prod = out_shape[i + 2] * self.block_size
  2202. crops_sum = self.crops[i][0] + self.crops[i][1]
  2203. validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
  2204. out_shape[i + 2] = x_block_prod - crops_sum
  2205. block_size_prod = self.block_size * self.block_size
  2206. if out_shape[0] % block_size_prod != 0:
  2207. raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
  2208. f'block_size_prod {block_size_prod}')
  2209. out_shape[0] = out_shape[0] // block_size_prod
  2210. return out_shape
  2211. class SpaceToBatchND(PrimitiveWithInfer):
  2212. r"""
  2213. Divide spatial dimensions into blocks and combine the block size with the original batch.
  2214. This operation will divide spatial dimensions (H, W) into blocks with block_shape, the output tensor's H and W
  2215. dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
  2216. product of the original batch and the product of block_shape. Prior to division into blocks, the spatial dimensions
  2217. of the input are zero padded according to paddings if necessary.
  2218. Args:
  2219. block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1.
  2220. The length of block_shape is M correspoding to the number of spatial dimensions.
  2221. paddings (list): The padding value for H and W dimension, containing M sub list, each containing 2 int value.
  2222. All values must be >= 0. paddings[i] specifies the paddings for spatial dimension i, which corresponds to
  2223. input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible
  2224. by block_shape[i].
  2225. Inputs:
  2226. - **input_x** (Tensor) - The input tensor.
  2227. Outputs:
  2228. Tensor, the output tensor with the same type as input. Assume input shape is :math:`(n, c, h, w)` with
  2229. :math:`block\_shape` and :math:`padddings`. The output tensor shape will be :math:`(n', c', h', w')`, where
  2230. :math:`n' = n*(block\_shape[0]*block\_shape[1])`
  2231. :math:`c' = c`
  2232. :math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]`
  2233. :math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]`
  2234. Examples:
  2235. >>> block_shape = [2, 2]
  2236. >>> paddings = [[0, 0], [0, 0]]
  2237. >>> space_to_batch_nd = P.SpaceToBatchND(block_shape, paddings)
  2238. >>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
  2239. >>> space_to_batch_nd(input_x)
  2240. [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
  2241. """
  2242. @prim_attr_register
  2243. def __init__(self, block_shape, paddings):
  2244. """Init SpaceToBatchND"""
  2245. validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
  2246. validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
  2247. block_rank = len(block_shape)
  2248. for elem in block_shape:
  2249. validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
  2250. validator.check_value_type('block_shape element', elem, [int], self.name)
  2251. self.block_shape = block_shape
  2252. validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
  2253. for elem in itertools.chain(*paddings):
  2254. validator.check_integer('paddings element', elem, 0, Rel.GE, self.name)
  2255. validator.check_value_type('paddings element', elem, [int], self.name)
  2256. self.paddings = paddings
  2257. def infer_dtype(self, x_dtype):
  2258. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2259. return x_dtype
  2260. def infer_shape(self, x_shape):
  2261. x_rank = len(x_shape)
  2262. validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
  2263. out_shape = copy.deepcopy(x_shape)
  2264. block_shape_prod = 1
  2265. offset = 2
  2266. if x_rank < 4:
  2267. offset = 1
  2268. for i in range(len(self.block_shape)):
  2269. padded = out_shape[i + offset] + self.paddings[i][0] + \
  2270. self.paddings[i][1]
  2271. if padded % self.block_shape[i] != 0:
  2272. raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
  2273. f'block_shape[{i}] {self.block_shape[i]}')
  2274. out_shape[i + offset] = padded // self.block_shape[i]
  2275. block_shape_prod = block_shape_prod * self.block_shape[i]
  2276. out_shape[0] *= block_shape_prod
  2277. return out_shape
  2278. class BatchToSpaceND(PrimitiveWithInfer):
  2279. r"""
  2280. Divide batch dimension with blocks and interleaves these blocks back into spatial dimensions.
  2281. This operation will divide batch dimension N into blocks with block_shape, the output tensor's N dimension
  2282. is the corresponding number of blocks after division. The output tensor's H, W dimension is product of original H, W
  2283. dimension and block_shape with given amount to crop from dimension, respectively.
  2284. Args:
  2285. block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1.
  2286. The length of block_shape is M correspoding to the number of spatial dimensions.
  2287. crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
  2288. All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
  2289. input dimension i+2. It is required that input_shape[i+2]*block_shape[i] > crops[i][0]+crops[i][1].
  2290. Inputs:
  2291. - **input_x** (Tensor) - The input tensor.
  2292. Outputs:
  2293. Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape
  2294. and crops. The output shape will be (n', c', h', w'), where
  2295. :math:`n' = n//(block\_shape[0]*block\_shape[1])`
  2296. :math:`c' = c`
  2297. :math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
  2298. :math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
  2299. Examples:
  2300. >>> block_shape = [2, 2]
  2301. >>> crops = [[0, 0], [0, 0]]
  2302. >>> batch_to_space_nd = P.BatchToSpaceND(block_shape, crops)
  2303. >>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
  2304. >>> output = batch_to_space_nd(input_x)
  2305. [[[[1., 2.], [3., 4.]]]]
  2306. """
  2307. @prim_attr_register
  2308. def __init__(self, block_shape, crops):
  2309. """Init BatchToSpaceND"""
  2310. validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
  2311. validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
  2312. block_rank = len(block_shape)
  2313. for elem in block_shape:
  2314. validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
  2315. validator.check_value_type('block_shape element', elem, [int], self.name)
  2316. self.block_shape = block_shape
  2317. validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
  2318. for elem in itertools.chain(*crops):
  2319. validator.check_integer('crops element', elem, 0, Rel.GE, self.name)
  2320. validator.check_value_type('crops element', elem, [int], self.name)
  2321. self.crops = crops
  2322. def infer_dtype(self, x_dtype):
  2323. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2324. return x_dtype
  2325. def infer_shape(self, x_shape):
  2326. x_rank = len(x_shape)
  2327. validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
  2328. out_shape = copy.deepcopy(x_shape)
  2329. block_shape_prod = 1
  2330. offset = 2
  2331. if x_rank < 4:
  2332. offset = 1
  2333. for i in range(len(self.block_shape)):
  2334. block_shape_prod = block_shape_prod * self.block_shape[i]
  2335. x_block_prod = out_shape[i + offset] * self.block_shape[i]
  2336. crops_sum = self.crops[i][0] + self.crops[i][1]
  2337. validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
  2338. out_shape[i + offset] = x_block_prod - crops_sum
  2339. if out_shape[0] % block_shape_prod != 0:
  2340. raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
  2341. f'block_shape_prod {block_shape_prod}')
  2342. out_shape[0] = out_shape[0] // block_shape_prod
  2343. return out_shape
  2344. class BroadcastTo(PrimitiveWithInfer):
  2345. """
  2346. Broadcasts input tensor to a given shape.
  2347. Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
  2348. When input shape is broadcast to target shape, it starts with the trailing dimensions.
  2349. Args:
  2350. shape (tuple): The target shape to broadcast.
  2351. Inputs:
  2352. - **input_x** (Tensor) - The input tensor.
  2353. Outputs:
  2354. Tensor, with the given `shape` and the same data type as `input_x`.
  2355. Examples:
  2356. >>> shape = (2, 3)
  2357. >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
  2358. >>> broadcast_to = P.BroadcastTo(shape)
  2359. >>> broadcast_to(input_x)
  2360. [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
  2361. """
  2362. @prim_attr_register
  2363. def __init__(self, shape):
  2364. """Init BroadcastTo"""
  2365. validator.check_value_type("shape", shape, (tuple), self.name)
  2366. validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
  2367. for i in shape:
  2368. validator.check_integer("shape element", i, 0, Rel.GT, self.name)
  2369. self.shape = shape
  2370. def infer_shape(self, x_shape):
  2371. validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
  2372. reversed_x_shape = tuple(reversed(x_shape))
  2373. reversed_target = tuple(reversed(self.shape))
  2374. for i, v in enumerate(reversed_x_shape):
  2375. if v not in (reversed_target[i], 1):
  2376. raise ValueError(f"Not supported shapes for broadcast, "
  2377. f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
  2378. return self.shape
  2379. def infer_dtype(self, x_dtype):
  2380. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  2381. return x_dtype
  2382. class InplaceUpdate(PrimitiveWithInfer):
  2383. r"""
  2384. Updates specified rows with values in `v`.
  2385. Args:
  2386. indices (Union[int, tuple]): Indices into the left-most dimension of `x`.
  2387. Inputs:
  2388. - **x** (Tensor) - A tensor which to be inplace updated. It can be of the following data types:
  2389. float32, float16, int32.
  2390. - **v** (Tensor) - A tensor of the same type as `x`. Same dimension size as `x` except
  2391. the first dimension, which must be the same as the size of `indices`.
  2392. Outputs:
  2393. Tensor, with the same type and shape as the input `x`.
  2394. Examples:
  2395. >>> x = Tensor(np.arange(24).reshape(3, 4, 2), mindspore.float32)
  2396. >>> v = Tensor(np.arange(-8, 8).reshape(2, 4, 2), mindspore.float32)
  2397. >>> inplace_update = P.InplaceUpdate((0, 2))
  2398. >>> result = inplace_update(x, v)
  2399. [[[-8. -7.]
  2400. [-6. -5.]
  2401. [-4. -3.]
  2402. [-2. -1.]]
  2403. [[ 8. 9.]
  2404. [10. 11.]
  2405. [12. 13.]
  2406. [14. 15.]]
  2407. [[ 0. 1.]
  2408. [ 2. 3.]
  2409. [ 4. 5.]
  2410. [ 6. 7.]]]
  2411. """
  2412. @prim_attr_register
  2413. def __init__(self, indices):
  2414. """Init InplaceUpdate"""
  2415. self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
  2416. self.indices = indices
  2417. validator.check_value_type("indices", indices, [int, tuple], self.name)
  2418. if isinstance(indices, int):
  2419. self.indices = (indices,)
  2420. for item in self.indices:
  2421. validator.check_value_type("item of indices", item, [int], self.name)
  2422. def infer_dtype(self, x_dtype, v_dtype):
  2423. args = {'x': x_dtype, 'v': v_dtype}
  2424. valid_type = [mstype.int32, mstype.float16, mstype.float32]
  2425. validator.check_tensor_type_same(args, valid_type, self.name)
  2426. return x_dtype
  2427. def infer_shape(self, x_shape, v_shape):
  2428. validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
  2429. validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
  2430. Rel.EQ, self.name)
  2431. for i in self.indices:
  2432. if i < 0 or i >= x_shape[0]:
  2433. raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
  2434. x_rank = len(x_shape)
  2435. for idx in range(x_rank)[1:]:
  2436. validator.check("x dim %d" % idx, x_shape[idx], 'v dim %d' % idx, v_shape[idx], Rel.EQ, self.name)
  2437. return x_shape
  2438. class ReverseSequence(PrimitiveWithInfer):
  2439. """
  2440. Reverses variable length slices.
  2441. Args:
  2442. seq_dim (int): The dimension along which reversal is performed. Required.
  2443. batch_dim (int): The input is sliced along this dimmension. Default: 0.
  2444. Inputs:
  2445. - **x** (Tensor) - The input to reverse, support all number types including bool.
  2446. - **seq_lengths** (Tensor) - Must be 1-D vector with types: int32, int64.
  2447. Outputs:
  2448. Reversed tensor with the same shape and data type as input.
  2449. Examples:
  2450. >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
  2451. >>> seq_lengths = Tensor(np.array([1, 2, 3]))
  2452. >>> reverse_sequence = P.ReverseSequence(seq_dim=1)
  2453. >>> output = reverse_sequence(x, seq_lengths)
  2454. [[1 2 3]
  2455. [5 4 6]
  2456. [9 8 7]]
  2457. """
  2458. @prim_attr_register
  2459. def __init__(self, seq_dim, batch_dim=0):
  2460. """init ReverseSequence"""
  2461. self.init_prim_io_names(inputs=['x', 'seq_lengths'], outputs=['y'])
  2462. validator.check_value_type("seq_dim", seq_dim, [int], self.name)
  2463. self.seq_dim_ = seq_dim
  2464. validator.check_value_type("batch_dim", batch_dim, [int], self.name)
  2465. self.batch_dim_ = batch_dim
  2466. def infer_shape(self, x, seq_lengths):
  2467. validator.check("seq_dim", self.seq_dim_, "x rank", len(x), Rel.LE, self.name)
  2468. validator.check("batch_dim", self.batch_dim_, "x rank", len(x), Rel.LE, self.name)
  2469. validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
  2470. validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
  2471. validator.check("seq_lengths vector size", seq_lengths[0],
  2472. "input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
  2473. return x
  2474. def infer_dtype(self, x, seq_lengths):
  2475. validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
  2476. validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
  2477. return x