Compare commits

..

23 Commits
1.7.8 ... 1.9.2

Author SHA1 Message Date
lucidrains
9f49a31977 1.9.2 2025-01-19 05:53:11 -08:00
JacobLinCool
ab63fc9cc8 remove duplicated qkv computation in na_vit_nested_tensor_3d.py (#341) 2025-01-19 05:52:46 -08:00
Phil Wang
c3018d1433 1.9.1 2025-01-04 07:55:49 -08:00
Kale Kundert
b7ed6bad28 add option to set frame padding for 3D CCT (#339) 2025-01-04 07:55:27 -08:00
lucidrains
e7cba9ba6d add a simple vit flavor for a new bytedance paper that proposes to break out of the traditional one residual stream architecture - "hyper-connections" 2024-12-20 17:43:50 -08:00
lucidrains
56373c0cbd make value residual learned 2024-11-24 08:21:28 -08:00
lucidrains
24196a3e8a allow for qk norm to be turned off for na vit nested tensor 2024-11-20 10:59:22 -08:00
Phil Wang
f6d7287b6b readme 2024-11-19 08:20:38 -08:00
lucidrains
d47c57e32f fix tests 2024-11-10 09:43:54 -08:00
lucidrains
0449865786 update minimum version for nested tensor of NaViT 2024-11-10 09:37:48 -08:00
lucidrains
6693d47d0b update comment for navit 3d 2024-11-07 20:02:07 -08:00
Phil Wang
141239ca86 fix value residual 2024-10-31 06:48:24 -07:00
lucidrains
0b5c9b4559 add value residual based simple vit 2024-10-28 09:19:00 -07:00
lucidrains
e300cdd7dc fix multiheaded qk rmsnorm in nViT 2024-10-10 19:15:17 -07:00
Phil Wang
36ddc7a6ba go all the way with the normalized vit, fix some scales 2024-10-10 10:42:37 -07:00
Phil Wang
1d1a63fc5c cite for hypersphere vit adapted from ngpt 2024-10-10 10:15:04 -07:00
Phil Wang
74b62009f8 go for multi-headed rmsnorm for the qknorm on hypersphere vit 2024-10-10 08:09:58 -07:00
Phil Wang
f50d7d1436 add a hypersphere vit, adapted from https://arxiv.org/abs/2410.01131 2024-10-09 07:32:25 -07:00
lucidrains
82f2fa751d address https://github.com/lucidrains/vit-pytorch/issues/330 2024-10-04 07:01:48 -07:00
lucidrains
fcb9501cdd add register tokens to the nested tensor 3d na vit example for researcher 2024-08-28 12:21:31 -07:00
lucidrains
c4651a35a3 1.7.11 2024-08-21 19:24:13 -07:00
roydenwa
9d43e4d0bb Add ViViT variant with factorized self-attention (#327)
* Add FactorizedTransformer

* Add variant param and check in fwd method

* Check if variant is implemented

* Describe new ViViT variant
2024-08-21 19:23:38 -07:00
Phil Wang
5e808f48d1 3d version of navit nested tensor 2024-08-21 07:23:21 -07:00
11 changed files with 847 additions and 58 deletions

View File

@@ -198,7 +198,7 @@ preds = v(
) # (5, 1000)
```
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.5` and import as follows
```python
import torch
@@ -1218,7 +1218,8 @@ pred = cct(video)
<img src="./images/vivit.png" width="350px"></img>
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant.
The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers.
```python
import torch
@@ -1234,7 +1235,8 @@ v = ViT(
spatial_depth = 6, # depth of the spatial transformer
temporal_depth = 6, # depth of the temporal transformer
heads = 8,
mlp_dim = 2048
mlp_dim = 2048,
variant = 'factorized_encoder', # or 'factorized_self_attention'
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
@@ -2131,4 +2133,43 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Loshchilov2024nGPTNT,
title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273026160}
}
```
```bibtex
@inproceedings{Liu2017DeepHL,
title = {Deep Hyperspherical Learning},
author = {Weiyang Liu and Yanming Zhang and Xingguo Li and Zhen Liu and Bo Dai and Tuo Zhao and Le Song},
booktitle = {Neural Information Processing Systems},
year = {2017},
url = {https://api.semanticscholar.org/CorpusID:5104558}
}
```
```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```
```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -6,10 +6,10 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.8',
version = '1.9.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
long_description = long_description,
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',

View File

@@ -167,8 +167,10 @@ class Tokenizer(nn.Module):
stride,
padding,
frame_stride=1,
frame_padding=None,
frame_pooling_stride=1,
frame_pooling_kernel_size=1,
frame_pooling_padding=None,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
@@ -188,16 +190,22 @@ class Tokenizer(nn.Module):
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
if frame_padding is None:
frame_padding = frame_kernel_size // 2
if frame_pooling_padding is None:
frame_pooling_padding = frame_pooling_kernel_size // 2
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv3d(chan_in, chan_out,
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
stride=(frame_stride, stride, stride),
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
padding=(frame_padding, padding, padding), bias=conv_bias),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
)
for chan_in, chan_out in n_filter_list_pairs
])
@@ -324,8 +332,10 @@ class CCT(nn.Module):
n_conv_layers=1,
frame_stride=1,
frame_kernel_size=3,
frame_padding=None,
frame_pooling_kernel_size=1,
frame_pooling_stride=1,
frame_pooling_padding=None,
kernel_size=7,
stride=2,
padding=3,
@@ -342,8 +352,10 @@ class CCT(nn.Module):
n_output_channels=embedding_dim,
frame_stride=frame_stride,
frame_kernel_size=frame_kernel_size,
frame_padding=frame_padding,
frame_pooling_stride=frame_pooling_stride,
frame_pooling_kernel_size=frame_pooling_kernel_size,
frame_pooling_padding=frame_pooling_padding,
kernel_size=kernel_size,
stride=stride,
padding=padding,

