mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c8dfc185e | ||
|
|
4f84ad7a64 | ||
|
|
c74bc781f0 | ||
|
|
dc5b89c942 | ||
|
|
c1043ab00c | ||
|
|
7a214d7109 | ||
|
|
6d1df1a970 | ||
|
|
d65a8c17a5 |
28
README.md
28
README.md
@@ -4,6 +4,8 @@
|
||||
|
||||
Implementation of <a href="https://openreview.net/pdf?id=YicbFdNTTy">Vision Transformer</a>, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in <a href="https://www.youtube.com/watch?v=TrdevFK_am4">Yannic Kilcher's</a> video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.
|
||||
|
||||
For a Pytorch implementation with pretrained models, please see Ross Wightman's repository <a href="https://github.com/rwightman/pytorch-image-models">here</a>
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -126,13 +128,23 @@ Other sparse attention frameworks I would highly recommend is <a href="https://g
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@inproceedings{
|
||||
anonymous2021an,
|
||||
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author={Anonymous},
|
||||
booktitle={Submitted to International Conference on Learning Representations},
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=YicbFdNTTy},
|
||||
note={under review}
|
||||
@misc{dosovitskiy2020image,
|
||||
title = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
|
||||
author = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
|
||||
year = {2020},
|
||||
eprint = {2010.11929},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
author = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
|
||||
year = {2017},
|
||||
eprint = {1706.03762},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.2.3',
|
||||
version = '0.2.7',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
@@ -30,10 +30,11 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.transformer(x)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
@@ -47,16 +47,17 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
mask_value = -torch.finfo(dots.dtype).max
|
||||
|
||||
if mask is not None:
|
||||
mask = F.pad(mask.flatten(1), (1, 0), value = True)
|
||||
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
|
||||
mask = mask[:, None, :] * mask[:, :, None]
|
||||
dots.masked_fill_(~mask, float('-inf'))
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = dots.softmax(dim=-1)
|
||||
@@ -105,8 +106,7 @@ class ViT(nn.Module):
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(mlp_dim, num_classes),
|
||||
nn.Dropout(dropout)
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None):
|
||||
@@ -114,10 +114,11 @@ class ViT(nn.Module):
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
Reference in New Issue
Block a user