You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops.py 166 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077
  1. # coding: utf-8
  2. # Copyright 2020 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. """Operators for array."""
  17. import copy
  18. import functools
  19. import itertools
  20. import numbers
  21. import numpy as np
  22. from .._utils import get_concat_offset
  23. from ..operations.math_ops import _infer_shape_reduce
  24. from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
  25. from .. import signature as sig
  26. from ..._checkparam import Rel
  27. from ..._checkparam import Validator as validator
  28. from ...common import dtype as mstype
  29. from ...common.parameter import Parameter
  30. from ...common.tensor import Tensor
  31. class _ScatterOp(PrimitiveWithInfer):
  32. """
  33. Defines Scatter operators
  34. """
  35. __mindspore_signature__ = (
  36. sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
  37. sig.make_sig('indices', dtype=sig.sig_dtype.T1),
  38. sig.make_sig('updates', dtype=sig.sig_dtype.T)
  39. )
  40. def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
  41. if updates_shape and updates_shape != indices_shape + x_shape[1:]:
  42. raise ValueError(f"For '{prim_name}', "
  43. f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
  44. f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
  45. @prim_attr_register
  46. def __init__(self, use_locking=False):
  47. """Initialize _ScatterOp"""
  48. validator.check_value_type('use_locking', use_locking, [bool], self.name)
  49. self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
  50. def infer_shape(self, x_shape, indices_shape, updates_shape):
  51. self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
  52. return x_shape
  53. def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
  54. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  55. args = {"x": x_dtype, "updates": updates_dtype}
  56. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  57. return x_dtype
  58. class _ScatterNdOp(_ScatterOp):
  59. """
  60. Defines _ScatterNd operators
  61. """
  62. def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
  63. validator.check('the dimension of x', len(x_shape),
  64. 'the dimension of indices', indices_shape[-1], Rel.GE)
  65. if indices_shape[:-1] + x_shape[indices_shape[-1]:] != updates_shape:
  66. raise ValueError(f"For '{prim_name}', updates_shape = "
  67. f"indices_shape[:-1] + x_shape[indices_shape[-1]:], but got x_shape: {x_shape}, "
  68. f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
  69. def _check_infer_attr_reduce(axis, keep_dims, prim_name):
  70. validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
  71. validator.check_value_type('axis', axis, [int, tuple], prim_name)
  72. if isinstance(axis, tuple):
  73. for index, value in enumerate(axis):
  74. validator.check_value_type('axis[%d]' % index, value, [int], prim_name)
  75. class ExpandDims(PrimitiveWithInfer):
  76. """
  77. Adds an additional dimension at the given axis.
  78. Note:
  79. If the specified axis is a negative number, the index is counted
  80. backward from the end and starts at 1.
  81. Raises:
  82. ValueError: If axis is not an integer or not in the valid range.
  83. Inputs:
  84. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  85. - **axis** (int) - Specifies the dimension index at which to expand
  86. the shape of `input_x`. The value of axis must be in the range
  87. `[-input_x.dim()-1, input_x.dim()]`. Only constant value is allowed.
  88. Outputs:
  89. Tensor, the shape of tensor is :math:`(1, x_1, x_2, ..., x_R)` if the
  90. value of `axis` is 0.
  91. Examples:
  92. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  93. >>> expand_dims = P.ExpandDims()
  94. >>> output = expand_dims(input_tensor, 0)
  95. [[[2.0, 2.0],
  96. [2.0, 2.0]]]
  97. """
  98. @prim_attr_register
  99. def __init__(self):
  100. """Initialize ExpandDims"""
  101. self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
  102. def __infer__(self, x, axis):
  103. validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
  104. x_shape = list(x['shape'])
  105. axis_v = axis['value']
  106. rank = len(x_shape)
  107. validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name)
  108. value = None
  109. if x['value'] is not None:
  110. value = x['value'].asnumpy()
  111. value = np.expand_dims(value, axis_v)
  112. value = Tensor(value)
  113. if axis_v < 0:
  114. axis_v = rank + 1 + axis_v
  115. x_shape.insert(axis_v, 1)
  116. out = {'shape': x_shape,
  117. 'dtype': x['dtype'],
  118. 'value': value}
  119. if 'min_shape' in x and 'max_shape' in x:
  120. out['min_shape'] = x['min_shape']
  121. out['min_shape'].insert(axis_v, 1)
  122. out['max_shape'] = x['max_shape']
  123. out['max_shape'].insert(axis_v, 1)
  124. return out
  125. class DType(PrimitiveWithInfer):
  126. """
  127. Returns the data type of input tensor as mindspore.dtype.
  128. Inputs:
  129. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  130. Outputs:
  131. mindspore.dtype, the data type of a tensor.
  132. Examples:
  133. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  134. >>> type = P.DType()(input_tensor)
  135. """
  136. @prim_attr_register
  137. def __init__(self):
  138. """Initialize DType"""
  139. def __infer__(self, x):
  140. validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
  141. out = {'shape': (),
  142. 'dtype': mstype.type_type,
  143. 'value': x['dtype'].element_type()}
  144. return out
  145. class SameTypeShape(PrimitiveWithInfer):
  146. """
  147. Checks whether data type and shape of two tensors are the same.
  148. Raises:
  149. TypeError: If the data types of two tensors are not the same.
  150. ValueError: If the shapes of two tensors are not the same.
  151. Inputs:
  152. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  153. - **input_y** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
  154. Outputs:
  155. Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
  156. if data type and shape of `input_x` and `input_y` are the same.
  157. Examples:
  158. >>> input_x = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  159. >>> input_y = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  160. >>> out = P.SameTypeShape()(input_x, input_y)
  161. [[2. 2.]
  162. [2. 2.]]
  163. """
  164. @prim_attr_register
  165. def __init__(self):
  166. """Initialize Same"""
  167. def __call__(self, x, y):
  168. """run in PyNative mode"""
  169. validator.check_value_type('x', x, Tensor, self.name)
  170. validator.check_value_type('y', y, Tensor, self.name)
  171. validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError)
  172. validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name)
  173. return x
  174. def __infer__(self, x, y):
  175. validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
  176. validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
  177. validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError)
  178. validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name)
  179. return x
  180. class Cast(PrimitiveWithInfer):
  181. """
  182. Returns a tensor with the new specified data type.
  183. Inputs:
  184. - **input_x** (Union[Tensor, Number]) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  185. The tensor to be cast.
  186. - **type** (dtype.Number) - The valid data type of the output tensor. Only constant value is allowed.
  187. Outputs:
  188. Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
  189. Examples:
  190. >>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
  191. >>> input_x = Tensor(input_np)
  192. >>> type_dst = mindspore.float16
  193. >>> cast = P.Cast()
  194. >>> result = cast(input_x, type_dst)
  195. """
  196. @prim_attr_register
  197. def __init__(self):
  198. # if primitive need setattr in __infer__ need add this flag
  199. """Initialize Cast"""
  200. self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
  201. def check_elim(self, x, dtype):
  202. if isinstance(x, (Tensor, numbers.Number, Parameter)):
  203. if isinstance(x, Tensor) and x.dtype == dtype:
  204. return (True, x)
  205. if isinstance(x, numbers.Number):
  206. return (True, Tensor(x, dtype=dtype))
  207. if isinstance(x, Parameter):
  208. data = x.data
  209. if data.dtype == dtype:
  210. return (True, x)
  211. return (False, None)
  212. def __infer__(self, x, t):
  213. src_type = x['dtype']
  214. dst_type = t['value']
  215. validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number], self.name)
  216. validator.check_subclass("type", dst_type, mstype.number, self.name)
  217. if isinstance(src_type, type(mstype.tensor)):
  218. src_type = x['dtype'].element_type()
  219. if isinstance(dst_type, type(mstype.tensor)):
  220. dst_type = dst_type.element_type()
  221. self.add_prim_attr('DstT', dst_type)
  222. self.add_prim_attr('SrcT', src_type)
  223. value = None
  224. if x['value'] is not None:
  225. np_dst_type = mstype.dtype_to_nptype(dst_type)
  226. if isinstance(x['value'], (int, float)):
  227. value = Tensor(np.array(x['value']).astype(np_dst_type))
  228. else:
  229. value = Tensor(x['value'].asnumpy().astype(np_dst_type))
  230. out = {'shape': x['shape'],
  231. 'dtype': mstype.tensor_type(t['value']),
  232. 'value': value}
  233. if 'min_shape' in x and 'max_shape' in x:
  234. out['min_shape'] = x['min_shape']
  235. out['max_shape'] = x['max_shape']
  236. return out
  237. class IsSubClass(PrimitiveWithInfer):
  238. """
  239. Checks whether one type is subtraction class of another type.
  240. Inputs:
  241. - **sub_type** (mindspore.dtype) - The type to be checked. Only constant value is allowed.
  242. - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
  243. Outputs:
  244. bool, the check result.
  245. Examples:
  246. >>> result = P.IsSubClass()(mindspore.int32, mindspore.intc)
  247. True
  248. """
  249. @prim_attr_register
  250. def __init__(self):
  251. pass
  252. def __infer__(self, sub_type, type_):
  253. sub_type_t = sub_type['value']
  254. type_v = type_['value']
  255. validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
  256. validator.check_value_type("type_", type_v, [mstype.Type], self.name)
  257. value = mstype.issubclass_(sub_type_t, type_v)
  258. out = {'shape': (),
  259. 'dtype': mstype.type_type,
  260. 'value': value}
  261. return out
  262. class IsInstance(PrimitiveWithInfer):
  263. """
  264. Checks whether an object is an instance of a target type.
  265. Inputs:
  266. - **inst** (Any Object) - The instance to be checked. Only constant value is allowed.
  267. - **type_** (mindspore.dtype) - The target type. Only constant value is allowed.
  268. Outputs:
  269. bool, the check result.
  270. Examples:
  271. >>> a = 1
  272. >>> result = P.IsInstance()(a, mindspore.int32)
  273. True
  274. """
  275. @prim_attr_register
  276. def __init__(self):
  277. pass
  278. def __infer__(self, inst, type_):
  279. sub_type_t = inst['dtype']
  280. type_v = type_['value']
  281. validator.check_value_type("type_", type_v, [mstype.Type], self.name)
  282. if type_v == mstype.list_:
  283. value = isinstance(sub_type_t, list)
  284. elif type_v == mstype.tuple_:
  285. value = isinstance(sub_type_t, tuple)
  286. else:
  287. value = mstype.issubclass_(sub_type_t, type_v)
  288. out = {'shape': (),
  289. 'dtype': mstype.type_type,
  290. 'value': value}
  291. return out
  292. class Reshape(PrimitiveWithInfer):
  293. """
  294. Reshapes input tensor with the same values based on a given shape tuple.
  295. Raises:
  296. ValueError: Given a shape tuple, if it has several -1; or if the product
  297. of its elements is less than or equal to 0 or cannot be divided by the product
  298. of the input tensor shape; or if it does not match the input's array size.
  299. Inputs:
  300. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  301. - **input_shape** (tuple[int]) - The input tuple is constructed by multiple
  302. integers, i.e., :math:`(y_1, y_2, ..., y_S)`. Only constant value is allowed.
  303. Outputs:
  304. Tensor, the shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  305. Examples:
  306. >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  307. >>> reshape = P.Reshape()
  308. >>> output = reshape(input_tensor, (3, 2))
  309. [[-0.1 0.3]
  310. [3.6 0.4 ]
  311. [0.5 -3.2]]
  312. """
  313. @prim_attr_register
  314. def __init__(self):
  315. """Initialize Reshape"""
  316. self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
  317. def __infer__(self, x, shape):
  318. shape_v = shape['value']
  319. x_shp = x['shape']
  320. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  321. validator.check_value_type("shape", shape_v, [tuple], self.name)
  322. shape_v = list(shape_v)
  323. neg_index = -1
  324. dim_prod = 1
  325. for i, shp_i in enumerate(shape_v):
  326. validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name)
  327. if shp_i == -1:
  328. if neg_index != -1:
  329. raise ValueError(f'The shape can only has one -1 at most, but {shape_v}.')
  330. neg_index = i
  331. else:
  332. dim_prod *= shp_i
  333. arr_prod = np.prod(x_shp)
  334. if dim_prod <= 0 or arr_prod % dim_prod != 0:
  335. raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
  336. f'The product of input_x\'s shape should > 0, '
  337. f'and can be divided by product of input_shape, '
  338. f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.')
  339. if neg_index != -1:
  340. shape_v[neg_index] = int(arr_prod / dim_prod)
  341. dim_prod *= shape_v[neg_index]
  342. if dim_prod != arr_prod:
  343. raise ValueError(f'For \'{self.name}\' input_x\'s shape is {x_shp}, input_shape\'s value is {shape_v}.'
  344. f'The product of input_x\'s shape should be equal to product of input_shape, '
  345. f'but product of input_x\'s shape is {arr_prod}, product of input_shape is {dim_prod}.')
  346. value = None
  347. if x['value'] is not None:
  348. value = Tensor(x['value'].asnumpy().reshape(shape_v))
  349. out = {'shape': tuple(shape_v),
  350. 'dtype': x['dtype'],
  351. 'value': value}
  352. return out
  353. class Shape(PrimitiveWithInfer):
  354. """
  355. Returns the shape of input tensor.
  356. Inputs:
  357. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  358. Outputs:
  359. tuple[int], the output tuple is constructed by multiple integers,
  360. :math:`(x_1, x_2, ..., x_R)`.
  361. Examples:
  362. >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
  363. >>> shape = P.Shape()
  364. >>> output = shape(input_tensor)
  365. (3, 2, 1)
  366. """
  367. @prim_attr_register
  368. def __init__(self):
  369. """Initialize Shape"""
  370. def __infer__(self, x):
  371. validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
  372. out = {'shape': (),
  373. 'dtype': mstype.tuple_,
  374. 'value': tuple(x['shape'])}
  375. return out
  376. class DynamicShape(Primitive):
  377. """
  378. Returns the shape of input tensor.
  379. Inputs:
  380. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  381. Outputs:
  382. Tensor[int], 1-dim Tensor of type int32
  383. Examples:
  384. >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
  385. >>> shape = P.DynamicShape()
  386. >>> output = shape(input_tensor)
  387. """
  388. @prim_attr_register
  389. def __init__(self):
  390. """init Shape"""
  391. self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
  392. self.add_prim_attr('is_dynamic_shape', True)
  393. self.add_prim_attr("dynamic_shape_depends", [0])
  394. class Squeeze(PrimitiveWithInfer):
  395. """
  396. Returns a tensor with the same type but dimensions of 1 are removed based on `axis`.
  397. Note:
  398. The dimension index starts at 0 and must be in the range `[-input.dim(), input.dim())`.
  399. Raises:
  400. ValueError: If the corresponding dimension of the specified axis does not equal to 1.
  401. Args:
  402. axis (Union[int, tuple(int)]): Specifies the dimension indexes of shape to be removed, which will remove
  403. all the dimensions that are equal to 1. If specified, it must be int32 or int64.
  404. Default: (), an empty tuple.
  405. Inputs:
  406. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  407. Outputs:
  408. Tensor, the shape of tensor is :math:`(x_1, x_2, ..., x_S)`.
  409. Examples:
  410. >>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
  411. >>> squeeze = P.Squeeze(2)
  412. >>> output = squeeze(input_tensor)
  413. [[1. 1.]
  414. [1. 1.]
  415. [1. 1.]]
  416. """
  417. @prim_attr_register
  418. def __init__(self, axis=()):
  419. """Initialize Squeeze"""
  420. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  421. validator.check_value_type('axis', axis, [int, tuple], self.name)
  422. if isinstance(axis, tuple):
  423. for idx, item in enumerate(axis):
  424. validator.check_value_type("axis[%d]" % idx, item, [int], self.name)
  425. else:
  426. self.axis = (axis,)
  427. self.add_prim_attr("axis", (axis,))
  428. def infer_shape(self, x_shape):
  429. axis = self.axis
  430. x_shape = list(x_shape)
  431. ndim = len(x_shape)
  432. if not axis:
  433. ret = [d for d in x_shape if d != 1]
  434. else:
  435. for a in axis:
  436. validator.check_int_range(a, -ndim, ndim - 1, Rel.INC_BOTH, 'axis or its elements', self.name)
  437. if x_shape[a] != 1:
  438. raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.')
  439. ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
  440. return ret
  441. def infer_dtype(self, x_dtype):
  442. validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
  443. return x_dtype
  444. class Transpose(PrimitiveWithInfer):
  445. """
  446. Permutes the dimensions of input tensor according to input permutation.
  447. Inputs:
  448. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  449. - **input_perm** (tuple[int]) - The permutation to be converted. The input tuple is constructed by multiple
  450. indexes. The length of `input_perm` and the shape of `input_x` must be the same. Only constant value is
  451. allowed. Must be in the range [0, rank(input_x)).
  452. Outputs:
  453. Tensor, the type of output tensor is the same as `input_x` and the shape of output tensor is decided by the
  454. shape of `input_x` and the value of `input_perm`.
  455. Examples:
  456. >>> input_tensor = Tensor(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]), mindspore.float32)
  457. >>> perm = (0, 2, 1)
  458. >>> transpose = P.Transpose()
  459. >>> output = transpose(input_tensor, perm)
  460. [[[1. 4.]
  461. [2. 5.]
  462. [3. 6.]]
  463. [[7. 10.]
  464. [8. 11.]
  465. [9. 12.]]]
  466. """
  467. @prim_attr_register
  468. def __init__(self):
  469. """Initialize Transpose"""
  470. self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
  471. def __infer__(self, x, perm):
  472. x_shape = x['shape']
  473. p_value = perm['value']
  474. x_type = x['dtype']
  475. validator.check_value_type("p_value", p_value, [tuple], self.name)
  476. validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
  477. if len(x_shape) != len(p_value):
  478. raise ValueError('The dimension of x and perm must be equal.')
  479. tmp = list(p_value)
  480. for i, dim in enumerate(p_value):
  481. validator.check_int(dim, 0, Rel.GE, f'perm[{i}]', self.name)
  482. validator.check_int(dim, len(p_value), Rel.LT, f'perm[{i}]', self.name)
  483. tmp.remove(dim)
  484. if dim in tmp:
  485. raise ValueError('The value of perm is wrong.')
  486. out_shapes = []
  487. for i in p_value:
  488. out_shapes.append(x_shape[i])
  489. out = {'shape': tuple(out_shapes),
  490. 'dtype': x['dtype'],
  491. 'value': None}
  492. return out
  493. class Unique(Primitive):
  494. """
  495. Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
  496. tensor corresponding to the output unique tensor.
  497. Inputs:
  498. - **x** (Tensor) - The input tensor.
  499. Outputs:
  500. Tuple, containing Tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor
  501. containing indices of elements in the input coressponding to the output tensor.
  502. Examples:
  503. >>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
  504. >>> out = P.Unique()(x)
  505. (Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.int32))
  506. """
  507. @prim_attr_register
  508. def __init__(self):
  509. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  510. class GatherV2(PrimitiveWithCheck):
  511. """
  512. Returns a slice of input tensor based on the specified indices and axis.
  513. Inputs:
  514. - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  515. The original Tensor.
  516. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  517. Specifies the indices of elements of the original Tensor. Must be in the range
  518. `[0, input_param.shape[axis])`.
  519. - **axis** (int) - Specifies the dimension index to gather indices.
  520. Outputs:
  521. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  522. Examples:
  523. >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
  524. >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
  525. >>> axis = 1
  526. >>> out = P.GatherV2()(input_params, input_indices, axis)
  527. [[2.0, 7.0],
  528. [4.0, 54.0],
  529. [2.0, 55.0]]
  530. """
  531. @prim_attr_register
  532. def __init__(self):
  533. """Initialize index_select"""
  534. self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
  535. self.add_prim_attr("dynamic_shape_depends", [2,])
  536. def __check__(self, params, indices, axis):
  537. validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
  538. validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
  539. validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
  540. axis_v = axis['value']
  541. params_shp = params['shape']
  542. rank = len(params_shp)
  543. validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
  544. if axis_v < 0:
  545. axis_v += rank
  546. out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
  547. out = {'shape': out_shape,
  548. 'dtype': params['dtype'],
  549. 'value': None}
  550. if 'min_shape' in indices and 'max_shape' in indices:
  551. out['min_shape'] = params_shp[:axis_v] + indices['min_shape'] + params_shp[axis_v + 1:]
  552. out['max_shape'] = params_shp[:axis_v] + indices['max_shape'] + params_shp[axis_v + 1:]
  553. return out
  554. class SparseGatherV2(GatherV2):
  555. """
  556. Returns a slice of input tensor based on the specified indices and axis.
  557. Inputs:
  558. - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  559. The original Tensor.
  560. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  561. Specifies the indices of elements of the original Tensor, must be in the range
  562. `[0, input_param.shape[axis])`.
  563. - **axis** (int) - Specifies the dimension index to gather indices.
  564. Outputs:
  565. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  566. Examples:
  567. >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
  568. >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
  569. >>> axis = 1
  570. >>> out = P.SparseGatherV2()(input_params, input_indices, axis)
  571. """
  572. class Padding(PrimitiveWithInfer):
  573. """
  574. Extends the last dimension of input tensor from 1 to pad_dim_size, by filling with 0.
  575. Args:
  576. pad_dim_size (int): The value of the last dimension of x to be extended, which must be positive.
  577. Inputs:
  578. - **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. The rank of x must be at least 2.
  579. The last dimension of x must be 1.
  580. Outputs:
  581. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  582. Examples:
  583. >>> x = Tensor(np.array([[8], [10]]), mindspore.float32)
  584. >>> pad_dim_size = 4
  585. >>> out = P.Padding(pad_dim_size)(x)
  586. [[8, 0, 0, 0], [10, 0, 0, 0]]
  587. """
  588. @prim_attr_register
  589. def __init__(self, pad_dim_size=8):
  590. """Initialize padding"""
  591. validator.check_value_type("pad_dim_size", pad_dim_size, [int], self.name)
  592. validator.check_positive_int(pad_dim_size, "pad_dim_size", self.name)
  593. self.pad_dim_size = pad_dim_size
  594. def __infer__(self, x):
  595. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  596. x_shape = list(x['shape'])
  597. validator.check_int(len(x_shape), 1, Rel.GT, "rank of x", self.name)
  598. validator.check_int(x_shape[-1], 1, Rel.EQ, "last dim of x", self.name)
  599. out_shape = x_shape
  600. out_shape[-1] = self.pad_dim_size
  601. out = {'shape': out_shape,
  602. 'dtype': x['dtype'],
  603. 'value': None}
  604. return out
  605. class Split(PrimitiveWithInfer):
  606. """
  607. Splits input tensor into output_num of tensors along the given axis and output numbers.
  608. Args:
  609. axis (int): Index of the split position. Default: 0.
  610. output_num (int): The number of output tensors. Default: 1.
  611. Raises:
  612. ValueError: If `axis` is out of the range [-len(`input_x.shape`), len(`input_x.shape`)),
  613. or if the `output_num` is less than or equal to 0, or if the
  614. dimension which to split cannot be evenly divided by `output_num`.
  615. Inputs:
  616. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  617. Outputs:
  618. tuple[Tensor], the shape of each output tensor is the same, which is
  619. :math:`(y_1, y_2, ..., y_S)`.
  620. Examples:
  621. >>> split = P.Split(1, 2)
  622. >>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
  623. >>> output = split(x)
  624. ([[1, 1],
  625. [2, 2]],
  626. [[1, 1],
  627. [2, 2]])
  628. """
  629. @prim_attr_register
  630. def __init__(self, axis=0, output_num=1):
  631. """Initialize Split"""
  632. validator.check_value_type("axis", axis, [int], self.name)
  633. validator.check_value_type("output_num", output_num, [int], self.name)
  634. self.axis = axis
  635. self.output_num = output_num
  636. def __infer__(self, x):
  637. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  638. x_shape = list(x['shape'])
  639. dim = len(x_shape)
  640. validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
  641. validator.check_positive_int(self.output_num, "output_num", self.name)
  642. output_valid_check = x_shape[self.axis] % self.output_num
  643. if output_valid_check != 0:
  644. raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by"
  645. f" output_num {self.output_num}")
  646. x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
  647. out_shapes = []
  648. out_dtypes = []
  649. for _ in range(self.output_num):
  650. out_shapes.append(tuple(x_shape))
  651. out_dtypes.append(x['dtype'])
  652. out_shapes = tuple(out_shapes)
  653. out_dtypes = tuple(out_dtypes)
  654. out = {'shape': out_shapes,
  655. 'dtype': out_dtypes,
  656. 'value': None}
  657. return out
  658. class Rank(PrimitiveWithInfer):
  659. """
  660. Returns the rank of a tensor.
  661. Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
  662. is the number of indices required to uniquely select each element of the tensor.
  663. Inputs:
  664. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  665. Outputs:
  666. Tensor. 0-D int32 Tensor representing the rank of input, i.e., :math:`R`.
  667. Examples:
  668. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  669. >>> rank = P.Rank()
  670. >>> rank(input_tensor)
  671. 2
  672. """
  673. @prim_attr_register
  674. def __init__(self):
  675. """Initialize Rank"""
  676. def __infer__(self, x):
  677. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  678. out = {'shape': None,
  679. 'dtype': None,
  680. 'value': len(x['shape'])}
  681. return out
  682. class TruncatedNormal(PrimitiveWithInfer):
  683. """
  684. Returns a tensor of the specified shape filled with truncated normal values.
  685. The generated values follow a normal distribution.
  686. Args:
  687. seed (int): A integer number used to create random seed. Default: 0.
  688. dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
  689. Inputs:
  690. - **shape** (tuple[int]) - The shape of the output tensor, is a tuple of positive integer.
  691. Outputs:
  692. Tensor, the data type of output tensor is the same as attribute `dtype`.
  693. Examples:
  694. >>> shape = (1, 2, 3)
  695. >>> truncated_normal = P.TruncatedNormal()
  696. >>> output = truncated_normal(shape)
  697. """
  698. @prim_attr_register
  699. def __init__(self, seed=0, dtype=mstype.float32):
  700. """Initialize TruncatedNormal"""
  701. validator.check_value_type('seed', seed, [int], self.name)
  702. validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name)
  703. def __infer__(self, shape):
  704. shape_value = shape['value']
  705. validator.check_value_type("shape", shape_value, [tuple], self.name)
  706. for i, value in enumerate(shape_value):
  707. validator.check_positive_int(value, f'{i}th value of shape', self.name)
  708. out = {'shape': shape_value,
  709. 'dtype': mstype.tensor_type(self.dtype),
  710. 'value': None}
  711. return out
  712. class Size(PrimitiveWithInfer):
  713. r"""
  714. Returns the elements count size of a tensor.
  715. Returns an int scalar representing the elements size of input, the total number of elements in the tensor.
  716. Inputs:
  717. - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  718. Outputs:
  719. int, a scalar representing the elements size of `input_x`, tensor is the number of elements
  720. in a tensor, :math:`size=x_1*x_2*...x_R`.
  721. Examples:
  722. >>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
  723. >>> size = P.Size()
  724. >>> output = size(input_tensor)
  725. 4
  726. """
  727. @prim_attr_register
  728. def __init__(self):
  729. """Initialize Size"""
  730. def __infer__(self, x):
  731. size = 1
  732. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  733. shp = x['shape']
  734. if not shp:
  735. size = 0
  736. else:
  737. size = functools.reduce(lambda x, y: x * y, x['shape'])
  738. out = {'shape': None,
  739. 'dtype': mstype.int32,
  740. 'value': size}
  741. return out
  742. class Fill(PrimitiveWithInfer):
  743. """
  744. Creates a tensor filled with a scalar value.
  745. Creates a tensor with shape described by the first argument and fills it with values in the second argument.
  746. Inputs:
  747. - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
  748. - **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
  749. - **value** (scalar) - Value to fill the returned tensor. Only constant value is allowed.
  750. Outputs:
  751. Tensor, has the same type and shape as input value.
  752. Examples:
  753. >>> fill = P.Fill()
  754. >>> fill(mindspore.float32, (2, 2), 1)
  755. [[1.0, 1.0],
  756. [1.0, 1.0]]
  757. """
  758. @prim_attr_register
  759. def __init__(self):
  760. """Initialize Fill"""
  761. def __infer__(self, dtype, dims, x):
  762. validator.check_value_type("shape", dims['value'], [tuple], self.name)
  763. validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
  764. for i, item in enumerate(dims['value']):
  765. validator.check_positive_int(item, f'dims[{i}]', self.name)
  766. valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
  767. mstype.uint8, mstype.uint32, mstype.uint64,
  768. mstype.float16, mstype.float32, mstype.float64]
  769. validator.check_type_same({"value": dtype['value']}, valid_types, self.name)
  770. x_nptype = mstype.dtype_to_nptype(dtype['value'])
  771. ret = np.full(dims['value'], x['value'], x_nptype)
  772. out = {
  773. 'value': Tensor(ret),
  774. 'shape': dims['value'],
  775. 'dtype': x['dtype'],
  776. }
  777. return out
  778. class OnesLike(PrimitiveWithInfer):
  779. """
  780. Creates a new tensor. The values of all elements are 1.
  781. Returns a tensor of ones with the same shape and type as the input.
  782. Inputs:
  783. - **input_x** (Tensor) - Input tensor.
  784. Outputs:
  785. Tensor, has the same shape and type as `input_x` but filled with ones.
  786. Examples:
  787. >>> oneslike = P.OnesLike()
  788. >>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
  789. >>> output = oneslike(x)
  790. [[1, 1],
  791. [1, 1]]
  792. """
  793. @prim_attr_register
  794. def __init__(self):
  795. """Initialize OnesLike"""
  796. def infer_shape(self, x_shape):
  797. return x_shape
  798. def infer_dtype(self, x_dtype):
  799. validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
  800. return x_dtype
  801. class ZerosLike(PrimitiveWithInfer):
  802. """
  803. Creates a new tensor. All elements value are 0.
  804. Returns a tensor of zeros with the same shape and data type as the input tensor.
  805. Inputs:
  806. - **input_x** (Tensor) - Input tensor.
  807. Outputs:
  808. Tensor, has the same shape and data type as `input_x` but filled with zeros.
  809. Examples:
  810. >>> zeroslike = P.ZerosLike()
  811. >>> x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
  812. >>> output = zeroslike(x)
  813. [[0.0, 0.0],
  814. [0.0, 0.0]]
  815. """
  816. @prim_attr_register
  817. def __init__(self):
  818. """Initialize ZerosLike"""
  819. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  820. def infer_shape(self, x_shape):
  821. return x_shape
  822. def infer_dtype(self, x_dtype):
  823. validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
  824. return x_dtype
  825. class TupleToArray(PrimitiveWithInfer):
  826. """
  827. Converts a tuple to a tensor.
  828. If the type of the first number in the tuple is integer, the data type of the output tensor is int.
  829. Otherwise, the data type of the output tensor is float.
  830. Inputs:
  831. - **input_x** (tuple) - A tuple of numbers. These numbers have the same type. Only constant value is allowed.
  832. Outputs:
  833. Tensor, if the input tuple contains `N` numbers, then the shape of the output tensor is (N,).
  834. Examples:
  835. >>> type = P.TupleToArray()((1,2,3))
  836. """
  837. @prim_attr_register
  838. def __init__(self):
  839. """Initialize TupleToArray"""
  840. def infer_value(self, x):
  841. validator.check_value_type("x", x, [tuple], self.name)
  842. validator.check("size of x", len(x), '', 0, Rel.GT, self.name)
  843. dtype = type(x[0])
  844. for i, item in enumerate(x):
  845. validator.check_value_type(f"x[{i}]", item, [numbers.Number], self.name)
  846. if not all(isinstance(item, dtype) for item in x):
  847. raise TypeError("For \'{self.name}\' all elements of input x must be have same type.")
  848. if isinstance(x[0], int):
  849. ret = np.array(x, np.int32)
  850. else:
  851. ret = np.array(x, np.float32)
  852. return Tensor(ret)
  853. def __call__(self, x):
  854. args = list()
  855. if isinstance(x, range):
  856. args.append(tuple(x))
  857. else:
  858. args.append(x)
  859. return _run_op(self, self.name, args)
  860. class ScalarToArray(PrimitiveWithInfer):
  861. """
  862. Converts a scalar to a `Tensor`.
  863. Inputs:
  864. - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
  865. Outputs:
  866. Tensor. 0-D Tensor and the content is the input.
  867. Examples:
  868. >>> op = P.ScalarToArray()
  869. >>> data = 1.0
  870. >>> output = op(data)
  871. """
  872. @prim_attr_register
  873. def __init__(self):
  874. pass
  875. def infer_value(self, x):
  876. validator.check_value_type("x", x, [int, float], self.name)
  877. if isinstance(x, int):
  878. ret = np.array(x, np.int32)
  879. else:
  880. ret = np.array(x, np.float32)
  881. return Tensor(ret)
  882. class ScalarToTensor(PrimitiveWithInfer):
  883. """
  884. Converts a scalar to a `Tensor`, and convert data type to specified type.
  885. Inputs:
  886. - **input_x** (Union[int, float]) - The input is a scalar. Only constant value is allowed.
  887. - **dtype** (mindspore.dtype) - The target data type. Default: mindspore.float32. Only
  888. constant value is allowed.
  889. Outputs:
  890. Tensor. 0-D Tensor and the content is the input.
  891. Examples:
  892. >>> op = P.ScalarToTensor()
  893. >>> data = 1
  894. >>> output = op(data, mindspore.float32)
  895. """
  896. @prim_attr_register
  897. def __init__(self):
  898. pass
  899. def infer_value(self, x, dtype=mstype.float32):
  900. validator.check_value_type("x", x, [int, float], self.name)
  901. validator.check_subclass("dtype", dtype, mstype.number, self.name)
  902. data_type = mstype.dtype_to_nptype(dtype)
  903. return Tensor(np.array(x, data_type))
  904. class InvertPermutation(PrimitiveWithInfer):
  905. r"""
  906. Computes the inverse of an index permutation.
  907. Given a tuple input, this operation inserts a dimension of 1 at the dimension
  908. This operation calculates the inverse of the index replacement. It requires a
  909. 1-dimensional tuple x, which represents the array starting at zero,
  910. and swaps each value with its index position. In other words, for the output
  911. tuple y and the input tuple x, this operation calculates the following:
  912. :math:`y[x[i]] = i, \quad i \in [0, 1, \ldots, \text{len}(x)-1]`.
  913. Note:
  914. These values must include 0. There must be no duplicate values and the
  915. values can not be negative.
  916. Inputs:
  917. - **input_x** (Union(tuple[int], list[int]) - The input is constructed by multiple
  918. integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices.
  919. The values must include 0. There can be no duplicate values or negative values.
  920. Only constant value is allowed. The maximum value msut be equal to length of input_x.
  921. Outputs:
  922. tuple[int]. It has the same length as the input.
  923. Examples:
  924. >>> invert = P.InvertPermutation()
  925. >>> input_data = (3, 4, 0, 2, 1)
  926. >>> output = invert(input_data)
  927. >>> output == (2, 4, 3, 0, 1)
  928. """
  929. @prim_attr_register
  930. def __init__(self):
  931. """Initialize InvertPermutation"""
  932. self.set_const_prim(True)
  933. def __infer__(self, x):
  934. x_shp = x['shape']
  935. x_value = x['value']
  936. if x_value is None:
  937. raise ValueError(f'For \'{self.name}\' the input value must be const.')
  938. validator.check_value_type("shape", x_shp, [tuple, list], self.name)
  939. if mstype.issubclass_(x['dtype'], mstype.tensor):
  940. raise ValueError(f'For \'{self.name}\' the input value must be non-Tensor.')
  941. for shp in x_shp:
  942. if shp != []:
  943. x_rank = len(np.array(x_value, np.int64).shape)
  944. raise ValueError(f'For \'{self.name}\' the rank of input must be 1, but got {x_rank}.')
  945. for i, value in enumerate(x_value):
  946. validator.check_value_type("input[%d]" % i, value, [int], self.name)
  947. z = [x_value[i] for i in range(len(x_value))]
  948. z.sort()
  949. for i in range(1, len(z)):
  950. if z[i - 1] == z[i]:
  951. raise ValueError(f"For {self.name}, {z[i]} is duplicated in the input.")
  952. validator.check(f'value min', min(x_value), '', 0, Rel.EQ, self.name)
  953. validator.check(f'value max', max(x_value), '', len(x_value) - 1, Rel.EQ, self.name)
  954. y = [None] * len(x_value)
  955. for i, value in enumerate(x_value):
  956. validator.check_value_type("input[%d]" % i, value, [int], self.name)
  957. validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
  958. y[value] = i
  959. z.append(value)
  960. return {'shape': x_shp,
  961. 'dtype': x['dtype'],
  962. 'value': tuple(y)}
  963. class Argmax(PrimitiveWithInfer):
  964. """
  965. Returns the indices of the max value of a tensor across the axis.
  966. If the shape of input tensor is :math:`(x_1, ..., x_N)`, the shape of the output tensor will be
  967. :math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  968. Args:
  969. axis (int): Axis where the Argmax operation applies to. Default: -1.
  970. output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
  971. Default: `mindspore.dtype.int32`.
  972. Inputs:
  973. - **input_x** (Tensor) - Input tensor.
  974. Outputs:
  975. Tensor, indices of the max value of input tensor across the axis.
  976. Examples:
  977. >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
  978. >>> index = P.Argmax(output_type=mindspore.int32)(input_x)
  979. 1
  980. """
  981. @prim_attr_register
  982. def __init__(self, axis=-1, output_type=mstype.int32):
  983. """Initialize Argmax"""
  984. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  985. validator.check_value_type("axis", axis, [int], self.name)
  986. validator.check_type_same({'output': output_type}, [mstype.int32], self.name)
  987. self.axis = axis
  988. self.add_prim_attr('output_type', output_type)
  989. def infer_shape(self, x_shape):
  990. axis = self.axis
  991. if axis is None:
  992. axis = 0
  993. x_rank = len(x_shape)
  994. validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
  995. axis = axis + x_rank if axis < 0 else axis
  996. ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
  997. return ouput_shape
  998. def infer_dtype(self, x_dtype):
  999. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  1000. return mstype.tensor_type(self.output_type)
  1001. class Argmin(PrimitiveWithInfer):
  1002. """
  1003. Returns the indices of the min value of a tensor across the axis.
  1004. If the shape of input tensor is :math:`(x_1, ..., x_N)`, the shape of the output tensor is
  1005. :math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  1006. Args:
  1007. axis (int): Axis where the Argmin operation applies to. Default: -1.
  1008. output_type (:class:`mindspore.dtype`): An optional data type of `mindspore.dtype.int32`.
  1009. Default: `mindspore.dtype.int32`.
  1010. Inputs:
  1011. - **input_x** (Tensor) - Input tensor.
  1012. Outputs:
  1013. Tensor, indices of the min value of input tensor across the axis.
  1014. Examples:
  1015. >>> input_x = Tensor(np.array([2.0, 3.1, 1.2]), mindspore.float32)
  1016. >>> index = P.Argmin()(input_x)
  1017. >>> assert index == Tensor(2, mindspore.int64)
  1018. """
  1019. @prim_attr_register
  1020. def __init__(self, axis=-1, output_type=mstype.int32):
  1021. """Initialize Argmin"""
  1022. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  1023. validator.check_value_type("axis", axis, [int], self.name)
  1024. validator.check_type_name("output_type", output_type, [mstype.int32, mstype.int64], self.name)
  1025. self.axis = axis
  1026. self.add_prim_attr('output_type', output_type)
  1027. def infer_shape(self, x_shape):
  1028. axis = self.axis
  1029. if axis is None:
  1030. axis = 0
  1031. x_rank = len(x_shape)
  1032. validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
  1033. axis = axis + x_rank if axis < 0 else axis
  1034. ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
  1035. return ouput_shape
  1036. def infer_dtype(self, x_dtype):
  1037. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  1038. return mstype.tensor_type(self.output_type)
  1039. class ArgMaxWithValue(PrimitiveWithInfer):
  1040. """
  1041. Calculates the maximum value with the corresponding index.
  1042. Calculates the maximum value along with the given axis for the input tensor. It returns the maximum values and
  1043. indices.
  1044. Note:
  1045. In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
  1046. Args:
  1047. axis (int): The dimension to reduce. Default: 0.
  1048. keep_dims (bool): Whether to reduce dimension, if true, the output will keep same dimension with the input,
  1049. the output will reduce dimension if false. Default: False.
  1050. Inputs:
  1051. - **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
  1052. :math:`(x_1, x_2, ..., x_N)`.
  1053. Outputs:
  1054. tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the maximum value of the input
  1055. tensor.
  1056. - index (Tensor) - The index for the maximum value of the input tensor. If `keep_dims` is true, the shape of
  1057. output tensors is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Otherwise, the shape is
  1058. :math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  1059. - output_x (Tensor) - The maximum value of input tensor, with the same shape as index.
  1060. Examples:
  1061. >>> input_x = Tensor(np.random.rand(5), mindspore.float32)
  1062. >>> index, output = P.ArgMaxWithValue()(input_x)
  1063. """
  1064. @prim_attr_register
  1065. def __init__(self, axis=0, keep_dims=False):
  1066. """Initialize ArgMaxWithValue"""
  1067. self.axis = axis
  1068. self.keep_dims = keep_dims
  1069. validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
  1070. validator.check_value_type('axis', axis, [int], self.name)
  1071. def infer_shape(self, x_shape):
  1072. axis = self.axis
  1073. x_rank = len(x_shape)
  1074. validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
  1075. ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
  1076. return ouput_shape, ouput_shape
  1077. def infer_dtype(self, x_dtype):
  1078. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  1079. return mstype.tensor_type(mstype.int32), x_dtype
  1080. class ArgMinWithValue(PrimitiveWithInfer):
  1081. """
  1082. Calculates the minimum value with corresponding index, return indices and values.
  1083. Calculates the minimum value along with the given axis for the input tensor. It returns the minimum values and
  1084. indices.
  1085. Note:
  1086. In auto_parallel and semi_auto_parallel mode, the first output index can not be used.
  1087. Args:
  1088. axis (int): The dimension to reduce. Default: 0.
  1089. keep_dims (bool): Whether to reduce dimension, if true the output will keep the same dimension as the input,
  1090. the output will reduce dimension if false. Default: False.
  1091. Inputs:
  1092. - **input_x** (Tensor) - The input tensor, can be any dimension. Set the shape of input tensor as
  1093. :math:`(x_1, x_2, ..., x_N)`.
  1094. Outputs:
  1095. tuple (Tensor), tuple of 2 tensors, containing the corresponding index and the minimum value of the input
  1096. tensor.
  1097. - index (Tensor) - The index for the maximum value of the input tensor. If `keep_dims` is true, the shape of
  1098. output tensors is :math:`(x_1, x_2, ..., x_{axis-1}, 1, x_{axis+1}, ..., x_N)`. Otherwise, the shape is
  1099. :math:`(x_1, x_2, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`.
  1100. - output_x (Tensor) - The minimum value of input tensor, with the same shape as index.
  1101. Examples:
  1102. >>> input_x = Tensor(np.random.rand(5), mindspore.float32)
  1103. >>> index, output = P.ArgMinWithValue()(input_x)
  1104. 0 0.0496291
  1105. """
  1106. @prim_attr_register
  1107. def __init__(self, axis=0, keep_dims=False):
  1108. """Initialize ArgMinWithValue"""
  1109. self.axis = axis
  1110. self.keep_dims = keep_dims
  1111. validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
  1112. validator.check_value_type('axis', axis, [int], self.name)
  1113. def infer_shape(self, x_shape):
  1114. axis = self.axis
  1115. x_rank = len(x_shape)
  1116. validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name)
  1117. ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
  1118. return ouput_shape, ouput_shape
  1119. def infer_dtype(self, x_dtype):
  1120. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  1121. return mstype.tensor_type(mstype.int32), x_dtype
  1122. class Tile(PrimitiveWithInfer):
  1123. r"""
  1124. Replicates a tensor with given multiples times.
  1125. Creates a new tensor by replicating input multiples times. The dimension of
  1126. output tensor is the larger of the input tensor dimension and the length of `multiples`.
  1127. Inputs:
  1128. - **input_x** (Tensor) - 1-D or higher Tensor. Set the shape of input tensor as
  1129. :math:`(x_1, x_2, ..., x_S)`.
  1130. - **multiples** (tuple[int]) - The input tuple is constructed by multiple
  1131. integers, i.e., :math:`(y_1, y_2, ..., y_S)`. The length of `multiples`
  1132. cannot be smaller than the length of the shape of `input_x`.
  1133. Only constant value is allowed.
  1134. Outputs:
  1135. Tensor, has the same data type as the `input_x`.
  1136. - If the length of `multiples` is the same as the length of shape of `input_x`,
  1137. then the shape of their corresponding positions can be multiplied, and
  1138. the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_R)`.
  1139. - If the length of `multiples` is larger than the length of shape of `input_x`,
  1140. fill in multiple 1 in the length of the shape of `input_x` until their lengths are consistent.
  1141. Such as set the shape of `input_x` as :math:`(1, ..., x_1, x_2, ..., x_S)`,
  1142. then the shape of their corresponding positions can be multiplied, and
  1143. the shape of Outputs is :math:`(1*y_1, ..., x_S*y_R)`.
  1144. Examples:
  1145. >>> tile = P.Tile()
  1146. >>> input_x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.float32)
  1147. >>> multiples = (2, 3)
  1148. >>> result = tile(input_x, multiples)
  1149. [[1. 2. 1. 2. 1. 2.]
  1150. [3. 4. 3. 4. 3. 4.]
  1151. [1. 2. 1. 2. 1. 2.]
  1152. [3. 4. 3. 4. 3. 4.]]
  1153. """
  1154. @prim_attr_register
  1155. def __init__(self):
  1156. """Initialize Tile"""
  1157. self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output'])
  1158. def check_elim(self, base_tensor, multiplier):
  1159. if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
  1160. raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
  1161. if all(v == 1 for v in multiplier):
  1162. return (True, base_tensor)
  1163. return (False, None)
  1164. def __infer__(self, x, multiples):
  1165. multiples_v = multiples['value']
  1166. x_shp = x['shape']
  1167. validator.check_value_type("multiples", multiples_v, [tuple], self.name)
  1168. for i, multiple in enumerate(multiples_v):
  1169. validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
  1170. validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name)
  1171. len_sub = len(multiples_v) - len(x_shp)
  1172. multiples_w = None
  1173. if len_sub == 0:
  1174. multiples_w = multiples_v
  1175. if len_sub > 0:
  1176. for i in range(0, len_sub):
  1177. x_shp.insert(0, 1)
  1178. multiples_w = multiples_v
  1179. elif len_sub < 0:
  1180. raise ValueError(f'For \'{self.name}\' the length of multiples can not be smaller than '
  1181. f'the length of dimension in input_x.')
  1182. for i, a in enumerate(multiples_w):
  1183. x_shp[i] *= a
  1184. value = None
  1185. if x['value'] is not None:
  1186. value = Tensor(np.tile(x['value'].asnumpy(), multiples_w))
  1187. return {'shape': x_shp,
  1188. 'dtype': x['dtype'],
  1189. 'value': value}
  1190. class UnsortedSegmentSum(PrimitiveWithInfer):
  1191. r"""
  1192. Computes the sum along segments of a tensor.
  1193. Calculates a tensor such that :math:`\text{output}[i] = \sum_{segment\_ids[j] == i} \text{data}[j, \ldots]`, where
  1194. :math:`j` is a tuple describing the index of element in data. `segment_ids` selects which elements in data to sum
  1195. up. Segment_ids does not need to be sorted, and it does not need to cover all values in the entire valid value
  1196. range.
  1197. If the sum of the given segment_ids :math:`i` is empty, then :math:`\text{output}[i] = 0`. If the given segment_ids
  1198. is negative, the value will be ignored. 'num_segments' must be equal to the number of different segment_ids.
  1199. Inputs:
  1200. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
  1201. - **segment_ids** (Tensor) - Set the shape as :math:`(x_1, x_2, ..., x_N)`, where 0 < N <= R. Type must be int.
  1202. - **num_segments** (int) - Set :math:`z` as num_segments.
  1203. Outputs:
  1204. Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
  1205. Examples:
  1206. >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
  1207. >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
  1208. >>> num_segments = 4
  1209. >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
  1210. [3, 3, 4, 0]
  1211. """
  1212. @prim_attr_register
  1213. def __init__(self):
  1214. """Initialize UnsortedSegmentSum"""
  1215. self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
  1216. self.add_prim_attr("dynamic_shape_depends", [2,])
  1217. def __infer__(self, x, segment_ids, num_segments):
  1218. x_type = x['dtype']
  1219. x_shp = x['shape']
  1220. validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
  1221. validator.check_value_type("x_shape", x_shp, [list], self.name)
  1222. x_shp_len = len(x_shp)
  1223. validator.check_positive_int(x_shp_len, "rank of input_x", self.name)
  1224. segment_ids_shp = segment_ids['shape']
  1225. segment_ids_type = segment_ids['dtype']
  1226. validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
  1227. validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
  1228. segment_ids_shp_len = len(segment_ids_shp)
  1229. validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
  1230. validator.check(f'rank of input_x', len(x_shp),
  1231. 'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
  1232. for i, value in enumerate(segment_ids_shp):
  1233. validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
  1234. num_segments_v = num_segments['value']
  1235. num_segments_type = num_segments['dtype']
  1236. validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
  1237. if isinstance(num_segments_type, type(mstype.tensor)):
  1238. validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name)
  1239. shp = [-1]
  1240. else:
  1241. validator.check_value_type('num_segments', num_segments_v, [int], self.name)
  1242. validator.check_positive_int(num_segments_v, "num_segments", self.name)
  1243. shp = [num_segments_v]
  1244. shp += x_shp[segment_ids_shp_len:]
  1245. if 'max_shape' in x:
  1246. output_max_shape = x['max_shape']
  1247. else:
  1248. output_max_shape = x_shp
  1249. out = {'shape': shp,
  1250. 'max_shape': output_max_shape,
  1251. 'min_shape': [1] * segment_ids_shp_len + x_shp[segment_ids_shp_len:],
  1252. 'dtype': mstype.tensor_type(x_type.element_type()),
  1253. 'value': None}
  1254. return out
  1255. class UnsortedSegmentMin(PrimitiveWithInfer):
  1256. """
  1257. Computes the minimum along segments of a tensor.
  1258. Inputs:
  1259. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
  1260. The data type must be float16, float32 or int32.
  1261. - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be >= 0.
  1262. The data type must be int32.
  1263. - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`.
  1264. Outputs:
  1265. Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
  1266. Examples:
  1267. >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
  1268. >>> segment_ids = Tensor(np.array([0, 1, 1]).astype(np.int32))
  1269. >>> num_segments = 2
  1270. >>> unsorted_segment_min = P.UnsortedSegmentMin()
  1271. >>> unsorted_segment_min(input_x, segment_ids, num_segments)
  1272. [[1., 2., 3.], [4., 2., 1.]]
  1273. """
  1274. @prim_attr_register
  1275. def __init__(self):
  1276. """Initialize UnsortedSegmentMin"""
  1277. self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
  1278. def __infer__(self, x, segment_ids, num_segments):
  1279. x_type = x['dtype']
  1280. x_shape = x['shape']
  1281. segment_ids_shape = segment_ids['shape']
  1282. valid_type = [mstype.float16, mstype.float32, mstype.int32]
  1283. validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
  1284. validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
  1285. validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
  1286. validator.check(f'first shape of input_x', x_shape[0],
  1287. 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
  1288. num_segments_v = num_segments['value']
  1289. validator.check_value_type('num_segments', num_segments_v, [int], self.name)
  1290. validator.check_positive_int(num_segments_v, "num_segments", self.name)
  1291. segment_ids_shape_len = len(segment_ids_shape)
  1292. out_shape = [num_segments_v]
  1293. out_shape += x_shape[segment_ids_shape_len:]
  1294. out = {'shape': out_shape,
  1295. 'dtype': x_type,
  1296. 'value': None}
  1297. return out
  1298. class UnsortedSegmentProd(PrimitiveWithInfer):
  1299. """
  1300. Computes the product along segments of a tensor.
  1301. Inputs:
  1302. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
  1303. With float16, float32 or int32 data type.
  1304. - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`, the value must be >= 0.
  1305. Data type must be int32.
  1306. - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`,
  1307. must be greater than 0.
  1308. Outputs:
  1309. Tensor, set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`.
  1310. Examples:
  1311. >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32))
  1312. >>> segment_ids = Tensor(np.array([0, 1, 0]).astype(np.int32))
  1313. >>> num_segments = 2
  1314. >>> unsorted_segment_prod = P.UnsortedSegmentProd()
  1315. >>> unsorted_segment_prod(input_x, segment_ids, num_segments)
  1316. [[4., 4., 3.], [4., 5., 6.]]
  1317. """
  1318. @prim_attr_register
  1319. def __init__(self):
  1320. """Initialize UnsortedSegmentProd"""
  1321. self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
  1322. def __infer__(self, x, segment_ids, num_segments):
  1323. x_type = x['dtype']
  1324. x_shape = x['shape']
  1325. segment_ids_shape = segment_ids['shape']
  1326. validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
  1327. validator.check_value_type("x_shape", x_shape, [list], self.name)
  1328. valid_type = [mstype.float16, mstype.float32, mstype.int32]
  1329. validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name)
  1330. validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
  1331. validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
  1332. validator.check(f'first shape of input_x', x_shape[0],
  1333. 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
  1334. num_segments_v = num_segments['value']
  1335. validator.check_value_type('num_segments', num_segments_v, [int], self.name)
  1336. validator.check_positive_int(num_segments_v, "num_segments", self.name)
  1337. segment_ids_shape_len = len(segment_ids_shape)
  1338. out_shape = [num_segments_v]
  1339. out_shape += x_shape[segment_ids_shape_len:]
  1340. out = {'shape': out_shape,
  1341. 'dtype': mstype.tensor_type(x_type.element_type()),
  1342. 'value': None}
  1343. return out
  1344. class Concat(PrimitiveWithInfer):
  1345. r"""
  1346. Concats tensor in specified axis.
  1347. Concats input tensors along with the given axis.
  1348. Note:
  1349. The input data is a tuple of tensors. These tensors have the same rank `R`. Set the given axis as `m`, and
  1350. :math:`0 \le m < R`. Set the number of input tensors as `N`. For the :math:`i`-th tensor :math:`t_i`, it has
  1351. the shape of :math:`(x_1, x_2, ..., x_{mi}, ..., x_R)`. :math:`x_{mi}` is the :math:`m`-th dimension of the
  1352. :math:`i`-th tensor. Then, the shape of the output tensor is
  1353. .. math::
  1354. (x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)
  1355. Args:
  1356. axis (int): The specified axis. Default: 0.
  1357. Inputs:
  1358. - **input_x** (tuple, list) - A tuple or a list of input tensors.
  1359. Outputs:
  1360. Tensor, the shape is :math:`(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)`.
  1361. Examples:
  1362. >>> data1 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
  1363. >>> data2 = Tensor(np.array([[0, 1], [2, 1]]).astype(np.int32))
  1364. >>> op = P.Concat()
  1365. >>> output = op((data1, data2))
  1366. [[0, 1],
  1367. [2, 1],
  1368. [0, 1],
  1369. [2, 1]]
  1370. """
  1371. @prim_attr_register
  1372. def __init__(self, axis=0):
  1373. """Initialize Tile"""
  1374. validator.check_value_type("axis", axis, [int], self.name)
  1375. def __infer__(self, input_x):
  1376. axis = self.axis
  1377. x_shp = input_x['shape']
  1378. x_type = input_x['dtype']
  1379. _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name)
  1380. self.add_prim_attr('T', x_type[0].element_type())
  1381. self.add_prim_attr('inputNums', len(x_shp))
  1382. ret_shp = x_shp[0].copy()
  1383. ret_shp[axis] = all_shp
  1384. out = {'shape': ret_shp,
  1385. 'dtype': x_type[0],
  1386. 'value': None}
  1387. return out
  1388. class ParallelConcat(PrimitiveWithInfer):
  1389. r"""
  1390. Concats tensor in the first dimension.
  1391. Concats input tensors along with the first dimension.
  1392. Note:
  1393. The input tensors are all required to have size 1 in the first dimension.
  1394. Inputs:
  1395. - **values** (tuple, list) - A tuple or a list of input tensors. The data type and shape of these
  1396. tensors must be the same.
  1397. Outputs:
  1398. Tensor, data type is the same as `values`.
  1399. Examples:
  1400. >>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32))
  1401. >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32))
  1402. >>> op = P.ParallelConcat()
  1403. >>> output = op((data1, data2))
  1404. [[0, 1], [2, 1]]
  1405. """
  1406. @prim_attr_register
  1407. def __init__(self):
  1408. """Initialize ParallelConcat"""
  1409. def __infer__(self, values):
  1410. x_shp = values['shape']
  1411. x_type = values['dtype']
  1412. validator.check_int(len(x_shp), 1, Rel.GE, f'x_shp length', self.name)
  1413. args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
  1414. validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
  1415. first_elem = x_shp[0]
  1416. for i, elem in enumerate(x_shp[1:]):
  1417. j = i + 1
  1418. validator.check_equal_int(elem[0], 1, f'x_shp[{j}][0]', self.name)
  1419. validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
  1420. ret_shp = x_shp[0].copy()
  1421. ret_shp[0] = len(x_shp)
  1422. self.add_prim_attr('shape', ret_shp)
  1423. self.add_prim_attr('N', len(x_shp))
  1424. out = {'shape': ret_shp,
  1425. 'dtype': x_type[0],
  1426. 'value': None}
  1427. return out
  1428. def _get_pack_shape(x_shape, x_type, axis, prim_name):
  1429. """for pack output shape"""
  1430. validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
  1431. validator.check_int(len(x_shape), 1, Rel.GE, "len of input_x", prim_name)
  1432. validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, prim_name)
  1433. rank_base = len(x_shape[0])
  1434. N = len(x_shape)
  1435. out_shape = x_shape[0]
  1436. validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name)
  1437. if axis < 0:
  1438. axis = axis + rank_base + 1
  1439. for i in range(1, N):
  1440. validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name, TypeError)
  1441. if x_shape[i] != x_shape[0]:
  1442. raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
  1443. out_shape.insert(axis, N)
  1444. return out_shape
  1445. class Pack(PrimitiveWithInfer):
  1446. r"""
  1447. Packs a list of tensors in specified axis.
  1448. Packs the list of input tensors with the same rank `R`, output is a tensor of rank `(R+1)`.
  1449. Given input tensors of shape :math:`(x_1, x_2, ..., x_R)`. Set the number of input tensors as `N`.
  1450. If :math:`0 \le axis`, the shape of the output tensor is :math:`(x_1, x_2, ..., x_{axis}, N, x_{axis+1}, ..., x_R)`.
  1451. Args:
  1452. axis (int): Dimension to pack. Default: 0.
  1453. Negative values wrap around. The range is [-(R+1), R+1).
  1454. Inputs:
  1455. - **input_x** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
  1456. Outputs:
  1457. Tensor. A packed Tensor with the same type as `input_x`.
  1458. Raises:
  1459. TypeError: If the data types of elements in `input_x` are not the same.
  1460. ValueError: If the length of `input_x` is not greater than 1;
  1461. or if axis is out of the range [-(R+1), R+1);
  1462. or if the shapes of elements in input_x are not the same.
  1463. Examples:
  1464. >>> data1 = Tensor(np.array([0, 1]).astype(np.float32))
  1465. >>> data2 = Tensor(np.array([2, 3]).astype(np.float32))
  1466. >>> pack = P.Pack()
  1467. >>> output = pack([data1, data2])
  1468. [[0, 1], [2, 3]]
  1469. """
  1470. @prim_attr_register
  1471. def __init__(self, axis=0):
  1472. """Initialize Pack"""
  1473. validator.check_value_type("axis", axis, [int], self.name)
  1474. self.axis = axis
  1475. def __infer__(self, value):
  1476. x_shape = value['shape']
  1477. x_type = value['dtype']
  1478. self.add_prim_attr('num', len(x_shape))
  1479. all_shape = _get_pack_shape(x_shape, x_type, self.axis, self.name)
  1480. out = {'shape': all_shape,
  1481. 'dtype': x_type[0],
  1482. 'value': None}
  1483. return out
  1484. class Unpack(PrimitiveWithInfer):
  1485. r"""
  1486. Unpacks tensor in specified axis.
  1487. Unpacks a tensor of rank `R` along axis dimension, output tensors will have rank `(R-1)`.
  1488. Given a tensor of shape :math:`(x_1, x_2, ..., x_R)`. If :math:`0 \le axis`,
  1489. the shape of tensor in output is :math:`(x_1, x_2, ..., x_{axis}, x_{axis+2}, ..., x_R)`.
  1490. This is the opposite of pack.
  1491. Args:
  1492. axis (int): Dimension along which to pack. Default: 0.
  1493. Negative values wrap around. The range is [-R, R).
  1494. Inputs:
  1495. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
  1496. A tensor to be unpacked and the rank of the tensor must be greater than 0.
  1497. Outputs:
  1498. A tuple of tensors, the shape of each objects is the same.
  1499. Raises:
  1500. ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
  1501. Examples:
  1502. >>> unpack = P.Unpack()
  1503. >>> input_x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
  1504. >>> output = unpack(input_x)
  1505. ([1, 1, 1, 1], [2, 2, 2, 2])
  1506. """
  1507. @prim_attr_register
  1508. def __init__(self, axis=0):
  1509. """Initialize Unpack"""
  1510. validator.check_value_type("axis", axis, [int], self.name)
  1511. self.axis = axis
  1512. def __infer__(self, x):
  1513. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  1514. x_shape = list(x['shape'])
  1515. dim = len(x_shape)
  1516. validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
  1517. if self.axis < 0:
  1518. self.axis = self.axis + dim
  1519. output_num = x_shape[self.axis]
  1520. validator.check_value_type("num", output_num, [int], self.name)
  1521. validator.check_positive_int(output_num, "output_num", self.name)
  1522. self.add_prim_attr('num', output_num)
  1523. output_valid_check = x_shape[self.axis] - output_num
  1524. validator.check_int(output_valid_check, 0, Rel.EQ,
  1525. "The dimension which to unpack divides output_num", self.name)
  1526. out_shapes = []
  1527. out_dtypes = []
  1528. out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
  1529. for _ in range(output_num):
  1530. out_shapes.append(tuple(out_shape))
  1531. out_dtypes.append(x['dtype'])
  1532. out_shapes = tuple(out_shapes)
  1533. out_dtypes = tuple(out_dtypes)
  1534. out = {'shape': out_shapes,
  1535. 'dtype': out_dtypes,
  1536. 'value': None}
  1537. return out
  1538. class Slice(PrimitiveWithInfer):
  1539. """
  1540. Slices a tensor in the specified shape.
  1541. Args:
  1542. x (Tensor): The target tensor.
  1543. begin (tuple): The beginning of the slice. Only constant value is allowed.
  1544. size (tuple): The size of the slice. Only constant value is allowed.
  1545. Returns:
  1546. Tensor.
  1547. Examples:
  1548. >>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
  1549. >>> [[3, 3, 3], [4, 4, 4]],
  1550. >>> [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
  1551. >>> type = P.Slice()(data, (1, 0, 0), (1, 1, 3))
  1552. [[[3 3 3]]]
  1553. """
  1554. @prim_attr_register
  1555. def __init__(self):
  1556. """Initialize slice"""
  1557. self.init_prim_io_names(inputs=['x', 'begin', 'size'], outputs=['output'])
  1558. def __infer__(self, x, begin, size):
  1559. x_shape = x['shape']
  1560. x_shp_len = len(x_shape)
  1561. validator.check_const_input('begin', begin['value'], self.name)
  1562. validator.check_const_input('size', size['value'], self.name)
  1563. begin_v, size_v = begin['value'], size['value']
  1564. if begin_v is None or size_v is None:
  1565. return {'shape': None,
  1566. 'dtype': x['dtype'],
  1567. 'value': None}
  1568. for key, value in zip(('begin', 'size'), (begin_v, size_v)):
  1569. validator.check(f'len of {key}', len(value),
  1570. 'len x\'s dim', x_shp_len)
  1571. for i in range(x_shp_len):
  1572. if x_shape[i] < begin_v[i] + size_v[i]:
  1573. y = begin_v[i] + size_v[i]
  1574. raise ValueError("For '%s' slice shape can not bigger than orign shape %d, %d." %
  1575. (self.name, x_shape[i], y))
  1576. return {'shape': size_v,
  1577. 'dtype': x['dtype'],
  1578. 'value': None}
  1579. class ReverseV2(PrimitiveWithInfer):
  1580. """
  1581. Reverses specific dimensions of a tensor.
  1582. Args:
  1583. axis (Union[tuple(int), list(int)): The indices of the dimensions to reverse.
  1584. Inputs:
  1585. - **input_x** (Tensor) - The target tensor.
  1586. Outputs:
  1587. Tensor, has the same shape and type as `input_x`.
  1588. Examples:
  1589. >>> input_x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.int32)
  1590. >>> op = P.ReverseV2(axis=[1])
  1591. >>> output = op(input_x)
  1592. [[4, 3, 2, 1], [8, 7, 6, 5]]
  1593. """
  1594. @prim_attr_register
  1595. def __init__(self, axis):
  1596. validator.check_value_type('axis', axis, [list, tuple], self.name)
  1597. for i, each in enumerate(axis):
  1598. validator.check_value_type(f'axis[{i}]', each, [int], self.name)
  1599. self.axis = axis
  1600. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  1601. def infer_shape(self, x_shape):
  1602. dim = len(x_shape)
  1603. for i, each in enumerate(self.axis):
  1604. validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name)
  1605. return x_shape
  1606. def infer_dtype(self, x_dtype):
  1607. validator.check_tensor_type_same({'x': x_dtype}, (mstype.bool_,) + mstype.number_type, self.name)
  1608. return x_dtype
  1609. class Rint(PrimitiveWithInfer):
  1610. """
  1611. Returns element-wise integer closest to x.
  1612. Inputs:
  1613. - **input_x** (Tensor) - The target tensor, which must be one of the following types:
  1614. float16, float32.
  1615. Outputs:
  1616. Tensor, has the same shape and type as `input_x`.
  1617. Examples:
  1618. >>> input_x = Tensor(np.array([-1.6, -0.1, 1.5, 2.0]), mindspore.float32)
  1619. >>> op = P.Rint()
  1620. >>> output = op(input_x)
  1621. [-2., 0., 2., 2.]
  1622. """
  1623. @prim_attr_register
  1624. def __init__(self):
  1625. self.init_prim_io_names(inputs=['x'], outputs=['output'])
  1626. def infer_shape(self, x_shape):
  1627. return x_shape
  1628. def infer_dtype(self, x_dtype):
  1629. validator.check_tensor_type_same({'x': x_dtype}, [mstype.float16, mstype.float32], self.name)
  1630. return x_dtype
  1631. class Select(PrimitiveWithInfer):
  1632. r"""
  1633. Returns the selected elements, either from input :math:`x` or input :math:`y`, depending on the `condition`.
  1634. Given a tensor as input, this operation inserts a dimension of 1 at the dimension,
  1635. if both :math:`x` and :math:`y` are none, the operation returns the coordinates of the true
  1636. element in the `condition`, the coordinates are returned as a two-dimensional
  1637. tensor, where the first dimension (row) represents the number of true elements
  1638. and the second dimension (columns) represents the coordinates of the true
  1639. elements. Keep in mind that the shape of the output tensor can vary depending
  1640. on how many true values are in the input. Indexes are output in row-first
  1641. order.
  1642. If neither is None, :math:`x` and :math:`y` must have the same shape. If :math:`x` and :math:`y` are
  1643. scalars, the conditional tensor must be a scalar. If :math:`x` and :math:`y` are
  1644. higher-demensional vectors, the `condition` must be a vector whose size matches the
  1645. first dimension of :math:`x`, or must have the same shape as :math:`y`.
  1646. The conditional tensor acts as an optional compensation (mask), which
  1647. determines whether the corresponding element / row in the output must be
  1648. selected from :math:`x` (if true) or :math:`y` (if false) based on the value of each
  1649. element.
  1650. If condition is a vector, then :math:`x` and :math:`y` are higher-demensional matrices, then it
  1651. chooses to copy that row (external dimensions) from :math:`x` and :math:`y`. If condition has
  1652. the same shape as :math:`x` and :math:`y`, you can choose to copy these elements from :math:`x`
  1653. and :math:`y`.
  1654. Inputs:
  1655. - **input_cond** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
  1656. The condition tensor, decides which element is chosen.
  1657. - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
  1658. The first input tensor.
  1659. - **input_y** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
  1660. The second input tensor.
  1661. Outputs:
  1662. Tensor, has the same shape as `input_x`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
  1663. Examples:
  1664. >>> select = P.Select()
  1665. >>> input_cond = Tensor([True, False])
  1666. >>> input_x = Tensor([2,3], mindspore.float32)
  1667. >>> input_y = Tensor([1,2], mindspore.float32)
  1668. >>> select(input_cond, input_x, input_y)
  1669. [2. 2.]
  1670. """
  1671. @prim_attr_register
  1672. def __init__(self):
  1673. """init"""
  1674. self.init_prim_io_names(inputs=['condition', 'x', 'y'], outputs=['output'])
  1675. def infer_shape(self, cond_shape, x_shape, y_shape):
  1676. if cond_shape != x_shape or x_shape != y_shape:
  1677. raise ValueError('The x_shape and y_shape must be the same as cond_shape.')
  1678. return x_shape
  1679. def infer_dtype(self, cond_type, x_type, y_type):
  1680. self.add_prim_attr('T', x_type)
  1681. validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
  1682. validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
  1683. validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name)
  1684. if x_type != y_type:
  1685. raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
  1686. return x_type
  1687. def infer_value(self, cond, x, y):
  1688. if cond is not None and x is not None and y is not None:
  1689. cond = cond.asnumpy()
  1690. x = x.asnumpy()
  1691. y = y.asnumpy()
  1692. out = np.where(cond, x, y)
  1693. return Tensor(out)
  1694. return None
  1695. def _compute_slicing_length(begin, end, stride, x_shape, i):
  1696. """Computes the length of the slicing."""
  1697. if i >= len(x_shape):
  1698. raise ValueError(f"For 'StridedSlice', When their is no new axis, the index length must be less or "
  1699. f"equal than the dim of x.")
  1700. x_dim = x_shape[i]
  1701. if stride > 0:
  1702. # When slicing forward, convert begin and end to positive numbers.
  1703. if begin >= x_dim or end < -x_dim:
  1704. # When slicing forward, if begin >= x_dim or end < -x_dim, the length of the slicing is 0.
  1705. slicing_length = 0
  1706. else:
  1707. if -x_dim <= begin < 0:
  1708. begin += x_dim
  1709. if begin < -x_dim:
  1710. # When slicing forward, if begin < -x_dim, set begin = 0, which means start from the 0th element.
  1711. begin = 0
  1712. if -x_dim <= end < 0:
  1713. end += x_dim
  1714. if end > x_dim:
  1715. # When slicing forward, if end > x_dim, set end = x_dims, which means slice to the last element.
  1716. end = x_dim
  1717. if begin >= end:
  1718. # When slicing forward, if begin >= end, the length of the slicing is 0.
  1719. slicing_length = 0
  1720. else:
  1721. slicing_length = 1 + (end - 1 - begin) // stride
  1722. else:
  1723. # When slicing backward, convert begin and end to negative numbers.
  1724. if begin < -x_dim or end >= x_dim:
  1725. # When slicing backward, if begin < -x_dim or end >= x_dim, the length of the slicing is 0.
  1726. slicing_length = 0
  1727. else:
  1728. if 0 <= begin < x_dim:
  1729. begin += -x_dim
  1730. if begin >= x_dim:
  1731. begin = -1
  1732. if 0 <= end < x_dim:
  1733. end += -x_dim
  1734. if end < -x_dim - 1:
  1735. # When slicing backward, if end < -x_dim - 1, set end = -x_dim - 1, which means
  1736. # slicing to the 0th element.
  1737. end = -x_dim - 1
  1738. if begin <= end:
  1739. # When slicing backward, if begin <= end, the length of the slicing is 0.
  1740. slicing_length = 0
  1741. else:
  1742. slicing_length = 1 + (end + 1 - begin) // stride
  1743. return slicing_length
  1744. class StridedSlice(PrimitiveWithInfer):
  1745. r"""
  1746. Extracts a strided slice of a tensor.
  1747. Given an input tensor, this operation inserts a dimension of length 1 at the dimension.
  1748. This operation extracts a fragment of size (end-begin)/stride from the given 'input_tensor'.
  1749. Starting from the begining position, the fragment continues adding stride to the index until
  1750. all dimensions are not less than the ending position.
  1751. Note:
  1752. The stride may be negative value, which causes reverse slicing.
  1753. The shape of `begin`, `end` and `strides` must be the same.
  1754. Args:
  1755. begin_mask (int): Starting index of the slice. Default: 0.
  1756. end_mask (int): Ending index of the slice. Default: 0.
  1757. ellipsis_mask (int): An int mask. Default: 0.
  1758. new_axis_mask (int): An int mask. Default: 0.
  1759. shrink_axis_mask (int): An int mask. Default: 0.
  1760. Inputs:
  1761. - **input_x** (Tensor) - The input Tensor.
  1762. - **begin** (tuple[int]) - A tuple which represents the location where to start. Only
  1763. constant value is allowed.
  1764. - **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
  1765. Only constant value is allowed.
  1766. - **strides** (tuple[int]) - A tuple which represents the stride is continuously added
  1767. before reaching the maximum location. Only constant value is allowed.
  1768. Outputs:
  1769. Tensor.
  1770. The output is explained by following example.
  1771. - In the 0th dimension, begin is 1, end is 2, and strides is 1,
  1772. because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
  1773. Thus, return the element with :math:`index = 1` in 0th dimension, i.e., [[3, 3, 3], [4, 4, 4]].
  1774. - In the 1st dimension, similarly, the interval is :math:`[0,1)`.
  1775. Based on the return value of the 0th dimension, return the element with :math:`index = 0`,
  1776. i.e., [3, 3, 3].
  1777. - In the 2nd dimension, similarly, the interval is :math:`[0,3)`.
  1778. Based on the return value of the 1st dimension, return the element with :math:`index = 0,1,2`,
  1779. i.e., [3, 3, 3].
  1780. - Finally, the output is [3, 3, 3].
  1781. Examples
  1782. >>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
  1783. >>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
  1784. >>> slice = P.StridedSlice()
  1785. >>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
  1786. >>> output.shape
  1787. (1, 1, 3)
  1788. >>> output
  1789. [[[3, 3, 3]]]
  1790. """
  1791. @prim_attr_register
  1792. def __init__(self,
  1793. begin_mask=0,
  1794. end_mask=0,
  1795. ellipsis_mask=0,
  1796. new_axis_mask=0,
  1797. shrink_axis_mask=0):
  1798. """Initialize StrideSlice"""
  1799. self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
  1800. validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
  1801. validator.check_non_negative_int(end_mask, 'end_mask', self.name)
  1802. validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)
  1803. if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1:
  1804. raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.")
  1805. validator.check_non_negative_int(new_axis_mask, 'new_axis_mask', self.name)
  1806. validator.check_non_negative_int(shrink_axis_mask, 'shrink_axis_mask', self.name)
  1807. def __infer__(self, x, begin, end, strides):
  1808. begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
  1809. validator.check_value_type("begin", begin_v, [tuple], self.name)
  1810. validator.check_value_type("end", end_v, [tuple], self.name)
  1811. validator.check_value_type("strides", strides_v, [tuple], self.name)
  1812. if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
  1813. raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
  1814. f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
  1815. if tuple(filter(lambda x: x == 0, strides_v)):
  1816. raise ValueError(f"For '{self.name}', the strides cannot contain 0, but got strides: {strides_v}.")
  1817. if len(end_v) != len(begin_v) or len(strides_v) != len(begin_v):
  1818. raise ValueError(f"For '{self.name}' the length of begin index: {begin_v}, end index: {end_v} and "
  1819. f"strides: {strides_v} must be equal.")
  1820. ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
  1821. value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type())
  1822. return {'shape': ret_shape,
  1823. 'dtype': x['dtype'],
  1824. 'value': value}
  1825. def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
  1826. """Computes the shape of the slicing."""
  1827. x_rank = len(x_shape)
  1828. slice_len = len(begin_v)
  1829. # After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'.
  1830. begin_pos = bin(self.begin_mask)[-1:1:-1]
  1831. end_pos = bin(self.end_mask)[-1:1:-1]
  1832. ellipsis_pos = bin(self.ellipsis_mask)[-1:1:-1]
  1833. new_axis_pos = bin(self.new_axis_mask)[-1:1:-1]
  1834. shrink_axis_pos = bin(self.shrink_axis_mask)[-1:1:-1]
  1835. ret_shape = []
  1836. i, j = 0, 0
  1837. has_ellipsis = False
  1838. while i < x_rank or j < slice_len:
  1839. if j < slice_len:
  1840. begin, end, stride = begin_v[j], end_v[j], strides_v[j]
  1841. if j < len(ellipsis_pos) and ellipsis_pos[j] == '1':
  1842. # When there is ellipsis, the latter part of the ellipsis will be processed separately.
  1843. has_ellipsis = True
  1844. break
  1845. if j < len(begin_pos) and begin_pos[j] == '1':
  1846. begin = -1 if strides_v[j] < 0 else 0
  1847. if j < len(end_pos) and end_pos[j] == '1':
  1848. end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
  1849. if j < len(new_axis_pos) and new_axis_pos[j] == '1':
  1850. ret_shape.append(1)
  1851. j += 1
  1852. continue
  1853. if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
  1854. if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
  1855. raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
  1856. f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
  1857. f"but got stride: {stride}, begin: {begin}.")
  1858. j += 1
  1859. i += 1
  1860. continue
  1861. else:
  1862. begin, end, stride = 0, x_shape[i], 1
  1863. slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
  1864. ret_shape.append(slicing_length)
  1865. i += 1
  1866. j += 1
  1867. if has_ellipsis:
  1868. # When there is ellipsis, handle the second half of the ellipsis split.
  1869. ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
  1870. len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
  1871. ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
  1872. j += 1
  1873. i += ellipsis_occupied_dims
  1874. while i < x_rank or j < slice_len:
  1875. begin, end, stride = begin_v[j], end_v[j], strides_v[j]
  1876. if j < len(begin_pos) and begin_pos[j] == '1':
  1877. begin = -1 if strides_v[j] < 0 else 0
  1878. if j < len(end_pos) and end_pos[j] == '1':
  1879. end = -(x_shape[i] + 1) if strides_v[j] < 0 else x_shape[i]
  1880. if j < len(new_axis_pos) and new_axis_pos[j] == '1':
  1881. ret_shape.append(1)
  1882. j += 1
  1883. continue
  1884. if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
  1885. if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
  1886. raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
  1887. f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
  1888. f"but got stride: {stride}, begin: {begin}.")
  1889. j += 1
  1890. i += 1
  1891. continue
  1892. slicing_length = _compute_slicing_length(begin, end, stride, x_shape, i)
  1893. ret_shape.append(slicing_length)
  1894. i += 1
  1895. j += 1
  1896. return ret_shape
  1897. class Diag(PrimitiveWithInfer):
  1898. r"""
  1899. Constructs a diagonal tensor with a given diagonal values.
  1900. Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of
  1901. rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where:
  1902. :math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else.
  1903. Inputs:
  1904. - **input_x** (Tensor) - The input tensor. The input shape must be less than 5d.
  1905. Outputs:
  1906. Tensor, has the same dtype as the `input_x`.
  1907. Examples:
  1908. >>> input_x = Tensor([1, 2, 3, 4])
  1909. >>> diag = P.Diag()
  1910. >>> diag(input_x)
  1911. [[1, 0, 0, 0],
  1912. [0, 2, 0, 0],
  1913. [0, 0, 3, 0],
  1914. [0, 0, 0, 4]]
  1915. """
  1916. @prim_attr_register
  1917. def __init__(self):
  1918. """Initialize Diag"""
  1919. def infer_dtype(self, x_type):
  1920. validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
  1921. return x_type
  1922. def infer_shape(self, x_shape):
  1923. validator.check("x rank", len(x_shape), "", 1, Rel.GE)
  1924. ret_shape = copy.deepcopy(x_shape)
  1925. ret_shape = ret_shape + ret_shape
  1926. return ret_shape
  1927. def infer_value(self, x):
  1928. if x is None:
  1929. return None
  1930. # do constant-folding only when x rank is 1
  1931. if len(x.shape) != 1:
  1932. return None
  1933. ret = np.diag(x.asnumpy())
  1934. return Tensor(ret)
  1935. class DiagPart(PrimitiveWithInfer):
  1936. r"""
  1937. Extracts the diagonal part from given tensor.
  1938. Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor
  1939. of rank k with dimensions :math:`[D_1,..., D_k]` where:
  1940. :math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
  1941. Inputs:
  1942. - **input_x** (Tensor) - The input Tensor.
  1943. Outputs:
  1944. Tensor.
  1945. Examples
  1946. >>> input_x = Tensor([[1, 0, 0, 0],
  1947. >>> [0, 2, 0, 0],
  1948. >>> [0, 0, 3, 0],
  1949. >>> [0, 0, 0, 4]])
  1950. >>> diag_part = P.DiagPart()
  1951. >>> diag_part(input_x)
  1952. [1, 2, 3, 4]
  1953. """
  1954. @prim_attr_register
  1955. def __init__(self):
  1956. """Initialize DiagPart"""
  1957. def infer_dtype(self, x_type):
  1958. validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
  1959. return x_type
  1960. def infer_shape(self, x_shape):
  1961. if len(x_shape) % 2 != 0 or \
  1962. not x_shape:
  1963. raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
  1964. f"with shapes {x_shape}")
  1965. length = len(x_shape) // 2
  1966. for i in range(length):
  1967. validator.check('input_shape[i + len(input_shape)/2]', x_shape[i + length],
  1968. 'input_shape[i]', x_shape[i], Rel.EQ, self.name)
  1969. ret_shape = x_shape[0:length]
  1970. return ret_shape
  1971. def infer_value(self, x):
  1972. if x is None:
  1973. return None
  1974. # do constant-folding only when x rank is 2
  1975. if len(x.shape) != 2:
  1976. return None
  1977. ret = np.diag(x.asnumpy())
  1978. return Tensor(ret)
  1979. class Eye(PrimitiveWithInfer):
  1980. """
  1981. Creates a tensor with ones on the diagonal and zeros the rest.
  1982. Inputs:
  1983. - **n** (int) - The number of rows of returned tensor
  1984. - **m** (int) - The number of columns of returned tensor
  1985. - **t** (mindspore.dtype) - MindSpore's dtype, The data type of the returned tensor.
  1986. Outputs:
  1987. Tensor, a tensor with ones on the diagonal and the rest of elements are zero.
  1988. Examples:
  1989. >>> eye = P.Eye()
  1990. >>> out_tensor = eye(2, 2, mindspore.int32)
  1991. [[1, 0],
  1992. [0, 1]]
  1993. """
  1994. @prim_attr_register
  1995. def __init__(self):
  1996. """Initialize Eye"""
  1997. def infer_value(self, n, m, t):
  1998. validator.check_positive_int(n, "n", self.name)
  1999. validator.check_positive_int(m, "m", self.name)
  2000. args = {"dtype": t}
  2001. validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
  2002. np_type = mstype.dtype_to_nptype(t)
  2003. ret = np.eye(n, m, dtype=np_type)
  2004. return Tensor(ret)
  2005. class ScatterNd(PrimitiveWithInfer):
  2006. """
  2007. Scatters a tensor into a new tensor depending on the specified indices.
  2008. Creates an empty tensor, and set values by scattering the update tensor depending on indices.
  2009. Inputs:
  2010. - **indices** (Tensor) - The index of scattering in the new tensor with int32 data type.
  2011. - **update** (Tensor) - The source Tensor to be scattered.
  2012. - **shape** (tuple[int]) - Define the shape of the output tensor, has the same type as indices.
  2013. Outputs:
  2014. Tensor, the new tensor, has the same type as `update` and the same shape as `shape`.
  2015. Examples:
  2016. >>> op = P.ScatterNd()
  2017. >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
  2018. >>> update = Tensor(np.array([3.2, 1.1]), mindspore.float32)
  2019. >>> shape = (3, 3)
  2020. >>> output = op(indices, update, shape)
  2021. [[0. 3.2 0.]
  2022. [0. 1.1 0.]
  2023. [0. 0. 0. ]]
  2024. """
  2025. @prim_attr_register
  2026. def __init__(self):
  2027. """Initialize ScatterNd"""
  2028. self.init_prim_io_names(inputs=['indices', 'update', 'shape'], outputs=['output'])
  2029. def __infer__(self, indices, update, shape):
  2030. shp = shape['value']
  2031. validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
  2032. validator.check_tensor_type_same({"indices": indices['dtype']}, [mstype.int32], self.name)
  2033. validator.check_value_type("shape", shp, [tuple], self.name)
  2034. for i, x in enumerate(shp):
  2035. validator.check_positive_int(x, f'shape[{i}]', self.name)
  2036. indices_shape, update_shape = indices["shape"], update["shape"]
  2037. if indices_shape[0] != update_shape[0]:
  2038. raise ValueError(f'For \'{self.name}\' The indices_shape[0] and update_shape[0] must be equal.')
  2039. return {'shape': shp,
  2040. 'dtype': update['dtype'],
  2041. 'value': None}
  2042. class ResizeNearestNeighbor(PrimitiveWithInfer):
  2043. r"""
  2044. Resizes the input tensor by using nearest neighbor algorithm.
  2045. Resizes the input tensor to a given size by using the nearest neighbor algorithm. The nearest
  2046. neighbor algorithm selects the value of the nearest point and does not consider the
  2047. values of neighboring points at all, yielding a piecewise-constant interpolant.
  2048. Args:
  2049. size (Union[tuple, list]): The target size. The dimension of size must be 2.
  2050. align_corners (bool): Whether the centers of the 4 corner pixels of the input
  2051. and output tensors are aligned. Default: False.
  2052. Inputs:
  2053. - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`.
  2054. Outputs:
  2055. Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`.
  2056. Examples:
  2057. >>> input_tensor = Tensor(np.array([[[[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]]]), mindspore.float32)
  2058. >>> resize = P.ResizeNearestNeighbor((2, 2))
  2059. >>> output = resize(input_tensor)
  2060. [[[[-0.1 0.3]
  2061. [0.4 0.5 ]]]]
  2062. """
  2063. @prim_attr_register
  2064. def __init__(self, size, align_corners=False):
  2065. """Initialize ResizeNearestNeighbor"""
  2066. validator.check_value_type("size", size, [tuple, list], self.name)
  2067. validator.check_value_type("align_corners", align_corners, [bool], self.name)
  2068. validator.check_equal_int(len(size), 2, "length of size", self.name)
  2069. for i, value in enumerate(size):
  2070. validator.check_non_negative_int(value, f'{i}th value of size', self.name)
  2071. self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
  2072. def infer_shape(self, x):
  2073. validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name)
  2074. return tuple(x)[:-2] + tuple(self.size)
  2075. def infer_dtype(self, x):
  2076. validator.check_subclass("x", x, mstype.tensor, self.name)
  2077. validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name)
  2078. return x
  2079. class GatherNd(PrimitiveWithInfer):
  2080. """
  2081. Gathers slices from a tensor by indices.
  2082. Using given indices to gather slices from a tensor with a specified shape.
  2083. Inputs:
  2084. - **input_x** (Tensor) - The target tensor to gather values.
  2085. - **indices** (Tensor) - The index tensor, with int data type.
  2086. Outputs:
  2087. Tensor, has the same type as `input_x` and the shape is indices_shape[:-1] + x_shape[indices_shape[-1]:].
  2088. Examples:
  2089. >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  2090. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  2091. >>> op = P.GatherNd()
  2092. >>> output = op(input_x, indices)
  2093. [-0.1, 0.5]
  2094. """
  2095. @prim_attr_register
  2096. def __init__(self):
  2097. """Initialize GatherNd"""
  2098. self.init_prim_io_names(inputs=['input_x', 'indices'], outputs=['y'])
  2099. def infer_shape(self, x_shape, indices_shape):
  2100. validator.check('the dimension of x', len(x_shape),
  2101. 'the dimension of indices', indices_shape[-1], Rel.GE, self.name)
  2102. return indices_shape[:-1] + x_shape[indices_shape[-1]:]
  2103. def infer_dtype(self, x_dtype, indices_dtype):
  2104. validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
  2105. validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
  2106. return x_dtype
  2107. class TensorScatterUpdate(PrimitiveWithInfer):
  2108. """
  2109. Updates tensor value using given values, along with the input indices.
  2110. Inputs:
  2111. - **input_x** (Tensor) - The target tensor. The dimension of input_x must be equal to indices.shape[-1].
  2112. - **indices** (Tensor) - The index of input tensor whose data type is int32.
  2113. - **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
  2114. and update.shape = indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
  2115. Outputs:
  2116. Tensor, has the same shape and type as `input_x`.
  2117. Examples:
  2118. >>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
  2119. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  2120. >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
  2121. >>> op = P.TensorScatterUpdate()
  2122. >>> output = op(input_x, indices, update)
  2123. [[1.0, 0.3, 3.6],
  2124. [0.4, 2.2, -3.2]]
  2125. """
  2126. @prim_attr_register
  2127. def __init__(self):
  2128. """Initialize TensorScatterUpdate"""
  2129. self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
  2130. def infer_shape(self, x_shape, indices_shape, value_shape):
  2131. validator.check('the dimension of x', len(x_shape),
  2132. 'the dimension of indices', indices_shape[-1], Rel.GE)
  2133. if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
  2134. raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
  2135. return x_shape
  2136. def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
  2137. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  2138. args = {"x": x_dtype, "value": value_dtype}
  2139. validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
  2140. return x_dtype
  2141. class ScatterUpdate(_ScatterOp):
  2142. """
  2143. Updates tensor value by using input indices and value.
  2144. Using given values to update tensor value, along with the input indices.
  2145. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2146. If they have different data types, lower priority data type will be converted to
  2147. relatively highest priority data type.
  2148. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2149. Args:
  2150. use_locking (bool): Whether protect the assignment by a lock. Default: True.
  2151. Inputs:
  2152. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
  2153. - **indices** (Tensor) - The index of input tensor. With int32 data type.
  2154. - **updates** (Tensor) - The tensor to update the input tensor, has the same type as input,
  2155. and updates.shape = indices.shape + input_x.shape[1:].
  2156. Outputs:
  2157. Tensor, has the same shape and type as `input_x`.
  2158. Examples:
  2159. >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
  2160. >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
  2161. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  2162. >>> np_updates = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
  2163. >>> updates = Tensor(np_updates, mindspore.float32)
  2164. >>> op = P.ScatterUpdate()
  2165. >>> output = op(input_x, indices, updates)
  2166. [[2.0, 1.2, 1.0],
  2167. [3.0, 1.2, 1.0]]
  2168. """
  2169. @prim_attr_register
  2170. def __init__(self, use_locking=True):
  2171. """Initialize ScatterUpdate"""
  2172. validator.check_value_type('use_locking', use_locking, [bool], self.name)
  2173. self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
  2174. def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
  2175. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  2176. args = {"x": x_dtype, "value": value_dtype}
  2177. validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
  2178. return x_dtype
  2179. class ScatterNdUpdate(_ScatterNdOp):
  2180. """
  2181. Updates tensor value by using input indices and value.
  2182. Using given values to update tensor value, along with the input indices.
  2183. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2184. If they have different data types, lower priority data type will be converted to
  2185. relatively highest priority data type.
  2186. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2187. Args:
  2188. use_locking (bool): Whether protect the assignment by a lock. Default: True.
  2189. Inputs:
  2190. - **input_x** (Parameter) - The target tensor, with data type of Parameter.
  2191. - **indices** (Tensor) - The index of input tensor, with int32 data type.
  2192. - **update** (Tensor) - The tensor to be updated to the input tensor, has the same type as input.
  2193. Outputs:
  2194. Tensor, has the same shape and type as `input_x`.
  2195. Examples:
  2196. >>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
  2197. >>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
  2198. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  2199. >>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
  2200. >>> op = P.ScatterNdUpdate()
  2201. >>> output = op(input_x, indices, update)
  2202. [[1. 0.3 3.6]
  2203. [0.4 2.2 -3.2]]
  2204. """
  2205. @prim_attr_register
  2206. def __init__(self, use_locking=True):
  2207. """Initialize ScatterNdUpdate"""
  2208. validator.check_value_type('use_locking', use_locking, [bool], self.name)
  2209. self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
  2210. def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
  2211. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  2212. args = {"x": x_dtype, "value": value_dtype}
  2213. validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
  2214. return x_dtype
  2215. class ScatterMax(_ScatterOp):
  2216. """
  2217. Updates the value of the input tensor through the max operation.
  2218. Using given values to update tensor value through the max operation, along with the input indices.
  2219. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2220. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2221. If they have different data types, lower priority data type will be converted to
  2222. relatively highest priority data type.
  2223. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2224. Args:
  2225. use_locking (bool): Whether protect the assignment by a lock. Default: True.
  2226. Inputs:
  2227. - **input_x** (Parameter) - The target parameter.
  2228. - **indices** (Tensor) - The index to do max operation whose data type must be mindspore.int32.
  2229. - **updates** (Tensor) - The tensor that performs the maximum operation with `input_x`,
  2230. the data type is the same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  2231. Outputs:
  2232. Parameter, the updated `input_x`.
  2233. Examples:
  2234. >>> input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mindspore.float32), name="input_x")
  2235. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  2236. >>> update = Tensor(np.ones([2, 2, 3]) * 88, mindspore.float32)
  2237. >>> scatter_max = P.ScatterMax()
  2238. >>> output = scatter_max(input_x, indices, update)
  2239. [[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]
  2240. """
  2241. @prim_attr_register
  2242. def __init__(self, use_locking=True):
  2243. """Initialize ScatterMax"""
  2244. self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
  2245. validator.check_value_type('use_locking', use_locking, (bool,), self.name)
  2246. class ScatterMin(_ScatterOp):
  2247. """
  2248. Updates the value of the input tensor through the min operation.
  2249. Using given values to update tensor value through the min operation, along with the input indices.
  2250. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2251. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2252. If they have different data types, lower priority data type will be converted to
  2253. relatively highest priority data type.
  2254. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2255. Args:
  2256. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  2257. Inputs:
  2258. - **input_x** (Parameter) - The target parameter.
  2259. - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
  2260. - **updates** (Tensor) - The tensor doing the min operation with `input_x`,
  2261. the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  2262. Outputs:
  2263. Parameter, the updated `input_x`.
  2264. Examples:
  2265. >>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="input_x")
  2266. >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
  2267. >>> update = Tensor(np.ones([2, 2, 3]), mindspore.float32)
  2268. >>> scatter_min = P.ScatterMin()
  2269. >>> output = scatter_min(input_x, indices, update)
  2270. [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
  2271. """
  2272. class ScatterAdd(_ScatterOp):
  2273. """
  2274. Updates the value of the input tensor through the add operation.
  2275. Using given values to update tensor value through the add operation, along with the input indices.
  2276. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2277. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2278. If they have different data types, lower priority data type will be converted to
  2279. relatively highest priority data type.
  2280. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2281. Args:
  2282. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  2283. Inputs:
  2284. - **input_x** (Parameter) - The target parameter.
  2285. - **indices** (Tensor) - The index to do add operation whose data type must be mindspore.int32.
  2286. - **updates** (Tensor) - The tensor that performs the add operation with `input_x`,
  2287. the data type is the same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  2288. Outputs:
  2289. Parameter, the updated `input_x`.
  2290. Examples:
  2291. >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="x")
  2292. >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
  2293. >>> updates = Tensor(np.ones([2, 2, 3]), mindspore.float32)
  2294. >>> scatter_add = P.ScatterAdd()
  2295. >>> output = scatter_add(input_x, indices, updates)
  2296. [[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
  2297. """
  2298. class ScatterSub(_ScatterOp):
  2299. """
  2300. Updates the value of the input tensor through the subtraction operation.
  2301. Using given values to update tensor value through the subtraction operation, along with the input indices.
  2302. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2303. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2304. If they have different data types, lower priority data type will be converted to
  2305. relatively highest priority data type.
  2306. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2307. Args:
  2308. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  2309. Inputs:
  2310. - **input_x** (Parameter) - The target parameter.
  2311. - **indices** (Tensor) - The index to perform the subtraction operation
  2312. whose data type must be mindspore.int32.
  2313. - **updates** (Tensor) - The tensor that performs the subtraction operation with `input_x`,
  2314. the data type is the same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  2315. Outputs:
  2316. Parameter, the updated `input_x`.
  2317. Examples:
  2318. >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), mindspore.float32), name="x")
  2319. >>> indices = Tensor(np.array([[0, 1]]), mindspore.int32)
  2320. >>> updates = Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32)
  2321. >>> scatter_sub = P.ScatterSub()
  2322. >>> output = scatter_sub(input_x, indices, updates)
  2323. [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]
  2324. """
  2325. class ScatterMul(_ScatterOp):
  2326. """
  2327. Updates the value of the input tensor through the mul operation.
  2328. Using given values to update tensor value through the mul operation, along with the input indices.
  2329. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2330. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2331. If they have different data types, lower priority data type will be converted to
  2332. relatively highest priority data type.
  2333. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2334. Args:
  2335. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  2336. Inputs:
  2337. - **input_x** (Parameter) - The target parameter.
  2338. - **indices** (Tensor) - The index to do mul operation whose data type must be mindspore.int32.
  2339. - **updates** (Tensor) - The tensor doing the mul operation with `input_x`,
  2340. the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  2341. Outputs:
  2342. Parameter, the updated `input_x`.
  2343. Examples:
  2344. >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
  2345. >>> indices = Tensor(np.array([0, 1]), mindspore.int32)
  2346. >>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
  2347. >>> scatter_mul = P.ScatterMul()
  2348. >>> output = scatter_mul(input_x, indices, updates)
  2349. [[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]
  2350. """
  2351. class ScatterDiv(_ScatterOp):
  2352. """
  2353. Updates the value of the input tensor through the div operation.
  2354. Using given values to update tensor value through the div operation, along with the input indices.
  2355. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2356. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2357. If they have different data types, lower priority data type will be converted to
  2358. relatively highest priority data type.
  2359. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2360. Args:
  2361. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  2362. Inputs:
  2363. - **input_x** (Parameter) - The target parameter.
  2364. - **indices** (Tensor) - The index to do div operation whose data type must be mindspore.int32.
  2365. - **updates** (Tensor) - The tensor that performs the div operation with `input_x`,
  2366. the data type is the same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
  2367. Outputs:
  2368. Parameter, the updated `input_x`.
  2369. Examples:
  2370. >>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
  2371. >>> indices = Tensor(np.array([0, 1]), mindspore.int32)
  2372. >>> updates = Tensor(np.ones([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
  2373. >>> scatter_div = P.ScatterDiv()
  2374. >>> output = scatter_div(input_x, indices, updates)
  2375. [[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]]
  2376. """
  2377. class ScatterNdAdd(_ScatterNdOp):
  2378. """
  2379. Applies sparse addition to individual values or slices in a Tensor.
  2380. Using given values to update tensor value through the add operation, along with the input indices.
  2381. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2382. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2383. If they have different data types, lower priority data type will be converted to
  2384. relatively highest priority data type.
  2385. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2386. Args:
  2387. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  2388. Inputs:
  2389. - **input_x** (Parameter) - The target parameter.
  2390. - **indices** (Tensor) - The index to do add operation whose data type must be mindspore.int32.
  2391. - **updates** (Tensor) - The tensor doing the add operation with `input_x`,
  2392. the data type is same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
  2393. Outputs:
  2394. Parameter, the updated `input_x`.
  2395. Examples:
  2396. >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
  2397. >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
  2398. >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
  2399. >>> scatter_nd_add = P.ScatterNdAdd()
  2400. >>> output = scatter_nd_add(input_x, indices, updates)
  2401. [1, 10, 9, 4, 12, 6, 7, 17]
  2402. """
  2403. class ScatterNdSub(_ScatterNdOp):
  2404. """
  2405. Applies sparse subtraction to individual values or slices in a Tensor.
  2406. Using given values to update tensor value through the subtraction operation, along with the input indices.
  2407. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2408. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2409. If they have different data types, lower priority data type will be converted to
  2410. relatively highest priority data type.
  2411. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2412. Args:
  2413. use_locking (bool): Whether protect the assignment by a lock. Default: False.
  2414. Inputs:
  2415. - **input_x** (Parameter) - The target parameter.
  2416. - **indices** (Tensor) - The index to do add operation whose data type must be mindspore.int32.
  2417. - **updates** (Tensor) - The tensor that performs the subtraction operation with `input_x`,
  2418. the data type is the same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
  2419. Outputs:
  2420. Parameter, the updated `input_x`.
  2421. Examples:
  2422. >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
  2423. >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
  2424. >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
  2425. >>> scatter_nd_sub = P.ScatterNdSub()
  2426. >>> output = scatter_nd_sub(input_x, indices, updates)
  2427. [1, -6, -3, 4, -2, 6, 7, -1]
  2428. """
  2429. class ScatterNonAliasingAdd(_ScatterNdOp):
  2430. """
  2431. Applies sparse addition to input using individual values or slices.
  2432. Using given values to update tensor value through the add operation, along with the input indices.
  2433. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
  2434. Inputs of `input_x` and `updates` comply with the implicit type conversion rules to make the data types consistent.
  2435. If they have different data types, lower priority data type will be converted to
  2436. relatively highest priority data type.
  2437. RuntimeError exception will be thrown when the data type conversion of Parameter is required.
  2438. Inputs:
  2439. - **input_x** (Parameter) - The target parameter. The data type must be float16, float32 or int32.
  2440. - **indices** (Tensor) - The index to perform the addition operation whose data type must be mindspore.int32.
  2441. - **updates** (Tensor) - The tensor that performs the addition operation with `input_x`,
  2442. the data type is the same as `input_x`, the shape is `indices_shape[:-1] + x_shape[indices_shape[-1]:]`.
  2443. Outputs:
  2444. Parameter, the updated `input_x`.
  2445. Examples:
  2446. >>> input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name="x")
  2447. >>> indices = Tensor(np.array([[2], [4], [1], [7]]), mindspore.int32)
  2448. >>> updates = Tensor(np.array([6, 7, 8, 9]), mindspore.float32)
  2449. >>> scatter_non_aliasing_add = P.ScatterNonAliasingAdd()
  2450. >>> output = scatter_non_aliasing_add(input_x, indices, updates)
  2451. [1, 10, 9, 4, 12, 6, 7, 17]
  2452. """
  2453. @prim_attr_register
  2454. def __init__(self):
  2455. """Initialize ScatterNonAliasingAdd"""
  2456. self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
  2457. def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
  2458. validator.check_tensor_type_same({'indices': indices_dtype}, [mstype.int32], self.name)
  2459. args = {"x": x_dtype, "updates": updates_dtype}
  2460. validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int32], self.name)
  2461. return x_dtype
  2462. class SpaceToDepth(PrimitiveWithInfer):
  2463. r"""
  2464. Rearranges blocks of spatial data into depth.
  2465. The output tensor's `height` dimension is :math:`height / block\_size`.
  2466. The output tensor's `weight` dimension is :math:`weight / block\_size`.
  2467. The depth of output tensor is :math:`block\_size * block\_size * input\_depth`.
  2468. The input tensor's height and width must be divisible by `block_size`.
  2469. The data format is "NCHW".
  2470. Args:
  2471. block_size (int): The block size used to divide spatial data. It must be >= 2.
  2472. Inputs:
  2473. - **x** (Tensor) - The target tensor.
  2474. Outputs:
  2475. Tensor, the same data type as `x`. It must be a 4-D tensor.
  2476. Examples:
  2477. >>> x = Tensor(np.random.rand(1,3,2,2), mindspore.float32)
  2478. >>> block_size = 2
  2479. >>> op = P.SpaceToDepth(block_size)
  2480. >>> output = op(x)
  2481. >>> output.asnumpy().shape == (1,12,1,1)
  2482. """
  2483. @prim_attr_register
  2484. def __init__(self, block_size):
  2485. """Initialize SpaceToDepth"""
  2486. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  2487. validator.check_value_type('block_size', block_size, [int], self.name)
  2488. validator.check('block_size', block_size, '', 2, Rel.GE)
  2489. self.block_size = block_size
  2490. def infer_shape(self, x_shape):
  2491. validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
  2492. out_shape = copy.deepcopy(x_shape)
  2493. for i in range(2):
  2494. if out_shape[i + 2] % self.block_size != 0:
  2495. raise ValueError(f'For \'{self.name}\' input shape[{i + 2}] {out_shape[i + 2]} should be '
  2496. f'fully divided by block_size {self.block_size}')
  2497. out_shape[i + 2] //= self.block_size
  2498. out_shape[1] *= self.block_size * self.block_size
  2499. return out_shape
  2500. def infer_dtype(self, x_dtype):
  2501. validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
  2502. return x_dtype
  2503. class DepthToSpace(PrimitiveWithInfer):
  2504. r"""
  2505. Rearranges blocks of depth data into spatial dimensions.
  2506. This is the reverse operation of SpaceToDepth.
  2507. The depth of output tensor is :math:`input\_depth / (block\_size * block\_size)`.
  2508. The output tensor's `height` dimension is :math:`height * block\_size`.
  2509. The output tensor's `weight` dimension is :math:`weight * block\_size`.
  2510. The input tensor's depth must be divisible by `block_size * block_size`.
  2511. The data format is "NCHW".
  2512. Args:
  2513. block_size (int): The block size used to divide depth data. It must be >= 2.
  2514. Inputs:
  2515. - **x** (Tensor) - The target tensor. It must be a 4-D tensor.
  2516. Outputs:
  2517. Tensor, has the same shape and dtype as the 'x'.
  2518. Examples:
  2519. >>> x = Tensor(np.random.rand(1,12,1,1), mindspore.float32)
  2520. >>> block_size = 2
  2521. >>> op = P.DepthToSpace(block_size)
  2522. >>> output = op(x)
  2523. >>> output.asnumpy().shape == (1,3,2,2)
  2524. """
  2525. @prim_attr_register
  2526. def __init__(self, block_size):
  2527. """Initialize DepthToSpace"""
  2528. self.init_prim_io_names(inputs=['x'], outputs=['y'])
  2529. validator.check_value_type('block_size', block_size, [int], self.name)
  2530. validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
  2531. self.block_size = block_size
  2532. def infer_shape(self, x_shape):
  2533. validator.check('x dimension', len(x_shape), '', 4, Rel.EQ)
  2534. out_shape = copy.deepcopy(x_shape)
  2535. for i in range(2):
  2536. out_shape[i + 2] *= self.block_size
  2537. validator.check_int(x_shape[1] % (self.block_size * self.block_size),
  2538. 0, Rel.EQ, 'x_shape[1] % (block_size*block_size)', self.name)
  2539. out_shape[1] //= self.block_size * self.block_size
  2540. return out_shape
  2541. def infer_dtype(self, x_dtype):
  2542. validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
  2543. return x_dtype
  2544. class SpaceToBatch(PrimitiveWithInfer):
  2545. r"""
  2546. Divides spatial dimensions into blocks and combine the block size with the original batch.
  2547. This operation will divide spatial dimensions (H, W) into blocks with `block_size`, the output tensor's H and W
  2548. dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
  2549. product of the original batch and the square of block_size. Before division, the spatial dimensions
  2550. of the input are zero padded according to paddings if necessary.
  2551. Args:
  2552. block_size (int): The block size of dividing blocks with value greater than 2.
  2553. paddings (Union[tuple, list]): The padding values for H and W dimension, containing 2 subtraction lists.
  2554. Each subtraction list contains 2 integer value. All values must be greater than 0.
  2555. paddings[i] specifies the paddings for the spatial dimension i, which corresponds to the
  2556. input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1]
  2557. is divisible by block_size.
  2558. Inputs:
  2559. - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor.
  2560. Outputs:
  2561. Tensor, the output tensor with the same data type as input. Assume input shape is :math:`(n, c, h, w)` with
  2562. :math:`block\_size` and :math:`paddings`. The shape of the output tensor will be :math:`(n', c', h', w')`,
  2563. where
  2564. :math:`n' = n*(block\_size*block\_size)`
  2565. :math:`c' = c`
  2566. :math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_size`
  2567. :math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_size`
  2568. Examples:
  2569. >>> block_size = 2
  2570. >>> paddings = [[0, 0], [0, 0]]
  2571. >>> space_to_batch = P.SpaceToBatch(block_size, paddings)
  2572. >>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
  2573. >>> space_to_batch(input_x)
  2574. [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
  2575. """
  2576. @prim_attr_register
  2577. def __init__(self, block_size, paddings):
  2578. """Initialize SpaceToBatch"""
  2579. validator.check_value_type('block_size', block_size, [int], self.name)
  2580. validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
  2581. self.block_size = block_size
  2582. validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name)
  2583. for elem in itertools.chain(*paddings):
  2584. validator.check_non_negative_int(elem, 'paddings element', self.name)
  2585. validator.check_value_type('paddings element', elem, [int], self.name)
  2586. self.paddings = paddings
  2587. def infer_dtype(self, x_dtype):
  2588. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2589. return x_dtype
  2590. def infer_shape(self, x_shape):
  2591. validator.check_equal_int(len(x_shape), 4, 'rank of input_x', self.name)
  2592. out_shape = copy.deepcopy(x_shape)
  2593. for i in range(2):
  2594. padded = out_shape[i + 2] + self.paddings[i][0] + self.paddings[i][1]
  2595. if padded % self.block_size != 0:
  2596. raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
  2597. f'block_size {self.block_size}')
  2598. out_shape[i + 2] = padded // self.block_size
  2599. out_shape[0] *= self.block_size * self.block_size
  2600. return out_shape
  2601. class BatchToSpace(PrimitiveWithInfer):
  2602. r"""
  2603. Divides batch dimension with blocks and interleaves these blocks back into spatial dimensions.
  2604. This operation will divide batch dimension N into blocks with block_size, the output tensor's N dimension
  2605. is the corresponding number of blocks after division. The output tensor's H, W dimension is product of original H, W
  2606. dimension and block_size with given amount to crop from dimension, respectively.
  2607. Args:
  2608. block_size (int): The block size of division, has the value not less than 2.
  2609. crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction lists.
  2610. Each list contains 2 integers.
  2611. All values must be not less than 0. crops[i] specifies the crop values for the spatial dimension i, which
  2612. corresponds to the input dimension i+2. It is required that
  2613. input_shape[i+2]*block_size >= crops[i][0]+crops[i][1].
  2614. Inputs:
  2615. - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor, dimension 0 must be divisible by
  2616. product of `block_shape`.
  2617. Outputs:
  2618. Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_size
  2619. and crops. The output shape will be (n', c', h', w'), where
  2620. :math:`n' = n//(block\_size*block\_size)`
  2621. :math:`c' = c`
  2622. :math:`h' = h*block\_size-crops[0][0]-crops[0][1]`
  2623. :math:`w' = w*block\_size-crops[1][0]-crops[1][1]`
  2624. Examples:
  2625. >>> block_size = 2
  2626. >>> crops = [[0, 0], [0, 0]]
  2627. >>> op = P.BatchToSpace(block_size, crops)
  2628. >>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
  2629. >>> output = op(input_x)
  2630. [[[[1., 2.], [3., 4.]]]]
  2631. """
  2632. @prim_attr_register
  2633. def __init__(self, block_size, crops):
  2634. """Initialize BatchToSpace"""
  2635. validator.check_value_type('block_size', block_size, [int], self.name)
  2636. validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
  2637. self.block_size = block_size
  2638. validator.check_value_type('crops type', crops, [list, tuple], self.name)
  2639. validator.check('crops shape', np.array(crops).shape, '', (2, 2))
  2640. for elem in itertools.chain(*crops):
  2641. validator.check_non_negative_int(elem, 'crops element', self.name)
  2642. validator.check_value_type('crops element', elem, [int], self.name)
  2643. self.crops = crops
  2644. def infer_dtype(self, x_dtype):
  2645. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2646. return x_dtype
  2647. def infer_shape(self, x_shape):
  2648. validator.check('rank of input_x', len(x_shape), '', 4)
  2649. out_shape = copy.deepcopy(x_shape)
  2650. for i in range(2):
  2651. x_block_prod = out_shape[i + 2] * self.block_size
  2652. crops_sum = self.crops[i][0] + self.crops[i][1]
  2653. validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
  2654. out_shape[i + 2] = x_block_prod - crops_sum
  2655. block_size_prod = self.block_size * self.block_size
  2656. if out_shape[0] % block_size_prod != 0:
  2657. raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
  2658. f'block_size_prod {block_size_prod}')
  2659. out_shape[0] = out_shape[0] // block_size_prod
  2660. return out_shape
  2661. class SpaceToBatchND(PrimitiveWithInfer):
  2662. r"""
  2663. Divides spatial dimensions into blocks and combine the block size with the original batch.
  2664. This operation will divide spatial dimensions (H, W) into blocks with block_shape, the output tensor's H and W
  2665. dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
  2666. product of the original batch and the product of `block_shape`. Before division,
  2667. the spatial dimensions of the input are zero padded according to paddings if necessary.
  2668. Args:
  2669. block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value greater than 1.
  2670. The length of `block_shape` is M correspoding to the number of spatial dimensions. M must be 2.
  2671. paddings (Union[tuple, list]): The padding values for H and W dimension, containing 2 subtraction list.
  2672. Each contains 2 integer value. All values must be greater than 0.
  2673. `paddings[i]` specifies the paddings for the spatial dimension i,
  2674. which corresponds to the input dimension i+2.
  2675. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible by block_shape[i].
  2676. Inputs:
  2677. - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor.
  2678. Outputs:
  2679. Tensor, the output tensor with the same data type as input. Assume input shape is :math:`(n, c, h, w)` with
  2680. :math:`block\_shape` and :math:`padddings`. The shape of the output tensor will be :math:`(n', c', h', w')`,
  2681. where
  2682. :math:`n' = n*(block\_shape[0]*block\_shape[1])`
  2683. :math:`c' = c`
  2684. :math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]`
  2685. :math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]`
  2686. Examples:
  2687. >>> block_shape = [2, 2]
  2688. >>> paddings = [[0, 0], [0, 0]]
  2689. >>> space_to_batch_nd = P.SpaceToBatchND(block_shape, paddings)
  2690. >>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
  2691. >>> space_to_batch_nd(input_x)
  2692. [[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
  2693. """
  2694. @prim_attr_register
  2695. def __init__(self, block_shape, paddings):
  2696. """Initialize SpaceToBatchND"""
  2697. self.ori_block_shape = block_shape
  2698. self.ori_paddings = paddings
  2699. validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
  2700. validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
  2701. block_rank = len(block_shape)
  2702. validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
  2703. for elem in block_shape:
  2704. validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
  2705. validator.check_value_type('block_shape element', elem, [int], self.name)
  2706. self.block_shape = block_shape
  2707. validator.check_value_type('paddings type', paddings, [list, tuple], self.name)
  2708. validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name)
  2709. validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
  2710. for elem in itertools.chain(*paddings):
  2711. validator.check_non_negative_int(elem, 'paddings element', self.name)
  2712. validator.check_value_type('paddings element', elem, [int], self.name)
  2713. self.paddings = paddings
  2714. block_shape_append = [1] + list(self.block_shape)
  2715. self.add_prim_attr("block_shape", block_shape_append)
  2716. paddings_append = [[0, 0]] + list(self.paddings)
  2717. self.add_prim_attr("paddings", paddings_append)
  2718. def infer_dtype(self, x_dtype):
  2719. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2720. return x_dtype
  2721. def infer_shape(self, x_shape):
  2722. x_rank = len(x_shape)
  2723. validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
  2724. out_shape = copy.deepcopy(x_shape)
  2725. block_shape_prod = 1
  2726. offset = 2
  2727. if x_rank <= 4:
  2728. offset = 1
  2729. for i in range(len(self.block_shape)):
  2730. padded = out_shape[i + offset] + self.paddings[i][0] + \
  2731. self.paddings[i][1]
  2732. if padded % self.block_shape[i] != 0:
  2733. raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
  2734. f'block_shape[{i}] {self.block_shape[i]}')
  2735. out_shape[i + offset] = padded // self.block_shape[i]
  2736. block_shape_prod = block_shape_prod * self.block_shape[i]
  2737. out_shape[0] *= block_shape_prod
  2738. return out_shape
  2739. class BatchToSpaceND(PrimitiveWithInfer):
  2740. r"""
  2741. Divides batch dimension with blocks and interleave these blocks back into spatial dimensions.
  2742. This operation will divide batch dimension N into blocks with block_shape, the output tensor's N dimension
  2743. is the corresponding number of blocks after division. The output tensor's H, W dimension is product of original H, W
  2744. dimension and block_shape with given amount to crop from dimension, respectively.B
  2745. Args:
  2746. block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1.
  2747. The length of block_shape is M correspoding to the number of spatial dimensions. M must be 2.
  2748. crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list,
  2749. each containing 2 int value.
  2750. All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
  2751. input dimension i+2. It is required that input_shape[i+2]*block_shape[i] > crops[i][0]+crops[i][1].
  2752. Inputs:
  2753. - **input_x** (Tensor) - The input tensor. It must be a 4-D tensor, dimension 0 must be divisible by
  2754. product of `block_shape`.
  2755. Outputs:
  2756. Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape
  2757. and crops. The output shape will be (n', c', h', w'), where
  2758. :math:`n' = n//(block\_shape[0]*block\_shape[1])`
  2759. :math:`c' = c`
  2760. :math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
  2761. :math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
  2762. Examples:
  2763. >>> block_shape = [2, 2]
  2764. >>> crops = [[0, 0], [0, 0]]
  2765. >>> batch_to_space_nd = P.BatchToSpaceND(block_shape, crops)
  2766. >>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
  2767. >>> output = batch_to_space_nd(input_x)
  2768. [[[[1., 2.], [3., 4.]]]]
  2769. """
  2770. @prim_attr_register
  2771. def __init__(self, block_shape, crops):
  2772. """Initialize BatchToSpaceND"""
  2773. self.ori_block_shape = block_shape
  2774. self.ori_crops = crops
  2775. validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
  2776. validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
  2777. block_rank = len(block_shape)
  2778. validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
  2779. for elem in block_shape:
  2780. validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
  2781. validator.check_value_type('block_shape element', elem, [int], self.name)
  2782. self.block_shape = block_shape
  2783. validator.check_value_type('crops type', crops, [list, tuple], self.name)
  2784. validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name)
  2785. validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
  2786. for elem in itertools.chain(*crops):
  2787. validator.check_non_negative_int(elem, 'crops element', self.name)
  2788. validator.check_value_type('crops element', elem, [int], self.name)
  2789. self.crops = crops
  2790. block_shape_append = [1] + list(self.block_shape)
  2791. self.add_prim_attr("block_shape", block_shape_append)
  2792. crops_append = [[0, 0]] + list(self.crops)
  2793. self.add_prim_attr("crops", crops_append)
  2794. def infer_dtype(self, x_dtype):
  2795. validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
  2796. return x_dtype
  2797. def infer_shape(self, x_shape):
  2798. x_rank = len(x_shape)
  2799. validator.check_int(x_rank, 4, Rel.EQ, 'x_shape rank', self.name)
  2800. out_shape = copy.deepcopy(x_shape)
  2801. block_shape_prod = 1
  2802. offset = 2
  2803. if x_rank <= 4:
  2804. offset = 1
  2805. for i in range(len(self.block_shape)):
  2806. block_shape_prod = block_shape_prod * self.block_shape[i]
  2807. x_block_prod = out_shape[i + offset] * self.block_shape[i]
  2808. crops_sum = self.crops[i][0] + self.crops[i][1]
  2809. validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
  2810. out_shape[i + offset] = x_block_prod - crops_sum
  2811. if out_shape[0] % block_shape_prod != 0:
  2812. raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
  2813. f'block_shape_prod {block_shape_prod}')
  2814. out_shape[0] = out_shape[0] // block_shape_prod
  2815. return out_shape
  2816. class BroadcastTo(PrimitiveWithInfer):
  2817. """
  2818. Broadcasts input tensor to a given shape.
  2819. Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
  2820. When input shape is broadcast to target shape, it starts with the trailing dimensions.
  2821. Args:
  2822. shape (tuple): The target shape to broadcast.
  2823. Inputs:
  2824. - **input_x** (Tensor) - The input tensor.
  2825. Outputs:
  2826. Tensor, with the given `shape` and the same data type as `input_x`.
  2827. Examples:
  2828. >>> shape = (2, 3)
  2829. >>> input_x = Tensor(np.array([1, 2, 3]).astype(np.float32))
  2830. >>> broadcast_to = P.BroadcastTo(shape)
  2831. >>> broadcast_to(input_x)
  2832. [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]
  2833. """
  2834. @prim_attr_register
  2835. def __init__(self, shape):
  2836. """Initialize BroadcastTo"""
  2837. validator.check_value_type("shape", shape, (tuple), self.name)
  2838. validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
  2839. for i in shape:
  2840. validator.check_positive_int(i, "shape element", self.name)
  2841. self.shape = shape
  2842. def infer_shape(self, x_shape):
  2843. validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
  2844. reversed_x_shape = tuple(reversed(x_shape))
  2845. reversed_target = tuple(reversed(self.shape))
  2846. for i, v in enumerate(reversed_x_shape):
  2847. if v not in (reversed_target[i], 1):
  2848. raise ValueError(f"Not supported shapes for broadcast, "
  2849. f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
  2850. return self.shape
  2851. def infer_dtype(self, x_dtype):
  2852. validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
  2853. return x_dtype
  2854. class Meshgrid(PrimitiveWithInfer):
  2855. """
  2856. Generates coordinate matrices from given coordinate tensors.
  2857. Given N one-dimensional coordinate tensors, returns a list outputs of N N-D
  2858. coordinate tensors for evaluating expressions on an N-D grid.
  2859. Args:
  2860. indexing (str): Either 'xy' or 'ij'. Default: 'xy'.
  2861. When the indexing argument is set to 'xy' (the default),
  2862. the broadcasting instructions for the first two dimensions are swapped.
  2863. Inputs:
  2864. - **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
  2865. The length of input_x should be greater than 1
  2866. Outputs:
  2867. Tensors, A Tuple of N N-D Tensor objects.
  2868. Examples:
  2869. >>> x = np.array([1, 2, 3, 4]).astype(np.int32)
  2870. >>> y = np.array([5, 6, 7]).astype(np.int32)
  2871. >>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
  2872. >>> inputs = (x, y, z)
  2873. >>> meshgrid = P.Meshgrid(indexing="xy")
  2874. >>> meshgrid(inputs)
  2875. (Tensor(shape=[3, 4, 6], dtype=UInt32, value=
  2876. [[[1, 1, 1, 1, 1],
  2877. [2, 2, 2, 2, 2],
  2878. [3, 3, 3, 3, 3],
  2879. [4, 4, 4, 4, 4]],
  2880. [[1, 1, 1, 1, 1],
  2881. [2, 2, 2, 2, 2],
  2882. [3, 3, 3, 3, 3],
  2883. [4, 4, 4, 4, 4]],
  2884. [[1, 1, 1, 1, 1],
  2885. [2, 2, 2, 2, 2],
  2886. [3, 3, 3, 3, 3],
  2887. [4, 4, 4, 4, 4]]]),
  2888. Tensor(shape=[3, 4, 6], dtype=UInt32, value=
  2889. [[[5, 5, 5, 5, 5],
  2890. [5, 5, 5, 5, 5],
  2891. [5, 5, 5, 5, 5],
  2892. [5, 5, 5, 5, 5]],
  2893. [[6, 6, 6, 6, 6],
  2894. [6, 6, 6, 6, 6],
  2895. [6, 6, 6, 6, 6],
  2896. [6, 6, 6, 6, 6]],
  2897. [[7, 7, 7, 7, 7],
  2898. [7, 7, 7, 7, 7],
  2899. [7, 7, 7, 7, 7],
  2900. [7, 7, 7, 7, 7]]]),
  2901. Tensor(shape=[3, 4, 6], dtype=UInt32, value=
  2902. [[[8, 9, 0, 1, 2],
  2903. [8, 9, 0, 1, 2],
  2904. [8, 9, 0, 1, 2],
  2905. [8, 9, 0, 1, 2]],
  2906. [[8, 9, 0, 1, 2],
  2907. [8, 9, 0, 1, 2],
  2908. [8, 9, 0, 1, 2],
  2909. [8, 9, 0, 1, 2]],
  2910. [[8, 9, 0, 1, 2],
  2911. [8, 9, 0, 1, 2],
  2912. [8, 9, 0, 1, 2],
  2913. [8, 9, 0, 1, 2]]]))
  2914. """
  2915. @prim_attr_register
  2916. def __init__(self, indexing="xy"):
  2917. """Init Meshgrid"""
  2918. validator.check_value_type("indexing", indexing, (str), self.name)
  2919. if indexing not in ("xy", "ij"):
  2920. raise ValueError("indexing parameter must be either 'xy' or 'ij'")
  2921. self.indexing = indexing
  2922. def infer_shape(self, x_shape):
  2923. validator.check_value_type("shape", x_shape, [tuple, list], self.name)
  2924. validator.check_int(len(x_shape), 2, Rel.GE, "len of input_x", self.name)
  2925. n = len(x_shape)
  2926. shape_0 = []
  2927. for s in x_shape:
  2928. validator.check_int(len(s), 1, Rel.EQ, 'each_input_rank', self.name)
  2929. shape_0.append(s[0])
  2930. if self.indexing == "xy":
  2931. shape_0[0], shape_0[1] = shape_0[1], shape_0[0]
  2932. out_shape = tuple(tuple(shape_0) for _ in range(n))
  2933. return out_shape
  2934. def infer_dtype(self, x_type):
  2935. validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
  2936. n = len(x_type)
  2937. for i in range(1, n):
  2938. validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
  2939. return x_type
  2940. class InplaceUpdate(PrimitiveWithInfer):
  2941. r"""
  2942. Updates specified rows with values in `v`.
  2943. Args:
  2944. indices (Union[int, tuple]): Indices into the left-most dimension of `x`, and determines which rows of x
  2945. to update with v. It is a int or tuple, whose value is in [0, the first dimension size of x).
  2946. Inputs:
  2947. - **x** (Tensor) - A tensor which to be inplace updated. It can be one of the following data types:
  2948. float32, float16 and int32.
  2949. - **v** (Tensor) - A tensor with the same type as `x` and the same dimension size as `x` except
  2950. the first dimension, which must be the same as the size of `indices`.
  2951. Outputs:
  2952. Tensor, with the same type and shape as the input `x`.
  2953. Examples:
  2954. >>> indices = (0, 1)
  2955. >>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
  2956. >>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
  2957. >>> inplace_update = P.InplaceUpdate(indices)
  2958. >>> result = inplace_update(x, v)
  2959. [[0.5, 1.0],
  2960. [1.0, 1.5],
  2961. [5.0, 6.0]]
  2962. """
  2963. @prim_attr_register
  2964. def __init__(self, indices):
  2965. """Initialize InplaceUpdate"""
  2966. self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
  2967. self.indices = indices
  2968. validator.check_value_type("indices", indices, [int, tuple], self.name)
  2969. if isinstance(indices, int):
  2970. self.indices = (indices,)
  2971. for item in self.indices:
  2972. validator.check_value_type("item of indices", item, [int], self.name)
  2973. def infer_dtype(self, x_dtype, v_dtype):
  2974. args = {'x': x_dtype, 'v': v_dtype}
  2975. valid_type = [mstype.int32, mstype.float16, mstype.float32]
  2976. validator.check_tensor_type_same(args, valid_type, self.name)
  2977. return x_dtype
  2978. def infer_shape(self, x_shape, v_shape):
  2979. validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name)
  2980. validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0],
  2981. Rel.EQ, self.name)
  2982. for i in self.indices:
  2983. if i < 0 or i >= x_shape[0]:
  2984. raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.')
  2985. x_rank = len(x_shape)
  2986. for idx in range(x_rank)[1:]:
  2987. validator.check('v dim %d' % idx, v_shape[idx], "x dim %d" % idx, x_shape[idx], Rel.EQ, self.name)
  2988. return x_shape
  2989. class ReverseSequence(PrimitiveWithInfer):
  2990. """
  2991. Reverses variable length slices.
  2992. Args:
  2993. seq_dim (int): The dimension where reversal is performed. Required.
  2994. batch_dim (int): The input is sliced in this dimension. Default: 0.
  2995. Inputs:
  2996. - **x** (Tensor) - The input to reverse, supporting all number types including bool.
  2997. - **seq_lengths** (Tensor) - Must be a 1-D vector with int32 or int64 types.
  2998. Outputs:
  2999. Reversed tensor with the same shape and data type as input.
  3000. Examples:
  3001. >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
  3002. >>> seq_lengths = Tensor(np.array([1, 2, 3]))
  3003. >>> reverse_sequence = P.ReverseSequence(seq_dim=1)
  3004. >>> output = reverse_sequence(x, seq_lengths)
  3005. [[1 2 3]
  3006. [5 4 6]
  3007. [9 8 7]]
  3008. """
  3009. @prim_attr_register
  3010. def __init__(self, seq_dim, batch_dim=0):
  3011. """Initialize ReverseSequence"""
  3012. self.init_prim_io_names(inputs=['x', 'seq_lengths'], outputs=['y'])
  3013. validator.check_value_type("seq_dim", seq_dim, [int], self.name)
  3014. self.seq_dim_ = seq_dim
  3015. validator.check_value_type("batch_dim", batch_dim, [int], self.name)
  3016. self.batch_dim_ = batch_dim
  3017. def infer_shape(self, x, seq_lengths):
  3018. validator.check("seq_dim", self.seq_dim_, "x rank", len(x), Rel.LE, self.name)
  3019. validator.check("batch_dim", self.batch_dim_, "x rank", len(x), Rel.LE, self.name)
  3020. validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
  3021. validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
  3022. validator.check("seq_lengths vector size", seq_lengths[0],
  3023. "input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
  3024. return x
  3025. def infer_dtype(self, x, seq_lengths):
  3026. validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
  3027. validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
  3028. return x
  3029. class EditDistance(PrimitiveWithInfer):
  3030. """
  3031. Computes the Levebshtein Edit Distance. It is used to measure the similarity of two sequences.
  3032. Args:
  3033. normalize (bool): If true, edit distances are normalized by length of truth. Default: True.
  3034. Inputs:
  3035. - **hypothesis_indices** (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type.
  3036. The shape of tensor is :math:`(N, R)`.
  3037. - **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor.
  3038. Must be 1-D vector with length of N.
  3039. - **hypothesis_shape** (Tensor) - The shape of the hypothesis list SparseTensor.
  3040. Must be R-length vector with int64 data type. Only constant value is allowed.
  3041. - **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type.
  3042. The shape of tensor is :math:`(M, R)`.
  3043. - **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M.
  3044. - **truth_shape** (Tensor) - The shape of the truth list SparseTensor.
  3045. Must be R-length vector with int64 data type. Only constant value is allowed.
  3046. Outputs:
  3047. Tensor, a dense tensor with rank `R-1` and float32 data type.
  3048. Examples:
  3049. >>> import numpy as np
  3050. >>> from mindspore import context
  3051. >>> from mindspore import Tensor
  3052. >>> import mindspore.nn as nn
  3053. >>> import mindspore.ops.operations as P
  3054. >>> context.set_context(mode=context.GRAPH_MODE)
  3055. >>> class EditDistance(nn.Cell):
  3056. >>> def __init__(self, hypothesis_shape, truth_shape, normalize=True):
  3057. >>> super(EditDistance, self).__init__()
  3058. >>> self.edit_distance = P.EditDistance(normalize)
  3059. >>> self.hypothesis_shape = hypothesis_shape
  3060. >>> self.truth_shape = truth_shape
  3061. >>>
  3062. >>> def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
  3063. >>> return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
  3064. >>> truth_indices, truth_values, self.truth_shape)
  3065. >>>
  3066. >>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64))
  3067. >>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32))
  3068. >>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64))
  3069. >>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64))
  3070. >>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32))
  3071. >>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64))
  3072. >>> edit_distance = EditDistance(hypothesis_shape, truth_shape)
  3073. >>> out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values)
  3074. >>> [[1.0, 1.0], [1.0, 1.0]]
  3075. """
  3076. @prim_attr_register
  3077. def __init__(self, normalize=True):
  3078. """Initialize EditDistance"""
  3079. self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name)
  3080. self.set_const_input_indexes([2, 5])
  3081. def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape):
  3082. validator.check_const_input('hypothesis_shape', h_shape['value'], self.name)
  3083. validator.check_const_input('truth_shape', truth_shape['value'], self.name)
  3084. args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
  3085. "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
  3086. validator.check_tensor_type_same(args_int, [mstype.int64], self.name)
  3087. args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
  3088. validator.check_tensor_type_same(args, mstype.number_type, self.name)
  3089. hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
  3090. validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
  3091. validator.check("truth_indices rank", len(truth_indices_shp), "expected", 2, Rel.EQ, self.name)
  3092. validator.check("hypothesis_values rank", len(h_values['shape']), "expected", 1, Rel.EQ, self.name)
  3093. validator.check("hypothesis_shape rank", len(h_shape['shape']), "expected", 1, Rel.EQ, self.name)
  3094. validator.check("truth_values rank", len(truth_values['shape']), "expected", 1, Rel.EQ, self.name)
  3095. validator.check("truth_shape rank", len(truth_shape['shape']), "expected", 1, Rel.EQ, self.name)
  3096. validator.check("hypothesis_values shape", h_values['shape'][0],
  3097. "hypothesis_indices shape[0]", hypothesis_indices_shp[0], Rel.EQ, self.name)
  3098. validator.check("hypothesis_shape", h_shape['shape'][0],
  3099. "hypothesis_indices shape[1]", hypothesis_indices_shp[1], Rel.EQ, self.name)
  3100. validator.check("truth_values shape", truth_values['shape'][0],
  3101. "truth_indices shape[0]", truth_indices_shp[0], Rel.EQ, self.name)
  3102. validator.check("hypothesis_shape", h_shape['shape'][0],
  3103. "truth_shape", truth_shape['shape'][0], Rel.EQ, self.name)
  3104. hypothesis_shape_v = h_shape['value'].asnumpy()
  3105. truth_shape_v = truth_shape['value'].asnumpy()
  3106. out_shape_rank = len(hypothesis_shape_v) - 1
  3107. out_shape = []
  3108. for i in range(out_shape_rank):
  3109. out_shape.append(max(hypothesis_shape_v[i], truth_shape_v[i]))
  3110. return {'shape': tuple(out_shape),
  3111. 'dtype': mstype.tensor_type(mstype.float32),
  3112. 'value': None}
  3113. class TransShape(PrimitiveWithInfer):
  3114. """
  3115. Transforms the shape of input tensor to target shape.
  3116. Inputs:
  3117. - **input_x** (Tensor) - A input tensor.
  3118. - **out_shape** (tuple[int]) - The shape of output data.
  3119. Outputs:
  3120. Tensor, a tensor whose data type is same as 'input_x', and the shape is the same as the `out_shape`.
  3121. """
  3122. @prim_attr_register
  3123. def __init__(self):
  3124. self.__setattr_flag__ = True
  3125. def __infer__(self, x, shape):
  3126. shp = shape['value']
  3127. dtype = x['dtype']
  3128. validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
  3129. self.add_prim_attr('out_shape', tuple(shp))
  3130. return {'shape': shp,
  3131. 'dtype': dtype,
  3132. 'value': None}
  3133. class Sort(PrimitiveWithInfer):
  3134. """
  3135. Sorts the elements of the input tensor along a given dimension in ascending order by value.
  3136. Args:
  3137. axis (int): The dimension to sort along. Default: -1.
  3138. descending (bool): Controls the sorting order. If descending is True then the elements
  3139. are sorted in descending order by value. Default: False.
  3140. Inputs:
  3141. - **x** (Tensor) - The input to sort, with float16 or float32 data type.
  3142. Outputs:
  3143. - **y1** (Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.
  3144. - **y2** (Tensor) - The indices of the elements in the original input tensor. Data type is int32.
  3145. Examples:
  3146. >>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
  3147. >>> sort = P.Sort()
  3148. >>> sort(x)
  3149. >>> ([[1.0, 2.0, 8.0], [3.0, 5.0, 9.0], [4.0, 6.0 ,7.0]],
  3150. [[2, 1, 0], [2, 0, 1], [0, 1, 2]])
  3151. """
  3152. @prim_attr_register
  3153. def __init__(self, axis=-1, descending=False):
  3154. """Initialize Sort"""
  3155. self.axis = validator.check_value_type("axis", axis, [int], self.name)
  3156. self.descending = validator.check_value_type("descending", descending, [bool], self.name)
  3157. def infer_shape(self, x_shape):
  3158. return x_shape, x_shape
  3159. def infer_dtype(self, x_dtype):
  3160. validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float32, mstype.float16], self.name)
  3161. return x_dtype, mstype.tensor_type(mstype.int32)
  3162. class EmbeddingLookup(PrimitiveWithInfer):
  3163. """
  3164. Returns a slice of input tensor based on the specified indices.
  3165. This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has one more inputs:
  3166. `offset`.
  3167. Inputs:
  3168. - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  3169. This represents a Tensor slice, instead of the entire Tensor. Currently, the dimension is restricted to be 2.
  3170. - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
  3171. Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`,
  3172. and the exceeding part will be filled with 0 in the output.
  3173. - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices
  3174. are equal to `input_indices` minus `offset`.
  3175. Outputs:
  3176. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  3177. Examples:
  3178. >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32)
  3179. >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32)
  3180. >>> offset = 4
  3181. >>> out = P.EmbeddingLookup()(input_params, input_indices, offset)
  3182. [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]]
  3183. """
  3184. @prim_attr_register
  3185. def __init__(self):
  3186. """Initialize index_select"""
  3187. self.__setattr_flag__ = True
  3188. self.init_prim_io_names(inputs=['params', 'indices', 'offset'],
  3189. outputs=['output'])
  3190. def __infer__(self, params, indices, offset):
  3191. validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
  3192. validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
  3193. validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name)
  3194. params_shp = params['shape']
  3195. if len(params_shp) != 2:
  3196. raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp))
  3197. out_shape = indices['shape'] + params_shp[1:]
  3198. out = {'shape': out_shape,
  3199. 'dtype': params['dtype'],
  3200. 'value': None}
  3201. return out
  3202. class GatherD(PrimitiveWithInfer):
  3203. """
  3204. Gathers values along an axis specified by dim.
  3205. Inputs:
  3206. - **x** (Tensor) - The source tensor.
  3207. - **dim** (int) - The axis along which to index. It must be int32. Only constant value is allowed.
  3208. - **index** (Tensor) - The indices of elements to gather. It can be one of the following data types:
  3209. int32, int64.
  3210. Outputs:
  3211. Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
  3212. Examples:
  3213. >>> x = Tensor(np.array([[1, 2], [3, 4]]), mindspore.int32)
  3214. >>> index = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
  3215. >>> dim = 1
  3216. >>> out = P.GatherD()(x, dim, index)
  3217. [[1, 1], [4, 3]]
  3218. """
  3219. @prim_attr_register
  3220. def __init__(self):
  3221. """Initialize GatherD"""
  3222. def __infer__(self, x, dim, index):
  3223. validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
  3224. validator.check_tensor_type_same({"index": index['dtype']}, [mstype.int32, mstype.int64], self.name)
  3225. validator.check_subclass("dim", dim['dtype'], mstype.int32, self.name)
  3226. x_shp = x['shape']
  3227. idx_shp = index['shape']
  3228. x_rank = len(x_shp)
  3229. idx_rank = len(idx_shp)
  3230. validator.check("x_rank, idx_rank", x_rank, "expected", idx_rank, Rel.EQ, self.name)
  3231. dim_v = dim['value']
  3232. validator.check("dim value", dim_v, "expected", 0, Rel.GE, self.name)
  3233. validator.check("dim value", dim_v, "expected", x_rank, Rel.LT, self.name)
  3234. for i in range(x_rank):
  3235. if i == dim_v:
  3236. continue
  3237. validator.check("x_shp[{0}], idx_shp[{0}]".format(i), x_shp[i], "expected", idx_shp[i], Rel.EQ, self.name)
  3238. out = {'shape': index['shape'],
  3239. 'dtype': x['dtype'],
  3240. 'value': None}
  3241. return out
  3242. class Identity(PrimitiveWithInfer):
  3243. """
  3244. Returns a Tensor with the same shape and contents as input.
  3245. Inputs:
  3246. - **x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
  3247. Outputs:
  3248. Tensor, the shape of tensor is the same as `input_x`, :math:`(x_1, x_2, ..., x_R)`.
  3249. Examples:
  3250. >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
  3251. >>> y = P.Identity()(x)
  3252. [1, 2, 3, 4]
  3253. """
  3254. @prim_attr_register
  3255. def __init__(self):
  3256. """Initialize identity"""
  3257. def __infer__(self, x):
  3258. out = {'shape': x['shape'],
  3259. 'dtype': x['dtype'],
  3260. 'value': None}
  3261. return out
  3262. class RepeatElements(PrimitiveWithInfer):
  3263. """
  3264. Repeat elements of a tensor along an axis, like np.repeat.
  3265. Args:
  3266. rep (int): The number of times to repeat, must be positive, required.
  3267. axis (int): The axis along which to repeat, default 0.
  3268. Inputs:
  3269. - **x** (Tensor) - The tensor to repeat values for. Must be of type int32 or float16.
  3270. Outputs:
  3271. One tensor with values repeated along the specified axis. If x has shape
  3272. (s1, s2, ..., sn) and axis is i, the output will have shape (s1, s2, ..., si * rep, ..., sn)
  3273. Examples:
  3274. >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
  3275. >>> repeat_elements = P.RepeatElements(rep = 2, axis = 0)
  3276. >>> output = repeat_elements(x)
  3277. [[0, 1, 2],
  3278. [0, 1, 2],
  3279. [3, 4, 5],
  3280. [3, 4, 5]],
  3281. """
  3282. @prim_attr_register
  3283. def __init__(self, rep, axis=0):
  3284. self.init_prim_io_names(inputs=["x"], outputs=["output"])
  3285. validator.check_value_type("rep", rep, [int], self.name)
  3286. self.rep = rep
  3287. validator.check_value_type("axis", axis, [int], self.name)
  3288. self.axis = axis
  3289. def infer_shape(self, x_shape):
  3290. validator.check("rep", self.rep, "", 0, Rel.GT, self.name)
  3291. validator.check("axis", self.axis, "dimension of x", len(x_shape), Rel.LT, self.name)
  3292. validator.check("axis", self.axis, "negative dimension of x", -len(x_shape), Rel.GE, self.name)
  3293. x_shape[self.axis] *= self.rep
  3294. return x_shape
  3295. def infer_dtype(self, x_dtype):
  3296. validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
  3297. return x_dtype