Compare commits

...

14 Commits
0.0.2 ... 0.2.2

Author SHA1 Message Date
Phil Wang
35796104b0 dropouts are more specific and aggressive in the paper, thanks for letting me know @hilach70 2020-10-14 05:48:27 -07:00
Phil Wang
b0e4790c24 bump package 2020-10-13 13:12:19 -07:00
Phil Wang
0b2b3fc20c add dropouts 2020-10-13 13:11:59 -07:00
Phil Wang
ced464dcb4 Update setup.py 2020-10-11 00:06:26 -07:00
Phil Wang
5bf45a2d4d Merge pull request #4 from adimyth/main
Image Classification Example
2020-10-10 19:12:31 -07:00
adimyth
fa32e22855 adds a classification example using 'cats & dogs' data 2020-10-11 03:15:19 +05:30
Phil Wang
a0fa41070f norm cls token before sending to mlp head 2020-10-10 12:08:42 -07:00
Phil Wang
b298031c17 write up example for using efficient transformers 2020-10-07 19:15:21 -07:00
Phil Wang
d66b29e4cf cleanup stray print 2020-10-07 11:22:45 -07:00
Phil Wang
f7123720c3 add masking 2020-10-07 11:21:03 -07:00
Phil Wang
f5fffd9e2e remove extraneous line 2020-10-04 15:22:26 -07:00
Phil Wang
8fb261ca66 fix a bug and add suggestion for BYOL pre-training 2020-10-04 14:55:29 -07:00
Phil Wang
112ba5c476 update with link to Yannics video 2020-10-04 13:53:47 -07:00
Phil Wang
f899226d4f add diagram 2020-10-04 12:47:08 -07:00
6 changed files with 6449 additions and 33 deletions

101
README.md
View File

@@ -1,6 +1,8 @@
<img src="./vit.png" width="500px"></img>
## Vision Transformer - Pytorch
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. There's really not much to code here, but may as well lay out all the code so we expedite the attention revolution and get everyone on the same page.
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
## Install
@@ -21,13 +23,106 @@ v = ViT(
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
preds = v(img, mask = mask) # (1, 1000)
```
## Research Ideas
### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
(1)
```bash
$ pip install byol-pytorch
```
(2)
```python
import torch
from vit_pytorch import ViT
from byol_pytorch import BYOL
model = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
learner = BYOL(
model,
image_size = 256,
hidden_layer = 'to_cls_token'
)
opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
def sample_unlabelled_images():
return torch.randn(20, 3, 256, 256)
for _ in range(100):
images = sample_unlabelled_images()
loss = learner(images)
opt.zero_grad()
loss.backward()
opt.step()
learner.update_moving_average() # update moving average of target encoder
# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')
```
A pytorch-lightning script is ready for you to use at the repository link above.
### Efficient Attention
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
```bash
$ pip install linformer
```
```python
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer
efficient_transformer = Linformer(
dim = 512,
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
depth = 12,
heads = 8,
k = 256
)
v = ViT(
dim = 512,
image_size = 2048,
patch_size = 32,
num_classes = 1000,
transformer = efficient_transformer
)
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)
```
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
## Citations
```bibtex

File diff suppressed because one or more lines are too long

View File

@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.0.2',
packages = find_packages(exclude=['examples']),
version = '0.2.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -25,4 +25,4 @@ setup(
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
)

BIN
vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

40
vit_pytorch/efficient.py Normal file
View File

@@ -0,0 +1,40 @@
import torch
from einops import rearrange
from torch import nn
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, num_classes)
)
def forward(self, img):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)

View File

@@ -1,48 +1,62 @@
import torch
from einops import rearrange
import torch.nn.functional as F
from einops import rearrange
from torch import nn
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim)
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
@@ -51,20 +65,22 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
layers = []
self.layers = nn.ModuleList([])
for _ in range(depth):
layers.extend([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
])
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
@@ -75,21 +91,33 @@ class ViT(nn.Module):
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, num_classes)
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes),
nn.Dropout(dropout)
)
def forward(self, img):
def forward(self, img, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
x = torch.cat((self.cls_token, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
return self.mlp_head(x[:, 0])
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)