2020-10-04 12:34:44 -07:00
import torch
2020-10-07 11:21:03 -07:00
import torch . nn . functional as F
2020-10-28 18:13:57 -07:00
from einops import rearrange , repeat
2020-10-04 12:34:44 -07:00
from torch import nn
2020-10-16 12:19:50 -07:00
MIN_NUM_PATCHES = 16
2020-10-04 12:34:44 -07:00
class Residual ( nn . Module ) :
def __init__ ( self , fn ) :
super ( ) . __init__ ( )
self . fn = fn
2020-10-07 11:21:03 -07:00
def forward ( self , x , * * kwargs ) :
return self . fn ( x , * * kwargs ) + x
2020-10-04 12:34:44 -07:00
class PreNorm ( nn . Module ) :
def __init__ ( self , dim , fn ) :
super ( ) . __init__ ( )
self . norm = nn . LayerNorm ( dim )
self . fn = fn
2020-10-07 11:21:03 -07:00
def forward ( self , x , * * kwargs ) :
return self . fn ( self . norm ( x ) , * * kwargs )
2020-10-04 12:34:44 -07:00
class FeedForward ( nn . Module ) :
2020-10-13 13:11:32 -07:00
def __init__ ( self , dim , hidden_dim , dropout = 0. ) :
2020-10-04 12:34:44 -07:00
super ( ) . __init__ ( )
self . net = nn . Sequential (
nn . Linear ( dim , hidden_dim ) ,
nn . GELU ( ) ,
2020-10-13 13:11:32 -07:00
nn . Dropout ( dropout ) ,
2020-10-14 05:48:27 -07:00
nn . Linear ( hidden_dim , dim ) ,
nn . Dropout ( dropout )
2020-10-04 12:34:44 -07:00
)
def forward ( self , x ) :
return self . net ( x )
class Attention ( nn . Module ) :
2020-12-17 07:43:52 -08:00
def __init__ ( self , dim , heads = 8 , dim_head = 64 , dropout = 0. ) :
2020-10-04 12:34:44 -07:00
super ( ) . __init__ ( )
2020-12-17 07:43:52 -08:00
inner_dim = dim_head * heads
2020-10-04 12:34:44 -07:00
self . heads = heads
self . scale = dim * * - 0.5
2020-12-17 07:43:52 -08:00
self . to_qkv = nn . Linear ( dim , inner_dim * 3 , bias = False )
2020-10-14 05:48:27 -07:00
self . to_out = nn . Sequential (
2020-12-17 07:43:52 -08:00
nn . Linear ( inner_dim , dim ) ,
2020-10-14 05:48:27 -07:00
nn . Dropout ( dropout )
)
2020-10-07 11:21:03 -07:00
def forward ( self , x , mask = None ) :
2020-10-04 12:34:44 -07:00
b , n , _ , h = * x . shape , self . heads
2020-10-22 22:37:06 -07:00
qkv = self . to_qkv ( x ) . chunk ( 3 , dim = - 1 )
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> b h n d ' , h = h ) , qkv )
2020-10-04 12:34:44 -07:00
dots = torch . einsum ( ' bhid,bhjd->bhij ' , q , k ) * self . scale
2020-11-13 12:25:21 -08:00
mask_value = - torch . finfo ( dots . dtype ) . max
2020-10-07 11:21:03 -07:00
if mask is not None :
mask = F . pad ( mask . flatten ( 1 ) , ( 1 , 0 ) , value = True )
assert mask . shape [ - 1 ] == dots . shape [ - 1 ] , ' mask has incorrect dimensions '
mask = mask [ : , None , : ] * mask [ : , : , None ]
2020-11-13 12:25:21 -08:00
dots . masked_fill_ ( ~ mask , mask_value )
2020-10-07 11:21:03 -07:00
del mask
2020-10-04 12:34:44 -07:00
attn = dots . softmax ( dim = - 1 )
out = torch . einsum ( ' bhij,bhjd->bhid ' , attn , v )
out = rearrange ( out , ' b h n d -> b n (h d) ' )
out = self . to_out ( out )
return out
class Transformer ( nn . Module ) :
2020-12-17 07:43:52 -08:00
def __init__ ( self , dim , depth , heads , dim_head , mlp_dim , dropout ) :
2020-10-04 12:34:44 -07:00
super ( ) . __init__ ( )
2020-10-07 11:21:03 -07:00
self . layers = nn . ModuleList ( [ ] )
2020-10-04 12:34:44 -07:00
for _ in range ( depth ) :
2020-10-07 11:21:03 -07:00
self . layers . append ( nn . ModuleList ( [
2020-12-17 07:43:52 -08:00
Residual ( PreNorm ( dim , Attention ( dim , heads = heads , dim_head = dim_head , dropout = dropout ) ) ) ,
2020-10-14 05:48:27 -07:00
Residual ( PreNorm ( dim , FeedForward ( dim , mlp_dim , dropout = dropout ) ) )
2020-10-07 11:21:03 -07:00
] ) )
def forward ( self , x , mask = None ) :
for attn , ff in self . layers :
x = attn ( x , mask = mask )
x = ff ( x )
return x
2020-10-04 12:34:44 -07:00
class ViT ( nn . Module ) :
2020-12-17 07:43:52 -08:00
def __init__ ( self , * , image_size , patch_size , num_classes , dim , depth , heads , mlp_dim , channels = 3 , dim_head = 64 , dropout = 0. , emb_dropout = 0. ) :
2020-10-04 12:34:44 -07:00
super ( ) . __init__ ( )
2020-11-21 18:23:04 +07:00
assert image_size % patch_size == 0 , ' Image dimensions must be divisible by the patch size. '
2020-10-04 12:34:44 -07:00
num_patches = ( image_size / / patch_size ) * * 2
patch_dim = channels * patch_size * * 2
2020-11-21 18:23:04 +07:00
assert num_patches > MIN_NUM_PATCHES , f ' your number of patches ( { num_patches } ) is way too small for attention to be effective (at least 16). Try decreasing your patch size '
2020-10-04 12:34:44 -07:00
self . patch_size = patch_size
self . pos_embedding = nn . Parameter ( torch . randn ( 1 , num_patches + 1 , dim ) )
self . patch_to_embedding = nn . Linear ( patch_dim , dim )
self . cls_token = nn . Parameter ( torch . randn ( 1 , 1 , dim ) )
2020-10-14 05:48:27 -07:00
self . dropout = nn . Dropout ( emb_dropout )
2020-12-17 07:43:52 -08:00
self . transformer = Transformer ( dim , depth , heads , dim_head , mlp_dim , dropout )
2020-10-04 12:34:44 -07:00
2020-10-04 14:55:29 -07:00
self . to_cls_token = nn . Identity ( )
2020-10-04 12:34:44 -07:00
self . mlp_head = nn . Sequential (
2020-10-10 12:08:42 -07:00
nn . LayerNorm ( dim ) ,
2020-12-07 14:31:50 -08:00
nn . Linear ( dim , num_classes )
2020-10-04 12:34:44 -07:00
)
2020-10-07 11:21:03 -07:00
def forward ( self , img , mask = None ) :
2020-10-04 12:34:44 -07:00
p = self . patch_size
2020-10-04 12:39:51 -07:00
x = rearrange ( img , ' b c (h p1) (w p2) -> b (h w) (p1 p2 c) ' , p1 = p , p2 = p )
2020-10-04 12:34:44 -07:00
x = self . patch_to_embedding ( x )
2020-10-25 13:17:42 -07:00
b , n , _ = x . shape
2020-10-04 14:55:29 -07:00
2020-10-28 18:13:57 -07:00
cls_tokens = repeat ( self . cls_token , ' () n d -> b n d ' , b = b )
2020-10-04 14:55:29 -07:00
x = torch . cat ( ( cls_tokens , x ) , dim = 1 )
2020-10-25 13:17:42 -07:00
x + = self . pos_embedding [ : , : ( n + 1 ) ]
2020-10-14 05:48:27 -07:00
x = self . dropout ( x )
2020-10-07 11:21:03 -07:00
x = self . transformer ( x , mask )
2020-10-04 12:34:44 -07:00
2020-10-04 14:55:29 -07:00
x = self . to_cls_token ( x [ : , 0 ] )
return self . mlp_head ( x )