|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136 |
- # Copyright 2019-2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import numpy as np
- import pytest
-
- import mindspore.context as context
- import mindspore.nn as nn
- from mindspore import Tensor
- from mindspore.ops.operations import _inner_ops as inner
- from mindspore.ops import operations as P
-
-
- class GatherNet(nn.Cell):
- def __init__(self):
- super(GatherNet, self).__init__()
- self.gather = P.Gather()
-
- def construct(self, x, indices):
- return self.gather(x, indices, 1)
-
-
- @pytest.mark.level0
- @pytest.mark.platform_x86_gpu_training
- @pytest.mark.env_onecard
- def test_gather0():
- x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
- indices = Tensor(np.ones((2, 2, 4, 5), dtype='i4'))
- expect = np.array([[[[[[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]]],
-
- [[[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]]]],
-
- [[[[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]]],
-
- [[[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]],
-
- [[[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]],
-
- [[20., 21., 22., 23., 24.],
- [25., 26., 27., 28., 29.],
- [30., 31., 32., 33., 34.],
- [35., 36., 37., 38., 39.]]]]]],
-
- [[[[[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]]],
-
- [[[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]]]],
-
- [[[[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]]],
-
- [[[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]],
-
- [[[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]],
-
- [[80., 81., 82., 83., 84.],
- [85., 86., 87., 88., 89.],
- [90., 91., 92., 93., 94.],
- [95., 96., 97., 98., 99.]]]]]]])
-
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- gather = GatherNet()
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
-
-
- class GatherNet1(nn.Cell):
- def __init__(self):
- super(GatherNet1, self).__init__()
- self.gather = P.Gather()
-
- def construct(self, x, indices):
- return self.gather(x, indices, -1)
-
-
- @pytest.mark.level0
- @pytest.mark.platform_x86_gpu_training
- @pytest.mark.env_onecard
- def test_gather1():
- x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
- indices = Tensor(np.array([1, 3, 4], dtype='i4'))
- expect = np.array([[[[1., 3., 4.],
- [6., 8., 9.],
- [11., 13., 14.],
- [16., 18., 19.]],
-
- [[21., 23., 24.],
- [26., 28., 29.],
- [31., 33., 34.],
- [36., 38., 39.]],
-
- [[41., 43., 44.],
- [46., 48., 49.],
- [51., 53., 54.],
- [56., 58., 59.]]],
-
- [[[61., 63., 64.],
- [66., 68., 69.],
- [71., 73., 74.],
- [76., 78., 79.]],
-
- [[81., 83., 84.],
- [86., 88., 89.],
- [91., 93., 94.],
- [96., 98., 99.]],
-
- [[101., 103., 104.],
- [106., 108., 109.],
- [111., 113., 114.],
- [116., 118., 119.]]]])
-
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- gather = GatherNet1()
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
-
-
- class GatherNet2(nn.Cell):
- def __init__(self):
- super(GatherNet2, self).__init__()
- self.gather = P.Gather()
-
- def construct(self, x, indices):
- return self.gather(x, indices, 0)
-
-
- @pytest.mark.level0
- @pytest.mark.platform_x86_gpu_training
- @pytest.mark.env_onecard
- def test_gather2():
- x = Tensor(np.array([[4., 5., 4., 1., 5.,],
- [4., 9., 5., 6., 4.,],
- [9., 8., 4., 3., 6.,],
- [0., 4., 2., 2., 8.,],
- [1., 8., 6., 2., 8.,],
- [8., 1., 9., 7., 3.,],
- [7., 9., 2., 5., 7.,],
- [9., 8., 6., 8., 5.,],
- [3., 7., 2., 7., 4.,],
- [4., 2., 8., 2., 9.,]]
- ).astype(np.float32))
-
- indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64))
- expect = np.array([[[0., 0., 0., 0., 0.],
- [4., 9., 5., 6., 4.],
- [0., 0., 0., 0., 0.]]])
-
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- gather = GatherNet2()
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
-
-
- # Dynamic Shape testing ahead
- class GatherNetDynamic(nn.Cell):
- def __init__(self, axis=0, dyn_a=True, dyn_b=True):
- super(GatherNetDynamic, self).__init__()
- self.gather = P.Gather()
- self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
- self.to_dyn_1 = dyn_a
- self.to_dyn_2 = dyn_b
- self.axis = axis
- def construct(self, x, indices):
- # testing selective inputs being dynamic
- if self.to_dyn_1:
- x = self.gpu_convert_to_dynamic_shape(x)
- if self.to_dyn_2:
- indices = self.gpu_convert_to_dynamic_shape(indices)
- return self.gather(x, indices, self.axis)
-
-
- @pytest.mark.level0
- @pytest.mark.platform_x86_gpu_training
- @pytest.mark.env_onecard
- def test_gatherV2_dyn_ab():
- """
- Tests for Dynamic shape with both inputs dynamic
- """
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- gather = GatherNetDynamic()
- x = Tensor(np.array([[4., 5., 4., 1., 5.,],
- [4., 9., 5., 6., 4.,],
- [9., 8., 4., 3., 6.,],
- [0., 4., 2., 2., 8.,],
- [1., 8., 6., 2., 8.,],
- [8., 1., 9., 7., 3.,],
- [7., 9., 2., 5., 7.,],
- [9., 8., 6., 8., 5.,],
- [3., 7., 2., 7., 4.,],
- [4., 2., 8., 2., 9.,]]
- ).astype(np.float32))
- indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
- expect = np.array([[[0., 0., 0., 0., 0.],
- [4., 9., 5., 6., 4.],
- [0., 0., 0., 0., 0.]]])
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
-
-
- @pytest.mark.level0
- @pytest.mark.platform_x86_gpu_training
- @pytest.mark.env_onecard
- def test_gatherV2_dyn_a():
- """
- Tests for Dynamic shape with only first input dynamic
- """
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- gather = GatherNetDynamic(-1, True, False)
- # test 1
- x = Tensor(np.array([[4., 5., 4., 1., 5.,],
- [4., 9., 5., 6., 4.,],
- [9., 8., 4., 3., 6.,],
- [0., 4., 2., 2., 8.,],
- [1., 8., 6., 2., 8.,],
- [8., 1., 9., 7., 3.,],
- [7., 9., 2., 5., 7.,],
- [9., 8., 6., 8., 5.,],
- [3., 7., 2., 7., 4.,],
- [4., 2., 8., 2., 9.,]]
- ).astype(np.float32))
- indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64))
- expect = np.array([[[0., 5., 0.]],
- [[0., 9., 0.]],
- [[0., 8., 0.]],
- [[0., 4., 0.]],
- [[0., 8., 0.]],
- [[0., 1., 0.]],
- [[0., 9., 0.]],
- [[0., 8., 0.]],
- [[0., 7., 0.]],
- [[0., 2., 0.]]]).astype(np.float32)
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
- # test 2
- x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
- indices = Tensor(np.array([1, 3, 4], dtype='i4'))
- expect = np.array([[[[1., 3., 4.],
- [6., 8., 9.],
- [11., 13., 14.],
- [16., 18., 19.]],
-
- [[21., 23., 24.],
- [26., 28., 29.],
- [31., 33., 34.],
- [36., 38., 39.]],
-
- [[41., 43., 44.],
- [46., 48., 49.],
- [51., 53., 54.],
- [56., 58., 59.]]],
-
- [[[61., 63., 64.],
- [66., 68., 69.],
- [71., 73., 74.],
- [76., 78., 79.]],
-
- [[81., 83., 84.],
- [86., 88., 89.],
- [91., 93., 94.],
- [96., 98., 99.]],
-
- [[101., 103., 104.],
- [106., 108., 109.],
- [111., 113., 114.],
- [116., 118., 119.]]]])
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
-
-
- @pytest.mark.level0
- @pytest.mark.platform_x86_gpu_training
- @pytest.mark.env_onecard
- def test_gatherV2_dyn_b():
- """
- Tests for Dynamic shape with only second input dynamic
- """
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- gather = GatherNetDynamic(-1, False, True)
- # test 1
- x = Tensor(np.array([[4., 5., 4., 1., 5.,],
- [4., 9., 5., 6., 4.,],
- [9., 8., 4., 3., 6.,],
- [0., 4., 2., 2., 8.,],
- [1., 8., 6., 2., 8.,],
- [8., 1., 9., 7., 3.,],
- [7., 9., 2., 5., 7.,],
- [9., 8., 6., 8., 5.,],
- [3., 7., 2., 7., 4.,],
- [4., 2., 8., 2., 9.,]]
- ).astype(np.float32))
- indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32))
- expect = np.array([[[0., 5., 0.]],
- [[0., 9., 0.]],
- [[0., 8., 0.]],
- [[0., 4., 0.]],
- [[0., 8., 0.]],
- [[0., 1., 0.]],
- [[0., 9., 0.]],
- [[0., 8., 0.]],
- [[0., 7., 0.]],
- [[0., 2., 0.]]]).astype(np.float32)
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
- # test 2
- x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5))
- indices = Tensor(np.array([1, 3, 4], dtype='i4'))
- expect = np.array([[[[1., 3., 4.],
- [6., 8., 9.],
- [11., 13., 14.],
- [16., 18., 19.]],
- [[21., 23., 24.],
- [26., 28., 29.],
- [31., 33., 34.],
- [36., 38., 39.]],
- [[41., 43., 44.],
- [46., 48., 49.],
- [51., 53., 54.],
- [56., 58., 59.]]],
- [[[61., 63., 64.],
- [66., 68., 69.],
- [71., 73., 74.],
- [76., 78., 79.]],
- [[81., 83., 84.],
- [86., 88., 89.],
- [91., 93., 94.],
- [96., 98., 99.]],
- [[101., 103., 104.],
- [106., 108., 109.],
- [111., 113., 114.],
- [116., 118., 119.]]]])
- output = gather(x, indices)
- error = np.ones(shape=output.asnumpy().shape) * 1.0e-6
- diff = output.asnumpy() - expect
- assert np.all(diff < error)
- assert np.all(-diff < error)
|