mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c8dfc185e | ||
|
|
4f84ad7a64 | ||
|
|
c74bc781f0 | ||
|
|
dc5b89c942 | ||
|
|
c1043ab00c | ||
|
|
7a214d7109 | ||
|
|
6d1df1a970 | ||
|
|
d65a8c17a5 | ||
|
|
f7c164d910 | ||
|
|
c7b74e0bc3 | ||
|
|
5b5d98a3a7 | ||
|
|
b0e4790c24 | ||
|
|
0b2b3fc20c | ||
|
|
ced464dcb4 | ||
|
|
5bf45a2d4d | ||
|
|
fa32e22855 |
32
README.md
32
README.md
@@ -4,6 +4,8 @@
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -23,7 +25,9 @@ v = ViT(
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
@@ -124,13 +128,23 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@inproceedings{
|
||||
anonymous2021an,
|
||||
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author={Anonymous},
|
||||
booktitle={Submitted to International Conference on Learning Representations},
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=YicbFdNTTy},
|
||||
note={under review}
|
||||
@misc{dosovitskiy2020image,
|
||||
title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
|
||||
year = {2020},
|
||||
eprint = {2010.11929},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
|
||||
year = {2017},
|
||||
eprint = {1706.03762},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
6253
examples/cats_and_dogs.ipynb
Normal file
6253
examples/cats_and_dogs.ipynb
Normal file
File diff suppressed because one or more lines are too long
6
setup.py
6
setup.py
@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.2.0',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.2.7',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -25,4 +25,4 @@ setup(
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
@@ -30,10 +30,11 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.transformer(x)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
@@ -19,36 +21,43 @@ class PreNorm(nn.Module):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim)
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8):
|
||||
def __init__(self, dim, heads = 8, dropout = 0.):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(dim, dim)
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
|
||||
if mask is not None:
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
dots.masked_fill_(~mask, float('-inf'))
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
@@ -59,13 +68,13 @@ class Attention(nn.Module):
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, mlp_dim):
|
||||
def __init__(self, dim, depth, heads, mlp_dim, dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x, mask = None):
|
||||
for attn, ff in self.layers:
|
||||
@@ -74,18 +83,21 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
@@ -93,6 +105,7 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
)
|
||||
|
||||
@@ -101,10 +114,13 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
|
||||
Reference in New Issue
Block a user