From 9bae62f69d957ae30fdd0d01fe56620d90835d21 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 23 Mar 2022 10:31:38 -0700 Subject: [PATCH] add proposed parallel vit from facebook ai for exploration purposes --- README.md | 41 +++++++++++ images/parallel-vit.png | Bin 0 -> 14658 bytes setup.py | 2 +- vit_pytorch/parallel_vit.py | 137 ++++++++++++++++++++++++++++++++++++ 4 files changed, 179 insertions(+), 1 deletion(-) create mode 100644 images/parallel-vit.png create mode 100644 vit_pytorch/parallel_vit.py diff --git a/README.md b/README.md index c589ea5..9028d5f 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ - [Adaptive Token Sampling](#adaptive-token-sampling) - [Patch Merger](#patch-merger) - [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets) +- [Parallel ViT](#parallelvit) - [Dino](#dino) - [Accessing Attention](#accessing-attention) - [Research Ideas](#research-ideas) @@ -240,6 +241,7 @@ preds = v(img) # (1, 1000) ``` ## CCT + CCT proposes compact transformers @@ -866,6 +868,37 @@ img = torch.randn(4, 3, 256, 256) tokens = spt(img) # (4, 256, 1024) ``` +## Parallel ViT + + + +This paper propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance. + +You can try this variant as follows + +```python +import torch +from vit_pytorch.parallel_vit import ViT + +v = ViT( + image_size = 256, + patch_size = 16, + num_classes = 1000, + dim = 1024, + depth = 12, + heads = 8, + mlp_dim = 2048, + num_parallel_branches = 2, # in paper, they claimed 2 was optimal + dropout = 0.1, + emb_dropout = 0.1 +) + +img = torch.randn(4, 3, 256, 256) + +preds = v(img) # (4, 1000) +``` + + ## Dino @@ -1396,6 +1429,14 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@inproceedings{Touvron2022ThreeTE, + title = {Three things everyone should know about Vision Transformers}, + author = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou}, + year = {2022} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/images/parallel-vit.png b/images/parallel-vit.png new file mode 100644 index 0000000000000000000000000000000000000000..4a84741cd69e3426470ff77ec01dee317cd08503 GIT binary patch literal 14658 zcmch;bxUe?eWN=-!;4Vf4j3JMBMUQS8_3hGrBcwLGJ5B{z9d@KP6 zSXT*oEky9+i)a}M-V?b=>$+(=zIXF9bFqT5c5t+}dhKfAVrAvvYUAj34$~t7PI?VF zNz%p2%+1!(fl|xX-U>>`%94_cn^M`#g_4Vdi;I$zLx77{fQO4xT$xf$N~_!F<2n=+ zC6v6BxRzJW>6)h&t_}(Og*AyZwK7(HNFgOPzM6&PV2xDBVjM61HQP`cVO4#)u9}pE zBx+ulG;UQrlm~0Qh4lAVjM5(a4=^Oi68}N0VjO)}I_l=Dq^Es?eyEA&1fVaBMc6e1)Rc>U_^TOU&h#~N3{lRc3 z;O&9KYFki))%ekFJx1}(?ypK6^%}#sMiCNuL^#UAiJw0Q%fu0zoh~XyX!urNQOq4_oKw36(mv{Yy(k;NHc- z#h)L}MF)q53a!S|2>G2T?G|d#qNECs6Xl0LCnsZ(lE(J-ib+XJ7ud`ca`5rVy0~zG zDT4O#@o_tthW5L^ni+4ex!?Y~JD$O1h3s*%7k;CFc!n{(#>|i+{n0sqUd*V2-uAWFSZ7p@2_m#&Q^(xIz7UYlW}5WV=o@>TysSH zh|SE*2>BdH1w7AUnKZxluC5v`n9a$0oGypu`QIb)+WqW3S!z`8@%J0qB@!Ll-rlbB zIMsJMo=1CEq1E@9PVwsc70TFpJelBhtD_0OprFNIB$kt>XGCV^Qs*fX6H{b#bbmA+ zi&}vMTr3ec{6GZ8e%UR76ejFMCO5*|SLJ~+^&*ad=SNlqSZr+Uiv^Q_f=`!ww|Dz_ ze~xQ={QnY(=MkjyQFpaEZAr9sb{2uv)#mr`X79(B81N+v%)dw1H8x6uJD%?JbgVLH zMNelp30EkQ0Mq7tdtqi2p+a$Sc`5nw{8ayblKA6t6A~K9n~5uwV(^U!xhxUcjK3j* zq1LL_#{*;aQ+%*kF__ZyXPG)7j}7L0jZw^6r{_-RdG~mt<*UGg0xB?jad<4+lw4fc zf}Uqm$@I#J85w0oBZrdU&n`Cm7VFGlRC0x)W{PAhg^#uBOvO=2`7y132aP~@S75|4bX@)rTT82s1$Q8)}7PMczcL`3D)j`sF(bc$)FhqEOY*VjW}T18az zi|3nc=jzM`;B)w$Uy+lO6Q5ex*hC)AmF^^nK8qE}#5tWV^ClG)vyXz|#3Uz=*DO^j z5D$R|=UH1v)YWmMuo)4*HfWL3(D>XYuF>-N_~`WaxBSUUD+bss#OKx)79n6AQM0k3 zONOIjaY~G#<0#J>knlN3nVM39b>(^m4F~mFzo~a9j%23BNN}~=&!8Al9_QuxfgPV! zH}L7O`El&FFVsml>sf`>fDw{$@Swr~ycMEH0(SN5K zQ%+_J#E)-%(9`S%!Uac(43yfif(A+*s2XJ-q#n_7gq+dXUG8&u>cDFD^6>!^NBqgo z`u9juo6pS|frtq=kta2V=}WiYVyh#?^(ltr&oxWlU8c5xdyaS`KloprZ?wF;4{6y3 zyR~1`sIT?jOGruxyfapF0R!*ugwd$W&m@Y|TBg*hrU z2I=5ks~>p)4ptIh;?_lg3(@e8b)OecLmrQpX2iyQ)`1t5LCT%2yufS*3>s$o?DjY& z49t34xV;}aA-jSVZta7K=qAW4iwBlYX=k=_qF@X7|e^ zXSbthzsh$?Y2{C2U5)&1C&;-07af(`HJgKB2qV7q7RYCLtHFWo+?))iB!zhaxs9)8 zW@g%a@3=ZVPP6t~^qXv;Qx(EO5HRehxW+hxy1m84#U&!p$+*_ia0R3oE%Gso@$OL& zjLW$``rwHuFgku@lr!?(&_XA;Ph(?2MGW}Yt+DOB)#+%5JjdRH54&q&rg*Lp{BCz% zqQ@*G>`b`=*&crlGsW?^yPNU0T^~_{+tX}4Zo!UP1?Mbc5_fqq8=3K`sE$uM)oL;{ z6t-`0koJLK3#PohJh-{J*$He-($n-*kOmM>S6UYo43_@f+z{Mf9VRz^`P?7oUoru= z;+Ty*naxL#)p~&)4!ePyax|t$j8ylLmbLr3CC?aH+&%{748@`QdJv(^Ivs zzJ7E@l2rQ^qb6%~#`rH)~rg8lgdVz8O#@R|xT-~RE5^R&$HMU3HnzGGD zKNGXXO>&83X8Ta-*17dx%7T2#PTX%WadR6(No*zdQf|WV?3s~&^_E6(-FQgV@B12zExLq=8>GQ07&X$7C7$D54+iFTJ= zmo^=53WBCeCY}%LhBoWp!lvyO1*qti{o0l6*m#V6rM*^^qmDZf>2n*yGE~w1OPu)f zLrl+?^>1w#Df@;?>L_akH2AN$_J;%R{_ci>?EAqmU@V2Hf4V@@^^rfXQh1hBmgnCW zN-^^l#7VrB&7R{!1d?2{hgthN+a<*LX(u_VL#X+%;)3A@G+iNU5$z{espSi-XMVANMWh zp2vusQp+J1eE&JQ*!+RO<8(Fr)9JWaMUCNmlp6&d&QV7k@(tu&;Cs(QhxDHdJ)qM5bkZ>oO5WJ3{u zy)Drmt8PN#j6>x5HzMl>BDn=E4GnBI!!{^SPtWFmDd{>En~Ta`DU&`Mv$NPLDk}W$ zM=2h!aDyo2OlMrUT27Y3fH*LJ$oav@q^ZLuednF({3l_f{-3i)QL=_^Ko2HEf@9E2 z4&36!4Pl$nYV>5Orc{7&Qc_ZxUgP^W%_(i)%o&G1`X!#W>P}|yg1BB@eZdHX-X08d zyRd*~dEgog-}YH{k@_bZ!f_S;qEb(!$qm@Qg>_80F#*96&V2uN;>!8So6tVL5jB@$ zz*t2`>JMWT8TidcQ}?jZOt|@;Zj1!f2G|Y`4omZMSK~uNL+~;(GHve1X&$d0vE^db z0u~XqmUkmEhZ{|Xq)$Y~mCVhl$who4fpn^@EsGrpyC`bQO7!d3DpA1e>PmOhrbQ>4 zZoV>ueI&cw8`q7>K7Qp~tnKMVz&qVW)~>9M+to;Me6u#h5D}mN8{G0u=2(Pgy1kF) z->lTC^?7sNy2=Pb z`%}dpq4hC7gb<28Fmh$T`$N@Uf;X|@Nl4q^XzD=!T-pn?eWU1g!C*2F?2azl30z#@tS#1zex%&w4}%PoZ**+Tr;Oo};zjuW+0VRE$OFE;MsJ zfAUeZ+oC%wO_j2T;;4#q9+%nFNKkQF*2EBf-i~lyTQy~ZOLW6M_8ghB|3KRUI_H8MEN zUPTkI*Eo~5L$o?jBm(<_6}i-I8ahn|v~aX=>QkPa@H!fr=+KP6$ddi_6NYfo@^99m zk5;Cn^Q@%jx_vv;-L9(QS+im>9!M1@1f?TsN|7b#DlYiWE?FG)$27s~ zxnzN#zk8ll>)DT+u`lzY80s6x^DI~w9ZI!d!xFQB7&j& zILX$dFiJ2qjD!Z024){h`Wb)q;e-SjC6v|wQscgKkV|(XW?9G!G1!0WG+}Ok1=OiI z(4uwC&F>yaH+VBsLP06^zWr&pS~IePEMbA*oaU_vc~S)WX3Z)#`i6G6JN`yEC@^%J zY_RsG@?J*Yh+{&uG^k?erabe#pYD0Vh&0yTBH`tpe+gzgeWVJ4gOXtn{28mAi?FoPdtgKTl7f-q@6}bPM8TQzA7eYaimVQHXTc*} z2fE_z?T~airI^727z|2wL-K#jR{mRwhHSDD60bl;a>46-bGleBS#-|m*C(Y317yB# zy*V7vAC_Et;V`j?Gs#aszPTG9utS2P}n=2YnbNyEC<8pq)K5eXlJL$j> zL!amYo`B<3;OKxo5zAO~ueCAbqaE?r^YkV>VI%GCB}QqiDd$C>dSQ=y)#1>sxVrP6 z>RV4cm@G;(w#OVV%z zqs)}5&`fdPe%Am!D|*Wnvjg^8oBvaXw#u)4Z!HxvSeL@A{Gc~Y*_J;Ceb;E^MP$!5+{b&&^md8 zy*7W`3^_IOU3#DBW} zPi674%fML#vT+}}nRHxfOue77&y2M3&g;!Em#=^({ z`p;uyeGk-uwx8^V2M1vdT0W|q$lJy53G16_ZQvoYU1G0%?A$Jkfxfs%j^^NL4Ue15 zHkoIiG!0DX{*AI+v^%ov)Ep_>PDB-XUyV#&^MX*T8Dp{6`+VsB8TI1e#5+{gDEjPD zr|4Y+lw7Q-9+xMQqDsx<)y>T?AP$X-S)J_JpmHg= z&-4B#u^9S;Enf7lSHWgCmIMl(D=-9aFGk0LH6$|(2*TpT?>S!@HW_3^yUV9hX<#Y` zxPHJTKNAFPe95(Vsv%HI%L@r|{K`e2Ik+`=&!J@S6Yg@JSMuIq6#KBp72DZGhT#ga zHX5%pK%94=YQW!TzZAfc4_7>=r>C6nCkR1_q9oJ5>vX?F(HQ`*f&LQA+>=G+bVHaj zNpBk+A1?#48CTra?)cR|Uc31!Bu!1tIP-}yfU8{FOJT!X-2eD>%6PqKd@W-N7$xc+?Pr>Xc23=%|T}j^6H?au0Roxx~Ura5&I46@-s_8@epq zk*|avkXk{WBV}WGk{#nad+~b<=UYpIcv5L6Chq$%Dyl62l}zt@E=mpx?-4${$cO52 z8Z0g?_5EleoDdyQCDGB*;T>JhX3y^OskeA9nBsI6)VtT(JnmiGKK|#&LB)>D~ zX+pp>c9jt+4o^~lBKDZgb!QFAgtqqft&t>}rJtY0sQVa@)ZfRG3s=lM8a-)3kqi5o zKHsR_%Cv0`t9aP;q==UEhwMCH-rY_{J{&W6-O{4XzmdjlwWoQmKThbuAocZpj9@={ zbyq`q;%60Y-r@l`N{BtD>qt`YWq^llXAiQ)URu~i!_3*!V0BfMse?m25a5;73`ug7 zGn-^ZTFlOXAhdOK*s%<i?yD4kiyH;-TUWLDlC*5|pbQ+yF%m5d_FuFrpmzIU~+ z-sSy|=l&1d%*EzLqcz6Se062m2xBQAfr6Q&E-=#f)6Ip*9iElJDH0o7`)}qCztkoU zY_16rGW|w|!~Khqp#K)U`0XO*IgD?TFyQvx!EgpQ8h50F$b%U|2=ELmvIpRJIbr?!(9#Y=c z#ndl>EWpJvk@1*m0fpT=tWQ-aQS0=uHv>saZ(d$W2^Y%7>dWM6b>%8Yh zMoLO*J{2`4zpA=q*`&YGdMZ3FZgM|cQI?1i7nxPR39-xPMn*zH;#$y>il3hlP!Ga4 zA2aZa#se(j6!SlSaOY3u8cz4VU0Gf(1PuRRqbDGDzT2qHiHa6m{wo3W%s!|2Afk}} zBOwYt>xerE;3Bu6K;X2UrF|_U^A@xyu0A(6^8%rmz-H93xiz3#<25@1SQw}4o)RoP zynkQmA+kskWXz1@x@1Vwn>9dMk z1w2k7Dg1At0c;%vjKsAG2m+t2ot+Tyi+rU_u5DO3VgqnrV`;BZJUwqdvR`>AV#2cO zHNcVydPvF1$%%9!G6ryQaY@+NFt#{u#QQ(p+djfbmuQ&%+5UboSBl{Q7B>L5Ni-!V zDJ(kZfMaKHTEKU>ABO|wV%0eXf)%IDG$ksT0ItvViFhjWI}yp^6g=?d7U-2S0s$x_ zs;Q~b?Qmn#&=hSim4{c&73#|u4}oxG$gNaLyZ2aXl;eWEe(^t!1paIDsl$OOjlRebUI)BV7*Il zgWH4>2M34MuniM1tni*5QC~kl=jX>yU^k5HxRYzln~ zG)fde_4d6)G5rFpZVQ0&A$JdY4{8-!OrVm41Gv8*l+|hg&lp80+NH4SVb7JS4uhcr zY>qBRz#aMN{@=i6UkIn~trJL5g>p$J4$KO7@;gIuGi}Zmt33e}pr6C2$d! zgV(1|Ti{xnj=#KBuPGJL;2&2)b6lx_CX@HgS)6nX{-FSFwV?Ahtgo-H8`+Sx zD`*Sw*i4Hpm_=>_bU5dZSWKhD^Di(W0-?XZ-;FHX=I@+~S{_$E`;C-LtoH zu8>|W&MoMM9SMX^6?*;I_Vzo0IlQc7dNC_3N2sc$0AznKr41D^_0xB(|1sD&|I=~b2HS5}iy7~~ zkA4|Q9~lT+zAUgIXQy>60RAbq{aeWiu*A-TH1V22s7$8mDNaM zeFwcRNHq9asiV43a4_My;TI9`@}$H5a0>znr1!KpnO%CU7AXg-#1nNDSYFOVF60H8 zpi)VBI1{a)M_1>#A>!rjeFWwdlniX<=H_5%v7&$~jRGBMIb712pHzSVm1sl_Cr}&< zr1*}GjVUZpVXBL?f_!Kcaq(_$zkU2ZIO;3<($bRI!E}LcooP^wr&ui|QkTE>KW}2@o#;cBcnLU$YkuKaYJ(no8Q^uS)w-Pb_!PaA2C#(MgnhtyZAoo`TuH zI1H2X))@tkc|q)xtDZMWg>@UQGe+7kmOg<-$Cwm^$i;RP*_xhSzWBJUxqc%26FKU8 zC{gUlE4{y*N)B_c;*Pwx(?g<&TItC8+URp_mJ7M}^5o$XgK@Y!3r=%z(#47Cf`R@o*72dhsn(mmDx(O_w?H0gQZu2oM^g4~2$QQZH?fR@T`LN5AzOC#WRK&kK`RZI?cXHv=6uZ5sV#=*-Qrj#bX zmnsXoBfI12I+~Km!zxCJ9KLkV-Nzv!be3Yv_EjeSu)IDh!K0zG+b-@L(Z!|8y2 zfbS@jMjAjh(2RE+qz}*@xOKjFE}-k1$)1P+6hc%~)RN8wO{%;->PtJN_p=Dj zaMqLrS#$Gl-Na&2A}&veE!&UeRM3G1&H2f@cw(UIdKgc!lip|8U`HvHXly;+UD_-) z(BVdtqW`m@eEk|t!AhH1BAQ6rPJGP1Qnd3o+DM)*}2rIW5N>jKQlic1GaZ%wAMbe}vbn+5$* zgd#twPD~Vg99T>GntlcPIw3%(SY1lX)kk*G4`}KOed~dv-wLhGrX# z`sQ@sd?zxC11dinWq2`9Cx7!mneeFix%g9_h@Toj5Ubr%dL5qUx*5_igIY*ApC1!32Ut+nG0uxwI68_+q&7ejx$Eayh+ffhG_Yoh^k-s*Wr}pwx1LFa#=F1_XG>}1 zqaJ;Fu*``muX#ipPqX0gur*TpVjaWfSTVxl(f7@LGLaPun?^B33rdVy<;OuN`PRF7 z^BeZEH0L*PSXem9g}wvDBHk^n63Kl%V#}7~X{kvts_5x%SnBHEj;v6U zak2y(U!Vj!l#AB+I@}Lt*QS?0KDzSIV$kgp&|$1kMPl^n17cKDPEYNIY!N|*=w-%R z+kYL(EOp@Ak2?DMT(Oj7REAZ<2I}vmc{P-6j}$VK(BxROpr7SR^SBI&;X!T1!fqFvZ(oA95Gn>w5j0=oGbzmvrX$KR{R^ zF&+J zdQMI(z`}hHdr^VS+U5Qfr{xGHL}5?^C4H63&CB~8D0m=r%ewFONQ@%#H6cZtzJ6BC z0A^f)+54L|T57*uspJC|z@M%>c#>!FSasEa*1_6E23id$%p^t-O#UxNe%a*6za_xc zvkBsBT}z9Mr6nzZ@n+`cg~xSrD8Q={esSRhabC$KQ6nNFOX%tn8X6fv%C!HeTB$tR48D)DfA>b%tt2a1;zb|+vgyvo%m9|WYF z)BQ;6_V(5kpoud5X6E02ZzzI5Bwnl>-G{JOT!J zKIjc%!ixWvV`>Ej3z~(61t_&Gvn~D0Kmohh>B$bjEC-+&STwX_8(=<@*8g5!h9o6n z0ayW-sgzda7Z(>-55z5S0wRIRoLXA$6BaPPjf}{FNafu5S7WOx=0F48u5wPhE0cdam+(oZtf6o#{ipN-{K#yRuB{mkYxr^OVg}zkN*=UAz@i+i<}?;S>a$c+LgSf zp`n3H|8P8+zHLI}x&_#M-uqznG6J_95Y{|oDxwY?d5NS(Z&q`AB-m5 z-|h+Yh$2w7WX+O|5MaPBgz(vobBaW03r{#f9n=fjMjYN(Hq|ZV8sz}9OKMf=CXN60 zr3d*`0yHEVSD*jRGfAz*4kdQ!}k5D4={_FH& z&(F`qbgJf`{iq{;blw zO#aOP=qL_@${?uc+;by$39QCa`oP}j68<{-Uv&rFq8Ar+7nhgK0XEW4Xa+4^eaGa3 z4}jtwzyJ;*Q6Lf4wEW)1g$t^8Y3Z#jih|}D;H{{TW}0~=sfy@>#m$u1*p`Xft`0Zs z8F@&h)%uDpRdFp_dU=k4lT=t;aGWx~a?E9Iq*cc~j?Hl4Z zL7f~GRrj@>UkuwVhL8eVF&PD)8BZpoFR4uV}ucr`RQNXf*6 z)Y8(jjGXkh3~;%{E+5Y~&9sSf`QjL*?CDyy#=`cYGEx+f+~Ga^^CvEF)B;;%$g(%Q z!$6(z1w+c{dme5M4ooruw-KGvuRU%MzaW5oHU3n70~Hr*_XOUl>Y`ahz;O%WMS<V#HxODsA<3l{bjTt63DBlr zKeK+-66PGR&jt#9G1{4XY~RO2bd=?;sO7r%|Z3L5Sk@GIHF1(o$mJYiMzprg${W9g8;FX!XTiFG4HmaBoY26 z)G>p315^gGWB$`L9f}K)NV*e)(K>hV<|a4Gn>xHIha)zIS=~o?BMw?K`kjas)lmL46ipUHv*7l_(fw z&2QjSjlKZvV6jkR1Q}~)?J7iTYil4D@};7%1!8mLUgTldD?;V8Mn^|E-n@|kxM~6L z3USQ3PnYi@NN|5DPd1L2ck|>V_7xm5L^A*wDy^l33;1IMm1KDD+??i4O9Gi-I0)lg z0Gn+KL(c+}HEJ5WNj$K$<_P=X11Ihvz${o46baz3ihvUdu-f$bg>kyV2@r~~iEt*_ zr+_yV=z<_ne3|V{WI|vzvrY}Fj?RMNT5w|2Ub=j_Rw2kQgl}DF0o@(`lp!V13n)g) zIV1s88?=U3RZtv%-^djTq=U1NXmz$UXIxqV_sgM}3-vVzos;sQ6 zQAD9Tq@_?)WYMj{1G>Jb1_lORUS0vWzguQgRr8X7$|lQAm8=TJ z>@3Wh9s@js?Ayf237?e2XLg`V@?##E@|m9c8E5mrF2ZqLOI^4<7Uu?e+aP@#3V2xk z**!>Ohj0dF%htZp*`;#$Zx}zm&JiL(IJWzxzD)bDh+QMSS@@Hj_jIE9>f_;;-<^>H z9&l4BuRs~nEEXp4OI>25kQBpdZ}2s4nIX*Xq}8E)S?-p`ps=oXE$j#04g3>X>M>+l z0M_)8v$b;}JK(b_VqE+*^g?|NZ| zVisEFLTo0!Gavu%)8=Oh=TF{2%30CMo7%=vOClsBJni?Lm+@CviEreE?Wx*Y3HtJc zKsfGe5etQMUrYSuma`C~#(h(e|1*RGc;iJB#ycJ_C~S7r3D~~i@7%BaP+U$xdJRu~ zD~{ha>Q7tv5hy^pDWqUrD`LK2Uri_v)sK-ad!Mh+oIC5>vft`^>C@=bl{vyfCC1jd z8wBO{Epl|Bql&CQ-jfEuVaVSfp+{VAgVsTajR;|F^q6%Ts}+=4vkP6f|6yU4n@eUT zk4cGZA8uLI-ckQ)zmn4GGJ3!#g30%^!B28hOFNPD>LQfR2Z1#u z7&M#bgvA(Q%+9Vxl<_~0(9a9=ievQ>2Zn}Tc}nn^Zmg*e*J%8Aj&f`3WzEywjzm(Z z+4!;aZ$>r$^MdKmIDp!|P}XUKeEQT(60Jy{{p9O9U!A3vg^ zprA~8js*eh$$mSCwssH&4v94qQuFy%~`>vY~z` z&f!tg9sV>S&JPDA7i@dtT!og0ej{VV;b_E7ixN0Ef=%a|kt0Wo4IVEUSXkh*7W97) zu}N3ktEx_|p6dQr2xG6B;#av*=No#}+^R14sDl`wt{68Mp%QXMI2^ffyEk|SKks@Z zQYl}qy{qT%7KOJW>F85{FBXpfjw5gOv)McT)U{vD{6K}2a3}VvuU=>IQq%nrvk9Rw zt67rMfIBG{CrUyaS(yGtT5LxTQL!>M9Ey_9LKIs)9wMQH8Is3M%j8~^nQGVJ)^gBH z=kH~kS`<`~@Lb-a;x@FV7O0I-rTcP(mnQ-!fuL3EfmtF)My0WUfZyLpf{e;Za0wXA$ic3 z=x^S{Jy*6%yWzFNRbT!xrP8r6er*Yysd>f{OdSPSsMz9|JbDYEk=>@C2e;T{oBY0X z`cf(?G-JH&FCnAGJbB5)_x83!$!_40H*C8g2skoSf;1cjJwBeP76kM`p1{e-Ae{Ij zsIyRjnS|(dloRP979L!Zv#GDJ;Eag{#5Sl0|F=@g|B-qisRaW_j8~9+7ZMTzd8`Zw z^TL3aX8;vS!DCaL0nffbY3V~~0?7O`Kr{V^CTx8PioEItPa0v8kf4kXx8nVT38u^k z4-J7QM!=(S0*!8m8irlo?;Iam!CJOi>u9c#r20?$=0Ew{>HlTdYWto% Q-v}iyts+$;VH*7305@kD0{{R3 literal 0 HcmV?d00001 diff --git a/setup.py b/setup.py index 4eb67ad..fcfc8c6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.28.2', + version = '0.29.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/parallel_vit.py b/vit_pytorch/parallel_vit.py new file mode 100644 index 0000000..62c574b --- /dev/null +++ b/vit_pytorch/parallel_vit.py @@ -0,0 +1,137 @@ +import torch +from torch import nn + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +# helpers + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# classes + +class Parallel(nn.Module): + def __init__(self, *fns): + super().__init__() + self.fns = nn.ModuleList(fns) + + def forward(self, x): + return sum([fn(x) for fn in self.fns]) + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + + attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)) + ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Parallel(*[attn_block() for _ in range(num_parallel_branches)]), + Parallel(*[ff_block() for _ in range(num_parallel_branches)]), + ])) + + def forward(self, x): + for attns, ffs in self.layers: + x = attns(x) + x + x = ffs(x) + x + return x + +class ViT(nn.Module): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = channels * patch_height * patch_width + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), + nn.Linear(patch_dim, dim), + ) + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout) + + self.pool = pool + self.to_latent = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, num_classes) + ) + + def forward(self, img): + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.transformer(x) + + x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + + x = self.to_latent(x) + return self.mlp_head(x)