Compare commits

...

6 Commits
0.2.1 ... 0.2.6

Author SHA1 Message Date
Phil Wang
7a214d7109 allow for training on different image sizes, provided images are smaller than what was passed as image_size keyword on init 2020-10-25 13:17:42 -07:00
Phil Wang
6d1df1a970 more efficient 2020-10-22 22:37:06 -07:00
Phil Wang
d65a8c17a5 remove dropout from last linear to logits 2020-10-16 13:58:23 -07:00
Phil Wang
f7c164d910 assert minimum number of patches 2020-10-16 12:19:50 -07:00
Phil Wang
c7b74e0bc3 rename ipy notebook 2020-10-14 10:35:46 -07:00
Phil Wang
5b5d98a3a7 dropouts are more specific and aggressive in the paper, thanks for letting me know @hila-chefer 2020-10-14 09:22:16 -07:00
5 changed files with 31 additions and 18 deletions

View File

@@ -24,8 +24,8 @@ v = ViT(
depth = 6,
heads = 8,
mlp_dim = 2048,
attn_dropout = 0.1,
ff_dropout = 0.1
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.2.1',
version = '0.2.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -30,10 +30,11 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])

View File

@@ -3,6 +3,8 @@ import torch.nn.functional as F
from einops import rearrange
from torch import nn
MIN_NUM_PATCHES = 16
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
@@ -25,7 +27,8 @@ class FeedForward(nn.Module):
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim)
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -37,12 +40,15 @@ class Attention(nn.Module):
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
@@ -54,7 +60,6 @@ class Attention(nn.Module):
del mask
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
@@ -62,13 +67,13 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, attn_dropout, ff_dropout):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = attn_dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = ff_dropout)))
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
@@ -77,18 +82,21 @@ class Transformer(nn.Module):
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, attn_dropout = 0., ff_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim, attn_dropout, ff_dropout)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
@@ -96,6 +104,7 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
)
@@ -104,10 +113,13 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])