2021-04-06 12:58:05 -07:00
from math import ceil
import torch
from torch import nn , einsum
import torch . nn . functional as F
from einops import rearrange , repeat
from einops . layers . torch import Rearrange
# helpers
def exists ( val ) :
return val is not None
def default ( val , d ) :
return val if exists ( val ) else d
def cast_tuple ( val , l = 3 ) :
val = val if isinstance ( val , tuple ) else ( val , )
return ( * val , * ( ( val [ - 1 ] , ) * max ( l - len ( val ) , 0 ) ) )
def always ( val ) :
return lambda * args , * * kwargs : val
# classes
class FeedForward ( nn . Module ) :
def __init__ ( self , dim , mult , dropout = 0. ) :
super ( ) . __init__ ( )
self . net = nn . Sequential (
nn . Conv2d ( dim , dim * mult , 1 ) ,
2021-08-19 08:22:55 -07:00
nn . Hardswish ( ) ,
2021-04-06 12:58:05 -07:00
nn . Dropout ( dropout ) ,
nn . Conv2d ( dim * mult , dim , 1 ) ,
nn . Dropout ( dropout )
)
def forward ( self , x ) :
return self . net ( x )
class Attention ( nn . Module ) :
def __init__ ( self , dim , fmap_size , heads = 8 , dim_key = 32 , dim_value = 64 , dropout = 0. , dim_out = None , downsample = False ) :
super ( ) . __init__ ( )
inner_dim_key = dim_key * heads
inner_dim_value = dim_value * heads
dim_out = default ( dim_out , dim )
self . heads = heads
self . scale = dim_key * * - 0.5
self . to_q = nn . Sequential ( nn . Conv2d ( dim , inner_dim_key , 1 , stride = ( 2 if downsample else 1 ) , bias = False ) , nn . BatchNorm2d ( inner_dim_key ) )
self . to_k = nn . Sequential ( nn . Conv2d ( dim , inner_dim_key , 1 , bias = False ) , nn . BatchNorm2d ( inner_dim_key ) )
self . to_v = nn . Sequential ( nn . Conv2d ( dim , inner_dim_value , 1 , bias = False ) , nn . BatchNorm2d ( inner_dim_value ) )
self . attend = nn . Softmax ( dim = - 1 )
2022-03-30 10:50:57 -07:00
self . dropout = nn . Dropout ( dropout )
2021-04-06 12:58:05 -07:00
2021-04-27 08:41:16 -07:00
out_batch_norm = nn . BatchNorm2d ( dim_out )
nn . init . zeros_ ( out_batch_norm . weight )
2021-04-06 12:58:05 -07:00
self . to_out = nn . Sequential (
nn . GELU ( ) ,
nn . Conv2d ( inner_dim_value , dim_out , 1 ) ,
2021-04-27 08:41:16 -07:00
out_batch_norm ,
2021-04-06 12:58:05 -07:00
nn . Dropout ( dropout )
)
2021-04-06 13:36:11 -07:00
# positional bias
self . pos_bias = nn . Embedding ( fmap_size * fmap_size , heads )
q_range = torch . arange ( 0 , fmap_size , step = ( 2 if downsample else 1 ) )
k_range = torch . arange ( fmap_size )
2022-04-06 16:12:17 -07:00
q_pos = torch . stack ( torch . meshgrid ( q_range , q_range , indexing = ' ij ' ) , dim = - 1 )
k_pos = torch . stack ( torch . meshgrid ( k_range , k_range , indexing = ' ij ' ) , dim = - 1 )
2021-04-06 13:36:11 -07:00
q_pos , k_pos = map ( lambda t : rearrange ( t , ' i j c -> (i j) c ' ) , ( q_pos , k_pos ) )
rel_pos = ( q_pos [ : , None , . . . ] - k_pos [ None , : , . . . ] ) . abs ( )
x_rel , y_rel = rel_pos . unbind ( dim = - 1 )
pos_indices = ( x_rel * fmap_size ) + y_rel
self . register_buffer ( ' pos_indices ' , pos_indices )
def apply_pos_bias ( self , fmap ) :
bias = self . pos_bias ( self . pos_indices )
bias = rearrange ( bias , ' i j h -> () h i j ' )
2021-05-10 11:37:23 -07:00
return fmap + ( bias / self . scale )
2021-04-06 13:36:11 -07:00
2021-04-06 12:58:05 -07:00
def forward ( self , x ) :
b , n , * _ , h = * x . shape , self . heads
q = self . to_q ( x )
y = q . shape [ 2 ]
qkv = ( q , self . to_k ( x ) , self . to_v ( x ) )
q , k , v = map ( lambda t : rearrange ( t , ' b (h d) ... -> b h (...) d ' , h = h ) , qkv )
dots = einsum ( ' b h i d, b h j d -> b h i j ' , q , k ) * self . scale
2021-04-06 13:36:11 -07:00
dots = self . apply_pos_bias ( dots )
2021-04-06 12:58:05 -07:00
attn = self . attend ( dots )
2022-03-30 10:50:57 -07:00
attn = self . dropout ( attn )
2021-04-06 12:58:05 -07:00
out = einsum ( ' b h i j, b h j d -> b h i d ' , attn , v )
out = rearrange ( out , ' b h (x y) d -> b (h d) x y ' , h = h , y = y )
return self . to_out ( out )
class Transformer ( nn . Module ) :
def __init__ ( self , dim , fmap_size , depth , heads , dim_key , dim_value , mlp_mult = 2 , dropout = 0. , dim_out = None , downsample = False ) :
super ( ) . __init__ ( )
dim_out = default ( dim_out , dim )
self . layers = nn . ModuleList ( [ ] )
self . attn_residual = ( not downsample ) and dim == dim_out
for _ in range ( depth ) :
self . layers . append ( nn . ModuleList ( [
Attention ( dim , fmap_size = fmap_size , heads = heads , dim_key = dim_key , dim_value = dim_value , dropout = dropout , downsample = downsample , dim_out = dim_out ) ,
FeedForward ( dim_out , mlp_mult , dropout = dropout )
] ) )
def forward ( self , x ) :
for attn , ff in self . layers :
attn_res = ( x if self . attn_residual else 0 )
x = attn ( x ) + attn_res
x = ff ( x ) + x
return x
class LeViT ( nn . Module ) :
def __init__ (
self ,
* ,
image_size ,
num_classes ,
dim ,
depth ,
heads ,
mlp_mult ,
stages = 3 ,
dim_key = 32 ,
dim_value = 64 ,
dropout = 0. ,
num_distill_classes = None
) :
super ( ) . __init__ ( )
2021-04-06 13:36:11 -07:00
dims = cast_tuple ( dim , stages )
depths = cast_tuple ( depth , stages )
layer_heads = cast_tuple ( heads , stages )
assert all ( map ( lambda t : len ( t ) == stages , ( dims , depths , layer_heads ) ) ) , ' dimensions, depths, and heads must be a tuple that is less than the designated number of stages '
2021-04-06 13:46:19 -07:00
self . conv_embedding = nn . Sequential (
2021-04-06 12:58:05 -07:00
nn . Conv2d ( 3 , 32 , 3 , stride = 2 , padding = 1 ) ,
nn . Conv2d ( 32 , 64 , 3 , stride = 2 , padding = 1 ) ,
nn . Conv2d ( 64 , 128 , 3 , stride = 2 , padding = 1 ) ,
2021-04-06 13:36:11 -07:00
nn . Conv2d ( 128 , dims [ 0 ] , 3 , stride = 2 , padding = 1 )
2021-04-06 12:58:05 -07:00
)
fmap_size = image_size / / ( 2 * * 4 )
layers = [ ]
for ind , dim , depth , heads in zip ( range ( stages ) , dims , depths , layer_heads ) :
is_last = ind == ( stages - 1 )
layers . append ( Transformer ( dim , fmap_size , depth , heads , dim_key , dim_value , mlp_mult , dropout ) )
if not is_last :
next_dim = dims [ ind + 1 ]
layers . append ( Transformer ( dim , fmap_size , 1 , heads * 2 , dim_key , dim_value , dim_out = next_dim , downsample = True ) )
fmap_size = ceil ( fmap_size / 2 )
self . backbone = nn . Sequential ( * layers )
self . pool = nn . Sequential (
nn . AdaptiveAvgPool2d ( 1 ) ,
Rearrange ( ' ... () () -> ... ' )
)
self . distill_head = nn . Linear ( dim , num_distill_classes ) if exists ( num_distill_classes ) else always ( None )
self . mlp_head = nn . Linear ( dim , num_classes )
def forward ( self , img ) :
2021-04-06 13:46:19 -07:00
x = self . conv_embedding ( img )
2021-04-06 12:58:05 -07:00
x = self . backbone ( x )
x = self . pool ( x )
out = self . mlp_head ( x )
distill = self . distill_head ( x )
if exists ( distill ) :
return out , distill
return out