mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-01-05 12:12:26 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0e4790c24 | ||
|
|
0b2b3fc20c | ||
|
|
ced464dcb4 | ||
|
|
5bf45a2d4d | ||
|
|
fa32e22855 | ||
|
|
a0fa41070f | ||
|
|
b298031c17 |
45
README.md
45
README.md
@@ -23,7 +23,9 @@ v = ViT(
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
mlp_dim = 2048,
|
||||
attn_dropout = 0.1,
|
||||
ff_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
@@ -32,7 +34,9 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
|
||||
preds = v(img, mask = mask) # (1, 1000)
|
||||
```
|
||||
|
||||
## Suggestion
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
@@ -82,6 +86,43 @@ torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
|
||||
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
### Efficient Attention
|
||||
|
||||
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
|
||||
|
||||
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
|
||||
|
||||
```bash
|
||||
$ pip install linformer
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.efficient import ViT
|
||||
from linformer import Linformer
|
||||
|
||||
efficient_transformer = Linformer(
|
||||
dim = 512,
|
||||
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
k = 256
|
||||
)
|
||||
|
||||
v = ViT(
|
||||
dim = 512,
|
||||
image_size = 2048,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
transformer = efficient_transformer
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
|
||||
6253
examples/VisualTransformer | Cats&Dogs Edition.ipynb
Normal file
6253
examples/VisualTransformer | Cats&Dogs Edition.ipynb
Normal file
File diff suppressed because one or more lines are too long
6
setup.py
6
setup.py
@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.0.5',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.2.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -25,4 +25,4 @@ setup(
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
40
vit_pytorch/efficient.py
Normal file
40
vit_pytorch/efficient.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = transformer
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.transformer(x)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
return self.mlp_head(x)
|
||||
@@ -19,24 +19,26 @@ class PreNorm(nn.Module):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8):
|
||||
def __init__(self, dim, heads = 8, dropout = 0.):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(dim, dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
def forward(self, x, mask = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x)
|
||||
@@ -52,6 +54,7 @@ class Attention(nn.Module):
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
@@ -59,13 +62,13 @@ class Attention(nn.Module):
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, mlp_dim):
|
||||
def __init__(self, dim, depth, heads, mlp_dim, attn_dropout, ff_dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = attn_dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = ff_dropout)))
|
||||
]))
|
||||
def forward(self, x, mask = None):
|
||||
for attn, ff in self.layers:
|
||||
@@ -74,7 +77,7 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, attn_dropout = 0., ff_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
@@ -85,11 +88,12 @@ class ViT(nn.Module):
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim, attn_dropout, ff_dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
|
||||
Reference in New Issue
Block a user