View File

@@ -6,9 +6,6 @@ from functools import partial
import torch
import packaging.version as pkg_version
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
print('nested tensor NaViT was tested on pytorch 2.4')
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
@@ -44,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
@@ -59,8 +56,8 @@ class Attention(Module):
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.dropout = dropout
@@ -114,13 +111,13 @@ class Attention(Module):
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -149,9 +146,15 @@ class NaViT(Module):
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
image_height, image_width = pair(image_size)
# what percent of tokens to dropout
@@ -182,7 +185,7 @@ class NaViT(Module):
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
# final attention pooling queries
@@ -323,3 +326,5 @@ if __name__ == '__main__':
]
assert v(images).shape == (5, 1000)
v(images).sum().backward()

View File

@@ -4,6 +4,8 @@ from typing import List
from functools import partial
import torch
import packaging.version as pkg_version
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
@@ -39,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
@@ -54,8 +56,8 @@ class Attention(Module):
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
self.dropout = dropout
@@ -82,7 +84,10 @@ class Attention(Module):
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
return t.unflatten(-1, (self.heads, self.dim_head))
def transpose_head_seq(t):
return t.transpose(1, 2)
query, key, value = map(split_heads, (query, key, value))
@@ -91,6 +96,8 @@ class Attention(Module):
query = self.query_norm(query)
key = self.key_norm(key)
query, key, value = map(transpose_head_seq, (query, key, value))
# attention
out = F.scaled_dot_product_attention(
@@ -105,13 +112,13 @@ class Attention(Module):
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -142,11 +149,16 @@ class NaViT(Module):
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
num_registers = 4,
qk_rmsnorm = True,
token_dropout_prob: float | None = None
):
super().__init__()
image_height, image_width = pair(image_size)
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
print('nested tensor NaViT was tested on pytorch 2.5')
# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width
@@ -172,13 +184,22 @@ class NaViT(Module):
nn.LayerNorm(dim),
)
self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))
# register tokens
self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))
nn.init.normal_(self.pos_embed_frame, std = 0.02)
nn.init.normal_(self.pos_embed_height, std = 0.02)
nn.init.normal_(self.pos_embed_width, std = 0.02)
nn.init.normal_(self.register_tokens, std = 0.02)
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
# final attention pooling queries
@@ -202,7 +223,7 @@ class NaViT(Module):
self,
volumes: List[Tensor], # different resolution images / CT scans
):
batch, device = len(images), self.device
batch, device = len(volumes), self.device
arange = partial(torch.arange, device = device)
assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'
@@ -254,8 +275,6 @@ class NaViT(Module):
pos_embed = frame_embed + height_embed + width_embed
# use nested tensor for transformers and save on padding computation
tokens = torch.cat(tokens)
# linear projection to patch embeddings
@@ -266,7 +285,15 @@ class NaViT(Module):
tokens = tokens + pos_embed
tokens = nested_tensor(tokens.split(seq_len.tolist()), layout = torch.jagged, device = device)
# add register tokens
tokens = tokens.split(seq_lens.tolist())
tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]
# use nested tensor for transformers and save on padding computation
tokens = nested_tensor(tokens, layout = torch.jagged, device = device)
# embedding dropout
@@ -299,7 +326,7 @@ class NaViT(Module):
if __name__ == '__main__':
# works for torch 2.2.2
# works for torch 2.5
v = NaViT(
image_size = 256,
@@ -318,12 +345,12 @@ if __name__ == '__main__':
# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]
# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking
images = [
volumes = [
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
torch.randn(3, 4, 64, 256)
]
assert v(images).shape == (5, 1000)
assert v(volumes).shape == (5, 1000)
v(volumes).sum().backward()

View File

@@ -0,0 +1,264 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# functions
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
def l2norm(t, dim = -1):
return F.normalize(t, dim = dim, p = 2)
# for use with parametrize
class L2Norm(Module):
def __init__(self, dim = -1):
super().__init__()
self.dim = dim
def forward(self, t):
return l2norm(t, dim = self.dim)
class NormLinear(Module):
def __init__(
self,
dim,
dim_out,
norm_dim_in = True
):
super().__init__()
self.linear = nn.Linear(dim, dim_out, bias = False)
parametrize.register_parametrization(
self.linear,
'weight',
L2Norm(dim = -1 if norm_dim_in else 0)
)
@property
def weight(self):
return self.linear.weight
def forward(self, x):
return self.linear(x)
# attention and feedforward
class Attention(Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
dim_inner = dim_head * heads
self.to_q = NormLinear(dim, dim_inner)
self.to_k = NormLinear(dim, dim_inner)
self.to_v = NormLinear(dim, dim_inner)
self.dropout = dropout
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
def forward(
self,
x
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q, k, v = map(self.split_heads, (q, k, v))
# query key rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.,
scale = 1.
)
out = self.merge_heads(out)
return self.to_out(out)
class FeedForward(Module):
def __init__(
self,
dim,
*,
dim_inner,
dropout = 0.
):
super().__init__()
dim_inner = int(dim_inner * 2 / 3)
self.dim = dim
self.dropout = nn.Dropout(dropout)
self.to_hidden = NormLinear(dim, dim_inner)
self.to_gate = NormLinear(dim, dim_inner)
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
def forward(self, x):
hidden, gate = self.to_hidden(x), self.to_gate(x)
hidden = hidden * self.hidden_scale
gate = gate * self.gate_scale * (self.dim ** 0.5)
hidden = F.silu(gate) * hidden
hidden = self.dropout(hidden)
return self.to_out(hidden)
# classes
class nViT(Module):
""" https://arxiv.org/abs/2410.01131 """
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
dropout = 0.,
channels = 3,
dim_head = 64,
residual_lerp_scale_init = None
):
super().__init__()
image_height, image_width = pair(image_size)
# calculate patching related stuff
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2)
num_patches = patch_height_dim * patch_width_dim
self.channels = channels
self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
NormLinear(patch_dim, dim, norm_dim_in = False),
)
self.abs_pos_emb = NormLinear(dim, num_patches)
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
# layers
self.dim = dim
self.scale = dim ** 0.5
self.layers = ModuleList([])
self.residual_lerp_scales = nn.ParameterList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, dim_inner = mlp_dim, dropout = dropout),
]))
self.residual_lerp_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
]))
self.logit_scale = nn.Parameter(torch.ones(num_classes))
self.to_pred = NormLinear(dim, num_classes)
@torch.no_grad()
def norm_weights_(self):
for module in self.modules():
if not isinstance(module, NormLinear):
continue
normed = module.weight
original = module.linear.parametrizations.weight.original
original.copy_(normed)
def forward(self, images):
device = images.device
tokens = self.to_patch_embedding(images)
seq_len = tokens.shape[-2]
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
tokens = l2norm(tokens + pos_emb)
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
attn_out = l2norm(attn(tokens))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
ff_out = l2norm(ff(tokens))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
pooled = reduce(tokens, 'b n d -> b d', 'mean')
logits = self.to_pred(pooled)
logits = logits * self.logit_scale * self.scale
return logits
# quick test
if __name__ == '__main__':
v = nViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
)
img = torch.randn(4, 3, 256, 256)
logits = v(img) # (4, 1000)
assert logits.shape == (4, 1000)

