mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13b313b6d8 | ||
|
|
56373c0cbd | ||
|
|
24196a3e8a | ||
|
|
f6d7287b6b | ||
|
|
d47c57e32f | ||
|
|
0449865786 | ||
|
|
6693d47d0b | ||
|
|
141239ca86 |
13
README.md
13
README.md
@@ -198,7 +198,7 @@ preds = v(
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.5` and import as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -2161,4 +2161,15 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Zhu2024HyperConnections,
|
||||
title = {Hyper-Connections},
|
||||
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
||||
journal = {ArXiv},
|
||||
year = {2024},
|
||||
volume = {abs/2409.19606},
|
||||
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
4
setup.py
4
setup.py
@@ -6,10 +6,10 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.8.6',
|
||||
version = '1.9.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description = long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
|
||||
@@ -6,9 +6,6 @@ from functools import partial
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.4')
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
@@ -44,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
@@ -59,8 +56,8 @@ class Attention(Module):
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
@@ -114,13 +111,13 @@ class Attention(Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
@@ -149,9 +146,15 @@ class NaViT(Module):
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
@@ -182,7 +185,7 @@ class NaViT(Module):
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
@@ -323,3 +326,5 @@ if __name__ == '__main__':
|
||||
]
|
||||
|
||||
assert v(images).shape == (5, 1000)
|
||||
|
||||
v(images).sum().backward()
|
||||
|
||||
@@ -6,9 +6,6 @@ from functools import partial
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.4')
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
@@ -44,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
@@ -59,8 +56,8 @@ class Attention(Module):
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
@@ -126,13 +123,13 @@ class Attention(Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
@@ -164,11 +161,15 @@ class NaViT(Module):
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_registers = 4,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
@@ -209,7 +210,7 @@ class NaViT(Module):
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
@@ -336,7 +337,7 @@ class NaViT(Module):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# works for torch 2.4
|
||||
# works for torch 2.5
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
@@ -362,3 +363,5 @@ if __name__ == '__main__':
|
||||
]
|
||||
|
||||
assert v(volumes).shape == (5, 1000)
|
||||
|
||||
v(volumes).sum().backward()
|
||||
|
||||
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
ViT + Hyper-Connections + Register Tokens
|
||||
https://arxiv.org/abs/2409.19606
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn, tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, einsum, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# hyper connections
|
||||
|
||||
class HyperConnection(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_residual_streams,
|
||||
layer_index
|
||||
):
|
||||
""" Appendix J - Algorithm 2, Dynamic only """
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
self.layer_index = layer_index
|
||||
|
||||
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
||||
|
||||
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
||||
init_alpha0[layer_index % num_residual_streams, 0] = 1.
|
||||
|
||||
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
||||
|
||||
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
||||
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
|
||||
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
||||
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
|
||||
|
||||
def width_connection(self, residuals):
|
||||
normed = self.norm(residuals)
|
||||
|
||||
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
|
||||
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
||||
alpha = dynamic_alpha + self.static_alpha
|
||||
|
||||
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
|
||||
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
||||
beta = dynamic_beta + self.static_beta
|
||||
|
||||
# width connection
|
||||
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
|
||||
|
||||
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
||||
|
||||
return branch_input, residuals, beta
|
||||
|
||||
def depth_connection(
|
||||
self,
|
||||
residuals,
|
||||
branch_output,
|
||||
beta
|
||||
):
|
||||
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
|
||||
super().__init__()
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for layer_index in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
|
||||
|
||||
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
|
||||
|
||||
x, attn_res, beta = attn_hyper_conn.width_connection(x)
|
||||
|
||||
x = attn(x)
|
||||
|
||||
x = attn_hyper_conn.depth_connection(attn_res, x, beta)
|
||||
|
||||
x, ff_res, beta = ff_hyper_conn.width_connection(x)
|
||||
|
||||
x = ff(x)
|
||||
|
||||
x = ff_hyper_conn.depth_connection(ff_res, x, beta)
|
||||
|
||||
x = reduce(x, 'b n e d -> b n d', 'sum')
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(x)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
num_residual_streams = 8
|
||||
)
|
||||
|
||||
images = torch.randn(3, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
@@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim):
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
@@ -50,6 +50,12 @@ class Attention(Module):
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_residual_mix = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b n h -> b h n 1')
|
||||
) if learned_value_residual_mix else (lambda _: 0.5)
|
||||
|
||||
def forward(self, x, value_residual = None):
|
||||
x = self.norm(x)
|
||||
|
||||
@@ -57,7 +63,8 @@ class Attention(Module):
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(value_residual):
|
||||
v = v + value_residual
|
||||
mix = self.to_residual_mix(x)
|
||||
v = v * mix + value_residual * (1. - mix)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
@@ -73,9 +80,10 @@ class Transformer(Module):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
for i in range(depth):
|
||||
is_first = i == 0
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user