mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c37586510 |
38
README.md
38
README.md
@@ -6,6 +6,7 @@
|
||||
- [Install](#install)
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Simple ViT](#simple-vit)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
@@ -106,6 +107,33 @@ Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
|
||||
## Simple ViT
|
||||
|
||||
<a href="https://arxiv.org/abs/2205.01580">An update</a> from some of the same authors of the original paper proposes simplifications to `ViT` that allows it to train faster and better.
|
||||
|
||||
Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head
|
||||
|
||||
You can use it by importing the `SimpleViT` as shown below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import SimpleViT
|
||||
|
||||
v = SimpleViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -1669,6 +1697,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Beyer2022BetterPlainViT
|
||||
title = {Better plain ViT baselines for ImageNet-1k},
|
||||
author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
|
||||
publisher = {arXiv},
|
||||
year = {2022}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.34.1',
|
||||
version = '0.35.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
from vit_pytorch.mae import MAE
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
120
vit_pytorch/simple_vit.py
Normal file
120
vit_pytorch/simple_vit.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(img, dim, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, _, device, dtype = *img.shape, img.device, img.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.dim = dim
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x, self.dim)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
Reference in New Issue
Block a user