offer a way to turn off ds conv in rotary vision transformer for ablation

This commit is contained in:
Phil Wang
2021-04-20 10:12:03 -07:00
parent 566365978d
commit e42e9876bc
2 changed files with 10 additions and 8 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.7',
version = '0.16.8',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -94,7 +94,7 @@ class FeedForward(nn.Module):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, conv_query_kernel = 5):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5):
super().__init__()
inner_dim = dim_head * heads
self.use_rotary = use_rotary
@@ -103,7 +103,8 @@ class Attention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False)
self.use_ds_conv = use_ds_conv
self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -115,7 +116,8 @@ class Attention(nn.Module):
def forward(self, x, pos_emb, fmap_dims):
b, n, _, h = *x.shape, self.heads
q = self.to_q(x, fmap_dims = fmap_dims)
q = self.to_q(x, fmap_dims = fmap_dims) if self.use_ds_conv else self.to_q(x)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
@@ -141,13 +143,13 @@ class Attention(nn.Module):
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary)),
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x, fmap_dims):
@@ -161,7 +163,7 @@ class Transformer(nn.Module):
# Rotary Vision Transformer
class RvT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
@@ -174,7 +176,7 @@ class RvT(nn.Module):
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),