mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-02-16 16:20:16 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c8dfc185e | ||
|
|
4f84ad7a64 | ||
|
|
c74bc781f0 | ||
|
|
dc5b89c942 | ||
|
|
c1043ab00c | ||
|
|
7a214d7109 | ||
|
|
6d1df1a970 | ||
|
|
d65a8c17a5 | ||
|
|
f7c164d910 | ||
|
|
c7b74e0bc3 | ||
|
|
5b5d98a3a7 |
28
README.md
28
README.md
@@ -4,6 +4,8 @@
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -126,13 +128,23 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@inproceedings{
|
||||
anonymous2021an,
|
||||
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author={Anonymous},
|
||||
booktitle={Submitted to International Conference on Learning Representations},
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=YicbFdNTTy},
|
||||
note={under review}
|
||||
@misc{dosovitskiy2020image,
|
||||
title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
|
||||
year = {2020},
|
||||
eprint = {2010.11929},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
|
||||
year = {2017},
|
||||
eprint = {1706.03762},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.2.2',
|
||||
version = '0.2.7',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
@@ -30,10 +30,11 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.transformer(x)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
@@ -45,16 +47,17 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
|
||||
if mask is not None:
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
dots.masked_fill_(~mask, float('-inf'))
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
@@ -85,6 +88,7 @@ class ViT(nn.Module):
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
@@ -102,8 +106,7 @@ class ViT(nn.Module):
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(mlp_dim, num_classes),
|
||||
nn.Dropout(dropout)
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None):
|
||||
@@ -111,10 +114,11 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
Reference in New Issue
Block a user