mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
dropouts are more specific and aggressive in the paper, thanks for letting me know @hila-chefer
This commit is contained in:
@@ -24,8 +24,8 @@ v = ViT(
|
|||||||
depth = 6,
|
depth = 6,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
mlp_dim = 2048,
|
mlp_dim = 2048,
|
||||||
attn_dropout = 0.1,
|
dropout = 0.1,
|
||||||
ff_dropout = 0.1
|
emb_dropout = 0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
img = torch.randn(1, 3, 256, 256)
|
img = torch.randn(1, 3, 256, 256)
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '0.2.1',
|
version = '0.2.2',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ class FeedForward(nn.Module):
|
|||||||
nn.Linear(dim, hidden_dim),
|
nn.Linear(dim, hidden_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(hidden_dim, dim)
|
nn.Linear(hidden_dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
@@ -37,8 +38,11 @@ class Attention(nn.Module):
|
|||||||
self.scale = dim ** -0.5
|
self.scale = dim ** -0.5
|
||||||
|
|
||||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||||
self.to_out = nn.Linear(dim, dim)
|
self.to_out = nn.Sequential(
|
||||||
self.dropout = nn.Dropout(dropout)
|
nn.Linear(dim, dim),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, mask = None):
|
def forward(self, x, mask = None):
|
||||||
b, n, _, h = *x.shape, self.heads
|
b, n, _, h = *x.shape, self.heads
|
||||||
qkv = self.to_qkv(x)
|
qkv = self.to_qkv(x)
|
||||||
@@ -54,7 +58,6 @@ class Attention(nn.Module):
|
|||||||
del mask
|
del mask
|
||||||
|
|
||||||
attn = dots.softmax(dim=-1)
|
attn = dots.softmax(dim=-1)
|
||||||
attn = self.dropout(attn)
|
|
||||||
|
|
||||||
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
@@ -62,13 +65,13 @@ class Attention(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
def __init__(self, dim, depth, heads, mlp_dim, attn_dropout, ff_dropout):
|
def __init__(self, dim, depth, heads, mlp_dim, dropout):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = attn_dropout))),
|
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
|
||||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = ff_dropout)))
|
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||||
]))
|
]))
|
||||||
def forward(self, x, mask = None):
|
def forward(self, x, mask = None):
|
||||||
for attn, ff in self.layers:
|
for attn, ff in self.layers:
|
||||||
@@ -77,7 +80,7 @@ class Transformer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class ViT(nn.Module):
|
class ViT(nn.Module):
|
||||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, attn_dropout = 0., ff_dropout = 0.):
|
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||||
num_patches = (image_size // patch_size) ** 2
|
num_patches = (image_size // patch_size) ** 2
|
||||||
@@ -88,7 +91,9 @@ class ViT(nn.Module):
|
|||||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||||
self.transformer = Transformer(dim, depth, heads, mlp_dim, attn_dropout, ff_dropout)
|
self.dropout = nn.Dropout(emb_dropout)
|
||||||
|
|
||||||
|
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||||
|
|
||||||
self.to_cls_token = nn.Identity()
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
@@ -96,7 +101,9 @@ class ViT(nn.Module):
|
|||||||
nn.LayerNorm(dim),
|
nn.LayerNorm(dim),
|
||||||
nn.Linear(dim, mlp_dim),
|
nn.Linear(dim, mlp_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(mlp_dim, num_classes)
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(mlp_dim, num_classes),
|
||||||
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, img, mask = None):
|
def forward(self, img, mask = None):
|
||||||
@@ -108,6 +115,8 @@ class ViT(nn.Module):
|
|||||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
x += self.pos_embedding
|
x += self.pos_embedding
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
x = self.transformer(x, mask)
|
x = self.transformer(x, mask)
|
||||||
|
|
||||||
x = self.to_cls_token(x[:, 0])
|
x = self.to_cls_token(x[:, 0])
|
||||||
|
|||||||
Reference in New Issue
Block a user