From 3dbac4f47f3faca8d31700349f04a43e028cbc6d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 Aug 2020 23:36:03 +0800 Subject: [PATCH] feat(mge): add atlas_subgraph module GitOrigin-RevId: 11530383c0a31f4648ed89d3070b2dab178ea5b2 --- .../megengine/functional/external.py | 11 ++++++ python_module/megengine/module/external.py | 29 ++++++++++++++- .../unit/module/AtlasRuntimeOprTest.basic.om | Bin 0 -> 32916 bytes .../test/unit/module/test_external.py | 33 ++++++++++++++++-- 4 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 python_module/test/unit/module/AtlasRuntimeOprTest.basic.om diff --git a/python_module/megengine/functional/external.py b/python_module/megengine/functional/external.py index badede8f..6c93d217 100644 --- a/python_module/megengine/functional/external.py +++ b/python_module/megengine/functional/external.py @@ -34,6 +34,17 @@ def cambricon_subgraph( ) +@wrap_io_tensor +def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]: + """Load a serialized Atlas subgraph (i.e. om model) and + execute the operations defined in the subgraph. + + :param inputs: List of input tensors of the subgraph. + :param data: The serialized subgraph. + """ + return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data) + + @wrap_io_tensor def extern_opr_subgraph( inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, diff --git a/python_module/megengine/module/external.py b/python_module/megengine/module/external.py index a5da2a14..962754e8 100644 --- a/python_module/megengine/module/external.py +++ b/python_module/megengine/module/external.py @@ -8,7 +8,11 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np -from ..functional.external import cambricon_subgraph, extern_opr_subgraph +from ..functional.external import ( + atlas_subgraph, + cambricon_subgraph, + extern_opr_subgraph, +) from .module import Module @@ -41,6 +45,29 @@ class CambriconSubgraph(Module): return outputs +class AtlasSubgraph(Module): + r"""Load a serialized Atlas subgraph. + + See :func:`~.atlas_subgraph` for more details. + """ + + def __init__(self, data): + super(AtlasSubgraph, self).__init__() + self._data = data + + @property + def data(self): + return self._data.tobytes() + + @data.setter + def data(self, val): + self._data = np.frombuffer(val, dtype=np.uint8) + + def forward(self, inputs): + outputs = atlas_subgraph(inputs, self._data) + return outputs + + class ExternOprSubgraph(Module): r"""Load a serialized extern opr subgraph. """ diff --git a/python_module/test/unit/module/AtlasRuntimeOprTest.basic.om b/python_module/test/unit/module/AtlasRuntimeOprTest.basic.om new file mode 100644 index 0000000000000000000000000000000000000000..942fe2edff071b681433616ea0de43b30aa9cb29 GIT binary patch literal 32916 zcmeHw4SZC^x%Zse&7RF+6P6`l2*_Q&Lr9j*W&;?tJ%NBh`7(kcwr)0?1tR$}*#JSU zCjk{xN^|j}sHm$Jl#f6l3H8#eu1H(zYkRrcR&U#1&AqnwTKm3N>-E=*y={5_b3XR$ zCVLLytGA_`A1CLTXXbfk=07vfJTr6luyFCxGQtsdIpx0>@aNI{72&#u>TvA^=wYH1 zccG`S(6h1jlUf<4JpLM(VZy_k@Zc}jV@HxSaz6C$L^HFVM( zyH>=KC$1HEnH%u>UbYBawaZ!-j0E#Y#c(0L8dMYAozWQzuJ(Oj6Rk_;mEV?km5|ZURI{e0KCrr> zsV*3C*<5l&Ik!xzY_cp_wqgf;k{hyE-dMq*>T6yb1boAU^c*=q)F^L+FihqKP`ov$ zPZP--FF4ioI^=@OUbZ~2VA0aj<#{f_QB%J<+@uwcW?N9Ss8TeEv-*C^3rW>=Sy{F7 zEY&66ifptSLoQGs4p#?44Rwt*wF-0oD3gaSG z;Xp%UAQ)+=s|f`*G&HSiZVZOPfz`FaHFD3mT5z^hY4MH0reIyOD_!7oFc?@9It8Jo zuCcaN>q=aOC|s44@jB!hu4TQ-HN`b45)On|qg54$ihZ-HR(pM+;5AiVZ@?c2C@H-a zR|z8m=oj^Ys$e9vHqhKyTN9DH#$;h^fPq-uQr8$zKn>Qi;xp~dtWTnIF!4}BeWa$o zrJOD7iuYguz-IqJ-z>kjVl8Nx7EdO%g@HS73eQT*m`b8}5YePB&fuyJjm zrkZuSg2I5e0<9fcUl*vVZ3wNauBk(Iv?5cmR|67tFmkDf#Z^KkFi{2+s@EJ*s%u3T za4bY8PB8+rgv`L|a3HCDY!#ygX97I3LKz*DBGS}S9|}fT2O>imzf*;DR++lsnwpSn zjO(h3<$<#Kw+7}eTvA$bdthw+3NdU+#njZpd$^F{vIms!jD#jrV zgX@|ZO!U%s5Qjq;%qKyXz!pa zkuxu(Sve-2eA`Gu?Mf+hYm*Bl}-Uy`htqmdCN-|1y-fTWrW27AAlO7S9S}%Q{QVDf|8lbhNfBXE6W=>XdIMy^F|D4 zNZ5FTA%DIe&sZTHy`(8D`yH7sN9jTqnqLoBIb4w=WHBF1@y%<4t?*CTFv76rDz241 zi1kVKVw!6TbM07#ss}rFUfo!X6s(QG>So1lW(ytWMV{XZ0|Rz-i|l4KhZ>kCo_$8U zIJh}{bEK)J8h)Ydp>w!cR;H8Mcp-yTp)l?(XSjy3Mrx>SiNKvp>Aj5|koWUDm3aV6)4%x~3Kj5tqGT_3Gwu zB(T9Huj-f+ta3I&5gF3j1f7(XLtS2}i_I41RC5)R1OY39h+>jmqlA%a{#q^@*ir#& zvq6g+L9pkrRd&kNjchKg$5J?)rmkw?@e~rrxq-&|O%AK76yB^~@n#XJIJ4~(-K&GX z!KIu@Oy{u_3)CUZV21Yed{?*eNz2+nUtrZpVQY;%aC6I_qJrnX0=Kd$J8rVsJkdc5 zbI8Qu3zk+aE?pj|n19Rsid*NG<&7G+xnOnCC1vHDJjI-k%ntgj)~Yl!^ZfmT5t*56 z9=s@M$ommnE!wo{+mm8(;!j+h@Ujs-%l5>1Im+bE?kJz%ZUPom30|wVpt>+e37&Am zB2D*%&8M^!j=uk#9DScYcch{;nUUVra|Oi|hjHn|@@#zGZRZ7%|(oEYyd` zg$?iI4cp>)UrpSwHFaK7Eu>P7*yQa-f@Q@;+F2a9gJ`tVID`>B)M2UbEOJBLS-imA zn{(Z$ku8EPg~#Ym;S>55OEE$@wOPV8_Ls`1GK#rSjyXEhI_Snh1ZcQWEGn}!}YT2+QNXcc_J5w`IhRD6%MP?vw!;j z*Wdm4(~syX^sRVlWW`vUufmo?B-peDt9NzS_O7KEM=`K z6vp5XvPs@7z_A@WYlelxK_5FC%o1EKW%H%PTOzL?U^!RecoDle;b5JL6SwM>;X)eA zrPWT46 zTIs#ZOLC|vpWU|;`KfIlv0bTDS5qe?YEpML5^&s^RLQ{mi{%58i+btM{}IrCqCihI z=xN311n3h^8dUsMFi|3<9R?Ilf3HF*_8|O|(#tB`5_o7AT{J1p|9ZPzI{Y8DJkR|> z!QZOkZ>>U%**d82OKg`b7rspbl*ON^fpLFN*xRuQwLOU3sseZInh&SW8?&c z|0dfaVJxI+ZINa$7W@}J7G#_g@K@VKT>H0*f5J3i)9|+qo(76l=r!|G&dEbV>+;%j z;IzQZP!hyR&}&lc`M+*%5(|@9o=azjhB>YJoG{EeawmQ7m0e%Q+`LlSnck@QB(=e+ zaBbZ2!P(E+mNnskI?vbEXs@;~=VKTSQyZFE6`#hM4IOZUMr{f=HQ=nNu{Dqw4VR~S z)gzqhCDqhdhc{~P2D$DN?r|mVzjz0PQoL-T9?}hp@h6EE4+srj#K4Ui&bO6Dbo9`}ZP+fC{}tmpCl-=C5wkxXZHE%guYN#-Sn z3cG%BaH~oEp7f60TYex>BAI^XsSk>Hllo;AeRBpe!PooJzZ6Y5lB_c^^sfc8^J`yD z<9Nr6+w-3@u}{Oj zY5B)}#`1UerRTnDg70fZz9}PS8q4P#?Va>mzHvO~xyo~Tk8%8|1HUh}n%Ey>$ty9J z0|NK^-|u_fWd8fc{@bQJWzv5R_jGpupw}3Qug!RA`fp71U)Qs!V3tYzTfMIrP4*k> z&!2U(?^%=ab>AUJ;VhH!U*>tI@OhK*cf7#XyTfGs=NvA{<4o+&uBmB#zch*e^5Y-n ze{52}k^6JHcFZ=e|5ttci^g1S9KW{jC$sN1X}{kc8R7qxiTyIxd?n^`K;VVmRWCg7 zcg8i@vFoipznx$lKVibd?oUkSC(FS{zVShivHa=+dLU@h|LoH~@B6OF_}kT0?tj_D zevI|MG0WjH&c9^Z3rAk)Gmd|~`)fT{n~a|u{FC!n9yON#iH9q)n~eW!4(;0akjeVw zJvtyh?ad>lo`+2AdLDmp1~xEV6pSChN!h_Pm#0 zb;uZlhi4v|{WBAJGCO>5rHOyKHQzaYugU!D@2Z_O!=(Rm-Pg{ZX)^xi9TxpZP3(*3 z@Z$$=HW{B=_I4c{V=}&~4;|`y#RQ+zJ%67czz<@Lu8(I1`ld`aiX~)j!B=K@ijCtp z7kz!gLnh^KnXsau%fvo@rSKlF)g<0H=PNOn0|J#j-=5jWe!y*Twf7#||AfaN670FT z^ZdUy>7Uo_y{h02llU*b{L1t+lX%;{oc&8o)~7d2`*n}S#J_mB_dB^&CjI-z`>&qT zV^Y7|{ZBpn8x#L#kMDQIb4fZk{v45`=JnRom zHI~mjG%WWjll58p{^`5_WU~G+)_f)AazLOx^M!(#$@qHL`_;YKM~rKFe?e~1ev|p< zH_tbB?=k6rQQxmTcbUk4tNWhayG+U_dY+p3eG~tJ>-~@Eg(l|{%MQnURulgt>*z7> z8I$t8;QRK}VJ7jP<=)v-Zeo8vIFMbi-b8+Nw5L16#C|k%`+BoY^#Anmj|%@_BL7ca zd0j_L%Ge>cF(ZqaNe@|4{gqi91c?z2jH*P?PcP>|L{Oi-~=z>b}+UoQZuIUi7{InVervp8if>#$jXmD>0Wh z0>_X1Z0dC;`%h0F8dr4Ni^dpy@?!ShL*F!x|9DpS;euS__zgYBdj*sJ{ms6o4hkmY z-?Oj5zsO{K3EgL2T5M9jZx?OserJYp`JVmOz4@z5=Ev6E?R#A&@z409d)y}WVfDfL z4#Z6MFWx`$*327B%Ael*xjSgGe|N3_k#4Jr{aVrUKYRY4$@#1?GFM_Q2L%3Dzv0Uc zROAyYA-;a+Y(S+B-I5ORt>SZx6`#rjPiSkniYkyi`u`+l+bOZHf?L=ZP~a5@$?wp8A0$_tbYRo+?7zr@m*&KJ|Yr z_EWE0$f-9h!-!jya)(3C$xCf-6X`81{zFRJT$FfVQyg;obfgz} zsq`cjeQ(l^x!)zE#D=#F?R+NEC^9Gu3d|8JFS)IBE6QodEzPmfIjn4uCuQHIQnsDi zbH2)No@*!M8m7a>lX8JRO9eSkC(=E~-NZ|-1*nt2yZw}R?{e_ZP&;KmpO;X_M^)V{ zFOJOT?bZi)3H2sxzr)+VIJcd}pFteUUrR%{@9@NLL;Rw4C*n#i^Lfb@6(t+wM#P$= zTtbPR(r!AJo*8cCg%QWeb{85WyI-^egIAAHzUvqB_>W)RjOQCXS$5oNUlv1)OQnQV zmJ*-W=PmHgD)9N;K3|F7U*ezbp5dM2^}0i~_*k5KeYlBz%fda!19@TCz z7A?AUvBI4$x4rQnX2PxHw}W0gy;&AX=tQcDP4W|<4Y+Pa*oEq9*8#0ldLrSC5#qsj z=IZd<72>IHh=e^iELl;|90`Wj5l?L`KF7d*^U++hhSasHzoC@BGU5q^!o-7@JR`v> zJX`D7b8@w(kQjavPXB2onY=OilZ#r(&h3eYKwlhub*ojqe?XL0ZH>*s?6zqP^(m^) zI8JI9jh}R(lgsi0r_%qcX0#wFwUU;xv~oBT=)|P#9p_Y+Ca1-@L0A@@X~09m9lx~t zYYeS^1LYSYL8`3gsLYM30w%}mju9+h?8h?fVs>x@p5jQ&4(h)N@ekD)N=$=#iSqMu zGlnpbFoI+zee>0UT5f%~d1RVt}fsTUi2OR_54VwI?jgNjTApaY1-=xU; ztWy7{jCT{?S!I>nzr)1w87t#WxEJ$L@G3tS_}G8zyl(KKe?Z>K(A}WjV<9&-Q>j}Ma6wo^xD#{~v;;Z|x*v25bT{MUJ3z#qsg`f4Dhpg3 zkR7J;JDY%y6FM?=UKG5_kHN^?SL&983;3PwzywbTeJTpNA9M_KH)wY;|&C)G_&Lv9S;+UY$mEeou%ea7+K+Cv* zj)9hOSqj;u3Ai9EA}r$qS^_QO0y+j-#^rj*T&I^osRM8kA^Sz0FXK`M9dmV_jLQws zHDBk+xZDVx3w2%`m&MSzNaxGA+^m!<376&2S)r$saajwUYg6C?z6idI3ups~~&3&X;ky6FR=6^JH9t&=t^m zGA`B78Pa)iT-HG6YMn3RvQ8;i5-#=7S*xd$acPIn_7u2)FM==Q0$Kts;{rMcTE=BF zWH%?^g0P6Nj0_sftGOr9Rn@n@>R%m=w*oGvK_Kt)A=$kJD{Ue=gGJ{23?QnJQ4VN*oiF3^4@$Wbafy8>U;!wQ%70{OZ|2)>$t2!KZIaa5Nv#4D(!ugl4FlktT3@|$l(vZ!<=uohZNWImyLsu?&rspRpA*sH=IuSe51{%Ybt83-BAlap?Gf=uK$pPooQIWw6*b0*BUeK?Uw zV_u*_`mcb^zhNQ$OX?(3QM7)s%|4NE*GwiEWDPG(|DcJ3Y#Vf+!qZ76^O>I$(q^B4 z)xBH%=nVUH@i^Y|4MNE1s7OY~#6K~dz$077Q-U(L*`|@Ug}EsIL_QN9Uq|@|a=lnu zLOlF4V>!|a8EN5eN-}bJOUX2Hg5;uZlj*RM{~%WIq|cwHmb~e5S=;7JL}q!wC?B_% z&-@Y>*LQ;W__lis3F+uVznR59(A5I0O8B;$W}>|hp`QOq+lG0_=C0X<+gnW9zlOJf zkNpFE@Q1BL$}ZtIA1o#hywyzTtB8N4n6w=zLfiB}?o%2ge!gv-5A|q)Em|CDD=8+@ z-ae{rvJlcQ5`4v-J>Ay>Hat1{8Rf3`6SvBz(PpfZ`bAL(3JRKspH>{7#`8qQL~m1=l2!V-;&PTxh&!YwK!mW5MD`@MKeWK#ouaML=pkB{HH zt&p_kPLlgB>$hol)RFd0GYI!3miJ_y_1z2-<~fu_x_v*|?IW~XKHrv;hjy#u$sLO$ zBqN`Z9^*WlG|`=;X1jk+-kT(4U1}cn;NcE!`s!c zq=q}ya8wQNP{T1b{HPl4SHq91VG>mG->HV(YWPVtEUMvMYPeDjKZ7u9JJxRG8fDyW z-wD%&v9%;k8zWJS?+JutOymz*Il{(uHrfeGY>{M)=QqDQ9=~FC!$t&QYj#oL^e-^x zeh!>|0sG*V?E}Y;BW<(C1Lq)5PXD|~Yd2;GxOa1)8_$QF8w^~byQZ~9%Noi$Vi4MH$b23@Dwt`)8=$|L?4Rw3zPVRmbqPg-Nv58PbytNAZTO093IY2{$aG&(WqCRIL>$LN7+-QcC#T+`~$WZ0$EZ=k%<=9;AD zKa|b2H0a-o4CD`f40S{8WCD^RDyMl);YE27*%u<)mR0<;ZAWdhSP5-0>6IvHNX8o^;XekX23% zY?Slc$ia|y(vbROf{;vpJJKjJI4M{rpdOSzC=#ONlYx9Pc`5i773cntc7#q6a-*GY zN4z4VoKBE2`7mq_o{Y)C6AoWZ7Dh6{*o1R&i?K16Nypq`tt1--nnB;B((N5EhHoOR zi*NIxrMO%8QZF{cwrq)boj35(2uz;dhw#LyOUC!0CyU?aE(>vF1y3lPWYpa;<;D;t zw_(G>!Ec+gD0GTBd>K*E;=Ucvm`J35;k^@1e8QYSvc%3y=~0nzcR+R=zin!Hh~vJD z(*`(pFvyN+cZMhjLtHVDcjQ)un4VLZ2$>#>JIoWuRQ}{brVCC~I&{+-$mQ{lJd}y9 z1D(%r%R_l+9q*X2NS2q&knKU7k3TsBvV}doWRHrQ?}dYmvN+*%N%okSL+<9C`DN?L z#yQkYO0X$82ae`FJeN5Ko08w=iQ6APQIR~Sn}~Z3&uw=SX-jW};l#?raM~7~A9C!z zDYR|(VkQ5zKDqqf37>w#^4qp&VaV}fAmn%{6ffV&g)A;6I`+*EZQFNK$gy9QMZHhX zKpe`;+K~8X$z_$YAK;z4Hmv9Fw!-1J$|qB2hNIn>!*4@5UVU9X@v-8>r;K(?I6=0% zNy2H56{kG&aoY0$o;W#jp0L`Tf7uh$y(m7$g#0)0&%lZ4I?ak-s)>nNC-bkwU%Xif z6)nZBRVlIFqH6cr5}BX}8fF$|w`RDtIj>iQ|7IsWTt}GZE zOtx90`OY%LyR&h+foCS-9S**m(bj&Ev`~qR&f=YYc;-MRdo=H6^mx$YM)O|KZZ~N6 zXnCid^jW1!ghhm%prfEA&~DKEpktuDpxqN7GeMKV4mrXi!p=UM6a_7Tc7yH*9RuwJ z%_cnp!>ZMn?WSraHA+z zbW9H>QZ{*xS{7*=q3o1!NZV$G6eeLd0dpfAq|F5fpl?BhbB!i#BVqLV$|FuP8u?

