From b7c01d451d67a20dbe821ffb70d678cc41212087 Mon Sep 17 00:00:00 2001 From: jmwang0117 <1021347250@qq.com> Date: Thu, 31 Oct 2024 16:00:50 +0000 Subject: [PATCH] Update --- networks/__pycache__/bev_net.cpython-310.pyc | Bin 8523 -> 7684 bytes .../__pycache__/completion.cpython-310.pyc | Bin 4131 -> 4740 bytes networks/__pycache__/occrwkv.cpython-310.pyc | Bin 0 -> 5231 bytes .../semantic_segmentation.cpython-310.pyc | Bin 10031 -> 10629 bytes networks/__pycache__/vrwkv.cpython-310.pyc | Bin 0 -> 14048 bytes networks/__pycache__/vrwkv6.cpython-310.pyc | Bin 0 -> 17351 bytes networks/cuda/wkv_cuda.cu | 345 ++++++++++++++++++ networks/cuda/wkv_op.cpp | 21 ++ scripts/check.sh | 1 + scripts/run_test.sh | 1 + scripts/run_train.sh | 2 +- scripts/run_val.sh | 1 + utils/__pycache__/checkpoint.cpython-310.pyc | Bin 2377 -> 2377 bytes utils/__pycache__/model.cpython-310.pyc | Bin 313 -> 321 bytes utils/__pycache__/optimizer.cpython-310.pyc | Bin 1189 -> 1263 bytes utils/__pycache__/ssc_loss.cpython-310.pyc | Bin 0 -> 2211 bytes utils/ssc_loss.py | 102 ++++++ 17 files changed, 472 insertions(+), 1 deletion(-) create mode 100644 networks/__pycache__/occrwkv.cpython-310.pyc create mode 100644 networks/__pycache__/vrwkv.cpython-310.pyc create mode 100644 networks/__pycache__/vrwkv6.cpython-310.pyc create mode 100755 networks/cuda/wkv_cuda.cu create mode 100755 networks/cuda/wkv_op.cpp create mode 100644 scripts/check.sh create mode 100644 scripts/run_test.sh create mode 100644 scripts/run_val.sh create mode 100644 utils/__pycache__/ssc_loss.cpython-310.pyc create mode 100644 utils/ssc_loss.py diff --git a/networks/__pycache__/bev_net.cpython-310.pyc b/networks/__pycache__/bev_net.cpython-310.pyc index 82126f7337484988ae856f000f6fcb0cbc48a971..c935f8f372b7f21fcf4403affcec9273f5596538 100644 GIT binary patch literal 7684 zcmbtZOK==V8J?c^KDAoOviwNmJj%(YK#pZYcsPz@IZi?nCB$}|Y(Y(}c1F_5`*P2$ zY^z2PkRfmo1s5(HkT{_$M@|9mT%b4%7pgdbs^H|CtCH~j|IF@=)>2|Bv(Yk2xfVMD7HAu(MG5-X96L1McZNHU3} z1WDG-L6S=(%I*Xt6N%)IE3PP&$z8mk zGG7rW9jraD}VhwX2$|x%yqDs(+~H zO4Yy&MwGkw^1GLNexuV~JzZB*qY-bT(4hWrK>}q%*%I4W%0S)HP;^&pi$Gm5LJ{&5 zn2m2E9l9P6Qa*$>q#gI<`$)s8AO6OW(WN0xIHyp%q_EAYIuY&SgF_L`30So3!2 zg(~^T@OxcPMz-TL+Ks?*BDLL)vRAyDJ+B=!YRzc;CAZcM8mrz*t4nWpI?Y)(GUu>7 zsI4oFrBIF8tJ~*T#T&xa;=RA z-4Rj7aq7*Q?>o+pa{rlA%bk{YYNd6n)?PaGR=xht_1E7$)%Jp09eKk)g{d7hX1aGI zDNF{%r>)NlUA09~n8N&6Hy-3B&1tbVncDc7{$kWgwfYzBzJW4_;#WaJWlIUwZLzL` z6QKqUJH4fbz`?o>PUE`S2X>Ei{Q=w?gZRvLogzmYl#G za*{}i$T*QQNMr;ZSznH{Ryca;t*}6hXS06a6=0Attq}3dD1KS*$p}m1kb+iB~@2Zd(iNt_J zb+welgalS5C9xpEFD0>q4E|-1ecdEekwsfewR3308V~O0N7{vvc5&qYV`!(xFO9Uv zmte_%BYvs`<*f;}j~O?6*A6D#+=jTMYyi%9szbpPY`GrVp^o>>gqB;_v>-nWc@ghC zDxZdY468FLKLUAaANkSX7^GqY%H$jm$}5kA8Tkiyye%+Od6S;tQ8s52^Sc+YQM(@y zf%OG}DcdE*kZk9)s@ZnD)`DBL+D@}}$CFi?E2rUBGqfieRL{YN<2NMWO*~WhaKqPiai{T(jW^(5>$I zmuXy@WLjA}Il$k*+RU7TN2@jc^I-aa*(D%0;GbZ}BYUCIY_z?ajIv9z)9X5Q*x{(q z^F7CfLtG$-7@5R6-3EE8#s{=MA7O|k6u#gFM7Q#oB9T`MV z$JK0Hor$ZnQHG2zJKQKIZ{1jRc-qtH#%#KACfzt&Ee`+V^tz`<+Os3=Gb8P@GKc4m zvTw5Kcmt~x84}LGi*jrL*?9Sp+UrIdMkg{$ljg4v0SJMY=p|SMs!14&tiT%%wpDDm>Zcj z6Xw3qtT~=l$E-QbN;RH`XU${QY?#Nag|IL(Yc4E&p;-$&tBTnQn3ZacDBiV*S@U5L zvyO#hBeNF5u`e|17{)G!V;EgshgIAfLqL&F_!9VR!jH#+2KOnW-vsSFo(RDyf047>V11IL0A+Sj{JK47Vcn z_Gy&aNL-A>TqItN#PzCvd)8Tm^bEK&%$)^ymbtH_;s;he3NhYBiB(U{szq^097Qm* zU!6akhuR^$5V}|!AF5sw+DM9rQWFdVl&5fGhGrQAD342$)lv3bF!6lddlpsywBib! zQY`k#K=j>AmE!=MSqwh2nAk}>%DJ810s<+rZ1<<1q3Xm7=TBUmC;E=}#X9JHSF9-bT)WWPS!COmnDJo)I@uDOZ1K5NYVdkjacSo1?H@2Eg%Vt-9`m;P0_ zXbU#GS*v?HLV9>*SbX+`3*h6(lN^ksWsyY?1lf%?PKBft^RVzZ^<>`gYB%oCx<)#! z^^Ota)o8$LEJr5%2x|5a@+}&N!7vD+d6pkT&S#S%@T9|}+X$n}Otbg(Oo_ReX78^} zrZ#vGEj(~KrEMf*e~+FU3LhJLR(E*}BD`^nDwQ8sUI(8->_8=c1^GP!BaTgCZ6ZB& zKaLlSr5#J2pfMgN@&rg^IgZ<@BT^@?#v#0XlDOd+VyG9i;CE4cvOu+BuQFF z*>m+ZL)IV$-qXwB^nIyNj2-Dk<)h4UkGM2KsFh@c!%tx?G_fH;6PrB7>lf5Rj@_*(Uiu zINpYw?PnUtbjnV!{IPDvC?;8g=a1JONr7sxfd>Pd?9?1h^fiz`SxM3>urC^WxA?$~ zApL<^eOJ+=kyiH(H`EK-pI@V-_2P&m=>vK}n5HW^(#o%buhYx* zFXG9%G5CK5W(54YJQW)9kr>RBQGzk2{Dex!Tvy#ho&v_&8p1c}WAEPwW1ZXk|G~fa z{G%_eX!Dn0jVP36Ls`kh5T*ARg(w4)u}e0L?L@_l4Oxc--TokwT)O-=k$v#`C&>Nd zATeHb*x91UeiCPesGbs2>e^U*B03n6w4ie|#Z4k)T%{y(5J>(;Hv^EhVls3e3BWVX6wXg+IS872qq>)$blPpPtaU-$0nT7izU|IcpX zX#WFx>Vx+4-*l7rGtWbR?rzc=L@2}EEJ)rc$o5=)U+15Mcn}5J905{BK7)9W^YE#x zbw<97&Q~x0D<0_|FV@MY_Alb;97s6ve+M$M{koh$+%`D#aRy4XpMWwy9Y=nyh||nR z@nr#q0R^0~Nf|KFF7DTd{1x3r1bHDg2%(YWwyd@t#NZi7aFRqsh_h5iUJtUajeJjL zlMD~1uQ}(FOfOU}JmOhd)rXR?L>&Zpj z#eV)Mt_VfDv>fP><;(?V!>fEkce=8(P|3$uh}Lun+sMAw^=qwe(~HLDfYhtVI{4j= z@6Ec5U=FPqzH*lSs1kXH=7|%ua%iO(!}KPx<>T&WsJV~5`Y+@@`5NF_pM-CDN*q^p zBne8YshWrw^-tGwz3vd(eQSWfD1Gx!mdgJnNGd2)>C*+D&qbe`R*=4q<#d*f<=`OK z>9Yi#Cm9QRTu;+^m#~3+C)Lx+A{+*iWjdb%#1iKaHdV?AAYWkM?8(^y`5M}G(&n!& z$?m3(p^@e`S}-8A3HT_0?-g<#to$Akvfq(buLbd<^8)fxCl28~>hwNH)o>e&i|@zP zdk03mD#CyymrWU#DYYcbPs>&o9}L)?XR*MEhFA2fmoLj7(gX)uZhGebpd`A^aV^Q} zD0{B{UiB^aw{W-K7gA5+jimOU0JB$h$!f=*n8QYcRYjfWAP3Ho&owmW5XF3?;Z;K- zl)f9JV7U+a4%5PrMJkF(u~wp-8K6G|8-cuk5l_!Mhw`v)@fgOB&FENv?9a8N^g>hg zT!e00!HHa#B&$qCM!VPQ-oXc~ZVVBdKa6ypU;U?DC_f!bT1tNuMM8T>Sat hJl-V^No>VNf}zZ52IK)xmftNOD?eGDE|;gC{U5_hZzli% literal 8523 zcmbVRTXP)8b)KHv&Rzg477z?kvV+jFtW_A&1eKdY(-I-mvMs}jNQw43$!NJV02bVv zdS*##vxpT#iS1Mnmt2)Rlc_uydCV_xmGhQ7=4oD%DyPbY$Ap!nN>#G*eP?EGARsE4 ztv+-5bWcxD_jfLRwmUahFmQe6{q3#)Id2$0XJz`bQF#d^`cII+2vp0cDY=_9Q|?yV zuGuPWbK16PD5LGxaw0KXUOQjQE8{ChUl1e72K(Z7pL$aJn zmLPc|Sb=0ElPm}7x>5bq5x(2ltg58=ithHVwz`em>#sB@oW88A5`6(oC`(i`zi(`* zKn3P|Ud>8Bzh;L{`q}%c=6>I>jam*b$tC%#Uw-piKWcV64K+0`c@u@hNAn=Dv2Wa0 z53t68aospnn^iM$qJA&biRb&xPBZp>%+-#<*4Cj)^1k0_)uYJw4~-9=yRg%3hZlC+ zck7+)3*T%sUcdSEH!pO;_-%P3|}?0ht30MkI$Lh-vdP+U+N5k z!diB{thb{V$2AJ?y}0(=)pKhb7tgJ|dT#Azb*8b~kNw6@z0(O>(Ti&v>xFUa>me4U zgR#aE&E-7$2-gw6OV5E+E#P6N-U}1^PBXllSW&-SbrW?jDNPTZsH;h_*%`koQLiTI z=J7>c%x2>2Fz7d8pmPDQh`99gg;iBCOQx*`PfeEjk^WVWxa`#1q3~d@qPYyjjj+}7{i!*6Xw}?(P5od^rsf$X65!VlCxy!yr#J;eZiL;FNa8CSb1s8+YjK zJ43^_Vw;KpiIYhjNMIJGBrYWQO-XW)I5QG2&Ic|e9>F)v4|758Acye_kUJRdxV#9t zjrSgxmmv2}kk7^Qy9>j--U;#@h0zNK0K8w07tvSko8T+){QafW3dhUA+`f9n*jEqO z>j~PV>ip)>OIVPjcS&HVK!DC88i=L=ojbG>?^d(b?1Xil6t;D@-}4*Y&YfhgABBDZ zBYTTRHgU=Ny`!IEoL`{G1>>Q5GbWPA=LTFyi7Fu5(7``+_7qL!`{vuQJOY_&iUyb+ z^ZXYQHH{-NAcf1VW)#E3^sZdwP#m{n44%n0er)*8U&id}t>_Av(E~@Ez~PE^a1$&i zZnGobWZ|CJLHBMaaY!$wxApXPDZRarnEhU2^?Mgd)=4grK;G#jIr<&wZj$4I2m>Oc z+hmf$H$QOv0LljyJ@rnQ!Xw$m3RUdP`Fo7+2a;swnXw5{9GaQ%k9 z0FsoOogQq^t$N(p@jG4J)?XEIqp2JHR$cqOdJr@_+ev=BRaz8L46+}l9VM$*me}!a|3_b#HW*X(=i{ua1P;4TUG0=O51 z`;95vDjP|QdPot+h7iU~Nm6+r6f#@e7llUn%?^`Ia z$`4d*?mE~t3DyoMbsrcHELgPMC+g*WqF#CQ!dn9C3{Z@>a{R&xvZ7c~;DivVBJR8$ zYD5Em)EtC@m}+TaX!KXmDJfhJ-|mN?5&2&&UIBpFD3T6yVs25xl;xY>y7WO`z6*w9S zOIl$lpz3UOV+U$1Dk_|b7|#5ZAv_HRc$f@C9auhqqH2G1k5BV3>$Tg1T2QN7VeNVA{X5 z?ZMLYdS#lYeNNyc^?C~pv?CMsTG5m}R=s#QVekXx;<%%W{5bq6|9`?CnA)jehv5&R z=qU-qpV@$j;SU3%NjT$;;R{>+9sshr_YekS}$<&*HIAmLB((eP&*htI!N zKQtHy$qA3gPYw+o z&yII}tcoYaKKgg?jDC^inEDX2`*v)JLqkXun6b;#>OOSsz(E^=AUHg0 z*Mnv`2jmbsF?fPDoYLn3f+;^ww{(2$z&*gbK5opDdrU^G<~CCV@v9EaBsU5#T{#&< zH-Oe0Lef^f5gsbtN3Uuj^?w?C8dlH=>RTiM2yAk*gP4F3>?Q(Oc)t8=Vg2?yT-U_r zS|2*8>vH4>oVbXiaTJySSpPo95flrc#wh*~f|>7gJ7 zcSGN0{T+~MF$MY$D3JY`1-hoW)_+VwLzg&&Hcptlt?NIb`~)2TJ>=0jkZBwdp2-@)SsL(P*=ZECb0}lckk0xj#SQm=tAt@ua4-fiiPhSuH!c=-)-X!L1!#%$Z334Jiu<=N-QPtQCtk|K$F z4c6!aCMJfmnol9h?>G)o0VY>LY?#=Is+&g2-=*|j60x?gle8mYb2$OxW-0O=$9`Dh>PsGKS_ z{o80eIiX|QndHAwzn2~;1=@+*?$OYE+`bA-PTSXi$2;WDf^=VJcXcqwY}x;1H<90a zpR+z~KmUV`vY#ax`mwvoHTXH_NR^aq#GICmLUraMmUu{qE%=9q(1alY()h5Qfd%}4pPe;IoI*x6g*=I(%Og!ku=Ne zn{grVkSNX{(C|p|LJ}B-3|7o6-oU>Sr3hKm(BG;z@UM^lOGuGuXf?NymFyzFp(A8w z-6+=0Ak=rziv9fS^hOjNaXHYVAh}`5h14Vg>GyQ^)+i&wHJ!jV@!se~^>(inCJP%t z>J8*Zqh2=(*8@TD#4Kt4Yo{5I=^K1ccAlSDDZwzmB(;3f{UWO;*sK4BTr!`)wY?1A z@{D@gw9Seto3801VzfUR6#Kmiwnx?g&oMtP=6{2O{0+!dP-x!AAAgeUR3mQHK(-^5 zGqIV;1t zRn)z#F8_#RyDNu3an5s?EX)VD0DnI+SH)0K(}hW3!k;wANLTB#a+mzDG)tsQY zwe{`v_N`On-6Dd3EHAQZt-}15mHSa8S3q`2oOuDuo9TDezH#lE{xM&0s@>+C{|hD4 zbct#yYoN?rqn#Rg@2{b?F%q&Td4@i`M`uB16_++U^}`Z0I-)3Mmg#d-Ps%YT7?uN# ztQt}R)B!yKStGFAO44i7HWb(_A)%tVXp>YuO^wb%|yfg_I)lTOOodf@DsN@!nG?=(%t~UiIVLgedMRfO12q10E@jE zCVtuuB0ox$G;QxTT2Fp$5!ZmBtcrjWWwr7zvOw(eEx9x9qM-r6$u+LBFzR?3YnD&s z$=Y#r2{=*yDbEwVtSU3JHd^1n0k-eAqX0e=Mv?CinSy={DJKxLGEl~98i)f0b-gT< z-qcwk0(d&tgRQMFrpy7AR~Lf^x056~PhsSoIx#th%of5+OQ7nRBQyL?oz~t3CN8@C zLw!iTEf=*1g^Bx$JK9rn$**N5?DU~-f$omzcBbq(nX+C7Yn>9>=qX@v`;xHX8I9Qb zDnx9&YD8ZKZ?nkpvtzkJn;SBFx=&WA49`d?2#+>xXzSyc)^n~wjhDE_5$>_)+}aD= z0s|i=x5u#|D@Kg9`5Q83b`Ws6r;Zxb5hlK{yBSg{S{3!8>3L777pt zP1Fg1r*J*?!_9k?S7JX39)^q?;^nW=!B@zYfaj56*f|_>oLTTwwhERgmf+jpZ6hZq~|S zWx27cud2=CrjDu< z;4hz5PkubofN_YmlCYhCQ9p^>AxRH1T`AA$sv(}E1IAj<0De)>3{e7Y= z?6bxlu;Mz)AM9-sm8YHc*3qSi`(vZrbAsX=hI|h}RC|S;MEMlLn+RBAOnevb9CjMv zFxcTuGYbwJnqtQ?B!?JpyJs0%p-G9gVXG7C6>v@DjAu8vktM8se>RSpQ}*gp1*Qhq zgpoUOnC&ENBc0vGmce9-z`Y2!aGEmQU9K-Cv=fC4 ouR_tP#MDMs^1Q1|apU5qAYaCWPBB|7sYwhtU3{9dX^)%#0BgAJ0RR91 delta 1605 zcmah}OKcoP5S^ai&VR>iVp}F8j!D21gM-Xp35of|m;5Ai2u8|ivOQkbjAxvlS;s_! z771T{uv+45i*t}7j$Ct##DN=qKwS5L#E~Oay^g(*NO#Y58VlOyJ=-&L|%?44H6{=BvN1=L3)~)ZfRh#)jfTul>-El_gL(OFE4IrTt{#!q+ z=^VoA0G~zXa|N$HRSlNF8Yl$dTd*QBlwUx5oZ$^%*g3+?WzWC~LzgfgDXW{uS@Lzt z*gXHeaY|*c@HedGW!af_H}6()*?o`-JHN|@4aY9ycCyQc?fG=~88FIE032eHDR`#I zG=FR*P3$Mk?_(7FsPMP!Yi1$<13)Bb?NLX9@{Q>nzuwLi-fLX#_cp@Y`|?n?dV%v*S7R8aQx4iNgrzkTn}l zF=U~yWqvZ#Be*J0G4efORRwE*xsd0~<3Ep`yN72f3`vUdUe(RYg~LgV<$$S)nF{lI rL3>#$=N6j>Mt*A~G#Mn0NMT*tklM>w^%7^ar?iMgJhi1c{)F=nVZ>2; diff --git a/networks/__pycache__/occrwkv.cpython-310.pyc b/networks/__pycache__/occrwkv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05ea5783508eee57ee2014b4d7f57aab4014ea6b GIT binary patch literal 5231 zcmZ`-Taz0{6`rd!x~}%_bsP{O5UVh|wnKopQZex@NfDdib)3qC)G#~U(vGZ|k$XnF z_SOgjPAa@C@B?_z=FR_wCm#EWDxUU{$`cfo@SPs5cGs9u_5F1BIj8%4=d_YWqpIO+ zzw*~^w4!PMpvKA1LgQyh>8Hq;#`J;a>Z&wcLzQM^xt3mx*^y;vx)zn(GI|VV4xFgs zR#e>@RHK?(Q+0b#j~Z@6)k}kBH0REtuCel<70tWzy7sol99H>EV-;@o7oM2zBGWgu zPW2F6Xm<{EwsQ%M*7YPB4tN$O@ipPc!8T3V zytD#+Hb=E-XuqK;BSz#?UkFv@Z*dodrTVl&7>uiOsqHKJjS(@8~ylb-j ziy*lB;m_|&Clh`c@A!kjRO}R)je-u4`qHydOQt>1Ci(#mAvbc(Gjg5jYdDHEZEWTy zddz8$xu&6IOv8N*?)hbO}Wz{O@TED~`RmU-5oeHZWhq%aVs$OSx)Ela& zsq(q3g?Tp3MWQtv*4Vc|HJ>f0UXxYMfUCZfTe-oS`=Ip}vnADIqGxWOBv{T?(1Tjv zSt_WtK6vPZgFg1}W8Z!QsfjfAgt%K}8kE_(Bg;45y7$)R&3m$R=lXkZZEm`a&Fh}3 zdAIJ~{3V*3Cz{fF@7>#*(%yX6y}9Yu-oE?p4e$2NdpGZH9)5#Q7*P7XTRr6X{(iM% zNox>BskGB@kEgQM6Cv{=o^g@NN*ITkM*|&Ov_LN_@s{v{fuE*gf!gz`?sWx!IO1^- zVg-A+?WeqBDgvX)_Ik~|elL_S9uZ71fDNhkKUdDz>| zq{G62PnIE_t#A;=+!t=OCz8?73zB%pZH4i0lzCfz7HoTQA|hF4A{i#5%xwjs2u1^6 zctfADFz(?(V`yuT1P=sVL|GlqtWuOzN6I5_OWB~HD=Vu(GV34%la(#L<3$PMgN`k& zln=Vn47xofmpB6ox0I$Hbn_HO?iG6Tyr^Kxz;7{9W{&bTtJj%#Xu1z@XvZZM+@WSBV2WcC3#g&kFx-P2iVpTM0c0Lq^$ zs-D4YwI|tEbyCZnyq?u3%F?oWf9^odVCLt9EXIK~ZeYFEWS+1n)W%KFH<|N>mRC?# zP*zb^QPxn_P-+wO9EGVhw z^Tv$JMXa}+FXnS>4*V=oeqfH5@}>R?p&?{gMLN?nRZoj#UDLkQ$IJQBWG!Dl(1}*< zZnJsf;J_Xe?wn-{4~#$S0=P1HAxF>T95~R~BG#N^OKLCV;C7WQ=PRL}uXc@1t*8Tq zR_@5sFbU%<739T46B+Kb$-+nh3Vf7>dU{Bf1w(TW-=b9+-*d#95OG}Jp3Babx1ZRD#)aH$-4yHO48Jw zhuJa@rm7sIhk?$E(j4$u+F2rkZCQ>JuP1yaZJ2V8%aT7Fa&Ti2uEAdYOuR^Imci|| zKjhNt4#J`M7V*<04&bW$JW#M>#6xNOLimphIJBXXS1(H|7%^Ygcr?r&dx0Npb6Ez` zMWc+%$|z2OjeL)bZxjD*+RIiuDTokj)M4#}e7C^EI#32E#yt{GR(m}2GGCzRoDtt4 z+A@FSQ-~m|1<)xu0M#snEQxR{!@vvGkkmuhqLwsBz8uc_d9ugavkMUqOE=vl9Q>j0KyWqQ?lGSa{{50?foX*It zkaTK2@jYTgW#1=W%H2fl`hxucjr1C_x=y%f0rRMK7T9*~UrWw4{XANh;poM$j{1^9 z#06tTe+gp?dfj*la~xxD?g$B|YX^kqM_&yU0(=F8e?$f5p&e*_J=gmPMfxVw0a|uu zAq>!SV^TWMazGq2_bq^5DJy4AztXS5jq4Os?HdIRC`$Y0PY56pVZDwrs{S@w9y55Ixz z)Vo`F+j`i859MUK{E;$qah@o2oUj@ZbVN@rGPKMexb>4pc%qf^p4&d&s1IOEv;Q_k_n8ob_We`M4*_MDb&yFyfd{C03g72eLek zhoza0B3VMsTLyNewkVnLsELG-|0BA8m3{KX24r6evMUR={mxu%LjFfV5dWb5x3;hAmRcv8?NvO0B@-0Un)n_78Fi2JjR|dCLh$Se7#dWoT=Yz*0LF=imn0uD9`pLa$ zR!aJ8>BovcN$@{h5V2A`2~uFgrdI5(BWfgrRN=CCi1BClZqg+`rs;#^D_YRe_g0Qn zahl}R{-l9fkVe+W^s%1dvBZqbn3za-iqOI5I)z#;>`Q^4My3JFla z8MiNO43ggYZqR9;-_EjOdSzq77mvc7^+fbGhOE1B`QpW&u3x(N+H0?`quH4)*89)X zfBxi=CclZD+`oa7rOG3g1}3 z`vd+s{hX$aknOck>g{=?|Cbl>TUz8(%GCKhkAX#EK&kF{??2ECwf7ACoRBkrrzq#5z)uKjW|WZ$3A!|bG%tQf+3%6b z#`F;ya0KGsi4*%Skx=+0NZ|yxY4JX6>dsl`6{qFcj)`Bzah!6?R8*ZMX&)PW0RxVT zx!2=Z`4@NbsJVXZq=ir0d_oz;-Rkk}mMPH2`&ZB{Q4j6ue_-n6?Jmzk3wYiL(SSN$ z^-OmcM_D*X*QbAokt@bsq!2-tZY$*}W>3hf4o`)e&Kre#-=r&-(ti&S5nh=hNAk|} HTQUC&-AO;( literal 0 HcmV?d00001 diff --git a/networks/__pycache__/semantic_segmentation.cpython-310.pyc b/networks/__pycache__/semantic_segmentation.cpython-310.pyc index f38f62d6cb63d62861d23eb6f8a0617dcfb51ff2..3ef8de2ef7ce2a77d5905b9957cbc5aa9e45886b 100644 GIT binary patch delta 4057 zcmbVPU2q#$72dnrl~%I;EKB~CZP|{UC{CQjPMiihY3inNN*ohx+JG9mm37y)Wc_h> zZ6{%c0!hdKF-^Fo1s*bTp$r))rIVQs?LgsyHy9onW}jey&g20Y`o!=+IOj@rlr)(Z zq}gxpJwJEPz2}~D&wYLN@_euq2>2xUy*KjP#lIiE9*oN}xxVX*RVqV(ExERfu2htIsCG$8#b}IrsqYd=wbK^r2TBJG&>*}!K_^7R zpc8KBbOEK6Mt~AAT)A!-r3WZ&GzyfcpnTkwic>N#CEIVKb+bw01oNe}w0T}PON&+^ zeHJL{RKAp1fkwbI3c8s|=kqXlX@$8$y6zRpzuPG?OZe@pqNJ z&ij--izT+6W^{nZ)%hv>I!gjL0bS4J)25j*@_Aiv`eY}H;CGn{5Q4{a189<;gj<=) z-%yvTNJy%VJ6$dtMJfaxm8FYI1`7a-1(EM&?bz-BaNG;HJSOuX1ub_-^EbsbgN5{p zq36@*3})(Pc@(5f+z$$gkRXW?A7K;F%_tIo+5Hk3;A5Wjo?#rzm8EALcfqihOO#Cl zrD-~1Xb(8VvJ5uC-|=*j$N3*Tm-dV0ml+Pg9(#a}dns)#8-iuZz}MVt>ee;w7v$(8 zD2+?N4gh?b5kPQWfEnP&qMj+O7A;fQF?(A9-c8vFHh9_M8qCgMll;1G$bAUM_!xiN zckbYQUJI|4&Kh|g_mY)nONMnEL{0Qww>L_*kua=>M%W2{IMBC6fnKmr=To=2+3%q|s5S;`iXw1gyOF`re-u!J%G^R};g92abQww=#JkJsIp zMM>dvOaWi!Ka8Fpd#DHJKqLcC!-G*0>0@(zckD{|0#I)!5HOP2m-w$^JIHhVZ?S6C z3Q!qSE)8tBY)KD7Uuo)xpzm$!hiw(4T5XjES7i2@E#;Jo+ltiMtVk{jBgvSoH(Wr8 zQF)WVyPdX*H*jtDro5?ue#e%Ep?pM@&Mg)1N7U$|(mb3#a12CvqGRvk)=p&0xCmnR zsbyBmjk60tNQw5u?FceOf^a$RqHYuxsN*TtBht}CveUtb7eVY8dj^Fue8mEP=udL}JvF>Z`j5l{C9JRchhKz%WoGl!RjZ*jFyKQA5>RIh-DImfSd4ryT}49UqY z^EW!*O>ITl?f=}lbElClMW~7RAR^R~-g-EQosTiWoo$N{ymvfI^J_ zv#Td8Y#u_nFCf@F-rd>16|!QZJHZ8GO-HvF87o^V{*KRg9|`RMaSfg);4v+Qzta83 z1v`)nolz>mN~jX9wAx|2by2QFU`K}SK&=%vi3Y`mx~bG;JBZwn&>yv>TALl+B$XJ99fkcL%LQmlUBdSBQ}Ki8 zu7_9>v%CIpOS*+6?RJN-gvJM|Jzz<@-2;}y!IBQJBo3DN>k@V+^=_zVT$SERU!@-; zgLbdoXL)KG$a*bL&bKL3e>aw6P>@}Abx9D7N+n?@9%6f^ow(O_`adk=0DQEI23F*2 zG8?y~9D#(PbxAY`5Bx7fqF9u>q*^ysAUQ}m*ZX|_PJC$hU_+zjLo@_mXbd$p{2!tr z_^46cX7$*ExU{x{co8h>k}AVM9r^%u1gP6SK;3S|?KZpH9$0eO!?j)*NviePzD=?u z+aryBzuhkS+Z+9a-HZJtSCxh~x@7Iqs0t4D(wmh*@Nxj`;)26X zC^2o^m!bTP0$nc^Gsf>h-Nd6sgYqs&reuW-66KE1m5_HrUn57zn9D~JWUD7P_x!F? z1pn>i)}2I^oD@E~0mHG^5PkyS_{$|&ZWE561tZU1NBZ{xz>98btYtF>q*?16EH$25 z>;gi)P~yKCJk~9oEHivL3f^#%;rdFy$w!BJhJS|QZz1eOpbfF&68kasv5sa1zA`kR z{Ro;#KVjYL!j&G5b3OXD|<#X&+B)IaWCH4bo ziQ@|ER-}53vY7>6>+6nZ-N-I2Lq&e?!n0q3+#LVP@I-YKwRa)xM(`ul5Z*+Ku(r4zfNoT@@=6CP2aYf_wOyEILk3?F02fs$HO0KJadWVdFOSY~;wSY&Q-Hc#`z zBg5_jj(Lb@M|w$sUm5AGI>Dl0t(Vw}IktueK69Z22XKNAl6i9s?r`F+$?&d`^3Rvk zaKkuRqN{mB+|0zSseY7Bp)Q64o52Cx4Jj=3gr!ZB6VvRG9s@D$(SjK5FwVl<34oVD c4X$Mx!MDHB@?%TZ;yz9DYu-p}#MgfKzgzpaVE_OC delta 3468 zcma)8TWlOx8J;t@oxQJj*UtKWN$q%(#)<7X=|f1Opg2w1T1g#Ix+QGKbH?kn*E8EQ zW4DgR*_0$eOH&#hwNO;RRs~WJ5<-4JJir4L61??!Bm|-YDX6M=;(-Uu|DWC1RuX7! z&9~?Lmvj05@0|1Bp8ds4%1tI?68w$w%KV=v?xuR=e@J(qyxur3QH83vB&yo+#neru zl%^WhZ%L&L&Cv)oZjn+OjnWuMvNTQ;@S6jlBu#-&s^yahNt$LrlChP=b{M4tByBVc zl59(|kCIubn7@yUm?@GlRdZKOFM6%)x%2+A`PvpG{XM|0Y0P6p2gYO0e~=BE?YIz^DMhoE0_07*?}fFFH-o`!<{4|BQ5!(>!L}Iwry|GaP_@=`lWf#q<|!b_!IU3}EQ& z%+~LW-;nrmw8jvyN&aT6yANlaxsQT-a+qQ^WYZ~EomJR}&5rSpV*B+IX#EubG`1Gs zlOJA>x9>X-Dh~tJ70r_;NCtLl5O#srlfC@c@d={xPvhqXqd0)b9>yR{dUeZnR$I{$ zxqm)!JTG!5CJ+fk`xp5KiSwtPfmV^QXYog2IUGW-I`tLbD<#E14@me>dv6k$M?YTeI0wGIs{z&%pZF+Al`bLS$oy`Xr9 zAIzO;M9^*mpx7Rg=}={MZH1LkSVnjW;T*y=Kp0(dyjNCi`Bt{;&2?GRFHm8jU~; zq;fLQXnaX#?*`JMx~}`FjkKRy%upHr6Et~80ZAK`HwpY^X-fQpHqbZaO%?oeJ09v6 z^vJh7rZRAof3Zp0sWdC$S4{J7>^pjThh1zQcGhH+R%hZ~j`Z|lcVd=pGOJlXtW=YXs$oOS-au|xR#|JmN&W|N?9 zv0hM>2-X*)O=b!bB0U*iM$ADFoD`RNA9O>F+Ky*~gSb25tFM+H+!!STI+p7)+T6w< zdh2V5*Z7~h3jI5WN>DLO&DpDE46<$cRoD3g&v&0m4}dww7SE*zkBrs*7yf?F;j5`Y z+DHdLm-RFdD+5$ZFD7X$h;QhdQjnn0JKEQj^~`$PdKN5GK_+PPBO3900q+z;frsEDoJE*LxPh?1PY#SWz6Z@>_wEhbr0^?fu)|)#=umd*p=L5>t_h-s z8r;8=6YQI4`UXN5!n+9HLU~<95iS6PDy>?6sP$Sk*Rk0hl)sJeU4Y_1 z^H_5}Ja~u<^2>vl`A35@d~oPQ3LmOc6!||{p0oMt(D}wg2PDy|*vhVB*&l-j<2JJpi^3w>splq5Tc=+zpt6afBZs z{1_li)Ln>?2UW(LU1J+4|2_Z|syek-tCkJdg1=U`o7u!G^Fdkk*%Q-UA~Q0>pHQv2 zo>$B?gZLqz8SWnb37Y>DVGIGE9yD_*?1$LLl8Y_zw}%If51?60z$5po;b%z?S4U1Z zpEnV=o###CS4RGQ1mAe5Jn_K|WAn^iF2mhfUt{m1K&iPE_5)~%yA=x|zJ8$a@L)d& zg8_c;z+|I{rUwxYAz(Sx@~?+mGj7zYf)n96iB z)IdRo$DmjUwLOb8e3w}ZV~)dg@sOP3=IF4#1cvMce`~at#QD!gdmE+1Y~6$c^EsEU zK>eejejGtu+-I?>pV| zTCGe1z3Maf)%SVsIo~<=s-4egG(2B@{?pZOUedIGWMurwBJv1c?+IPg9L>=iT2n9S zI>Sc8C>in_E5+p3ESd5fFD2wRSxU-ps+lgO<(pU|)y$N#(f7G#-q3WdStw2ETG~v< zoik3tNnX=S)8C^xDJOkRbJA{ldCz97G~>)Wd3Uc}bB{XvodeFnn&IS}L(XC6h;!6=mvhWL z{JQprQA;@QzHXM@g&HmAJ*e?LZgTn9^1IQx(drf+`MsG`MOI%$v_0ExF1n7j=GJRVzBixASa=>)HSa+QF*Yi7%5SU6()<;-<+WAuLF?JJSLfVY zQi)S<)dm$=YjuCg@|MaScUTYW`sDf3*3+l0i!aR&lRZ_jrJ!}n?{*q4>b zjjq(*u3Aoet>u-Q$abu@s@Li*%U9)k3r!9yV*T}Ix#pg>kX>!7X4$t+J#pH4p*U*I z@%-U0iCnJ=bctT&t?u zO>3^@mN5bT-06wY%{A)Gm^*FFE!M}w!uE4kx#d`Mz8YkWM$M^t%&nHY4d1i;Ha`zP z85L=Do3?uulWP0z6)doIs^0RWwLspYt1PU8lt2wEaoX~_l_jg}S(Sztel@CYXN%_} z)_TtMyQ($DUE?e@TzN33GsbslIOGsv9~(!`s&}FYc5}VfoNs}HFhU=N-K!n9;-gm)s|NaAZ|QgU z=8E&yvkkY5#j7m2l@*qKx!$bV++QzSWmW4o-IhPUa09zeyHtz?CUb&(xlymR>`Gg; zT;&BOdF4$#h&lCUF%y_A{04DpDoBD6xQ~Kl$5qX)?*{s%AdAIuRj1u3BOaGI43bsE ze76;(v5GAZV{(I7rR*2;rSx!srSx#WmD0n_pmOM4X?iCQ1iBsQPY3$NQk-8cg+FVh z+&H6@GRzAFNGZcA!|tR=l&T zYHek8zOx=oW1N^8d)U<}eC?eEu@0E}jGog|pefzZ&70wyxM|*;HD`^Zx@n~J#LZb_ z#!$!bK?U!`V@5IWE(Bbyq&ph4unrBJdR43Hj_Jg&rJ;wZN3~M?Rn62&2}sLyka?CY z@HmBJ@0j%@#S1r(q*{_pu@|yp0+rv;z=oAUdu*0r=8!$TQy{*!scq@o(BXY(6x;A) z%Vyu`#}I>7x^5yZ?m+((jfJ8SWLI49NJAJZNI;s@9k-YOb9BKDLCUrv1by2MG97j_ zEM~M?LE>V&wR(>em>tz#E}Gyww^3DhqWw~;4$fNwiv(#A+%~uk%$K(9N~7#~wk->% z-i5O21jq;RdK^MVPv{5r-rmui%n#bHkiQ2H%W~Kk@p?~z^z~)!MRUX0h;5h~@jiyQ ztnUXff~wuukB?&Bfu}U8*)5~5Z|hKb2B=B3w{2`-dN3*GLb&DAs8;rUF|MW==Fk-9 zSM4V`08&h-B7Oq>YM_5e9g*M@f&M}`K60lF5<91@*2>EH5K_DpNI|zSP*d;Cj14s^ zR5TRFr!0_Sk}J?EH(joPZDSjE8w}@c`;~6F5#=Op+i6!YubfSF9wf*|lU;EejlqOV zhwf$(nfiAl^dMg8QA#uQf-xOeO#LQac`jWE(vE6(Y%KC(kQ%7OAn#Sm=*Qx^z2dr^ z33g24!{OSzj0l+yz$MW)yi^}-2gsox)Hb+AfIo+1dvJvu+XbP1Z}vfl>(d49bDt<>?CiWYd; z3iZ2nYHi7FS>;ZrQHQvNCT~K|%a;_@6cP?*)>>?|D=Qu}yX?IyOLfn}KZJbtSiJF zXqv0a#+r2xqF-PSnwygmMvcuezic`6YSmS!U*HJ)&(&&o(HRWU^4v~Yp)(jACSZO( zRG*h!7aw|+uIB+;J!mcYe#d+8?AaRdLw9k$(r%uGlm`NVXma55IHOlK)1!9PS+LDom zfPwv%MVPbD<0TaI@hzs{>ckvl9?Di&k<1cn1LBd!=De<2t*>GNIcA`haES(uMJuV^ z2cDFTE&IHvPi%7Pgdhz70NmQ~*MkgjNvqv$H0pL>I`vh%m{ez2!TU)bAmJ7&rMVAe zmx!<^<)I;+cGC_ga;H3q_7ukX1U~jCUUC@|y63Py4Ly`WNIi_i3f_rlU~_m6A$da2 z+NJ<)lhJ#72J^7HuA$!Hv|r?Bvmm~x)nyuM=+<&<)4Z&y zN03o8Wzlx92E;jaM94MLlTmjlsVjKD#YSfJ-t@35kuInZeGVRmhmw^&adh-WBPOL8B{dqILYWJlQ#%*A%QG15wklc??|8MQl7NLs?kda*Mfxwn@+^YM;<6_SUM zdgS%No_A@j z8`rc=y<{RBb4&@xoj6z^5ySw2D+C&Yhc2U`gW?;A2sDI)0vk&QGz3iS11PTO-jpCD zq|Iu+3FBYiG=QWGXz-(w1IGZ&{^Sar6a6rpLO2bbod)iLYl6-MXlHXPv#o7qx2fsl z!!YuPVU!$(lc+O?J^`jyF&xD5j(I({k?{*#Q(^j`_I_qNh5A!!sh`GyF4ek~Z@`1?-#s)|@?rOaP^Qh@I+iKm$HpAK#r`0^BB+##f7%qOre&*4S z2eD?o6~rBX9lCQL^UbAlqZ*_Zg}2Y&AMEkUtKfjPvKPyh6#_i<7|UfU?UrA!b=zGp zh=+P)&nPKsG)R#**)IWX#Vl+o|EhFEeFU{caB?lsyMewE=&M1R%@YEGkrtzXzCN~- z1l3Wdtm73^AYmjRQu>}tT}&*f4`V??OXq2Xpbnm40v(yrb{KJ? zdX`xq1p)8}%R3RQ@xpBp@42e&spnbfIg-mHS4chqf(;03pzp^YnV!s_8T$kC zzhx$8VE7-$HVA(UMggH-j$2~T^#z!ElX@=5d43rdM$`3|+K%jD#pcJlB%PCf(MG+> zeubLNKo5)q1zj<{LW^neFg%pgClT>s%EQtT$w){9xk?%rnjqcAf`<4bCiMkz+BWM- zmNoeuOz&8a+1ViV+{YH|iywRJ(ceMLqcAwBt7ZH|X5rT_Oc;FOkRl23Ad1uL4Mma& zO>G$FVwA&-!%)DJF?f)%^PwoPm*EqK@qI{x&sii3B!o%tHzY(1hMorV52}Ux2>xu- zr+P7Fq4f}p(%`>_ISHRO>q-w&Z!L)_OtI8MXeTT+Dl-TsJ`rj$ZE4p0pq%%#}} z#fXf2e3xiXEY{0j@la?|p1?E&DRD2@btg!+YzYU67LTzUQIYVhiAOX#(NZ0LlL;A2 zu$k$?1u;nB>0ywsz?DHCYsh(dFI*@a&JBA7Ds?QnIvG`pIv_|qE$1>pY{gv<;;6ps z1{vj6;M>B&xgndW`h)p*e!&VA9hR}b?kAFDNaA^}QrLlJ6t= z0TLNfT1LqcsUK$I=RitH900UzueY0B{g)`<(E`lG2s-&^Vl(hJgwqxSsl77tK$9OfntK(&D%^UMYB5>UstV;iu8hCRx` z!h*+iIetXzCvh~M>!+$QwBL+sEkGngnMyrVlW7i)ZrN|in;zxiNEoUl`ex4yP=qG_ zHe)>czpS0tyf@ddv-$T4yiCLV=HW&Ex5<%2W1kpZG@dAJFs2#={6r6jqaEf9`gd#2 zo~Q?f730Ts^+42z8JVlF*7PoE`9WXD`@r?MeLse|yY^+6EB*2Jenu}&EvPlHp(v^d z7X9mm8~=Ip=FJ;q&stTUf4WfnyT9f0?-wRKImA>{0R+k(QDO~T`%;D{5Dr$&;n@Xv z_4CFn^ynC&I+Q~J7KR;b%%2PN`c9RUd8kqMd^l%24}Y6Avzen>@19$?*0=IKVfx+j|kiK!DcCS2te<8G^<5Dz8)Ee!@a-%pKYTsn5ZP!%U3FQbSN?_us*;AA- za5}*U?1q9Y{Ve>^K+HH<8j3x~t(4b;MCd8SNhN$KL9EfPjY?GO%JYL*r@f{=gBsN) z$u$xx)?n{K+kcWr_w-4-&f`kq@<2M2ao16xl!HIG-nP%-4B0*(%!IM?jFlVBw&$NK zWu+6g*Ku7Z(3eVyUcJ^U*FvY=7uXUfJ|qI@HxQ-L^d17Cc$=nm9V_UkfC7%ij_cD1 z@7HJbqcA^?|4VA$EW(HJcMN42K88P-97gZhJFm`!vy_H7*+8Tp(Rp9r(zXe(!*ehC z0(nyM0XadT9^?!6Dd2U2ntU0?IdnCpjTxP`0M;P}&=;l0qjbP(qyy4L>4_*k-9k;g zo&xD4;2hUNfCmH_Kwcq~6<<&qp_~t66-KiTLp)FOvl@d7Hx`|+5q%VMil9VPj`}o1 zqJefS^A9rPRg%|81o9=ldac>6JH^77e^oJ?zMPTq8W>lv;)UUObrr7~nU#UA0>e`~ zmNrN&{S;swBn8N1Hf-6Wuy^ zyKce3KvQ9a+C{-!x-^*Gdzd-$vY%j-Qb@^Uv%%Xt)>o1H7kK5QDWw^><6!88w^Mhd zPR{I8%%vcmc*v@JJhJMTg%VFEoPt;=NhgIn4KXL}WW++50z}EfG8x@&m>yUr7eoAe zt7W1Zs68BLg$Bvs;j|5nlfOdbos1I*78)lsOdx)CHB5@P7$)Ftd=(icv~9uHS-$|U zp5D&_IE1FjG~CLM!v(#k57>tCGm+Ub6O|m99eX7uJ7$Q?iUBff?~4qPeFQU#UZCRs zk-0I_nW8O)1;ceE3J*c~_cQcql96z{#^_0s&yt9Q{2_*9pNW_hf%zjygOC-Tr8>Sn z;Lk4A9mj3iJUms@n`5{?9p&wg{69cWKaDR&$bTp9PXX%>Fh5a$P9Grt5pV2{_{jzo zaH8|((Ncu?KMLeb#D4)emHLcti0}=;)NmCeg_s(r2@DC!&%s?S=pTX(?kBi4;8X^t zJ{DodG-NIAg2-ZLc1g?bf|;M%1vAHinO_;=(5Np%#5<^&H&3LO!<=k{coX>ki@SaA zBqJQVYaPQ}smIR&slufqNcBg#qMXdo>H9^7ev;%TK#F-$s_LhhAl9JRg+I-xaLvy! zB<%LH3~iB!k`_MPMij@g!eqiZKhJy-kzZg)q@xg#jucy;AkSC`48T(em-g$&-%2G3 zlTteX(m0=t!`ItuS@mnEzPLkQCXmKg^yM!jK*cMd=@hgY_Bi8s4d_Y}DwkoHOx)KD zMIeZ}6nb-*J^)j(7MP2tBUI>H(5KLvFfzAb)`cb`8`Dp!4l%@WA& zra)KNBsD=n&rR=Cptm*xT(Luej+9HJMUy*x_~FU4Xwi38pudY=Jjm`85ExUS$7W^} z^KP#`&mnh0eWubWCoc+=ijxYo;AEW~6elijn7AAWi*G83e~8zQb`I!$sBEM%7cN!U zox#H}btnO*_YWXDb*6aKL$f(P7KS@d`2VO6Ait@%aGED4W10b@)Z{Y^#>>{PNRk*cU8VP3Y*liVX$ z3xd>>(0hK}UzhtE%s3~YD2Tz`CF>i+s?}DI4LyT$_buR+S3mT$iAdzmQxur6j_F`N zAN}4VDt%z3h%mjK&n>(@@_Vd^QnZ9ixG+V;-RG%4W3njs&oOj2$!Lq7K=c|J^ezN4 zCeD2EcNm)!$6au{8G7n%rsB-^?NCY67RBY4g7Dc$^OAW6kBIImsWA{0Yg|NdAcAt03TX5#zgv>o1WPj-w1MBD!FY z)7y|e`fc20P9@%wn0fSV#Pr^sArJh&WHtP*Ls>mDl-2o(nsHTVhk)KG&j!}c&a>9B z8$>im#_)t}0|6C(fQ(J{Vq8SS%K7h^!-{!uqmh{f0^yz*o4#}f0f9s-h%FRP!V9(D z!F5Gtu;d#cfzifo9eAU7i!q4f0y}P!8?8?0gGhz{xuDewp^y_Fjz>T%`?16*K%bj|gjMju=y3M@#lLx?@a3-GrjIU2U5frglN=U@R=O<9DhpGY$#jj)@leVHsfTbL zfD6bEW1d4_KQEQggP#(CGou@6oNfQt%-shIGkY_8O(O?K$$nm)H&YK~%$dTBzyHbq E0oXLkq5uE@ literal 0 HcmV?d00001 diff --git a/networks/__pycache__/vrwkv6.cpython-310.pyc b/networks/__pycache__/vrwkv6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e317cf327951cf0fc1b64ebb6969cb521568364 GIT binary patch literal 17351 zcmb_^dvqMvdEdPDiN)f;jDwW272YmuNx$+V?VqGVEx=t!_4QkJ}tGFt2m77Of4 zcV+={&urtEv?aG?EaxQcNmQkP+on~UoSf61o}Tpd)M?#U()5oseNJ1~joYMb`f}o) zI%;k9_xomFfFN0EJK%hK=f3CO?|$!Lb9_9l;rjYcW%byPXxcaVFnVS1@C0tJp=+9> zIeJa2>qT8>+Nc>tLw?PoDZf_HlHXV{F29LlLVlC=R52xO%v!RZE{^ruXX@jIrfc84)<4ytPgad-LomX4}oyQn}UY+_irxt1gw>S3BvG%JojRy;8;POq@i*wO31?Z`Vuy zTE{&3sfRjaCtqs$*GgWV4+|$xy4QSFI_bEZ)v|kX(=Gc=Rqd2G$xi&_MSIzPOn$H0 zj}@%YYLx13nDClyRd&5FvFWOnrso!?q!(LeTwbk}DqflIKYE#XcmlU~48YZjx}y~h z*K~BpxT)RHix$$RV@W#Z#IT0(F!tgPecC=-7BojIUPS6~+!t_r-v;otZGA`kQcc%< zBQOJNUfVFd@lR@>(*r%wY%?$bEtF_>EU*CML2Ns*qj`5row%g$lXODT4@f#G>C=)< zN%~_ECIG`5(ZkW1uMLVr|-WCom|2~7;L0-XFE|_6F z+O#m{_}5!*I9_+_O?BN~Z7S?`VLFUqE6ZzPqS3S~s^o;Jr=NQAqJ8CZ@u^V%M5td4 z^$Vf?bU0r1>~gc=S1avi+Y3$iS}TmL0wet}J(54=l{SG)O=Yi?${SlH<%BV}-tw=j zG3-K^kQNV}38#i_dQ%Bgg26-%RBI}>Eo4Ho>|aw;Xsf0P4iMznm99d4Bh)uTeJj-4 zq2c1^`a7ChMB%%Or`DQv_tbiQtJJ8R`c%36JaFXH;8biXj7_kfV45G z$9sPnz5kcfr}ZQHf^O)?zSS8QTwChRuWX=I|C&UCa}=C=9_O})b8lh4v~{Bk3|lw5 zz_4|z3k+M2bz_0P9S4rh3jS#8iI-#DcsJ2ac2nInFd<3Y2rOWSzMgV`RqMtvES=tBb4mVaTydbY1<^hwjHKzyWVu#HKsGR{aU+J>y>0YAN0Ficm1`dBj-TI z9hnPHl9CCLw9eu7m`G}dp4N@&tgiT2P66~Urn!Q02DqL)1A(@OjwN?eSG0HZ<$@tS zgt1bqRl8nDeVB#37oEyNj!cwu471O_uxwv=;o_4v&-&TNk>~Mv&g<%9__3W)Lu5batJ}DjTuNoU=vw6o`%zZo{?uuoel(vKW{i@1$@fUK5wj= zhqP{Tkgqu=N{1yFVOw`BCx!!I1Tlwf2fXc&25249c4FOV4j>I60U!c6=H@nt#g5+1 zjP#0;Kn`Mo;aE9HX@|iX`{Ot~g8Cfi5<|b4X!IU@pEf+UtbY!pkJgtFq<2ZhM#t&5 z**nGuwHck;zGKM<{T`fnE_yGWgoSU^Tgmsl*Kl@w+hUgKVY_|D#3<1UvMtwfcbk3R z?>%ksW_#v{v8*?B^@$+Eod!0=9cxb8G5h=9&G_be9K4M<;Oy-?xYWm)B9h-W0p8I=)2Y@Ai7@7^mB8y$kl7ku3!+z;0HI`_ki{_F zaaGf^Yt;=mG^-6CTo$sxco?fTT5UfZ^VX`XzTK)`b8BI8%`G`L`VA*rCBM98qo-!A zP0=sRZdLs?yIk{Zzq#Qy~#oN#W>bI6{er#n?F@wah=YQ{N+02 z=Dgp`OKx7Woob_!&f`+&(D`i*`%&(X$1$0X7R`Fd?^Dan1cQT0jgqb-tZMxznNFEksQXPnS# zspdK*bI+}%WJ@oFy->axW&Mrl#Tr?d7f>LFr*jI;8olnUm*<%X-@AP z7_P}ufBaqY5bN@_FR8AXO!L*$X8)IjY&Rv8O9x5@#Hn->5*EDVM4-=H`!Yt+|an#ek~ohHOm`bVJR&~?dn?9 z%j2KF=CU6*4|SDkjtabQKugV6(S81MzO3AmkDjq_`KMM_tL3WO@RwcxMBeqw`MzFo zBEQz$a-prBkp9^h)Z9uHig=z<69=nTy&8xU7M#N^${Y3uTV#1D?^IV;U4{NRTAlOf zZh`9Yy-In{ZIu)zgSW#1EG;d`SXW#Z4ZU*P^FUZ1&9C`>%X{?HsR{^edu6HIte=X+ zR47ner$9hYRXwlmdZ!+K_VyGKSj=O_YpCd>EtcTgt~JIwl=4wWvP$d? zCIuO6(QB*C>LwPD@4VRGqW)mLom4BpNjb-n^sbMdX_ynpeNqi~%NA8KOs{|%HrusY z)ebGEx@qI=+3cZA;1KZm7E?S3a!iCc6vx31oo2l!@b5YG{X>Npt)a0u36M!!dJwl?xmuK|2?I!F|AKAQFQE zCy7uXG;5t771a7LhvyG{dW27L$*)E$#>Fa|Bk#m?|MsXN`AQX=aB^yyNh58N$Q$p^B`UR`sz%BeG3;vSUJ5&9Y*q^&l>fyp) zU^5y7e3{ldrnxSf_6@5)F-x{>4{|{FAV-MsMrJac4&V3<+`rF3a(ZWSFqOXEPi3Ni zF+CuzoC!7ch}NKxnNro3$ZRQVo`r2!^8$`4RW? zv$%V@AH*;t8Jjd^pTopTwcN)29=(d1p=`}}kRZ2#91Wf25X^3B8qbD21IjaS2nz{Jxc@L90wTa7`FOG+(B?|yupCXfu|dBe%x;-*iOehnz}JlIzU^d1toNLu+Vw^~{A zy&myM%O(+j_9ZZ#st?DQC5zAt6$>c2`s_ypM+ zIbF%mAAkBp{zb57j>OA>P5BQs8fg6g+sP(+-w^d^VyGBuJu=7ApX9VvsaI|NFJ~p<5C?tC-+&SXYV$SM$d^#|n9;>3WStA2J`5ERjlJT!!$5S@n=7 z)IajB{y=B3u#|re(ib+byyliSSohUxy&`JSs{;vUX<77wOE7+0q6v+cYSjv~ohEcB z<%Je;MLmWW0F_+DPoz3f-wYFw$?9$24fRXm81}?fttRw{NN-j@%F0g=Tm%5SP7U6r zm>L|dVrqc*s*QT#LsrGF3n*HgsqP7cq!L?9ol&@bZ)0o)5m?q0NaeFzi-4Pv{`ZA5>bmdT= zwgkeWgLE96gz9B{9$dywAeHcw;2uY5$Ab}Cvs2J+(xCe~WE-UIG4L79&QK|}$Ng*& zyTLvu{K@qxw5H-3qeto%$D7_a_rSinxqWjF21!+iyJv>#E_$48m@=!`UTwGu2BXa( z?-1Wb<^Zdq%wZmUN3#!$y&$%H32MBxQZ0FvZ>@dvjlcSxfB!g~ho;-|3JZ~lMsX`l zio|YL9k_mMNr&+U85A<0I5%vs!wgzt!j#y6h+&FianArB86A2&NL^+DQ6x!3m?S^UPmZ2D8-_fjsGZ&_wp?0zi-V!EJ(QY86 zmf1e_gAkcoQexE__O&n(S)y8nu}FTp$=Q5?;0D3x2);-lGfauOM99ETGE?UM(@gCn znErL-dk+Gn&Auf*Ze;aHOpS8V`iwE7^AGB}9yfA$h6uZ`5039I&1TU3z!V)oW{4DY z7!c1zUZgkmFD+8Nj$~V^<%hK{q|1KJVc6%pChA%cJ}Cdjc3@@k1x;?~7qwS1P~mw9 zU}A;XgTElQYM#?xJMJeS!mwou78!VIV5B4xrr#%j?0xbxd-EZf88jXA`WqiDhqz|u z;0fxLWj~;7!bx=FEXNEdIe!>ZlWe;^MLM$s+mIZ>UJVF;^3E`xj+P(b%Aj=)ri+BE zN={bgU^fXW1z#=8qE%(!K{WFNV)}sTHi>omE1W}+*gn`>-`tKVBhNV5Zfe*3QhVk$ zJ17Nd!3=iKe6IIioLPEgW(#m5K(w6IoC(;-uu{%sPU}(#rI0$9EzL8**sKNufLXq& zzoyAN)7uA8hlFCjw?cNd%8<>ZBh-~Fw#^Iz>J&R5@sk-D-f*KeP%{G(sz_3iVD zwM7G}{Or&0`sMS5yGHzVa#(q6LklvRr+Aj3Bf)CnM${wq94shsT*35dwwhaE+ONXd zRWFr52!`p={Edo&ESinia$VGE z6mwGN;!_u%{LHXcX!`Y5G2W?GI;F}U*+Nkh4snZlBMni4*7!-#xfy-VoP%vNuP@-w zGL9mZF|+z{+DnZiaGGa8Cv*Byecm|sb?aE?o_*<|(Bd8!*=2>8!=Q287dZGQ!0+KM z*LLWkiwwxn56R_05=G$`t2wQhapaN+z&8fVWx|hbGw1?be-fk(hZq)2`h4J)Hn&re zH%@9n0zj>vPRClPVRaKhVjXrICw@zZdtO}g-7)xTqA|vUD(U* zknZJ7^l~P7WMy`fy_~5a)60Q`u{|A3f*edEbs)$hbpWXxTt{n_Y7p`X~@_7OM~&JaXcUZ7+GU~5uo4YE^H^JlQ0UlUqoULw(D@D z3+BQKPd5nd@pO+!F7iytqx1{53(OElNo1L`qRo_z`YH4*GE~nc5*fh0%G}+Hx_2Sc z{aM>*@>~j9c(QoW9L~6b`wTx4t0zoFZ>v2-x- z?HD-F@I*y1H@+-gz`^0nfUE?Q64DU#Z>72#QV4i#>*^F*gScRUkq0A?2DO2v9>iMd z5|>a&Myt(2XutjGwOKj`ze=g*J-)QNuXijZLU94zg+{pr zmoC-rRuhUb%48z)2686<4reR#`WjQ5oH)IXb4CC&YLlT&-Gnp9`*O@G_@PZ6|papv*hf& z&OFgR1t;I&<7I+N1YachqXfdTsAQ@?P9Qwy%S_3+6TU;1QyBj+isIN>)p6X0P3M)O zG#2$D<8m6sRIh^9*-v$Rl&Jp^)EiM9Avo<;9h2Z4k=i&7k*afOFOux9F%l_~xG&)L zejXsw7}XC$pJ1fxdBlKHEE3U3v_peB5RBTSgJc)xF71#sFVYTSRDgDfs6%KyM|!$T z3S<}FciCvs4AIwU9^z5=Y8T!o-=kY(pj*Uxy2YQwl%oTVHR5iNW<~LOpJPf;;Afe- zN$?hd?CPIl>TLixyd&e04Tq$?vHMY}UqBW3q=ZPx&ix`A3c30vriQdZC>8PP1#VXk ziKMXbNXCUOuOEHaO7=P2@br>eLOAG&uJ6y`)Za#zdwEqHdAx;Jaeica;7VY6ZxDhR z!qEUWk2G{;?5@F7&mB+x1wJ*%M`)3Vli*(4(7VB%p!;t_(}lPT$2^}=KP4ROQPhF1 zALYU9#Z(M6$B+*P2Wtm09C^$Je9c@|e;Kort->zJ>5>yCXW$>Q=+_D4oN*5ko`NVt z1sjfXA0m*fDR>UYqf$ zeMzu%B)Gwf8FT4GxCcK_i5x{uU0FoqOMsg9vQa z?1+2Cm!qf!>Zy$yVZg#P#4uCQ%`YLO)3`!+51sua{2ue$R*uSD5$^>@MXm`Mr!qCq(jl z5ks7p(oR+oZpi0|O8tw-%$s@p5>l>z8n+PAB|L%9N!+eVgP>v8LEe~WsOJGq^)N~2 zHpyWOzj53U;;9S}2X=Jm%tD5x1{8qp6mDSA-Wv2+qXhlz?qhckyIc3mu5do4y58?W zW|Ck}oGnM%cP4xvVEf)0BQx36M}M^5EO)i;mxxT9k~xD^J$K_wb~5 z?u&5X|A}CKnd9{c_Vj>YkMHUkF25m_-b>i}dd55Pef117h4&AX_zje9at5Px8mh&2 zr{@*R%D8u313LXYvPbB&kuu?2i>2cDP2e|~N~O}NG5lpxokU-^=J_*mr zp*#)~Rpgk9tA~AjH|3|d$3zMRBk0-LN#d!YlQqz-2L=I?9F8u?JHDG*pFoO)eNsYA z_=|vgNHH{kY4+5x*I<@?^m?84lj{c%h@ZH{Im5UXSpLinLp1`cLK%wBvo{Q{i$Ig4 zKNlb{MM6tV&fP!6XZTi*Avk5Odnh;rhNh`E(Z3ZOf=%8E=FWi>@%Ipjn5k{(>TQ;8 z=-!`ez}bji)4?Hr*2w-|`*F(P_Pl708rI-2+NOf!j(HJ*@ZEX5!-6o#yZj^UMaG`* z&LMr*_T7FS77O?)f;spC#)8RUo}>1^#-Rxip(`*Z>e_2x4Pa!n?+K>Ve?*^0(LR#w z@1v0)Ot6ze+VkPaTkkl$rQGB65HCWjPd zpaM$&N}TpI8k4?kyascR5IW4os`_pn5nRU7gAo>Hvgk z#w3V5;f2{KFbz?Mk)q4C;pEWGXxOwyns|%>9xP&l2^IOK0i7RI#FTlzaHrk`d18cUgD1Vf$^9Y~s7D_umNqGr?f% z7^4pXth`z&jriWi7y4^&@%xvFlZJ8O$UieaiIdi&U z(DY^Dj{cH=kWS_FxGBbnJ6V%op0NdLBUoh8Ou}$+1Xjg3dVubQC(3f8f0Xftf9cL>_PlY)0=Qu$j1#P%*>2ZGUmxKs zFh&`I)fg4@{S}b)FYC{1KfzeU^mYz*F?=25B;akrS46#5W)zp$fOhhs+G*F?jy|0L%vb3?@LJT5-;NyIx^G95>5<&&X%Xwz} zSAu|mhc)sa3BHn)MELq&kP71n+(KAfWXrk8vTqT5o8UXFU?L1m3A&3o$#W!LRJF;J zVE-SmUMvbu?AhoBCg0;+mMDEcNzCT61cUW*T6*8=m!CAN;J`Wh`m#zXUDDla33dh4Z$Sgu8do&mEd268onw+RLEBt~W zjDae(u7^gW6?p}dQG8*e5%CI>U2OprXQBdpHCFSMhE9YqCOPWwv*O_Ou`14v?qWyDQrmCJ2kpsAz5h-@J~oZK zF0NNh=i~s-t literal 0 HcmV?d00001 diff --git a/networks/cuda/wkv_cuda.cu b/networks/cuda/wkv_cuda.cu new file mode 100755 index 0000000..8d178b6 --- /dev/null +++ b/networks/cuda/wkv_cuda.cu @@ -0,0 +1,345 @@ +#include +#include +#define MIN_VALUE (-1e38) +#define CHANNEL_SPLIT (512 / 32) +#define EPS (1e-6) +#define TOKEN_SPLIT (512 / CHANNEL_SPLIT) // the number of split tokens +#define IDEAL_T_LEN (Tmax / TOKEN_SPLIT) + + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int channel_id = threadIdx.x; + const int token_id = threadIdx.y; + const int _b = idx / C; + const int _c = idx % C; + const int _T = (T + TOKEN_SPLIT - 1) / TOKEN_SPLIT; + const int _t = _T * token_id; + const int _offset = _b * T * C + _c; + const int _tokenLength = min(T - _t, _T); + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + // for saving smem, del Sc, Sd, So1 + __shared__ F Sa[CHANNEL_SPLIT][TOKEN_SPLIT], Sb[CHANNEL_SPLIT][TOKEN_SPLIT], + So2[CHANNEL_SPLIT][TOKEN_SPLIT]; + F a = 0, b = 0, c = 0, d = 0; + F o1 = MIN_VALUE, o2 = MIN_VALUE; + for (int i = _t; i < (_t + _tokenLength); i++){ + const int ii = i * C; + F no = max(o1, k[ii] - w * (i - _t)); + F e1 = exp(o1 - no); + F e3 = exp(k[ii] - w * (i - _t) - no); + c = e1 * c + e3 * v[ii]; + d = e1 * d + e3; + o1 = no; + const int ni = 2 * _t + _tokenLength - 1 - i; + const int nini = ni * C; + const int exp_w = _t + _tokenLength - ni; + no = max(o2, k[nini] - w * exp_w); + F e2 = exp(o2 - no); + e3 = exp(k[nini] - w * exp_w - no); + a = e2 * a + e3 * v[nini]; + b = e2 * b + e3; + o2 = no; + } + + So2[channel_id][token_id] = o2; + Sa[channel_id][token_id] = a; + Sb[channel_id][token_id] = b; + __syncthreads(); + a = 0; + b = 0; + o2 = MIN_VALUE; + for (int i = 0; i < token_id; i++){ + const int exp_w = (token_id - i - 1) * _T; + F no = max(So2[channel_id][i] - w * exp_w, o2); + a = a * exp(o2 - no) + Sa[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + b = b * exp(o2 - no) + Sb[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + o2 = no; + } + __syncthreads(); + Sa[channel_id][token_id] = c; + Sb[channel_id][token_id] = d; + So2[channel_id][token_id] = o1; + __syncthreads(); + c = 0; + d = 0; + o1 = MIN_VALUE; + for (int i = token_id; i < TOKEN_SPLIT; i++){ + const int exp_w = (i - token_id) * _T; + F no = max(So2[channel_id][i] - w * exp_w, o1); + c = c * exp(o1 - no) + Sa[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + d = d * exp(o1 - no) + Sb[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + o1 = no; + } + c -= exp(k[_t * C] - o1) * v[_t * C]; + d -= exp(k[_t * C] - o1); + for (int i = _t; i < (_t + _tokenLength); i++) { + const int ii = i * C; + F no = max(o1, u + k[ii]); + no = max(no, o2); + F e1 = exp(o1 - no); + F e2 = exp(o2 - no); + F e3 = exp(u + k[ii] - no); + y[ii] = (c * e1 + a * e2 + e3 * v[ii])/(d * e1 + b * e2 + e3 + EPS); + // update a, b, c, d + const int ii2 = ((i + 1) % T) * C; + no = max(o2 - w, k[ii]); + e2 = exp(o2 - w - no); + e3 = exp(k[ii] - no); + a = e2 * a + e3 * v[ii]; + b = e2 * b + e3; + o2 = no; + no = max(o1 + w, k[ii2] + w); + e1 = exp(o1 + w - no); + e3 = exp(k[ii2] + w - no); + c = e1 * c - e3 * v[ii2]; + d = e1 * d - e3; + o1 = no; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int channel_id = threadIdx.x; + const int token_id = threadIdx.y; + const int _b = idx / C; + const int _c = idx % C; + const int _T = (T + TOKEN_SPLIT - 1) / TOKEN_SPLIT; + const int _t = _T * token_id; + const int _offset = _b * T * C + _c; + const int _tokenLength = min(T - _t, _T); + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const gy = _gy + _offset; + + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F y[IDEAL_T_LEN], z[IDEAL_T_LEN], zexp[IDEAL_T_LEN]; + // for saving smem, del Sc, Sd, Sdcdw, Sdddw, So1 + __shared__ F Sa[CHANNEL_SPLIT][TOKEN_SPLIT], Sb[CHANNEL_SPLIT][TOKEN_SPLIT]; + __shared__ F Sdadw[CHANNEL_SPLIT][TOKEN_SPLIT], Sdbdw[CHANNEL_SPLIT][TOKEN_SPLIT]; + __shared__ F So2[CHANNEL_SPLIT][TOKEN_SPLIT]; + F a = 0, b = 0, c = 0, d = 0; + F dadw = 0, dbdw = 0, dcdw = 0, dddw = 0; + F o1 = MIN_VALUE, o2 = MIN_VALUE; + for (int i = _t; i < (_t + _tokenLength); i++){ + const int ii = i * C; + F no = max(o1, k[ii] - w * (i - _t)); + F e1 = exp(o1 - no); + F e3 = exp(k[ii] - w * (i - _t) - no); + dcdw = dcdw * e1 - (i - _t) * e3 * v[ii]; + dddw = dddw * e1 - (i - _t) * e3; + c = e1 * c + e3 * v[ii]; + d = e1 * d + e3; + o1 = no; + const int ni = 2 * _t + _tokenLength - 1 - i; + const int nini = ni * C; + const int exp_w = _t + _tokenLength - ni; + no = max(o2, k[nini] - w * exp_w); + F e2 = exp(o2 - no); + e3 = exp(k[nini] - w * exp_w - no); + dadw = dadw * e2 - exp_w * e3 * v[nini]; + dbdw = dbdw * e2 - exp_w * e3; + a = e2 * a + e3 * v[nini]; + b = e2 * b + e3; + o2 = no; + } + __syncthreads(); + So2[channel_id][token_id] = o2; + Sa[channel_id][token_id] = a; + Sb[channel_id][token_id] = b; + Sdadw[channel_id][token_id] = dadw; + Sdbdw[channel_id][token_id] = dbdw; + __syncthreads(); + a = 0; + b = 0; + dadw = 0; + dbdw = 0; + o2 = MIN_VALUE; + for (int i = 0; i < token_id; i++){ + const int exp_w = (token_id - i - 1) * _T; + F no = max(So2[channel_id][i] - w * exp_w, o2); + a = a * exp(o2 - no) + Sa[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + b = b * exp(o2 - no) + Sb[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + dadw = dadw * exp(o2 - no) + (Sdadw[channel_id][i] - exp_w * Sa[channel_id][i]) + * exp(So2[channel_id][i] - w * exp_w - no); + dbdw = dbdw * exp(o2 - no) + (Sdbdw[channel_id][i] - exp_w * Sb[channel_id][i]) + * exp(So2[channel_id][i] - w * exp_w - no); + o2 = no; + } + __syncthreads(); + So2[channel_id][token_id] = o1; + Sa[channel_id][token_id] = c; + Sb[channel_id][token_id] = d; + Sdadw[channel_id][token_id] = dcdw; + Sdbdw[channel_id][token_id] = dddw; + __syncthreads(); + c = 0; + d = 0; + dcdw = 0; + dddw = 0; + o1 = MIN_VALUE; + for (int i = token_id; i < TOKEN_SPLIT; i++){ + const int exp_w = (i - token_id) * _T; + F no = max(So2[channel_id][i] - w * exp_w, o1); + c = c * exp(o1 - no) + Sa[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + d = d * exp(o1 - no) + Sb[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - no); + dcdw = dcdw * exp(o1 - no) + (Sdadw[channel_id][i] - exp_w * Sa[channel_id][i]) + * exp(So2[channel_id][i] - w * exp_w - no); + dddw = dddw * exp(o1 - no) + (Sdbdw[channel_id][i] - exp_w * Sb[channel_id][i]) + * exp(So2[channel_id][i] - w * exp_w - no); + o1 = no; + } + c -= exp(k[_t * C] - o1) * v[_t * C]; + d -= exp(k[_t * C] - o1); + + F gw = 0, gu = 0; + F gc = 0, gd = 0, ga = 0, gb = 0; + F go1 = MIN_VALUE, go2 = MIN_VALUE; + for (int i = _t; i < (_t + _tokenLength); i++) { + const int ii = i * C; + F no = max(o1, u + k[ii]); + no = max(no, o2); + F e1 = exp(o1 - no); + F e2 = exp(o2 - no); + F e3 = exp(u + k[ii] - no); + F num = (c * e1 + a * e2 + e3 * v[ii]); + F iden = 1 / (d * e1 + b * e2 + e3 + EPS); + y[i - _t] = num * iden; + z[i - _t] = iden; + zexp[i - _t] = -no; + gw += gy[ii] * (dadw - dbdw * y[i - _t]) * iden * e2; + gw += gy[ii] * (dcdw - dddw * y[i - _t]) * iden * e1; + gu += gy[ii] * (v[ii] - y[i - _t]) * e3 * iden; + gk[ii] = gy[ii] * iden * (v[ii] - y[i - _t]) * e3; + gv[ii] = gy[ii] * iden * e3; + // cal gc & gd for gk & gv + F gno = max(- w + go1, -no); + e1 = exp(- w + go1 - gno); + e3 = gy[ii] * iden * exp(- no - gno); + gc = e1 * gc + e3 * y[i - _t]; + gd = e1 * gd + e3; + go1 = gno; + + // update a, b, c, d + const int ii2 = ((i + 1) % T) * C; + no = max(o2 - w, k[ii]); + e2 = exp(o2 - w - no); + e3 = exp(k[ii] - no); + dadw = e2 * (dadw - a); + dbdw = e2 * (dbdw - b); + a = e2 * a + e3 * v[ii]; + b = e2 * b + e3; + o2 = no; + no = max(o1 + w, k[ii2] + w); + e1 = exp(o1 + w - no); + e3 = exp(k[ii2] + w - no); + dcdw = e1 * (c + dcdw) - e3 * v[ii2]; + dddw = e1 * (d + dddw) - e3; + c = e1 * c - e3 * v[ii2]; + d = e1 * d - e3; + o1 = no; + } + __syncthreads(); + Sdadw[channel_id][token_id] = gw; + Sdbdw[channel_id][token_id] = gu; + __syncthreads(); + if(token_id == 0){ + const int _offsetBC = _b * C + _c; + for(int i = 0; i < TOKEN_SPLIT; i++){ + _gw[_offsetBC] += Sdadw[channel_id][i]; + _gu[_offsetBC] += Sdbdw[channel_id][i]; + } + } + __syncthreads(); + for (int i = _t + _tokenLength - 1; i >=_t ; i--) { + const int ii = i * C; + F gno = max(-w + go2, zexp[i - _t]); + F e2 = exp(-w + go2 - gno); + F e3 = gy[ii] * z[i - _t] * exp(zexp[i - _t] - gno); + ga = e2 * ga + e3 * y[i - _t]; + gb = e2 * gb + e3; + go2 = gno; + } + __syncthreads(); + Sa[channel_id][token_id] = gc; + Sb[channel_id][token_id] = gd; + So2[channel_id][token_id] = go1; + __syncthreads(); + gc = 0; + gd = 0; + go1 = MIN_VALUE; + for (int i = 0; i < token_id; i++){ + const int exp_w = (token_id - i - 1) * _T; + F gno = max(So2[channel_id][i] - w * exp_w, go1); + gc = gc * exp(go1 - gno) + Sa[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - gno); + gd = gd * exp(go1 - gno) + Sb[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - gno); + go1 = gno; + } + + __syncthreads(); + Sa[channel_id][token_id] = ga; + Sb[channel_id][token_id] = gb; + So2[channel_id][token_id] = go2; + __syncthreads(); + ga = 0; + gb = 0; + go2 = MIN_VALUE; + for (int i = token_id + 1; i < TOKEN_SPLIT; i++){ + const int exp_w = (i - token_id - 1) * _T; + F gno = max(So2[channel_id][i] - w * exp_w, go2); + ga = ga * exp(go2 - gno) + Sa[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - gno); + gb = gb * exp(go2 - gno) + Sb[channel_id][i] * exp(So2[channel_id][i] - w * exp_w - gno); + go2 = gno; + } + + for (int i = _t; i < (_t + _tokenLength); i++) { + const int ii = i * C; + const int ni = 2 * _t + _tokenLength - 1 - i; + const int nini = ni * C; + gk[ii] += exp(k[ii] + go1) * (gd * v[ii] - gc); + gk[nini] += exp(k[nini] + go2) * (gb * v[nini] - ga); + gv[ii] += exp(k[ii] + go1) * gd; + gv[nini] += exp(k[nini] + go2) * gb; + F gno = max(-w + go1, zexp[i - _t]); + F e1 = exp(-w + go1 - gno); + F e3 = gy[ii] * z[i - _t] * exp(zexp[i - _t] - gno); + gc = e1 * gc + e3 * y[i - _t]; + gd = e1 * gd + e3; + go1 = gno; + gno = max(-w + go2, zexp[ni - _t]); + F e2 = exp(-w + go2 - gno); + e3 = gy[nini] * z[ni - _t] * exp(zexp[ni - _t] - gno); + ga = e2 * ga + e3 * y[ni - _t]; + gb = e2 * gb + e3; + go2 = gno; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + // 1024 threads per plock + dim3 threadsPerBlock(min(CHANNEL_SPLIT, C), TOKEN_SPLIT); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock(min(CHANNEL_SPLIT, C), TOKEN_SPLIT); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); +} diff --git a/networks/cuda/wkv_op.cpp b/networks/cuda/wkv_op.cpp new file mode 100755 index 0000000..efe56d8 --- /dev/null +++ b/networks/cuda/wkv_op.cpp @@ -0,0 +1,21 @@ +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/scripts/check.sh b/scripts/check.sh new file mode 100644 index 0000000..9f1c1a9 --- /dev/null +++ b/scripts/check.sh @@ -0,0 +1 @@ +python /home/jmwang/semantic-kitti-api/validate_submission.py --task completion /home/jmwang/OccRWKV/predictions/submission.zip /home/jmwang/datasets/semantic_kitti/dataset \ No newline at end of file diff --git a/scripts/run_test.sh b/scripts/run_test.sh new file mode 100644 index 0000000..a664f2d --- /dev/null +++ b/scripts/run_test.sh @@ -0,0 +1 @@ +python test.py --weights /home/jmwang/OccRWKV/outputs/OccRWKV_SemanticKITTI_0828_074122/chkpt/best-metric/weights_epoch_043.pth --dset_root /home/jmwang/datasets/semantic_kitti/dataset/sequences --out_path /home/jmwang/OccRWKV/predictions diff --git a/scripts/run_train.sh b/scripts/run_train.sh index 4f6426c..3da68ad 100644 --- a/scripts/run_train.sh +++ b/scripts/run_train.sh @@ -1 +1 @@ - python train.py --cfg /home/jmwang/OccRWKV/cfgs/2024.6.11.yaml --dset_root /home/jmwang/datasets/semantic_kitti/dataset/sequences + python train.py --cfg /home/jmwang/OccRWKV/cfgs/2024.8.29.yaml --dset_root /home/jmwang/datasets/semantic_kitti/dataset/sequences diff --git a/scripts/run_val.sh b/scripts/run_val.sh new file mode 100644 index 0000000..d9f6c40 --- /dev/null +++ b/scripts/run_val.sh @@ -0,0 +1 @@ +python /home/jmwang/OccRWKV/validate.py --weights /home/jmwang/OccRWKV/outputs/OccRWKV_SemanticKITTI_0828_074122/chkpt/best-metric/weights_epoch_060.pth --dset_root /home/jmwang/datasets/semantic_kitti/dataset/sequences \ No newline at end of file diff --git a/utils/__pycache__/checkpoint.cpython-310.pyc b/utils/__pycache__/checkpoint.cpython-310.pyc index 8c3b6bf16ccca5aee306e91db5c757163472c2c6..f1a9de740fce24b953059169e487cfecb04d9f17 100644 GIT binary patch delta 22 ccmX>pbW(^npO=@50SKft&ZK2-pbW(^npO=@50SKZerl)0X4F#_fJj^3il417`9MQATPC~JijQrST8?0xu`t5 JY%&X@DFC`P5f=ae delta 47 zcmX@ew3CT9pO=@50SKZerl+M&Coi?6JijQrSTChGc`^^9DF7$B B4PXEO diff --git a/utils/__pycache__/optimizer.cpython-310.pyc b/utils/__pycache__/optimizer.cpython-310.pyc index 66721224a351080b21d2fe9aad75f530fc75b93c..8cda097d455c13a2e2ebd3eec57bc5264dcf033a 100644 GIT binary patch delta 682 zcmZ9Iy>HV%7{+}ccI?D08DK?HS7a=5v z44oJ{ApZmHfCM`$13QTQMg|x_;txPbyyuV*^(;TS_kDhTpF6*?epyc9I7JPXKt~53 zUOU0$EflGCs4-O67_Mt$EYQ3rb?pd^$*exXj57TckLg0kU(A%Kz0Br%#f6(I$N3%UOBVTRqy-fPip;a zwC%mEhuaSxcQ&`08&5cbyGXjh{Mf^sK#@H{%M)ES_PW=bt3STpS!sBL>tJz1!S?g8 z-|dMG3j_a!B)i+2|7>Z-NgO^CqnMijdZyGntRD!eA4G{SBh6R$nM^z(FEM{L&4OOyl=>Fvua9ngQe@t9K)G&Y75T)iHP|Ryt3xB*; z0q7l)dN@de9#_>=s<4yh-N=vknSbv>cALDMI9zr%VW>Yt2nPe3uZI7OM zD1n~s!9)KH@9N$88wx#o5I=siOIC1(dAv98&0_}U-1+1*vu3k_V7xerU!EN`ql04* zCE7&-FCjse&>op73W^F=>o-U(*IY30nBcVr?|RJz2alITSg(k>Ti{(%;*azMl~l|L z`;hKgrL}{=cfjv%Lx9zRl9gOE4%sx~U#NQc+QD=8>E`ywaRVCy$?yq|^v|HZiSs$=qbr47MTthpRz(LfnBg^Q-xzCR4pz6e(@SNnAT z*2G<$MJCeZR0rS+tegPY$Y;s{q8imr*0Wqg=c~u8bLYSM`TjT=iXe{$QjCXET{0Cd s0FTQ*Og;hy;_$}`Y4@ULrj54dldnWj|?iQEEkkxku&zX7O@NdN!< diff --git a/utils/__pycache__/ssc_loss.cpython-310.pyc b/utils/__pycache__/ssc_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e070a03cc512192b2b8c5ac25b25b418a7b4f598 GIT binary patch literal 2211 zcmZWqUuzpj5Wn5~qr0;t%YV{1aSKN%r6R_Lz7#@eOP!R`7#A97K*b2@`l4Cb3c6MfVclI>%o0+233N4Jozn>2~ zEzA0ogv*0N_ySXX1Cm$?8(Y14Wf|{rR`Yhoxy7u^>ABGG#Qxdpd5M#_SR09#G_d-K zp9ENgggvyn;Tzbq4r8*v94;7|`v1h!pDt^P7*`@V_m3(ej_8F@f>FBLj z_ROgq*mR!X4kL^Q2T_v#B;-isv51O1%8UHDkcEok@jxk|%qhw6g0|H(>oOfkkxb(< zE%L6T*;m>-8TVmByJaEcV{MO%QP>rZ+O+$x34UR!bn#{cPneXSi88_Uuh-Ze!1@yvinS+nP1Ny9?fjU3Uv%s?XKZVlODIxl=VtckWe< zGj_&}>>Jr9`8xguP6S31)Eb1-J$yqUdd;+*q?rlVn{z0-4e4Of4OI-wZ17CG!*MYv zB|%AVK2Gz2JnhE_x9W?$l*Qy!l4o^L51ru+L<8YI;1Gn5$&yrY*L=~C$e~~ zdTj*tlB}IO+vWearTlkJ)5cS(7)xz0W$l!}xiJ^MCod`~U~9AYv1W3q-tt!>%GCLHc06 zjkom#7(jI0(oXEkeaY`xN53Ji3*=PZkW#Htv6){rO#H!{H7K3ATOmHP5I4grsJs*V zu4S}{BlDm){EIr5LnPTWr^iJf@^mcva8mYVpsR;^i_E34*G2FBVz zhF`VK8*M_yY_)1ttFTkwT3Y{gaSJkTb&D)?v9|$x8)OeT(!$DryaCSMwt!pjP__{ixk`i(EZ0Crm>7S3 z-VF?Vaxm9|k_z+7&V za>XcX!1q=N@aZs()AF>}F*}zu?_`1!+1pIG4QuKUq-{IQK?a1Z%{s0_S;0bNLmTh7E4(qQ7Lj&$ek#gsxo>kq949|xD)N~L|^Yj zk9MMaM|JlhOrWfYLEg2s_f({#<5D-%QC>*VPxC}Plk~~hNr9*5WuG^UmRZE_Wby!@6?% zrVH-ox%?Pfy|EN`htu4+U@+F 0: + count += 1.0 + nominator = torch.sum(p * completion_target) + loss_class = 0 + if torch.sum(p) > 0: + precision = nominator / (torch.sum(p)) + loss_precision = F.binary_cross_entropy( + precision, torch.ones_like(precision) + ) + loss_class += loss_precision + if torch.sum(completion_target) > 0: + recall = nominator / (torch.sum(completion_target)) + loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall)) + loss_class += loss_recall + if torch.sum(1 - completion_target) > 0: + specificity = torch.sum((1 - p) * (1 - completion_target)) / ( + torch.sum(1 - completion_target) + ) + loss_specificity = F.binary_cross_entropy( + specificity, torch.ones_like(specificity) + ) + loss_class += loss_specificity + loss += loss_class + return loss / count + + +def CE_ssc_loss(pred, target, class_weights): + """ + :param: prediction: the predicted tensor, must be [BS, C, H, W, D] + """ + criterion = nn.CrossEntropyLoss( + weight=class_weights, ignore_index=255, reduction="mean" + ) + loss = criterion(pred, target.long()) + + return loss \ No newline at end of file