From 48282b86c88c85cf7d7d8cfdd5fee9cbeb41aa25 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 21 Feb 2019 23:47:00 -0600 Subject: [PATCH] fixed #175 --- data/linear_regression.zip | Bin 3423 -> 20774 bytes .../Gradients/math_grad.py.cs | 27 ++++--- src/TensorFlowNET.Core/Graphs/Graph.cs | 4 +- .../Operations/Operation.Output.cs | 5 +- .../Operations/Operation.cs | 9 +-- .../TensorFlowNET.Core.csproj | 6 -- src/TensorFlowNET.Core/Tensors/Tensor.cs | 13 ++- src/TensorFlowNET.Core/ops.py.cs | 2 +- .../LinearRegression.cs | 76 ++++++++---------- .../TensorFlowNET.Examples.csproj | 1 - test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 17 ++++ 11 files changed, 83 insertions(+), 77 deletions(-) diff --git a/data/linear_regression.zip b/data/linear_regression.zip index b88caeb563075c8d2c5dbbeda48b1f6c99023aa5..50415d840a195126de18a099fc617c8245b0ea11 100644 GIT binary patch literal 20774 zcmaf(V~{0V)TWDFwr$(CZQHhO+qP|6T{gOG+g+}ydq2!f%*4!}b0RV#&)KneW;|;> z>y?)R20;b@000NTsa26ZiaKqdg$4i^7Y6{q{JCmvVQXUGsOM;6=4j&NWMOAZYh&VU zKx=4WJFP7fjV*?CC^Kh2=-zyk?VMF}qMBP;Kv4EA{p*|;8ynDV_f8;5Uy5EcNA zU96m?%J#=kdJQ@rnMlcCMr-{P9|pu{71XWj3KcsT^=g)}qNFCP*4G=Jjyais$IiEe z_V_r}m#Mgh2|=aJ%ns<3>pEgP+3_#0|F%k4-{@pvEvLK>DaVZ2qUGAIW{Q!4o;*}; zs_PQ{(@_!!O^P$VROkjDcwzt8$g}j$v*jR!rrr1#dp|R=N#v>ouayGmeqbAtZ%!{$ zEW=lOozksR-Xd+n#U1-N4;L02t&V?rbTj+Li(CyTsJYS#I=f&?8|XWn4%AiuEMq9Z zSO3{nd+};(A=O|2`*AvV&pKJ_G_w^P#gOoc=T&r5kB9z}r$}yq8i+%e73|uh8O04| zGkBt@)Wia_V};v3lb;yVtt)rEJq!z)CIQvy3K+hR(B?Wx-fIOD_R4H(1JucWSxWFl z8h`g(K6f~wiVL8{wMDy(a2GG4s2IhMq171O!_C^9XN5`x&mH zS~~EKX9m@-0y!VsCX~t^zX~LFi0I`THMREgO(xg@uu0Qt%F)|j2}54MqbegS!mx!& zI{w$p1kp}homTL<+SfO5Rm3T)p%bxNW@g>`$6N<1^CuXtoB{7a*0}y|zV#0_?XzT7 zcyG)JI_MkmhBE$q7QA5Chd)mPxdjL`u~i$kO^!pCJ0*HF5Bf`e*1jnqYFNw&-Xh8ByqAKTXhA{S!uPh&3M0!`hxHy-59t&X$YU&Fkx z=f4MhQu{x0E4xm4F(He)+n9B8WjyQ%KJEvxqhGOCuZ#L^7UiD*GGQ4bn2{1I=(R&H zqOD0qkf6BT5T$$&RPuZbA9)qk3VRz^b(~HR(eH(r9c9Dl z%tQJVMrAT-;#*A`cHfyMR%;7u&fub*WlutD;WT|PhmQt7>jJ@qaVKP#DEL58-e8v zFp4mn91m9rOwFwaY87z3(9SY;1V%WeUf+sUDqmn>J^DXM!%D=7bT~`RNrj6&DCCN<18qc0oZ-Er_mr<$#^RS+H zHz-OF9eV*)bM+Qdt!Gb01%R67{Ga!NHPzEDbJ>t@j>I!oyNCOSuY(a6gWY7Z5AUMb zqme#L2XC5Yih^neB|_)ISA2VTmID3z$^$TcQTcSaL-loWu~X$D_|pQlL>v`=@~_v8 zitM&$W+K*2P}uMnn(099o9j(Ra1X_0X!!>-sJsS1R@XM~ut4$FEx~JCJ0u5%CXVi?~g}HOC z7dYzDUONy%?ZmfmWZzchC+5`6gSP^}np6_enMycLAt_u-WVygGSQ#?yX&1a}`h*o* zs!L6G2hk#nEzP_pGCzot6X1Q(kY!Iw$}2=7*WxsztX<_s1E(UqIye?%9Z4~U2|4bu zi0SMPfDN6@*W=JjzM4m8%H-GgEsdwdGh|7aPRWNN4Icv)Ey6y20iy&QYj`d#gvt5l~{B;GNd1c;I?IT1~`| z5eOA2X$#52S4XQTxNzM*$4M;=cjtl-I1Uj@N0S|*xH2Au>UMSq8(9(QKjZ0|?B{YL8NJ4jFC6Mq-y%OPeRKZsW1x01h7ZdRxK6rz#2bF1 zXDygQ7iZr{3?$-aIifZl@9WSF$2n|EdfhTHsQEitUCcVRCXjBogqDDYdiu-7qNw;d zR`9o9nw9R_fc{?=P2*SN#B(BtSD3$nLox>0SajJI$^##i#)Dztvf!}2qHD3uw3r)r zvt!}Vc;4A_egx0HmWXW_ImJs-$6->eQt@?Ty)?9Z+o}a^2IS6$<^rn0$SXk>BN10* z3lJ!Of)j0-j@4FzIS;ldvgXwG0ry*`3{x(EK~6yLp=#_Uw)~m$<=PsbBIo`NZ>(7j zS3;F;xx-W5)<}!}_quffOB>WdI-(5Wd%M-qW_ixKPWAP#DCjN+>UkFF`9C) zZnRA6^3Cfp7Kaahev8XDkYda`%?cJNU%WxtA!pb)J4O|Bt7?I9d!_)=S4_KhPrh=m zuucFH!0c5<D&J_F7b_frFn{YX~<11I7;iX!o|{2H_;b-1QRD=upqBB zgt3n{_FW?V^Cc^6g$AT2;zv@ne6}Aky0W7#k0katy#wS&lJLx(+{$>fijuVcCgr?If(alb30$=^qkj zptHc4Kc?gR&Cq_raaY_I>ZCg&j~!G)+q z3JALbNmE&75`#fxa)OBLqTpr?2uX2s#ggFUg(kw1B;+}3N%3r$IT4yn zYe~ZJI21gE;(I9Y054CA(BXcoEFK2>*Ca$>ICwMuyjb+>5s!cf)3`ij489ONxz|$q zFSQ~Yu%z*T!4&d3h#X=r7EOrCX>-MrE)*qZB9dluE)qQum6I)^ zqAnJEC@toUl}N*8XtKHxv0&e(#*2*3Uiz@lovqiV+|dwz91k2VIc0tX=42_+V%ts8 zdq#DIer|*xWaO6#tJbUS>go45S)Qa!4;?FhgsX|_mwZ%}wsqgw6`d^|$jk*T3t@tN zFjNv!AB_z@P32IxB5NcV9t+gVvdlXo^~dX-9Qm_vu6AH>pPcJh;$Ap_5*m8^W>N=v z&A+7@R`Q#cc}YFzI-8!@17VdZYzYaz+UacT!PLf=#TMk=7p$_%@{a>9V8zxkt4pz0 zJ^%yXzA;X}f#m~epOMn}J}w0`v3ny8XX~``+nv8R*)YAdUrpVhor>yP z57UJ&<=g5SA5~Rkr;{Q+O}&2Q#t%1qON!OsjW0*Oeal{IxVls{lV+w>@^7L@iw(D< zmA81RcU6rq2i`4I(t`@_N=@F+f3U&HM5!=7ja7QmuNDf9dC z>Tn~!GjA%jJZl%eR@lk94s6`$1gZw1`%)~$sriiBloWmS{50))t9NGpZ_cyL{z@*z`YX*z+Q|W zE5F!zDqE!G{z{CPQ*;>-m^fmfMvkZJHKZ zD=uiUB3tSO%82Nc?L0-ugjCNpQPOTJZ%VJJH(OEnB2RYpJ;$Y99QktSZ7szaY(zJf zII>Uv##sTfQEHSZGHA^(q}|MkShqyiiNJmyL|kd;9$lWXpo=eo&!Ll~aYR9hbx}^I zqTr4kH=!*xsrftN+O}Ce>F^Fnn(>-mwow`C($L!McKGVi`=(Bt1oAYxY}(70c*XVI zyY$3DCbGT+OiX?t;a&9_T?hvT#@YxZx|epS=!iJZlK)wHRIlAc$VpVlQS z+{VWK$*bp{=fvNKE5~ckg`KpE&m<}~9o}%+)@qygUr!;DOsaOGVFZXv^93IV-2@P( z#-hQ}(Y2QPhRALHQl#jJ0U~9Z-gSn%;E?=tr3(t+fn4>t!mzgklL{Y>#xQR-4fKMS zWR~xa(!dU%V|Nn-qn39$bbn6whZLbx_C-Qu;rs0hI&xhcG8&SCa3q6C<1n&^{gcG# z+`5joG<_VJ$0`X8B1I&_$W_@2MJd7ef)$w+*44!#h2e(8V~x`vU@@4l z=pxd99B^GnrY)V`EjoIwWMWm)v*8ao){f0!8B&WpI?`{VvRq>x!HGSbx-xHEYb59) zj*BDbLd9<3O)L5%g)MCKyu(uEuJo;4bo7cwW^Yw<`A9HFwHRdjsQpBf`3NQ|d~4Dz zEs&-NK~#GOLPs2+^8@4xKiGfG=~Kbd!7iU0_-_vxbh;0vW|S@~mxutYP-##rlvU5L zGXlpdyR-YJYqUPi2E3|^;pdbO-^n19VfNInw5HQ7nQ?Mx+M$kk_HFt_Ikx1pNTkG~ zE@%R)N!^Gp#injW+_q~rMt_#C=vPg~8Ln#TL2F^Oipo?euU4*_hSvD39^w|Ln6M+I zj8Q@ds^FzQ=xi7B6^o1~vmMPYr|bsv$-XRBYzOz@<))8_Q)S<(-8l^Ep*4MaZ;`%+ z!R=6x8vDPWwni9mFRaowa|>X zI)Nt=hrd8%pR#%TCxT0`&yWlXNYoPZwa`dRXYYr(5*8wTNCzuL1tz^i8@@-Ok^MD6 zAR_zgLLjkCQd6hvDp41~8!BC|qg9ij_T4RC9d2Hw4m@+0>}rdRU*y6{Daj(-VMPRnbCiS;fI5=}={{s%qfP zF=kDLU*H5#YbKmdRbG*1|!D1NS_m2I8w!E(N3nPGm$PDz|NfCV|(O{bB+WQcd z>@mZ@y~xU+wTPmSm(EQhA%F&wO(||Nv2}YY93W{VL-Vc@*;z42-jA)lga744#?`gg z(y{kyOeIp(Y0B7D#Y)eB9@*^P)(UhP@R=#Gc#S4=A*kUMUMs))kNtUbT9Jh zZXj2EeV~Zq@3t6G=BBV&BA2{`YWr^1VN=34aEXkgK={l6o4c>~E0B{ENIGK(CTJS)~h&nZy-)PMDL*8wBdq zZ;LO_B5U!5d)rU~9`J2Qs5M9?g{g{_5WA2yx35ErzLT*ogJqNIsn4s>oVqB2?^^=n z5nX6|wo7AmbW3Ld7E<&%cdaS4I;l+Sc{mpNd;_*6084!SAxk}Z1QMcaVr2lLsHAy_ zQK5Cs>cmYCSLMrnVYA${1NMNQZ7`M7@R?o~(^~Ah@@}BcDc&NQjF58JnLZ{K({&ShDf+fSTi-5V&!PtsJO&UL83hcaO^s);E}%fQ z@CO7MGA{wlYFam6sN792Yw;mAeEP(G}PymLn zLf_f37j^`Rbs-jFjNj9VYY&-DuU;9!Ti#pa_g%UzNSF@6QY& zFo28Ite5kJcw0fGfQ2r_SA~d(WIHotHH9TjE*b^h)FrM|=`{_Ths_~Qb=il^$J4Vt zta*w`N)fT~G$lX5X%9%e><RPj2_g#B8zMYHvHn*=x(LDz_j%0?5DD6MqJ+b`ku8v+7?qAod|ZB#^!9d7}mBa2VGdyOh8+ z?AiGjCpj!*aly|bWAafj0X$=l=s+Q4vV&nN6fzTLptUo9d$Oum&&_Waka;^Tj+QD5 z3xJp7>R=6?svu(P?@4a85zwoGKys;E^jXAr1zJDZ4a5j^seyJsdv=c`_6ABOAUGX> zYW!e#BJt3)R;-+QL`Nkg+k(Xhr|S#Eh<92$O2SCDSl5??1^}jG^BxGSNU=P&SD{#* zyBq!pdv-UZsrcW=r8baasq`>Ulv8pmI35@ARa$iukn>|(yy0$PJDoFQw&Mo*9by@; zwZ4q8jjit-5^dvK^-jAPk>bl&&m|p7|Eh!T+LsTi#o0cVf-m?elG+{lMX-ra&h5Y` zyYs|B6+g1!-N_)W^mXT^JY9nb)*cCOZsrsYn z90Rkqr9`_$+~{OaIxKRaAe}zXaEE^bJlm6%slQU^7NT1+W?lFj5?{ISZwK-=NwC8z zd!_R`tUBs99|0WmkAjNAqT`1II721I)8mw%)&oMukWYzXPp#Bf&_WR?F2J*Ug8>2U z6=09#^AFgQ!`o9s5<|uTCHOz_JMAtA>433<{=&=y%J^N*|GcZbOVD8H@@cb-Z@>rN zh1T`yk}L(^rJMh0k#4K672{C0SF@^#6N(55HGK+tYdm|w?5!Nt_-sRf1gm`G#|eetCFKi$jIp61}ff+ zxXcH0d+Ds_KXf#>B%^N&JtwbL+F1yOn9u>P5(I`Q_f%Rm>=>F~Ik<(9*#gw~xu@}< zDvTAray>j?3sO)Og>Z$zPo*^!ZO-+xR(b&84P=If0xPUV)odP28Df%0{@B&PF(Pf0hl~Bu>LGWGKc-ZTeQDm$o<7QdG25QVW{X)rPB*qbRcB zxg{;O+d?Gp^t4KS(?`MMKm5}2iIsv;{fZ>#qh&$2IPO;{TmID(HdJ}}Ip|TktR>|g zE6%8-X!!AdT%n)EWVp{{NDnh&#-4sgK*)lBDJg>B%nA*!1rFI|Cnmqfz+VVFq@ck_ zz|vUE9UM9_rLlk8eMXJBYai2e-bj$_1P8=^YXDYUZvtwu*n|bh-Rz$O3_oVTg=9{> z)WHUbeMk>NiPJia>7>jbX5g~MI!Qz$aH9c`u-XQIi6IO|i6Rb0h$9WwNRLlqR3Z5c%u^cUl>rp|D*W3Xm`Y^YX% zWSDk^n!q6Ic|xL_gij{>mBQS=_YUh`sX7qO#Bw`0JR{t;!Q4${p+*M zd1-kMR({eu+j0ssKTd{(JddUS{XJ1dOln@aLfAEcO4_S-k2{+A@htrZ)a8nB7VfX@ zf+EPx0G%M7Xy8vn11et0)gko!%FDq3PxKCXax4!C8{@+SPK(ZIu-(-meYheLxS%w= z6Zk9N*9c6v=D30{i<>Zg3^RVArkQh1(i@A_n#x2s0 zD!uz|c<%{!_afa*lpGziPj#a0qxYVsC-W%D#vs&8?qIp;&`p_o3{FJvN|W_YSp`0o zFv=8}J!OpP_Ko0}g}^ZiMB9tRPm8VV_cR&?x(L!jRku*e{#vDeE?L1M)o~=jGGV2X zFcrxI{>0CcO2F+il-K)+6u78E@6N|f^`z7u9`XeCKTT$O6U&{StyisCWUa=Dgi?}x zCsc{4T0E^RQ{VG$kjJ|@EbgYBG&ykKozSkL(n0Q-md1@-$tcH3P7|fL&P&d0XWF?f z$;KZ~zHFplf6l-X-S|CXo;B~DAw5Xs9j}EZ!z6u?&J4{e-QyN$#))Pt^gfx?KDBcUWZevx`&Tk)-@3JG?C1mQck8MR zZfo1f?8DO_N%%kRs@GNg{_m0J_u+(08eCx#&FW&W(2HOR29$3PfGgkPT1je11fjGR z^;(H&mO)yljfN;MQrFVSwIWq{rsHJj(L9CtJ7Eu}4|z-PUixk)P2{p~Ew8ESpR2jf z>bZmqQ)jAG7o$HiWv!>38pm)?*)8pA?sqo@3?E?4Lena`wJVh< zJeC0wj~IHjXnP5ZntY&#=@n5)BjXK>FXqAwipNMN$IJ}TJ+A2$DO)uZalnp6h&z?s z152jpG+QWjT@?&Vw7;>t4?GgymM)$%@34+yzuL8HaF^vQ-;1YN5mSDX_+7k*tJej# zA$VXZl-;uu+7Ycdqx3$o8Y)3JevG!We+}K`ng6pdqkc0yUq$-o538cd{O^Q1oIb=Y z*M(wTHs=rH2k2CR*1X$#LOd)%@$(8o+RRn#?%%gBi4}fo@(u_jRh2j7-*Itkw=@y7 z!r)OLc~1+<(a_ZwQ2{K#R{Cc8HOX=4>QXVx8|pcC7^sAk&mn_I)ZVVFj+|J)_xN;* zIi6b&Yp~v*HA^e^uC034cHfF2_Kx%}uxkd{WfCofIONjDPAbe3evrC?q~s)5uq zrrt3CW}NBfbXe^(i4i>Oy4nqTMa5?OBg$tce+M;_w~@jgR}*UJT#eTfKwVvpQ|=MC z8WjVku{C6ekTmSuepcYcd*4hqK6KvScq8(Bd1{8>^XW9kfHUo6n$uSGP`*1u_fR)E z3zvuIqXZgHv8>ffv!*rsD9VYFeL5Yn&6$GK!!h4gTBVCl=E;_Ve^P{w7Iw!TjYjix zX#4cr9btFK!aleW+ho~33D?=Y@tSfy;tp#|V%!|oEHbW%hm8fWSu&FsjRde*q7d6w zWNjBYhBP>0;}x($*jIHC{!-dB#^+)uH{z z0>t3`hwSH#utMy{u+qP9KQdZ@#|ibP@`GLbsr-!C-w7R0E>tt>n;YcjlC+*?FdxJO z-wC<(W|hC@8JVFnoh(F$)wRP+h3$n5oKa9zRouM|Pp*L5IaybL2Ta8_(o3(JvQ>7F zC|M?voH0TBw`tTeNjMAW@W+l1^3xPF_GyFSA)L=@A~78_MtIwg1RdB9D3ob2FqB<_ z!JUt~-#_Zt>7=Z-c}@gFVirf`f-Y}QpQr%f=&Fjhk@yL6-UX2{ZWWjjv=FE<+7Uc^ImN)3ofH~S0+ z%OIe8xyYEVf+(N1SdyWC+I7t#JG#W|yF&$!nW>dK^L7~y#w#@_1^IexH0u!oh-;=0 z92BW-OnW}B8o+Mi3RV~Hs7&d#V$DZ^4AqT$Q4Ba48oW?z+D zB703F(BWf9W~d%fYh>U2G=sZ$Khh&W$z@+_^01`+`0)Yi!Wm}M_28Ycar{q0+3xMM zQ*mbJo%FV$E(lz3Mf&^KdqOi!nlY{D+bpe54u(Tr3- zsGbK09U=3~dYyv^lra^hz}2V3@(}L3F#*6IPl>jDiy(P&fEu|WnWrIUj9F#jNEB68F$M>V*&1=@J>`xdbS$urer0F(aCDds#qie^rnmR|~ zel^@lHY=P+Wknw`ih7Q}UbpMZd8{>tk6#(5mN`AlrPNs2t@6Q#8}3GWEa$vE?s&c)z-&mbH@+p2=Dn=2 zy-x%Cx8PsY?H`E<^(|e3kra2AAm=%wG?5xr*8?8EGiwmE7CWHq9!N4pXR3@=3nPdj z&{^&WixE4c*w;=!*m1OEhKz2rm+w<$GrqYk@a|kz_^X@(>`Ey;O+WCNM2(TygdQaf z)Oq@r;|+o^+0PlY%zXe(?-nK3{#BJNnbhEyT;%0Np3pBjZ)e9%Z^v`x`PqPTAbCEU zsMp@r*89+EYpuq%uT<)Vid{C7R~d~TQ?2yO&_^+8H=&x=Q2LP7NPs8}dK{@xkUv^4 zKRLb>4c>;Xb2JGT1^JFm!Z$!KC6kmxT!4z)wA$QV4mv=MUNWG?iHIvPV|+Q}klVU4 zSd2f?zD;6rEI`5Cz5Tag8*Va`%yLP2SfjjoZPOOrzq7 zZZ=*gN@-AG*A})<4Qja7riB-c^#djonP7s?+h0s$4&&Cj*A{LV9V^N-)7ImGO!*>W zai`+WHC_(2)*cR^959Z)-#Tf&|(Q;=vIZjpJY5QI)U99 zq0JN%sZW!i-5!BX^fheuQG0yZjZ?Z}m!poir*-jx7JeT5(*NPdek0?gMH$^Mf=&!Nw zsiEo~SJ;0&ghU?i{35gWI~LPp0_KLb7hERj5gV*$;kbgGtH1!k0iaAWQh}!GO9o`FXnm1JR%Cq^A zgW0>UXdG-l7u^WKV2Cd@X=Lt!Y+ewC*AKBK;_)TPTmzK+`xMcxaSxg$`=8&dP}wGN zI52%VT2Y5*f>(ze3p7*Py>zr>hD^onD;4Bll+F_&!nfSO6BR13T$C#c@$PeTz0|oU z8B+SC>d?;r2gVPVMj-e;t~MWq?u#L+ZxD%gZzwXE^q>hyNi)U4up;SIO6jkLEJab3 zw^$2aRE0l<#<`1#lq3d#`3&>BF1Vrl$^ z@TBk%`LN7j3*1XC`7-+X2xRdVOAy9z_aL~4KB|9qRACUufE&CTvUYLJ&FEZv_}>@l zf43t65L<6m(R!hF-~b_Y@uQ&m0{+4t0*Dx9;3sK-mbZWmvG5Py&pa0tWe>3P5AA~C zC)5VZFEB6boVq@6R@yRWx>{X4v2HeMD-yG-idxAR3Q%}c`eweZU&I20O%qdO3-BP} z7vf1ekHi!jX3H+nyd~e~*!GTA--c0EY3B;4R+VnSo3q>w9fQ%jf(2Kd?5m&+tu^y6 zu>NFpzT&_16qub`hV2$0VYMcun6Ag@c32+(E|RBF=R*V4t2vL>4PXIjVa5_gKNNhmx zd&buhMTmM#wvg?&LdzAJeUJzVp9sDE`q#_iS(XwGH%RB+_L!lkCc$Nf+Bci;>F~@Y?MU!-)To0$<1_^M>PYw$A~E9s82!#T0b@^$-w*?aDt=adBg3y;kHOs?bq6 zlp^LRxs9ej)Z3*weo8vm3cDB}5=Q;hg5J_M*T3vvUDnhtD!5*N-g3J)vJE2lh* zrSpAfQG8Ynmx2-O7EJ~1!#JlK!CZIxR}T1PtFV?W$Tr46>|5RPNbPWdZ^H{R9?_*F zw7iPz9)J!6cLX41qr%N-Ig!ggi&s6ohcg$Y8<0(8PXuarg!YhAvJ{+ylyfagJX0*k zdY1fma~oSdH+6nNhq?<5C0@olCsm0oyc8e9^+!iNL+ZVn_7c{ywvBFaB?IRwj&>wZ8ff? z1D}?E4cTrlW?1VLXF$PKz)j`R{AKy`61zxUlS~VDQPLlrkWK5Y2u!E!!m@-JD9KfjG1 zyHiOUkC@dD34eKIrJ_$e!?lb|NTz=Pa?}X4PB{D$!TYpMNeU+aHXot1N+PylF?**T zmcMeUlwA`Hv*el(>|3`zj@o#&{a#FXcJyk?(tpN0NoIpL>+`9Ru)02>!3Z3SVD& zfAy|wpRer+y2`#AM3>%5p@wC1xg5f-`=BZwL?ZKQmBT33`?dy*AxEmgz!X4pT=6m5 zDx2cgMmw(B>7gebyz~vCS=IhRg2g$Kgzhxxg(Xu>0prbS)J}$ zJ_&PT_fKv~9N7Ag5?560FaGRLi5vg-r^IDy!3+1MAZYoS(Erf|8RHH8v+nt)mid-x zP(N$%0Eu&GnwLP9NsUNvchQ%GPI{3jTT9kQRAA)|1{%6m)6Ps-P_3)?z$IB59@V2H z*mrFh%7nsLnqJ^NkqV=Pu{SaZtkbj{je`@OfQVZ6c0>K&Qr9urn|U|Fy^URZgD6{0 zHb+zAl@`Fo{h~Rf951D9S#c9?b`4Yf@Fq`#)s^EJ4HAa^`DZeQ{XR*_|0*^)w$$a> z-~q#2_cLQW-459yaoFB|yeRzQFn)>b6`F0L5-C!c@6(IodICJ6s5I%Y=7#M&cS&hd zY`K^2SFuvgmK9==^r>I&ANX0LC8N0I1PU{r3#KsMBI$7M zN=nNnU3e0IB|XvKe^zKY6)8Tqqk&; zI}9Ws^&kFg8IAq$A24;c`iK9eH5>;hQs%SnKm5lgW%y72NBr=gLm(X|qKf?IjT<+<*LUgJif zF;~c#2W;#upX@gNwyE>0)slVo`AXn&)gb@mKe0dBj`m;JA1O%{KtT)qf&K`Kg{sevp9gTNwSKp{tv|bf?f$1kr;x}GUh=sgxR}eqv3nof97Wp#x1>mTX5zr zeY^=st(pZ!k@K*4YG}ZJnE2&eeKFcy@m+iJ<#r>38w((@Sdx-Wm!P-Fx8M!z%TcTG zLTeKlO;6XIvG2I;cqQ#`HNTgLS|w$JsntY-XKY;TUGi593x*EHsE}51`?C_8^{!L? zQY1bDYSWpM<71xDT;%${^CS?3x7rB`6YB=s!KrT`H->U1dUt_n@@0nxbYK)%ptH4A zz1seX0=Jhq5Ch~Bg{M2$%*?9J8}A~ROyRI!=z$acAOzMg zzKWK=vN&Vd(7-q@{Ubw?>y?ocG2U7_Y^)9w;_`_;_=yWY_=$vO_^jmo+?)6@phgG+ zts#_vaAH(c(_W_TKSl1{DUFa%$DS)KKjVtJbch9*!4t;-Hi6$U^t@yBA8j0zXB9qk z_{(0qD62iLIu~E?FPG1cU&IR9wq6Uza5%RWcTSOAY=+c7*bc*I|AS^KTaZjlVTM`x zo0iVi1ok;P)SP{zV4R45NtQ_@P)sSn!g~6N?p!R}tQ+@C5NErVqOVp6qel4W<*-GJ zfKAJ*P!#7%y4-KQ3rignb#6|%x)U`eRz+vqo6bb3b}j13yV5u->u1!4s!)k9%V0^f z5zGFk{N0=DvABpD!)>LDBio7W0)#l)VMfOYVI8m{gTo&VE%Ji>B-ZKLKPUdE<8X5% z{PL_tX9B}hA?kvk@$m(nte6{6*D+T2qXv9tUg$UF$thsN-7_ZYmxYm2jMF1AhjU+O z+Y0QHyOoXq&&V;XwdFU5r+b=DXxk4~J!^3B<4iq_V%S#JSJfw(I6bx$#->+t${9vE zJTcd9gzaN-xe$j0%U;k8#0>wv|5*Iz+O|pFFPsXbnm-u0>wGFP3p8Dbmy*qUl82ai z-OuEQrdD!G6DJn&>v6%70I5d*}d*-6?%aFOl`o!k`KSwT6y#0p$FZ*B# z%1PDOF6psSxp^m5aV0Ip<$SfQ4?HryJsvOhl;&?)Ek*a%eLA%_*pD02CkyYd=hk1Z zN|_&)O&T26TN@u6hHW0J+o*iAf1v&-_2${7IsHlc>ear<)9yIC7jSN@c!dR6gi0~} z>Cr^4e1GChhsu}8Ar8O5;}nAs%s}ItDTRKE{V(C9o`1thnm^$rnBf94Ju5`1pKy`^ zmE2FPSMw*F#1-nSKQfT|(W6)Mcjv z&N!QCfZ&XS7jPcNerK{`0}Yg4(RaBO%tocpLGg4ZX?I(WcP(gujZu3o zt=lbBFdxzOWQu!t5IK9MdGKW!RS9&L7Jcv-rDq5;G-bR+_*S1^P`|Wl+>8hds&sbl z80Bffqq^sW`pyqR8j%@E)AN0Pq@^N;-pBy34%1Q;GfsFs0xI2$4fPA8T)Q`iI4|a{ zaQ7Nc=@r5ZS($828lg5orS?epa7U;DAE8BUO*>vDJzMhF4qwydh2te1GN%1yUBn~% z^MAx_5-QxD3?DGfby0Bh=6ay|r1rae&i;t}sGD1qTd>*M{1@3n{SV9iN$uaI~Nnc35;e!Jop!JPPgG$YFE8?lxHJrJk^4{H3~o>9y3_ zbV$+YgWP{!rADcUU7$>%fd8V{8(R77hKIMW5V3-U5-1v!%$2&jhrhdB7Z4M4V;4QDT6KoMvY zU?9TV0td~M7*Ah6AO}3WJnRv9B?zPobJ2mi=t)9!Z4#4@7=ve=!6Wv-DTnWT7DC{DaU&96HdR{2Q?;HSpg9}}kwI*^}27t=;TPXUuvhfbNM5jDL zu{`_}ZXVcCmY*PvJzGNj zbX7p&R`QbPzwqSpqf22XaVzdL%4IXY`&*}LsRBv>`1 zl(3m*WD2`C42lBL@a<*9&HNSFDJ;0XS2QYbuZG(%lI;uq`(@Jk|Jmj{o!gdojZW58 zm};hw+)4i%pn*{OCqSdrsMYlX*91NQZdZTio4uj_O7#<@Z(~@P0#uM)#QMq3FrM;| zvo1PTNYE1=NlGBz-O#@?T!u3WXOieF^C_f^-Jf%Qy{3QC5GlZ3HzsWRG#;u z2pmNNM8sDAI*EO0ZyHt;B9b}owr8(%wuP{7|E=P(_A-po3i^T-uaA%kJ4_LFE12)zi4+D{nCHQdh=KcEx%@+zpHK|Fs2$}`D5mP4P|RHi zgZwm_7YHfm7r61)KA|WXQOa`2mSVRq1GjG_dWpu@3YDc&h!GiZ`A3bKp*82dOTOjO z>7mO%M0P9ujoqoGsB@g;axIwe!SKn2uxb<0x(VLR6o~VrR~4lb@sn%0qu>=?P5C$1 z0@#W8@#@+<$s6+%3(;)?gGvx^8D_zYpAIKVG;TG$7gl`i*(FFV-)tqNk6K@*D zLs5Dl^j@X600HR&B1JlcB1lt;0s;XMgouP*gD6b|g4EC#l@^Lh5EPUS5g~L0=`CVJ zaD(oyyXWlx!+zYEGv_^X&dmFD?%dzKkJ@5xNe2zVA|isKA4|@imtu=6-fO7pf3j1y z(We~IP-t<+H<&|tZCdy;Teq%3ZP^sakS7Ecn81H0Q*?PCTqEJpmPbKUs)9MjS4K}m zSXqf`XI@X2PaF>$;rkm)$N%pSc;sD&$_PmZl^N5}{x0j?Q?R$^=@&C1vy5D`96Kp< zMpjHzhS@Sh$6=1sh)l^~Y!V7V;3jb z=1$ang!5>bbLj8#^I2Z@j+hgW>ubwv%~!hAm{=WSumVo--8=d|8ouv&{1x1hoeK4(9_60FciJCrA z!iUMSUYZQ$Li+08OiSi||F}$h^)o(HIesJO+TBCk$uOyPGr-YVux-bw$o2^lFayi* zaxiG=FWH&Xvo9Y@jP<#lCI99r7k589m(=}KqbDEA6Kzlpb_E}~eQ)5BZZ$N?XkbLy zy{2pzi|H<03-O+cgf$f?L8=y7;*95aWq-^_zR_c34JVPv2`%NC9&R?_$xeH3gxiF( zhRC@Q%z8+c2lSL5o(C#bvRI{Aqj2VD$O6USF@_VDn|6M;061yzPT|#EQL@KMSze}E zJ;vH!T1ODGvK+1vAi-@igU2Q+*@zx%5|oz@I3_W?w23C3v*edDo+o!fw0;ph875J3 zGlr)Upc_7Psl|wLvA{8PKh3b#zK4w64W${Zv;a$e zKY5l|;cM_)F%pmo9IHWG_mNlcjHtlHN+iioa!do z!6lW@5Y%#}BBt*!{1R*4Gi{3mgw*lk?Tp0a697Mvuk$LpkFcq=&=x`Y?E54B7d=rS zfntl&C|es=T{FkvPU-le!tFiw11{5OzoJiA^5o;unH#+3DO%pjKm|4YgHX%fM*Q8SR9YO z!kPzAs$S(c##4UI0x%FvtE7GC)jE&awUfh;=LJNvLn&<8w#)^q*=#ncUVK8C^wT>H z2aH&m_+BkHPJWG?3|{1 zBbHlQOVa>2VhC-rL3lq#at^~_RdVAj4uEcVkznJ!{2cw?|;n)v3T7c0%fv zd4VfRq)4s3#X(2`^QI|^WjXA~Vt48XqUA`E0Bw<}uT1c<&F)s}Xd$tObEbIsw3+!k z|ILL9q;j7~`|I)}nE&IPBfBV3R1Ua$Qw_kN7~5R^S|#-snV>^5!5*Cl$YA%K#ab$9)L_1|jtr;wxRC zM$y8sc8yTW&iK^xNz7t^2Sly{tpWzUmur@9#)U#}tnPKgr~{dTMf{lT40Sc~ARO zljBEaiAtzd2{5u0s>|y!=7`~gy|6l@30-wJykfmo(;Te&#e9P1E$3osTaWMy5o@Jj zfPft97)kYr$dgwueQXGh@aa@Cil%N$FiWv_fbDjDZS6-|M zr}%ykYS@u2*~q&yZ{6UVXNcU$rUhjLh% z_8LbSbYHGYNz8$AfRrlbg#E3Z^*2f}=x@d$Acf2|y+w&}&o|aZOt%t}D*RjJ4Qr{du=!0sW2%O$=oDN$mlCP;c#KQo>;6djDfoSSZn+!>6 zgy#x{kQP1YfmM$?b(?>Y#YBemdw1%rr%sc>oC|LQTivPGHpz=X0i3)yZl^|vkIv}? zVxvRAXY}{W1CJY8Mxa+qk zk$s3u=XFWrN$1l{wECGjj(*g%W7Cu`&c%CDeWLZ&J7^z;RX9IxrrD~VZ5y-Z)r^{r z%*%J{DRoTR`iUVG8=vs9R~~Km&(2j{LUA%p-jfZ*vuI{tbd;lJnFZKIc70}%#x=G&;$vYeC`{TOFv_B(s1nhCp909^c;n?_9`lsEA!P_y|K8oysE7oTW_vI zW9%4RYuMX22YXs&cQ~|g>HHkrh3X~8G4}>kt!=y+slB6HV}q(9Uw??I(b(;>kRDI0N`M!s8B``>NYfNMi}5=XA?0My#ma-E8RXJFx~ZIGbbsVnEnm46$M((GqteR-86BwT>n zSNkO5rh738nNq(uy9SkYaWabl+B`7mg_$viTMna zuZ^vV97-?00OjjO8tHm;*&;x@=IbL^NieSuUmC&l(0JL>d&cdd)i>;P2J;EqJOcCn zX!GB2rBm8K@PBzIzdU?PiaiAat@$8yn7^S9_6&_^aa`SYo0Xep|1SM%nFnUS}$CDlR zgvL&;k#}#v#{oSI>}LU_m!Ps^-kYI7!fQL@-N~H)9--!@X27n;(eM=Z`a+B(z^#+GZ2ECK^=vd>wCI66z9!~u9k zw@rQ5GJjralcDP|sXtsPTQ2isH93^my*zStb3lu!tN9)||GfOxJ;n8}DA zdY2TM&q4iF!yY;>n)^no+(h78uzg|s;)SiJ>|VPKyXVG({lW09!8z?W-zBLSuQ^)$ zOKSs`y8-Wes2!riD>AJ{2*0$2T1%mAegV1xOMfvT-R)V({1zN}%P{wDyPM{K3^v{B z`c&?7ed}mQ?TeS8hP1g><%ho?|F(53rKxBahSZHQzfrgT5l)+l%yDZXQwT`5kGWzRrsCIjn?l`0p zE`MnaazW=gUnMD-jX*ge^}fc)ImDKaB18B$K-)ESM>Tw3`A6q>N-==jP@#e!K=6Ko zpyAn&2%V=N3(U)Rm~zz9hX>Ni021#>+~JVhe; z(m_o-DLOr(FW8e7!Bx5Z5|kanLX?WICx3%4Eh-|u2J@sJ+B}pe+uYpTb4ocwnq{xI zqA!>N!@g)08110<@m?{NkQ*Mj1`V6GWSd?;78;<12GlPv)XLgi2hFLVFGu3M22L(c zo-drZF3fx4G{0}eaje)kGp?_7CM9$gRyw3k)czCxp%w%zP8Hnry9oN?ISl*y=zqwJ zS_IM@bT~=`eB2D{3AWsWhR_sBJ)Z_vroVl_WOn{xu9lufV zp{<6{YL}}dJhV(nl1xj)=bNw=$)umUhfa4Z%jCD8RWo@A%aLSndh1Pcanc4X0+8#B z)n0B3>K8nV!HoW`xc!x7sPx zkHRB~(U|xOhSW>pT#b!@M1m9=48KBvWoCr+AO{F8g0d(Qr)6Hb&$Ey6lEe-O9EEpUyd@r#+@%ABkQc zHllNAQbzfw`73pP$G#j}GNUQT(5Upb@KXfqA~kf~&Rf01QWOzyCn3oE@C5V2X*i{w zO!$G9PDk@G!9hKox%rJZmZ6W(sraE7v|o#%9Sq5IH@tCyCE))jfPa6^buauu+DH27 zB(X^zXPg3BgKDRum1+g?z0Y+k?BEP1nlr>RlbLEUJby*PV!%yxa~vqmyNnAbn3!e|K3}>-{n- zsSR>=5at{8MdXDAdVhnuv{6!XMIYUNm;cKO-YDd~iRZ|5EJz)HHZAI!URY%nz4rmLOk|eP3LD zyjz9d4MU+5LkqGZq_7@LD^BcwBArrkO(^L$V}4~#s59Q3vVRGEJsu!l%Pdy8K0$C^ zAivNGm(&e71Nw9;TW`dhOeUI5_+@!gVEzAfVRT76vsa;2mt?y_1K7P2AE(f>9)#R^ z6~Ql3z~2p#Ku~dyEp!hF0A2c^A&|sJ_&axVMXvKBB(!NKTw|0X!pbS0vC1MZ4e@oB zcadg-@IL$(9L5!Obv$ra;7S zU^rdN+&5-kdG~VkdG4D~>KUNhyUL~ye zxk> zPdxqw>+(`_R5~ry_4qwFBM>H8i$8HojvSxNIN05WKO%TSsws2*J%eKLS&>rty0kn* zPqB&A8Gksfok~okl6v7!2%ain4Y%T+8UCUc$v9_|+4C5j(N5>slU!vdc)4( zOj%Y(9JfvG>Ea#nr%0U95g`=vVM`$EesRRB!hhJaM+2m0Dd8kk?~L3PqVWqn}Gxo;uwff zO)Rq*1CbtbNKTlmbdH!L_viK)%hCHCf`R6;r!E!CSN5kzDDiSvZrzBS2;x-Ca z-*|L!s>=@7%&9IJFbUQ1=BSLdiKvbVQ=qyea*pcKW+kXDp-GkMQcBN?l+K>&l6n(V zmuroT>e7BHYSCm=Cy383oPP{6J@_{xo0g&519Py~wRJaRGba4y%$L;eJpBtnJE>$m z^7b-@+g?Tg?nL-cP}r8(Y}mM1nDQN)+LM{Nul1h@wj>t&dqvgl_4kpgTbo-|C$v_& zDfD}=lC;Dx8{4;vx&U94U6PR}%~I_|^Fst{uC1E`y4SaRC)o7sXkPCi?QH}kWqkpi zEsfu4jhAArL1(LdOtjIuv-LEB4pb>_A|bxr7dP0(RdV%_^lo2zZSxe=JaFCNNLz!| zl*5`oMo|Ah1N~o6O928D02BZK00;oJid9Z-ZtkYZla)|46a@wV000010096<00001 I4FCWD0G>o)MF0Q* diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs index 9713e021..00caf73d 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.py.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.py.cs @@ -47,10 +47,14 @@ namespace Tensorflow x = math_ops.conj(x); y = math_ops.conj(y); - var r1 = math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx); - var r2 = math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry); - - return (gen_array_ops.reshape(r1, sx), gen_array_ops.reshape(r2, sy)); + var mul1 = gen_math_ops.mul(grad, y); + var mul2 = gen_math_ops.mul(x, grad); + var reduce_sum1 = math_ops.reduce_sum(mul1, rx); + var reduce_sum2 = math_ops.reduce_sum(mul2, ry); + var reshape1 = gen_array_ops.reshape(reduce_sum1, sx); + var reshape2 = gen_array_ops.reshape(reduce_sum2, sy); + + return (reshape1, reshape2); } public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad) @@ -129,9 +133,12 @@ namespace Tensorflow var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); x = math_ops.conj(x); y = math_ops.conj(y); - y = math_ops.conj(z); - var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx); - Tensor log_x = null; + z = math_ops.conj(z); + var pow = gen_math_ops.pow(x, y - 1.0f); + var mul = grad * y * pow; + var reduce_sum = math_ops.reduce_sum(mul, rx); + var gx = gen_array_ops.reshape(reduce_sum, sx); + // Avoid false singularity at x = 0 Tensor mask = null; if (x.dtype.is_complex()) @@ -142,8 +149,10 @@ namespace Tensorflow var safe_x = array_ops.where(mask, x, ones); var x1 = gen_array_ops.log(safe_x); var y1 = array_ops.zeros_like(x); - log_x = array_ops.where(mask, x1, y1); - var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy); + var log_x = array_ops.where(mask, x1, y1); + var mul1 = grad * z * log_x; + var reduce_sum1 = math_ops.reduce_sum(mul1, ry); + var gy = gen_array_ops.reshape(reduce_sum1, sy); return (gx, gy); } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index a025bfc6..5caaf6b0 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -196,11 +196,11 @@ namespace Tensorflow _create_op_helper(op, true); - Console.Write($"create_op: {op_type} '{node_def.Name}'"); + /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); Console.Write($", control_inputs: {(control_inputs.Length == 0 ? "empty" : String.Join(", ", control_inputs.Select(x => x.name)))}"); Console.Write($", outputs: {(op.outputs.Length == 0 ? "empty" : String.Join(", ", op.outputs.Select(x => x.name)))}"); - Console.WriteLine(); + Console.WriteLine();*/ return op; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 02bdc6a1..3ec16704 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -1,5 +1,4 @@ -using Newtonsoft.Json; -using System; +using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; @@ -15,7 +14,7 @@ namespace Tensorflow private Tensor[] _outputs; public Tensor[] outputs => _outputs; - [JsonIgnore] + //[JsonIgnore] public Tensor output => _outputs.FirstOrDefault(); public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 74e26ebc..5f5d9b1c 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,5 +1,4 @@ using Google.Protobuf.Collections; -using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; @@ -13,15 +12,15 @@ namespace Tensorflow private readonly IntPtr _handle; // _c_op in python private Graph _graph; - [JsonIgnore] + //[JsonIgnore] public Graph graph => _graph; - [JsonIgnore] + //[JsonIgnore] public int _id => _id_value; - [JsonIgnore] + //[JsonIgnore] public int _id_value; public string type => OpType; - [JsonIgnore] + //[JsonIgnore] public Operation op => this; public TF_DataType dtype => TF_DataType.DtInvalid; private Status status = new Status(); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 4a1ae983..8a28fc73 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -52,10 +52,4 @@ Upgraded to TensorFlow 1.13 RC2. - - - C:\Program Files\dotnet\sdk\NuGetFallbackFolder\newtonsoft.json\9.0.1\lib\netstandard1.0\Newtonsoft.Json.dll - - - diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 993cd910..e2dd55a5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -1,5 +1,4 @@ -using Newtonsoft.Json; -using NumSharp.Core; +using NumSharp.Core; using System; using System.Collections.Generic; using System.Linq; @@ -18,13 +17,13 @@ namespace Tensorflow private readonly IntPtr _handle; private int _id; - [JsonIgnore] + //[JsonIgnore] public int Id => _id; - [JsonIgnore] + //[JsonIgnore] public Graph graph => op?.graph; - [JsonIgnore] + //[JsonIgnore] public Operation op { get; } - [JsonIgnore] + //[JsonIgnore] public Tensor[] outputs => op.outputs; /// @@ -104,7 +103,7 @@ namespace Tensorflow public int NDims => rank; - [JsonIgnore] + //[JsonIgnore] public Operation[] Consumers => consumers(); public string Device => op.Device; diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index e019c14e..6f2408db 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -351,7 +351,7 @@ namespace Tensorflow return (oper, out_grads) => { - Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'"); + // Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'"); switch (oper.type) { diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index df223ea1..22f80d2e 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -1,5 +1,4 @@ -using Newtonsoft.Json; -using NumSharp.Core; +using NumSharp.Core; using System; using System.Collections.Generic; using System.Text; @@ -13,17 +12,15 @@ namespace TensorFlowNET.Examples /// public class LinearRegression : Python, IExample { - private NumPyRandom rng = np.random; + NumPyRandom rng = np.random; + + // Parameters + float learning_rate = 0.01f; + int training_epochs = 1000; + int display_step = 50; public void Run() { - var graph = tf.Graph().as_default(); - - // Parameters - float learning_rate = 0.01f; - int training_epochs = 1000; - int display_step = 10; - // Training Data var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); @@ -31,46 +28,28 @@ namespace TensorFlowNET.Examples 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); var n_samples = train_X.shape[0]; + var graph = tf.Graph().as_default(); + // tf Graph Input var X = tf.placeholder(tf.float32); var Y = tf.placeholder(tf.float32); // Set model weights - //var rnd1 = rng.randn(); - //var rnd2 = rng.randn(); + // We can set a fixed init value in order to debug + // var rnd1 = rng.randn(); + // var rnd2 = rng.randn(); var W = tf.Variable(-0.06f, name: "weight"); var b = tf.Variable(-0.73f, name: "bias"); - var mul = tf.multiply(X, W); - var pred = tf.add(mul, b); + // Construct a linear model + var pred = tf.add(tf.multiply(X, W), b); // Mean squared error - var sub = pred - Y; - var pow = tf.pow(sub, 2.0f); - - var reduce = tf.reduce_sum(pow); - var cost = reduce / (2.0f * n_samples); + var cost = tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * n_samples); // radient descent // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default - var grad = tf.train.GradientDescentOptimizer(learning_rate); - var optimizer = grad.minimize(cost); - - //tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); - // import meta - // var new_saver = tf.train.import_meta_graph("linear_regression.meta.bin"); - var text = JsonConvert.SerializeObject(graph, new JsonSerializerSettings - { - Formatting = Formatting.Indented - }); - - /*var cost = graph.OperationByName("truediv").output; - var pred = graph.OperationByName("Add").output; - var optimizer = graph.OperationByName("GradientDescent"); - var X = graph.OperationByName("Placeholder").output; - var Y = graph.OperationByName("Placeholder_1").output; - var W = graph.OperationByName("weight").output; - var b = graph.OperationByName("bias").output;*/ + var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); // Initialize the variables (i.e. assign their default value) var init = tf.global_variables_initializer(); @@ -89,22 +68,33 @@ namespace TensorFlowNET.Examples sess.run(optimizer, new FeedItem(X, x), new FeedItem(Y, y)); - var rW = sess.run(W); } // Display logs per epoch step - /*if ((epoch + 1) % display_step == 0) + if ((epoch + 1) % display_step == 0) { var c = sess.run(cost, new FeedItem(X, train_X), new FeedItem(Y, train_Y)); - var rW = sess.run(W); - Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + - $"W={rW} b={sess.run(b)}"); - }*/ + Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); + } } Console.WriteLine("Optimization Finished!"); + var training_cost = sess.run(cost, + new FeedItem(X, train_X), + new FeedItem(Y, train_Y)); + Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); + + // Testing example + var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); + var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); + Console.WriteLine("Testing... (Mean square loss Comparison)"); + var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), + new FeedItem(X, test_X), + new FeedItem(Y, test_Y)); + Console.WriteLine($"Testing cost={testing_cost}"); + Console.WriteLine($"Absolute mean square loss difference: {Math.Abs((float)training_cost - (float)testing_cost)}"); }); } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index c22cbaa1..9dc1bd17 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,7 +6,6 @@ - diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs index c6023402..f5aec32b 100644 --- a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -23,6 +23,23 @@ namespace TensorFlowNET.UnitTest { var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); }); + + //tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); + // import meta + /*tf.train.import_meta_graph("linear_regression.meta.bin"); + + var cost = graph.OperationByName("truediv").output; + var pred = graph.OperationByName("Add").output; + var optimizer = graph.OperationByName("GradientDescent"); + var X = graph.OperationByName("Placeholder").output; + var Y = graph.OperationByName("Placeholder_1").output; + var W = graph.OperationByName("weight").output; + var b = graph.OperationByName("bias").output;*/ + + /*var text = JsonConvert.SerializeObject(graph, new JsonSerializerSettings + { + Formatting = Formatting.Indented + });*/ } public void ImportSavedModel()