{bSs7xyRdP-$(QLC#x1{zHk zW;bkJ(j}^)I9_TP-8bnnnlv10{tJ>)>ywsN4!ELn{sWScid2Ewl%uY(oF8&p~4t1$sh1#lh^h;&u z&i21!UL+G&iZPKeMC1{|{%lq4_=Ua-6~{@{JXHI3QpNXt?OUcc55;Sw5L&ZoY&VEh z_D!q5mhN2bn^VpIZ1(NScZ)TI*>lhvi)v(?Z_uiVI?jpKzfqa+t=22wEgp#Bp8@|X t-z^>hdK_~j{-0$mT={PC0Qhwqb?NZG^4;Qrpey0{khpZvpY6NF{~yJK!y*6x literal 0 HcmV?d00001 diff --git a/python_module/test/unit/module/test_external.py b/python_module/test/unit/module/test_external.py index 3a4e6d7f..44f5cf21 100644 --- a/python_module/test/unit/module/test_external.py +++ b/python_module/test/unit/module/test_external.py @@ -13,10 +13,10 @@ import numpy as np import megengine as mge from megengine import tensor from megengine.module import Module -from megengine.module.external import CambriconSubgraph +from megengine.module.external import AtlasSubgraph, CambriconSubgraph -class MyModule(Module): +class CambriconModule(Module): def __init__(self, data): super().__init__() self.cambricon = CambriconSubgraph(data, "subnet0", True) @@ -31,7 +31,7 @@ def test_cambricon_module(): model = os.path.join(os.path.dirname(__file__), model) with open(model, "rb") as f: data = f.read() - m = MyModule(data) + m = CambriconModule(data) inputs = [] inputs.append(tensor(dtype=np.float16, device="cambricon0")) inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) @@ -41,3 +41,30 @@ def test_cambricon_module(): return pred pred = inference(inputs) + + +class AtlasModule(Module): + def __init__(self, data): + super().__init__() + self.atlas = AtlasSubgraph(data) + + def forward(self, inputs): + out = self.atlas(inputs) + return out + + +def test_atlas_module(): + model = "AtlasRuntimeOprTest.basic.om" + model = os.path.join(os.path.dirname(__file__), model) + with open(model, "rb") as f: + data = f.read() + m = AtlasModule(data) + inputs = [] + inputs.append(tensor(dtype=np.float32, device="atlas0")) + inputs[0].set_value(np.random.normal(size=(4, 3, 16, 16)).astype(np.float32)) + + def inference(inps): + pred = m(inps) + return pred + + pred = inference(inputs)