mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1de866d15d | ||
|
|
9f49a31977 | ||
|
|
ab63fc9cc8 | ||
|
|
c3018d1433 | ||
|
|
b7ed6bad28 | ||
|
|
e7cba9ba6d | ||
|
|
56373c0cbd | ||
|
|
24196a3e8a | ||
|
|
f6d7287b6b | ||
|
|
d47c57e32f | ||
|
|
0449865786 | ||
|
|
6693d47d0b | ||
|
|
141239ca86 | ||
|
|
0b5c9b4559 | ||
|
|
e300cdd7dc | ||
|
|
36ddc7a6ba | ||
|
|
1d1a63fc5c | ||
|
|
74b62009f8 | ||
|
|
f50d7d1436 | ||
|
|
82f2fa751d | ||
|
|
fcb9501cdd | ||
|
|
c4651a35a3 | ||
|
|
9d43e4d0bb | ||
|
|
5e808f48d1 | ||
|
|
bed48b5912 | ||
|
|
73199ab486 |
1
.github/workflows/python-test.yml
vendored
1
.github/workflows/python-test.yml
vendored
@@ -28,6 +28,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
python -m pip install wheel
|
||||
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
||||
86
README.md
86
README.md
@@ -198,6 +198,38 @@ preds = v(
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.5` and import as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.na_vit_nested_tensor import NaViT
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
preds = v(images)
|
||||
|
||||
assert preds.shape == (5, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -1186,7 +1218,8 @@ pred = cct(video)
|
||||
|
||||
<img src="./images/vivit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant.
|
||||
The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -1202,7 +1235,8 @@ v = ViT(
|
||||
spatial_depth = 6, # depth of the spatial transformer
|
||||
temporal_depth = 6, # depth of the temporal transformer
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
mlp_dim = 2048,
|
||||
variant = 'factorized_encoder', # or 'factorized_self_attention'
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
@@ -2099,4 +2133,52 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Loshchilov2024nGPTNT,
|
||||
title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
|
||||
author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:273026160}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2017DeepHL,
|
||||
title = {Deep Hyperspherical Learning},
|
||||
author = {Weiyang Liu and Yanming Zhang and Xingguo Li and Zhen Liu and Bo Dai and Tuo Zhao and Le Song},
|
||||
booktitle = {Neural Information Processing Systems},
|
||||
year = {2017},
|
||||
url = {https://api.semanticscholar.org/CorpusID:5104558}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Zhou2024ValueRL,
|
||||
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
|
||||
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:273532030}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Zhu2024HyperConnections,
|
||||
title = {Hyper-Connections},
|
||||
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
||||
journal = {ArXiv},
|
||||
year = {2024},
|
||||
volume = {abs/2409.19606},
|
||||
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Fuller2025SimplerFV,
|
||||
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
|
||||
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
|
||||
year = {2025},
|
||||
url = {https://api.semanticscholar.org/CorpusID:276557720}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
8
setup.py
8
setup.py
@@ -6,10 +6,10 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.7.5',
|
||||
version = '1.10.1',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description = long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
@@ -29,8 +29,8 @@ setup(
|
||||
],
|
||||
tests_require=[
|
||||
'pytest',
|
||||
'torch==1.12.1',
|
||||
'torchvision==0.13.1'
|
||||
'torch==2.4.0',
|
||||
'torchvision==0.19.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
@@ -167,8 +167,10 @@ class Tokenizer(nn.Module):
|
||||
stride,
|
||||
padding,
|
||||
frame_stride=1,
|
||||
frame_padding=None,
|
||||
frame_pooling_stride=1,
|
||||
frame_pooling_kernel_size=1,
|
||||
frame_pooling_padding=None,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
@@ -188,16 +190,22 @@ class Tokenizer(nn.Module):
|
||||
|
||||
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
||||
|
||||
if frame_padding is None:
|
||||
frame_padding = frame_kernel_size // 2
|
||||
|
||||
if frame_pooling_padding is None:
|
||||
frame_pooling_padding = frame_pooling_kernel_size // 2
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv3d(chan_in, chan_out,
|
||||
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
|
||||
stride=(frame_stride, stride, stride),
|
||||
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
|
||||
padding=(frame_padding, padding, padding), bias=conv_bias),
|
||||
nn.Identity() if not exists(activation) else activation(),
|
||||
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
|
||||
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
|
||||
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||
)
|
||||
for chan_in, chan_out in n_filter_list_pairs
|
||||
])
|
||||
@@ -324,8 +332,10 @@ class CCT(nn.Module):
|
||||
n_conv_layers=1,
|
||||
frame_stride=1,
|
||||
frame_kernel_size=3,
|
||||
frame_padding=None,
|
||||
frame_pooling_kernel_size=1,
|
||||
frame_pooling_stride=1,
|
||||
frame_pooling_padding=None,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
@@ -342,8 +352,10 @@ class CCT(nn.Module):
|
||||
n_output_channels=embedding_dim,
|
||||
frame_stride=frame_stride,
|
||||
frame_kernel_size=frame_kernel_size,
|
||||
frame_padding=frame_padding,
|
||||
frame_pooling_stride=frame_pooling_stride,
|
||||
frame_pooling_kernel_size=frame_pooling_kernel_size,
|
||||
frame_pooling_padding=frame_pooling_padding,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
|
||||
201
vit_pytorch/jumbo_vit.py
Normal file
201
vit_pytorch/jumbo_vit.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
|
||||
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
|
||||
return pos_emb.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, mult = 4.):
|
||||
hidden_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class JumboViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
|
||||
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
|
||||
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
|
||||
channels = 3,
|
||||
dim_head = 64
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
jumbo_cls_dim = dim * jumbo_cls_k
|
||||
|
||||
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
|
||||
|
||||
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
|
||||
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
self.jumbo_ff = nn.Sequential(
|
||||
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
|
||||
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
|
||||
jumbo_cls_to_tokens
|
||||
)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim),
|
||||
]))
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
# pos embedding
|
||||
|
||||
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
|
||||
|
||||
x = x + pos_emb
|
||||
|
||||
# add cls tokens
|
||||
|
||||
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
|
||||
|
||||
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
|
||||
|
||||
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
for layer, (attn, ff) in enumerate(self.layers, start = 1):
|
||||
is_last = layer == len(self.layers)
|
||||
|
||||
x = attn(x) + x
|
||||
|
||||
# jumbo feedforward
|
||||
|
||||
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
|
||||
|
||||
x = ff(x) + x
|
||||
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
|
||||
|
||||
if is_last:
|
||||
continue
|
||||
|
||||
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
|
||||
|
||||
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
# normalization and project to logits
|
||||
|
||||
embed = self.norm(pooled)
|
||||
|
||||
embed = self.to_latent(embed)
|
||||
logits = self.linear_head(embed)
|
||||
return logits
|
||||
|
||||
# copy pasteable file
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = JumboViT(
|
||||
num_classes = 1000,
|
||||
image_size = 64,
|
||||
patch_size = 8,
|
||||
dim = 16,
|
||||
depth = 2,
|
||||
heads = 2,
|
||||
mlp_dim = 32,
|
||||
jumbo_cls_k = 3,
|
||||
jumbo_ff_mult = 2,
|
||||
)
|
||||
|
||||
images = torch.randn(1, 3, 64, 64)
|
||||
|
||||
logits = v(images)
|
||||
assert logits.shape == (1, 1000)
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -245,7 +247,7 @@ class NaViT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
|
||||
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
@@ -264,6 +266,11 @@ class NaViT(nn.Module):
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
# if List[Tensor] is not grouped -> List[List[Tensor]]
|
||||
|
||||
if torch.is_tensor(batched_images[0]):
|
||||
batched_images = [batched_images]
|
||||
|
||||
# process images into variable lengthed sequences with attention mask
|
||||
|
||||
num_images = []
|
||||
|
||||
330
vit_pytorch/na_vit_nested_tensor.py
Normal file
330
vit_pytorch/na_vit_nested_tensor.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nested import nested_tensor
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
dim_inner = heads * dim_head
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_values = nn.Linear(dim, dim_inner, bias = False)
|
||||
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
# for attention pooling, one query pooling to entire sequence
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
def transpose_head_seq(t):
|
||||
return t.transpose(1, 2)
|
||||
|
||||
query, key, value = map(split_heads, (query, key, value))
|
||||
|
||||
# qk norm for attention stability
|
||||
|
||||
query = self.query_norm(query)
|
||||
key = self.key_norm(key)
|
||||
|
||||
query, key, value = map(transpose_head_seq, (query, key, value))
|
||||
|
||||
# attention
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = out.transpose(1, 2).flatten(-2)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (h p1) (w p2) -> h w (c p1 p2)', p1 = patch_size, p2 = patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: List[Tensor], # different resolution images
|
||||
):
|
||||
batch, device = len(images), self.device
|
||||
arange = partial(torch.arange, device = device)
|
||||
|
||||
assert all([image.ndim == 3 and image.shape[0] == self.channels for image in images]), f'all images must have {self.channels} channels and number of dimensions of 3 (channels, height, width)'
|
||||
|
||||
all_patches = [self.to_patches(image) for image in images]
|
||||
|
||||
# prepare factorized positional embedding height width indices
|
||||
|
||||
positions = []
|
||||
|
||||
for patches in all_patches:
|
||||
patch_height, patch_width = patches.shape[:2]
|
||||
hw_indices = torch.stack(torch.meshgrid((arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
|
||||
hw_indices = rearrange(hw_indices, 'h w c -> (h w) c')
|
||||
positions.append(hw_indices)
|
||||
|
||||
# need the sizes to compute token dropout + positional embedding
|
||||
|
||||
tokens = [rearrange(patches, 'h w d -> (h w) d') for patches in all_patches]
|
||||
|
||||
# handle token dropout
|
||||
|
||||
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
|
||||
|
||||
if self.training and self.token_dropout_prob > 0:
|
||||
|
||||
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
|
||||
|
||||
kept_tokens = []
|
||||
kept_positions = []
|
||||
|
||||
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
one_image_kept_tokens = one_image_tokens[keep_indices]
|
||||
one_image_kept_positions = one_image_positions[keep_indices]
|
||||
|
||||
kept_tokens.append(one_image_kept_tokens)
|
||||
kept_positions.append(one_image_kept_positions)
|
||||
|
||||
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
|
||||
|
||||
# add all height and width factorized positions
|
||||
|
||||
height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
|
||||
height_embed, width_embed = self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
|
||||
|
||||
pos_embed = height_embed + width_embed
|
||||
|
||||
# use nested tensor for transformers and save on padding computation
|
||||
|
||||
tokens = torch.cat(tokens)
|
||||
|
||||
# linear projection to patch embeddings
|
||||
|
||||
tokens = self.to_patch_embedding(tokens)
|
||||
|
||||
# absolute positions
|
||||
|
||||
tokens = tokens + pos_embed
|
||||
|
||||
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
|
||||
|
||||
# embedding dropout
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# transformer
|
||||
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# attention pooling
|
||||
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
|
||||
|
||||
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
|
||||
|
||||
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
|
||||
|
||||
pooled = self.attn_pool(attn_pool_queries, tokens)
|
||||
|
||||
# back to unjagged
|
||||
|
||||
logits = torch.stack(pooled.unbind())
|
||||
|
||||
logits = rearrange(logits, 'b 1 d -> b d')
|
||||
|
||||
logits = self.to_latent(logits)
|
||||
|
||||
return self.mlp_head(logits)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
assert v(images).shape == (5, 1000)
|
||||
|
||||
v(images).sum().backward()
|
||||
356
vit_pytorch/na_vit_nested_tensor_3d.py
Normal file
356
vit_pytorch/na_vit_nested_tensor_3d.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nested import nested_tensor
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
dim_inner = heads * dim_head
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_values = nn.Linear(dim, dim_inner, bias = False)
|
||||
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# for attention pooling, one query pooling to entire sequence
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
def transpose_head_seq(t):
|
||||
return t.transpose(1, 2)
|
||||
|
||||
query, key, value = map(split_heads, (query, key, value))
|
||||
|
||||
# qk norm for attention stability
|
||||
|
||||
query = self.query_norm(query)
|
||||
key = self.key_norm(key)
|
||||
|
||||
query, key, value = map(transpose_head_seq, (query, key, value))
|
||||
|
||||
# attention
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = out.transpose(1, 2).flatten(-2)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
max_frames,
|
||||
patch_size,
|
||||
frame_patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_registers = 4,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
assert divisible_by(max_frames, frame_patch_size)
|
||||
|
||||
patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size)
|
||||
|
||||
patch_dim = channels * (patch_size ** 2) * frame_patch_size
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
|
||||
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))
|
||||
|
||||
# register tokens
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))
|
||||
|
||||
nn.init.normal_(self.pos_embed_frame, std = 0.02)
|
||||
nn.init.normal_(self.pos_embed_height, std = 0.02)
|
||||
nn.init.normal_(self.pos_embed_width, std = 0.02)
|
||||
nn.init.normal_(self.register_tokens, std = 0.02)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
volumes: List[Tensor], # different resolution images / CT scans
|
||||
):
|
||||
batch, device = len(volumes), self.device
|
||||
arange = partial(torch.arange, device = device)
|
||||
|
||||
assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'
|
||||
|
||||
all_patches = [self.to_patches(volume) for volume in volumes]
|
||||
|
||||
# prepare factorized positional embedding height width indices
|
||||
|
||||
positions = []
|
||||
|
||||
for patches in all_patches:
|
||||
patch_frame, patch_height, patch_width = patches.shape[:3]
|
||||
fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
|
||||
fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c')
|
||||
|
||||
positions.append(fhw_indices)
|
||||
|
||||
# need the sizes to compute token dropout + positional embedding
|
||||
|
||||
tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches]
|
||||
|
||||
# handle token dropout
|
||||
|
||||
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
|
||||
|
||||
if self.training and self.token_dropout_prob > 0:
|
||||
|
||||
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
|
||||
|
||||
kept_tokens = []
|
||||
kept_positions = []
|
||||
|
||||
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
one_image_kept_tokens = one_image_tokens[keep_indices]
|
||||
one_image_kept_positions = one_image_positions[keep_indices]
|
||||
|
||||
kept_tokens.append(one_image_kept_tokens)
|
||||
kept_positions.append(one_image_kept_positions)
|
||||
|
||||
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
|
||||
|
||||
# add all height and width factorized positions
|
||||
|
||||
|
||||
frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
|
||||
frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
|
||||
|
||||
pos_embed = frame_embed + height_embed + width_embed
|
||||
|
||||
tokens = torch.cat(tokens)
|
||||
|
||||
# linear projection to patch embeddings
|
||||
|
||||
tokens = self.to_patch_embedding(tokens)
|
||||
|
||||
# absolute positions
|
||||
|
||||
tokens = tokens + pos_embed
|
||||
|
||||
# add register tokens
|
||||
|
||||
tokens = tokens.split(seq_lens.tolist())
|
||||
|
||||
tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]
|
||||
|
||||
# use nested tensor for transformers and save on padding computation
|
||||
|
||||
tokens = nested_tensor(tokens, layout = torch.jagged, device = device)
|
||||
|
||||
# embedding dropout
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# transformer
|
||||
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# attention pooling
|
||||
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
|
||||
|
||||
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
|
||||
|
||||
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
|
||||
|
||||
pooled = self.attn_pool(attn_pool_queries, tokens)
|
||||
|
||||
# back to unjagged
|
||||
|
||||
logits = torch.stack(pooled.unbind())
|
||||
|
||||
logits = rearrange(logits, 'b 1 d -> b d')
|
||||
|
||||
logits = self.to_latent(logits)
|
||||
|
||||
return self.mlp_head(logits)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# works for torch 2.5
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
max_frames = 8,
|
||||
patch_size = 32,
|
||||
frame_patch_size = 2,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]
|
||||
|
||||
volumes = [
|
||||
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
|
||||
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
|
||||
torch.randn(3, 4, 64, 256)
|
||||
]
|
||||
|
||||
assert v(volumes).shape == (5, 1000)
|
||||
|
||||
v(volumes).sum().backward()
|
||||
264
vit_pytorch/normalized_vit.py
Normal file
264
vit_pytorch/normalized_vit.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
from einops import rearrange, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
def l2norm(t, dim = -1):
|
||||
return F.normalize(t, dim = dim, p = 2)
|
||||
|
||||
# for use with parametrize
|
||||
|
||||
class L2Norm(Module):
|
||||
def __init__(self, dim = -1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t):
|
||||
return l2norm(t, dim = self.dim)
|
||||
|
||||
class NormLinear(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
norm_dim_in = True
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(dim, dim_out, bias = False)
|
||||
|
||||
parametrize.register_parametrization(
|
||||
self.linear,
|
||||
'weight',
|
||||
L2Norm(dim = -1 if norm_dim_in else 0)
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.linear.weight
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
# attention and feedforward
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.to_q = NormLinear(dim, dim_inner)
|
||||
self.to_k = NormLinear(dim, dim_inner)
|
||||
self.to_v = NormLinear(dim, dim_inner)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||
|
||||
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
):
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
q, k, v = map(self.split_heads, (q, k, v))
|
||||
|
||||
# query key rmsnorm
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
|
||||
q = q * self.q_scale
|
||||
k = k * self.k_scale
|
||||
|
||||
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p = self.dropout if self.training else 0.,
|
||||
scale = 1.
|
||||
)
|
||||
|
||||
out = self.merge_heads(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_inner,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = int(dim_inner * 2 / 3)
|
||||
|
||||
self.dim = dim
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_hidden = NormLinear(dim, dim_inner)
|
||||
self.to_gate = NormLinear(dim, dim_inner)
|
||||
|
||||
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
|
||||
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
|
||||
|
||||
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
|
||||
|
||||
def forward(self, x):
|
||||
hidden, gate = self.to_hidden(x), self.to_gate(x)
|
||||
|
||||
hidden = hidden * self.hidden_scale
|
||||
gate = gate * self.gate_scale * (self.dim ** 0.5)
|
||||
|
||||
hidden = F.silu(gate) * hidden
|
||||
|
||||
hidden = self.dropout(hidden)
|
||||
return self.to_out(hidden)
|
||||
|
||||
# classes
|
||||
|
||||
class nViT(Module):
|
||||
""" https://arxiv.org/abs/2410.01131 """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dropout = 0.,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
residual_lerp_scale_init = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
num_patches = patch_height_dim * patch_width_dim
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
|
||||
NormLinear(patch_dim, dim, norm_dim_in = False),
|
||||
)
|
||||
|
||||
self.abs_pos_emb = NormLinear(dim, num_patches)
|
||||
|
||||
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
|
||||
|
||||
# layers
|
||||
|
||||
self.dim = dim
|
||||
self.scale = dim ** 0.5
|
||||
|
||||
self.layers = ModuleList([])
|
||||
self.residual_lerp_scales = nn.ParameterList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout),
|
||||
FeedForward(dim, dim_inner = mlp_dim, dropout = dropout),
|
||||
]))
|
||||
|
||||
self.residual_lerp_scales.append(nn.ParameterList([
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
|
||||
]))
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.ones(num_classes))
|
||||
|
||||
self.to_pred = NormLinear(dim, num_classes)
|
||||
|
||||
@torch.no_grad()
|
||||
def norm_weights_(self):
|
||||
for module in self.modules():
|
||||
if not isinstance(module, NormLinear):
|
||||
continue
|
||||
|
||||
normed = module.weight
|
||||
original = module.linear.parametrizations.weight.original
|
||||
|
||||
original.copy_(normed)
|
||||
|
||||
def forward(self, images):
|
||||
device = images.device
|
||||
|
||||
tokens = self.to_patch_embedding(images)
|
||||
|
||||
seq_len = tokens.shape[-2]
|
||||
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
|
||||
|
||||
tokens = l2norm(tokens + pos_emb)
|
||||
|
||||
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
|
||||
|
||||
attn_out = l2norm(attn(tokens))
|
||||
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
|
||||
|
||||
ff_out = l2norm(ff(tokens))
|
||||
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
|
||||
|
||||
pooled = reduce(tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
logits = self.to_pred(pooled)
|
||||
logits = logits * self.logit_scale * self.scale
|
||||
|
||||
return logits
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = nViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
logits = v(img) # (4, 1000)
|
||||
assert logits.shape == (4, 1000)
|
||||
@@ -20,6 +20,18 @@ def divisible_by(val, d):
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
@@ -212,10 +224,10 @@ class RegionViT(nn.Module):
|
||||
if tokenize_local_3_conv:
|
||||
self.local_encoder = nn.Sequential(
|
||||
nn.Conv2d(3, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
ChanLayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
ChanLayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
|
||||
)
|
||||
|
||||
@@ -3,14 +3,14 @@ from math import sqrt, pi, log
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.amp import autocast
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
@autocast(enabled = False)
|
||||
@autocast('cuda', enabled = False)
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
@@ -24,7 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
@autocast(enabled = False)
|
||||
@autocast('cuda', enabled = False)
|
||||
def forward(self, x):
|
||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||
|
||||
|
||||
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
ViT + Hyper-Connections + Register Tokens
|
||||
https://arxiv.org/abs/2409.19606
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn, tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, einsum, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# hyper connections
|
||||
|
||||
class HyperConnection(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_residual_streams,
|
||||
layer_index
|
||||
):
|
||||
""" Appendix J - Algorithm 2, Dynamic only """
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
self.layer_index = layer_index
|
||||
|
||||
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
||||
|
||||
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
||||
init_alpha0[layer_index % num_residual_streams, 0] = 1.
|
||||
|
||||
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
||||
|
||||
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
||||
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
|
||||
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
||||
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
|
||||
|
||||
def width_connection(self, residuals):
|
||||
normed = self.norm(residuals)
|
||||
|
||||
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
|
||||
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
||||
alpha = dynamic_alpha + self.static_alpha
|
||||
|
||||
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
|
||||
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
||||
beta = dynamic_beta + self.static_beta
|
||||
|
||||
# width connection
|
||||
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
|
||||
|
||||
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
||||
|
||||
return branch_input, residuals, beta
|
||||
|
||||
def depth_connection(
|
||||
self,
|
||||
branch_output,
|
||||
residuals,
|
||||
beta
|
||||
):
|
||||
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
|
||||
super().__init__()
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for layer_index in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
|
||||
|
||||
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
|
||||
|
||||
x, attn_res, beta = attn_hyper_conn.width_connection(x)
|
||||
|
||||
x = attn(x)
|
||||
|
||||
x = attn_hyper_conn.depth_connection(x, attn_res, beta)
|
||||
|
||||
x, ff_res, beta = ff_hyper_conn.width_connection(x)
|
||||
|
||||
x = ff(x)
|
||||
|
||||
x = ff_hyper_conn.depth_connection(x, ff_res, beta)
|
||||
|
||||
x = reduce(x, 'b n e d -> b n d', 'sum')
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(x)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
num_residual_streams = 8
|
||||
)
|
||||
|
||||
images = torch.randn(3, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
159
vit_pytorch/simple_vit_with_value_residual.py
Normal file
159
vit_pytorch/simple_vit_with_value_residual.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, hidden_dim):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_residual_mix = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b n h -> b h n 1')
|
||||
) if learned_value_residual_mix else (lambda _: 0.5)
|
||||
|
||||
def forward(self, x, value_residual = None):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(value_residual):
|
||||
mix = self.to_residual_mix(x)
|
||||
v = v * mix + value_residual * (1. - mix)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.to_out(out), v
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for i in range(depth):
|
||||
is_first = i == 0
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
value_residual = None
|
||||
|
||||
for attn, ff in self.layers:
|
||||
|
||||
attn_out, values = attn(x, value_residual = value_residual)
|
||||
value_residual = default(value_residual, values)
|
||||
|
||||
x = attn_out + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
|
||||
logits = v(images)
|
||||
@@ -78,6 +78,30 @@ class Transformer(nn.Module):
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class FactorizedTransformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, f, n, _ = x.shape
|
||||
for spatial_attn, temporal_attn, ff in self.layers:
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
x = spatial_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
|
||||
x = temporal_attn(x) + x
|
||||
x = ff(x) + x
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -96,7 +120,8 @@ class ViT(nn.Module):
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
emb_dropout = 0.,
|
||||
variant = 'factorized_encoder',
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -104,6 +129,7 @@ class ViT(nn.Module):
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
num_frame_patches = (frames // frame_patch_size)
|
||||
@@ -125,15 +151,20 @@ class ViT(nn.Module):
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
if variant == 'factorized_encoder':
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
elif variant == 'factorized_self_attention':
|
||||
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.variant = variant
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
@@ -147,32 +178,37 @@ class ViT(nn.Module):
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
if self.variant == 'factorized_encoder':
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
x = self.spatial_transformer(x)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
# append temporal CLS tokens
|
||||
|
||||
# append temporal CLS tokens
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
# attend across time
|
||||
|
||||
# attend across time
|
||||
x = self.temporal_transformer(x)
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
elif self.variant == 'factorized_self_attention':
|
||||
x = self.factorized_transformer(x)
|
||||
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user