View File

@@ -20,6 +20,18 @@ def divisible_by(val, d):
# helper classes
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@@ -212,10 +224,10 @@ class RegionViT(nn.Module):
if tokenize_local_3_conv:
self.local_encoder = nn.Sequential(
nn.Conv2d(3, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
ChanLayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
ChanLayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
)

View File

@@ -3,14 +3,14 @@ from math import sqrt, pi, log
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.amp import autocast
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# rotary embeddings
@autocast(enabled = False)
@autocast('cuda', enabled = False)
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
@@ -24,7 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales)
@autocast(enabled = False)
@autocast('cuda', enabled = False)
def forward(self, x):
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))

View File

@@ -0,0 +1,233 @@
"""
ViT + Hyper-Connections + Register Tokens
https://arxiv.org/abs/2409.19606
"""
import torch
from torch import nn, tensor
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# hyper connections
class HyperConnection(Module):
def __init__(
self,
dim,
num_residual_streams,
layer_index
):
""" Appendix J - Algorithm 2, Dynamic only """
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
self.num_residual_streams = num_residual_streams
self.layer_index = layer_index
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
init_alpha0 = torch.zeros((num_residual_streams, 1))
init_alpha0[layer_index % num_residual_streams, 0] = 1.
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
def width_connection(self, residuals):
normed = self.norm(residuals)
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta
# width connection
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
return branch_input, residuals, beta
def depth_connection(
self,
branch_output,
residuals,
beta
):
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
# classes
class FeedForward(Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
super().__init__()
self.num_residual_streams = num_residual_streams
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for layer_index in range(depth):
self.layers.append(nn.ModuleList([
HyperConnection(dim, num_residual_streams, layer_index),
Attention(dim, heads = heads, dim_head = dim_head),
HyperConnection(dim, num_residual_streams, layer_index),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
x, attn_res, beta = attn_hyper_conn.width_connection(x)
x = attn(x)
x = attn_hyper_conn.depth_connection(x, attn_res, beta)
x, ff_res, beta = ff_hyper_conn.width_connection(x)
x = ff(x)
x = ff_hyper_conn.depth_connection(x, ff_res, beta)
x = reduce(x, 'b n e d -> b n d', 'sum')
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(x)
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
x, ps = pack([x, r], 'b * d')
x = self.transformer(x)
x, _ = unpack(x, ps, 'b * d')
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# main
if __name__ == '__main__':
vit = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 12,
heads = 8,
mlp_dim = 2048,
num_residual_streams = 8
)
images = torch.randn(3, 3, 256, 256)
logits = vit(images)

View File

@@ -0,0 +1,159 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
def FeedForward(dim, hidden_dim):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if learned_value_residual_mix else (lambda _: 0.5)
def forward(self, x, value_residual = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
if exists(value_residual):
mix = self.to_residual_mix(x)
v = v * mix + value_residual * (1. - mix)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out), v
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for i in range(depth):
is_first = i == 0
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
value_residual = None
for attn, ff in self.layers:
attn_out, values = attn(x, value_residual = value_residual)
value_residual = default(value_residual, values)
x = attn_out + x
x = ff(x) + x
return self.norm(x)
class SimpleViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# quick test
if __name__ == '__main__':
v = SimpleViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
)
images = torch.randn(2, 3, 256, 256)
logits = v(images)

View File

@@ -78,6 +78,30 @@ class Transformer(nn.Module):
x = ff(x) + x
return self.norm(x)
class FactorizedTransformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
b, f, n, _ = x.shape
for spatial_attn, temporal_attn, ff in self.layers:
x = rearrange(x, 'b f n d -> (b f) n d')
x = spatial_attn(x) + x
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
x = temporal_attn(x) + x
x = ff(x) + x
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
return self.norm(x)
class ViT(nn.Module):
def __init__(
self,
@@ -96,7 +120,8 @@ class ViT(nn.Module):
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
emb_dropout = 0.,
variant = 'factorized_encoder',
):
super().__init__()
image_height, image_width = pair(image_size)
@@ -104,6 +129,7 @@ class ViT(nn.Module):
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)
@@ -125,15 +151,20 @@ class ViT(nn.Module):
self.dropout = nn.Dropout(emb_dropout)
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
if variant == 'factorized_encoder':
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
elif variant == 'factorized_self_attention':
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.variant = variant
def forward(self, video):
x = self.to_patch_embedding(video)
@@ -147,32 +178,37 @@ class ViT(nn.Module):
x = self.dropout(x)
x = rearrange(x, 'b f n d -> (b f) n d')
if self.variant == 'factorized_encoder':
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
# attend across space
x = self.spatial_transformer(x)
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
# excise out the spatial cls tokens or average pool for temporal attention
# excise out the spatial cls tokens or average pool for temporal attention
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
# append temporal CLS tokens
# append temporal CLS tokens
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
# attend across time
# attend across time
x = self.temporal_transformer(x)
x = self.temporal_transformer(x)
# excise out temporal cls token or average pool
# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
elif self.variant == 'factorized_self_attention':
x = self.factorized_transformer(x)
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)