mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
first commit
This commit is contained in:
28
README.md
28
README.md
@@ -1,7 +1,33 @@
|
||||
## Vision Transformer - Pytorch (wip)
|
||||
## Vision Transformer - Pytorch
|
||||
|
||||
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. There's really not much to code here, but may as well lay out all the code so we expedite the attention revolution and get everyone on the same page.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
$ pip install vit-pytorch
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
|
||||
28
setup.py
Normal file
28
setup.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.0.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
url = 'https://github.com/lucidrains/vit-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
'attention mechanism',
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'torch>=1.6',
|
||||
'einops>=0.3'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
1
vit_pytorch/__init__.py
Normal file
1
vit_pytorch/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
95
vit_pytorch/vit_pytorch.py
Normal file
95
vit_pytorch/vit_pytorch.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
def forward(self, x):
|
||||
return self.fn(x) + x
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x))
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(dim, dim)
|
||||
def forward(self, x):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
attn = dots.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('bhij,bhjd->bhid', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, mlp_dim):
|
||||
super().__init__()
|
||||
layers = []
|
||||
for _ in range(depth):
|
||||
layers.extend([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
|
||||
])
|
||||
self.net = nn.Sequential(*layers)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (p1 h) (p2 w) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
x = torch.cat((self.cls_token, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.transformer(x)
|
||||
|
||||
return self.mlp_head(x[:, 0])
|
||||
Reference in New Issue
Block a user