Compare commits

...

16 Commits
0.0.3 ... 0.2.6

Author SHA1 Message Date
Phil Wang
7a214d7109 allow for training on different image sizes, provided images are smaller than what was passed as image_size keyword on init 2020-10-25 13:17:42 -07:00
Phil Wang
6d1df1a970 more efficient 2020-10-22 22:37:06 -07:00
Phil Wang
d65a8c17a5 remove dropout from last linear to logits 2020-10-16 13:58:23 -07:00
Phil Wang
f7c164d910 assert minimum number of patches 2020-10-16 12:19:50 -07:00
Phil Wang
c7b74e0bc3 rename ipy notebook 2020-10-14 10:35:46 -07:00
Phil Wang
5b5d98a3a7 dropouts are more specific and aggressive in the paper, thanks for letting me know @hila-chefer 2020-10-14 09:22:16 -07:00
Phil Wang
b0e4790c24 bump package 2020-10-13 13:12:19 -07:00
Phil Wang
0b2b3fc20c add dropouts 2020-10-13 13:11:59 -07:00
Phil Wang
ced464dcb4 Update setup.py 2020-10-11 00:06:26 -07:00
Phil Wang
5bf45a2d4d Merge pull request #4 from adimyth/main
Image Classification Example
2020-10-10 19:12:31 -07:00
adimyth
fa32e22855 adds a classification example using 'cats & dogs' data 2020-10-11 03:15:19 +05:30
Phil Wang
a0fa41070f norm cls token before sending to mlp head 2020-10-10 12:08:42 -07:00
Phil Wang
b298031c17 write up example for using efficient transformers 2020-10-07 19:15:21 -07:00
Phil Wang
d66b29e4cf cleanup stray print 2020-10-07 11:22:45 -07:00
Phil Wang
f7123720c3 add masking 2020-10-07 11:21:03 -07:00
Phil Wang
f5fffd9e2e remove extraneous line 2020-10-04 15:22:26 -07:00
5 changed files with 6396 additions and 33 deletions

View File

@@ -23,14 +23,20 @@ v = ViT(
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
preds = v(img, mask = mask) # (1, 1000)
```
## Suggestion
## Research Ideas
### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
@@ -80,6 +86,43 @@ torch.save(model.state_dict(), './pretrained-net.pt')
A pytorch-lightning script is ready for you to use at the repository link above.
### Efficient Attention
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
```bash
$ pip install linformer
```
```python
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer
efficient_transformer = Linformer(
dim = 512,
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
depth = 12,
heads = 8,
k = 256
)
v = ViT(
dim = 512,
image_size = 2048,
patch_size = 32,
num_classes = 1000,
transformer = efficient_transformer
)
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)
```
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
## Citations
```bibtex

6253
examples/cats_and_dogs.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.0.3',
packages = find_packages(exclude=['examples']),
version = '0.2.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -25,4 +25,4 @@ setup(
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
)

41
vit_pytorch/efficient.py Normal file
View File

@@ -0,0 +1,41 @@
import torch
from einops import rearrange
from torch import nn
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, num_classes)
)
def forward(self, img):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)

View File

@@ -1,48 +1,64 @@
import torch
from einops import rearrange
import torch.nn.functional as F
from einops import rearrange
from torch import nn
MIN_NUM_PATCHES = 16
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim)
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
@@ -51,50 +67,60 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
layers = []
self.layers = nn.ModuleList([])
for _ in range(depth):
layers.extend([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
])
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
)
def forward(self, img):
def forward(self, img, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
cls_tokens = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)