2024-07-19 09:48:49 -07:00
import torch
from torch import nn
import torch . nn . functional as F
from torch . nn import Module , ModuleList
from einops import einsum , rearrange , repeat , reduce
from einops . layers . torch import Rearrange
# helpers
def exists ( val ) :
return val is not None
def default ( val , d ) :
return val if exists ( val ) else d
def divisible_by ( num , den ) :
return ( num % den ) == 0
# simple vit sinusoidal pos emb
def posemb_sincos_2d ( t , temperature = 10000 ) :
h , w , d , device = * t . shape [ 1 : ] , t . device
y , x = torch . meshgrid ( torch . arange ( h , device = device ) , torch . arange ( w , device = device ) , indexing = ' ij ' )
assert ( d % 4 ) == 0 , " feature dimension must be multiple of 4 for sincos emb "
omega = torch . arange ( d / / 4 , device = device ) / ( d / / 4 - 1 )
omega = temperature * * - omega
y = y . flatten ( ) [ : , None ] * omega [ None , : ]
x = x . flatten ( ) [ : , None ] * omega [ None , : ]
pos = torch . cat ( ( x . sin ( ) , x . cos ( ) , y . sin ( ) , y . cos ( ) ) , dim = 1 )
return pos . float ( )
# bias-less layernorm with unit offset trick (discovered by Ohad Rubin)
class LayerNorm ( Module ) :
def __init__ ( self , dim ) :
super ( ) . __init__ ( )
self . ln = nn . LayerNorm ( dim , elementwise_affine = False )
self . gamma = nn . Parameter ( torch . zeros ( dim ) )
def forward ( self , x ) :
normed = self . ln ( x )
return normed * ( self . gamma + 1 )
# mlp
def MLP ( dim , factor = 4 , dropout = 0. ) :
hidden_dim = int ( dim * factor )
return nn . Sequential (
LayerNorm ( dim ) ,
nn . Linear ( dim , hidden_dim ) ,
nn . GELU ( ) ,
nn . Dropout ( dropout ) ,
nn . Linear ( hidden_dim , dim ) ,
nn . Dropout ( dropout )
)
# attention
class Attention ( Module ) :
def __init__ (
self ,
dim ,
heads = 8 ,
dim_head = 64 ,
dropout = 0. ,
reuse_attention = False
) :
super ( ) . __init__ ( )
inner_dim = dim_head * heads
self . scale = dim_head * * - 0.5
self . heads = heads
self . reuse_attention = reuse_attention
self . split_heads = Rearrange ( ' b n (h d) -> b h n d ' , h = heads )
2024-07-19 10:00:03 -07:00
self . norm = LayerNorm ( dim ) if not reuse_attention else nn . Identity ( )
2024-07-19 09:48:49 -07:00
self . attend = nn . Softmax ( dim = - 1 )
self . dropout = nn . Dropout ( dropout )
self . to_q = nn . Linear ( dim , inner_dim , bias = False ) if not reuse_attention else None
self . to_k = nn . Linear ( dim , inner_dim , bias = False ) if not reuse_attention else None
self . to_v = nn . Linear ( dim , inner_dim , bias = False )
self . to_out = nn . Sequential (
Rearrange ( ' b h n d -> b n (h d) ' ) ,
nn . Linear ( inner_dim , dim , bias = False ) ,
nn . Dropout ( dropout )
)
def forward (
self ,
x ,
context = None ,
return_attn = False ,
attn = None
) :
x = self . norm ( x )
context = default ( context , x )
v = self . to_v ( context )
v = self . split_heads ( v )
if not self . reuse_attention :
qk = ( self . to_q ( x ) , self . to_k ( context ) )
q , k = tuple ( self . split_heads ( t ) for t in qk )
q = q * self . scale
sim = einsum ( q , k , ' b h i d, b h j d -> b h i j ' )
attn = self . attend ( sim )
attn = self . dropout ( attn )
else :
assert exists ( attn ) , ' attention matrix must be passed in for reusing previous attention '
out = einsum ( attn , v , ' b h i j, b h j d -> b h i d ' )
out = self . to_out ( out )
if not return_attn :
return out
return out , attn
# LookViT
class LookViT ( Module ) :
def __init__ (
self ,
* ,
dim ,
image_size ,
num_classes ,
depth = 3 ,
patch_size = 16 ,
heads = 8 ,
mlp_factor = 4 ,
dim_head = 64 ,
highres_patch_size = 12 ,
highres_mlp_factor = 4 ,
cross_attn_heads = 8 ,
cross_attn_dim_head = 64 ,
patch_conv_kernel_size = 7 ,
dropout = 0.1 ,
channels = 3
) :
super ( ) . __init__ ( )
assert divisible_by ( image_size , highres_patch_size )
assert divisible_by ( image_size , patch_size )
assert patch_size > highres_patch_size , ' patch size of the main vision transformer should be smaller than the highres patch sizes (that does the `lookup`) '
assert not divisible_by ( patch_conv_kernel_size , 2 )
self . dim = dim
self . image_size = image_size
self . patch_size = patch_size
kernel_size = patch_conv_kernel_size
patch_dim = ( highres_patch_size * highres_patch_size ) * channels
self . to_patches = nn . Sequential (
Rearrange ( ' b c (h p1) (w p2) -> b (p1 p2 c) h w ' , p1 = highres_patch_size , p2 = highres_patch_size ) ,
nn . Conv2d ( patch_dim , dim , kernel_size , padding = kernel_size / / 2 ) ,
Rearrange ( ' b c h w -> b h w c ' ) ,
LayerNorm ( dim ) ,
)
# absolute positions
num_patches = ( image_size / / highres_patch_size ) * * 2
self . pos_embedding = nn . Parameter ( torch . randn ( num_patches , dim ) )
# lookvit blocks
layers = ModuleList ( [ ] )
for _ in range ( depth ) :
layers . append ( ModuleList ( [
Attention ( dim = dim , dim_head = dim_head , heads = heads , dropout = dropout ) ,
MLP ( dim = dim , factor = mlp_factor , dropout = dropout ) ,
Attention ( dim = dim , dim_head = cross_attn_dim_head , heads = cross_attn_heads , dropout = dropout ) ,
Attention ( dim = dim , dim_head = cross_attn_dim_head , heads = cross_attn_heads , dropout = dropout , reuse_attention = True ) ,
LayerNorm ( dim ) ,
MLP ( dim = dim , factor = highres_mlp_factor , dropout = dropout )
] ) )
self . layers = layers
self . norm = LayerNorm ( dim )
self . highres_norm = LayerNorm ( dim )
self . to_logits = nn . Linear ( dim , num_classes , bias = False )
def forward ( self , img ) :
assert img . shape [ - 2 : ] == ( self . image_size , self . image_size )
# to patch tokens and positions
highres_tokens = self . to_patches ( img )
size = highres_tokens . shape [ - 2 ]
pos_emb = posemb_sincos_2d ( highres_tokens )
highres_tokens = highres_tokens + rearrange ( pos_emb , ' (h w) d -> h w d ' , h = size )
tokens = F . interpolate (
rearrange ( highres_tokens , ' b h w d -> b d h w ' ) ,
img . shape [ - 1 ] / / self . patch_size ,
mode = ' bilinear '
)
tokens = rearrange ( tokens , ' b c h w -> b (h w) c ' )
highres_tokens = rearrange ( highres_tokens , ' b h w c -> b (h w) c ' )
# attention and feedforwards
for attn , mlp , lookup_cross_attn , highres_attn , highres_norm , highres_mlp in self . layers :
# main tokens cross attends (lookup) on the high res tokens
lookup_out , lookup_attn = lookup_cross_attn ( tokens , highres_tokens , return_attn = True ) # return attention as they reuse the attention matrix
tokens = lookup_out + tokens
tokens = attn ( tokens ) + tokens
tokens = mlp ( tokens ) + tokens
# attention-reuse
lookup_attn = rearrange ( lookup_attn , ' b h i j -> b h j i ' ) # transpose for reverse cross attention
highres_tokens = highres_attn ( highres_tokens , tokens , attn = lookup_attn ) + highres_tokens
highres_tokens = highres_norm ( highres_tokens )
highres_tokens = highres_mlp ( highres_tokens ) + highres_tokens
# to logits
tokens = self . norm ( tokens )
highres_tokens = self . highres_norm ( highres_tokens )
tokens = reduce ( tokens , ' b n d -> b d ' , ' mean ' )
highres_tokens = reduce ( highres_tokens , ' b n d -> b d ' , ' mean ' )
return self . to_logits ( tokens + highres_tokens )
# main
if __name__ == ' __main__ ' :
v = LookViT (
image_size = 256 ,
num_classes = 1000 ,
dim = 512 ,
depth = 2 ,
heads = 8 ,
dim_head = 64 ,
patch_size = 32 ,
highres_patch_size = 8 ,
highres_mlp_factor = 2 ,
cross_attn_heads = 8 ,
cross_attn_dim_head = 64 ,
dropout = 0.1
) . cuda ( )
img = torch . randn ( 2 , 3 , 256 , 256 ) . cuda ( )
pred = v ( img )
assert pred . shape == ( 2 , 1000 )