Compare commits

..

11 Commits
0.2.0 ... 0.2.6

Author SHA1 Message Date
Phil Wang
7a214d7109 allow for training on different image sizes, provided images are smaller than what was passed as image_size keyword on init 2020-10-25 13:17:42 -07:00
Phil Wang
6d1df1a970 more efficient 2020-10-22 22:37:06 -07:00
Phil Wang
d65a8c17a5 remove dropout from last linear to logits 2020-10-16 13:58:23 -07:00
Phil Wang
f7c164d910 assert minimum number of patches 2020-10-16 12:19:50 -07:00
Phil Wang
c7b74e0bc3 rename ipy notebook 2020-10-14 10:35:46 -07:00
Phil Wang
5b5d98a3a7 dropouts are more specific and aggressive in the paper, thanks for letting me know @hila-chefer 2020-10-14 09:22:16 -07:00
Phil Wang
b0e4790c24 bump package 2020-10-13 13:12:19 -07:00
Phil Wang
0b2b3fc20c add dropouts 2020-10-13 13:11:59 -07:00
Phil Wang
ced464dcb4 Update setup.py 2020-10-11 00:06:26 -07:00
Phil Wang
5bf45a2d4d Merge pull request #4 from adimyth/main
Image Classification Example
2020-10-10 19:12:31 -07:00
adimyth
fa32e22855 adds a classification example using 'cats & dogs' data 2020-10-11 03:15:19 +05:30
5 changed files with 6290 additions and 19 deletions

View File

@@ -23,7 +23,9 @@ v = ViT(
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)

6253
examples/cats_and_dogs.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.2.0',
packages = find_packages(exclude=['examples']),
version = '0.2.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -25,4 +25,4 @@ setup(
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
)

View File

@@ -30,10 +30,11 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])

View File

@@ -3,6 +3,8 @@ import torch.nn.functional as F
from einops import rearrange
from torch import nn
MIN_NUM_PATCHES = 16
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
@@ -19,28 +21,34 @@ class PreNorm(nn.Module):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim)
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
@@ -59,13 +67,13 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
@@ -74,18 +82,21 @@ class Transformer(nn.Module):
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
@@ -93,6 +104,7 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
)
@@ -101,10 +113,13 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])