Compare commits

...

11 Commits
0.0.3 ... 0.2.2

Author SHA1 Message Date
Phil Wang
35796104b0 dropouts are more specific and aggressive in the paper, thanks for letting me know @hilach70 2020-10-14 05:48:27 -07:00
Phil Wang
b0e4790c24 bump package 2020-10-13 13:12:19 -07:00
Phil Wang
0b2b3fc20c add dropouts 2020-10-13 13:11:59 -07:00
Phil Wang
ced464dcb4 Update setup.py 2020-10-11 00:06:26 -07:00
Phil Wang
5bf45a2d4d Merge pull request #4 from adimyth/main
Image Classification Example
2020-10-10 19:12:31 -07:00
adimyth
fa32e22855 adds a classification example using 'cats & dogs' data 2020-10-11 03:15:19 +05:30
Phil Wang
a0fa41070f norm cls token before sending to mlp head 2020-10-10 12:08:42 -07:00
Phil Wang
b298031c17 write up example for using efficient transformers 2020-10-07 19:15:21 -07:00
Phil Wang
d66b29e4cf cleanup stray print 2020-10-07 11:22:45 -07:00
Phil Wang
f7123720c3 add masking 2020-10-07 11:21:03 -07:00
Phil Wang
f5fffd9e2e remove extraneous line 2020-10-04 15:22:26 -07:00
5 changed files with 6389 additions and 30 deletions

View File

@@ -23,14 +23,20 @@ v = ViT(
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
preds = v(img, mask = mask) # (1, 1000)
```
## Suggestion
## Research Ideas
### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
@@ -80,6 +86,43 @@ torch.save(model.state_dict(), './pretrained-net.pt')
A pytorch-lightning script is ready for you to use at the repository link above.
### Efficient Attention
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
```bash
$ pip install linformer
```
```python
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer
efficient_transformer = Linformer(
dim = 512,
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
depth = 12,
heads = 8,
k = 256
)
v = ViT(
dim = 512,
image_size = 2048,
patch_size = 32,
num_classes = 1000,
transformer = efficient_transformer
)
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)
```
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
## Citations
```bibtex

File diff suppressed because one or more lines are too long

View File

@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.0.3',
packages = find_packages(exclude=['examples']),
version = '0.2.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
@@ -25,4 +25,4 @@ setup(
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
)

40
vit_pytorch/efficient.py Normal file
View File

@@ -0,0 +1,40 @@
import torch
from einops import rearrange
from torch import nn
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, num_classes)
)
def forward(self, img):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)

View File

@@ -1,48 +1,62 @@
import torch
from einops import rearrange
import torch.nn.functional as F
from einops import rearrange
from torch import nn
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim)
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
self.to_out = nn.Sequential(
nn.Linear(dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
del mask
attn = dots.softmax(dim=-1)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
@@ -51,20 +65,22 @@ class Attention(nn.Module):
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim):
def __init__(self, dim, depth, heads, mlp_dim, dropout):
super().__init__()
layers = []
self.layers = nn.ModuleList([])
for _ in range(depth):
layers.extend([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
])
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x, mask = None):
for attn, ff in self.layers:
x = attn(x, mask = mask)
x = ff(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
@@ -75,17 +91,22 @@ class ViT(nn.Module):
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, num_classes)
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes),
nn.Dropout(dropout)
)
def forward(self, img):
def forward(self, img, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
@@ -94,7 +115,9 @@ class ViT(nn.Module):
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)