mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74b62009f8 | ||
|
|
f50d7d1436 | ||
|
|
82f2fa751d | ||
|
|
fcb9501cdd | ||
|
|
c4651a35a3 | ||
|
|
9d43e4d0bb | ||
|
|
5e808f48d1 | ||
|
|
bed48b5912 | ||
|
|
73199ab486 | ||
|
|
4f22eae631 | ||
|
|
dfc8df6713 | ||
|
|
9992a615d1 | ||
|
|
4b2c00cb63 | ||
|
|
ec6c48b8ff | ||
|
|
547bf94d07 | ||
|
|
bd72b58355 | ||
|
|
e3256d77cd | ||
|
|
90be7233a3 | ||
|
|
bca88e9039 |
1
.github/workflows/python-test.yml
vendored
1
.github/workflows/python-test.yml
vendored
@@ -28,6 +28,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
python -m pip install wheel
|
||||
python -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cpu
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
||||
74
README.md
74
README.md
@@ -198,6 +198,38 @@ preds = v(
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.na_vit_nested_tensor import NaViT
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
preds = v(images)
|
||||
|
||||
assert preds.shape == (5, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -1186,7 +1218,8 @@ pred = cct(video)
|
||||
|
||||
<img src="./images/vivit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant.
|
||||
The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -1202,7 +1235,8 @@ v = ViT(
|
||||
spatial_depth = 6, # depth of the spatial transformer
|
||||
temporal_depth = 6, # depth of the temporal transformer
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
mlp_dim = 2048,
|
||||
variant = 'factorized_encoder', # or 'factorized_self_attention'
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
@@ -2072,4 +2106,40 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Koner2024LookupViTCV,
|
||||
title = {LookupViT: Compressing visual information to a limited number of tokens},
|
||||
author = {Rajat Koner and Gagan Jain and Prateek Jain and Volker Tresp and Sujoy Paul},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:271244592}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Bao2022AllAW,
|
||||
title = {All are Worth Words: A ViT Backbone for Diffusion Models},
|
||||
author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu},
|
||||
journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year = {2022},
|
||||
pages = {22669-22679},
|
||||
url = {https://api.semanticscholar.org/CorpusID:253581703}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Rubin2024,
|
||||
author = {Ohad Rubin},
|
||||
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Loshchilov2024nGPTNT,
|
||||
title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere},
|
||||
author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:273026160}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
6
setup.py
6
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.6.7',
|
||||
version = '1.8.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
@@ -29,8 +29,8 @@ setup(
|
||||
],
|
||||
tests_require=[
|
||||
'pytest',
|
||||
'torch==1.12.1',
|
||||
'torchvision==0.13.1'
|
||||
'torch==2.4.0',
|
||||
'torchvision==0.19.0'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
@@ -12,6 +14,9 @@ from einops import rearrange, repeat
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
@@ -20,12 +25,12 @@ class DistillMixin:
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x)
|
||||
@@ -97,7 +102,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
|
||||
# knowledge distillation wrapper
|
||||
|
||||
class DistillWrapper(nn.Module):
|
||||
class DistillWrapper(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -105,7 +110,8 @@ class DistillWrapper(nn.Module):
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5,
|
||||
hard = False
|
||||
hard = False,
|
||||
mlp_layernorm = False
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
@@ -122,14 +128,14 @@ class DistillWrapper(nn.Module):
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
|
||||
b, *_ = img.shape
|
||||
alpha = alpha if exists(alpha) else self.alpha
|
||||
T = temperature if exists(temperature) else self.temperature
|
||||
|
||||
alpha = default(alpha, self.alpha)
|
||||
T = default(temperature, self.temperature)
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_logits = self.teacher(img)
|
||||
|
||||
278
vit_pytorch/look_vit.py
Normal file
278
vit_pytorch/look_vit.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import einsum, rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
# simple vit sinusoidal pos emb
|
||||
|
||||
def posemb_sincos_2d(t, temperature = 10000):
|
||||
h, w, d, device = *t.shape[1:], t.device
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (d % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(d // 4, device = device) / (d // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pos = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
|
||||
return pos.float()
|
||||
|
||||
# bias-less layernorm with unit offset trick (discovered by Ohad Rubin)
|
||||
|
||||
class LayerNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
||||
self.gamma = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
normed = self.ln(x)
|
||||
return normed * (self.gamma + 1)
|
||||
|
||||
# mlp
|
||||
|
||||
def MLP(dim, factor = 4, dropout = 0.):
|
||||
hidden_dim = int(dim * factor)
|
||||
return nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
cross_attend = False,
|
||||
reuse_attention = False
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.reuse_attention = reuse_attention
|
||||
self.cross_attend = cross_attend
|
||||
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
|
||||
self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity()
|
||||
self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity()
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
Rearrange('b h n d -> b n (h d)'),
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
return_qk_sim = False,
|
||||
qk_sim = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
assert not (exists(context) ^ self.cross_attend)
|
||||
|
||||
if self.cross_attend:
|
||||
context = self.norm_context(context)
|
||||
else:
|
||||
context = x
|
||||
|
||||
v = self.to_v(context)
|
||||
v = self.split_heads(v)
|
||||
|
||||
if not self.reuse_attention:
|
||||
qk = (self.to_q(x), self.to_k(context))
|
||||
q, k = tuple(self.split_heads(t) for t in qk)
|
||||
|
||||
q = q * self.scale
|
||||
qk_sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
||||
|
||||
else:
|
||||
assert exists(qk_sim), 'qk sim matrix must be passed in for reusing previous attention'
|
||||
|
||||
attn = self.attend(qk_sim)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
||||
out = self.to_out(out)
|
||||
|
||||
if not return_qk_sim:
|
||||
return out
|
||||
|
||||
return out, qk_sim
|
||||
|
||||
# LookViT
|
||||
|
||||
class LookViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
num_classes,
|
||||
depth = 3,
|
||||
patch_size = 16,
|
||||
heads = 8,
|
||||
mlp_factor = 4,
|
||||
dim_head = 64,
|
||||
highres_patch_size = 12,
|
||||
highres_mlp_factor = 4,
|
||||
cross_attn_heads = 8,
|
||||
cross_attn_dim_head = 64,
|
||||
patch_conv_kernel_size = 7,
|
||||
dropout = 0.1,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert divisible_by(image_size, highres_patch_size)
|
||||
assert divisible_by(image_size, patch_size)
|
||||
assert patch_size > highres_patch_size, 'patch size of the main vision transformer should be smaller than the highres patch sizes (that does the `lookup`)'
|
||||
assert not divisible_by(patch_conv_kernel_size, 2)
|
||||
|
||||
self.dim = dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
kernel_size = patch_conv_kernel_size
|
||||
patch_dim = (highres_patch_size * highres_patch_size) * channels
|
||||
|
||||
self.to_patches = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = highres_patch_size, p2 = highres_patch_size),
|
||||
nn.Conv2d(patch_dim, dim, kernel_size, padding = kernel_size // 2),
|
||||
Rearrange('b c h w -> b h w c'),
|
||||
LayerNorm(dim),
|
||||
)
|
||||
|
||||
# absolute positions
|
||||
|
||||
num_patches = (image_size // highres_patch_size) ** 2
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
|
||||
|
||||
# lookvit blocks
|
||||
|
||||
layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
layers.append(ModuleList([
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
|
||||
MLP(dim = dim, factor = mlp_factor, dropout = dropout),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True),
|
||||
LayerNorm(dim),
|
||||
MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.highres_norm = LayerNorm(dim)
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_classes, bias = False)
|
||||
|
||||
def forward(self, img):
|
||||
assert img.shape[-2:] == (self.image_size, self.image_size)
|
||||
|
||||
# to patch tokens and positions
|
||||
|
||||
highres_tokens = self.to_patches(img)
|
||||
size = highres_tokens.shape[-2]
|
||||
|
||||
pos_emb = posemb_sincos_2d(highres_tokens)
|
||||
highres_tokens = highres_tokens + rearrange(pos_emb, '(h w) d -> h w d', h = size)
|
||||
|
||||
tokens = F.interpolate(
|
||||
rearrange(highres_tokens, 'b h w d -> b d h w'),
|
||||
img.shape[-1] // self.patch_size,
|
||||
mode = 'bilinear'
|
||||
)
|
||||
|
||||
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
|
||||
highres_tokens = rearrange(highres_tokens, 'b h w c -> b (h w) c')
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
for attn, mlp, lookup_cross_attn, highres_attn, highres_norm, highres_mlp in self.layers:
|
||||
|
||||
# main tokens cross attends (lookup) on the high res tokens
|
||||
|
||||
lookup_out, qk_sim = lookup_cross_attn(tokens, highres_tokens, return_qk_sim = True) # return attention as they reuse the attention matrix
|
||||
tokens = lookup_out + tokens
|
||||
|
||||
tokens = attn(tokens) + tokens
|
||||
tokens = mlp(tokens) + tokens
|
||||
|
||||
# attention-reuse
|
||||
|
||||
qk_sim = rearrange(qk_sim, 'b h i j -> b h j i') # transpose for reverse cross attention
|
||||
|
||||
highres_tokens = highres_attn(highres_tokens, tokens, qk_sim = qk_sim) + highres_tokens
|
||||
highres_tokens = highres_norm(highres_tokens)
|
||||
|
||||
highres_tokens = highres_mlp(highres_tokens) + highres_tokens
|
||||
|
||||
# to logits
|
||||
|
||||
tokens = self.norm(tokens)
|
||||
highres_tokens = self.highres_norm(highres_tokens)
|
||||
|
||||
tokens = reduce(tokens, 'b n d -> b d', 'mean')
|
||||
highres_tokens = reduce(highres_tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
return self.to_logits(tokens + highres_tokens)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = LookViT(
|
||||
image_size = 256,
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 2,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
patch_size = 32,
|
||||
highres_patch_size = 8,
|
||||
highres_mlp_factor = 2,
|
||||
cross_attn_heads = 8,
|
||||
cross_attn_dim_head = 64,
|
||||
dropout = 0.1
|
||||
).cuda()
|
||||
|
||||
img = torch.randn(2, 3, 256, 256).cuda()
|
||||
pred = v(img)
|
||||
|
||||
assert pred.shape == (2, 1000)
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -245,7 +247,7 @@ class NaViT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
|
||||
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
@@ -264,6 +266,11 @@ class NaViT(nn.Module):
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
# if List[Tensor] is not grouped -> List[List[Tensor]]
|
||||
|
||||
if torch.is_tensor(batched_images[0]):
|
||||
batched_images = [batched_images]
|
||||
|
||||
# process images into variable lengthed sequences with attention mask
|
||||
|
||||
num_images = []
|
||||
|
||||
325
vit_pytorch/na_vit_nested_tensor.py
Normal file
325
vit_pytorch/na_vit_nested_tensor.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.4')
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nested import nested_tensor
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
dim_inner = heads * dim_head
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_values = nn.Linear(dim, dim_inner, bias = False)
|
||||
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
# for attention pooling, one query pooling to entire sequence
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
def transpose_head_seq(t):
|
||||
return t.transpose(1, 2)
|
||||
|
||||
query, key, value = map(split_heads, (query, key, value))
|
||||
|
||||
# qk norm for attention stability
|
||||
|
||||
query = self.query_norm(query)
|
||||
key = self.key_norm(key)
|
||||
|
||||
query, key, value = map(transpose_head_seq, (query, key, value))
|
||||
|
||||
# attention
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = out.transpose(1, 2).flatten(-2)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (h p1) (w p2) -> h w (c p1 p2)', p1 = patch_size, p2 = patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: List[Tensor], # different resolution images
|
||||
):
|
||||
batch, device = len(images), self.device
|
||||
arange = partial(torch.arange, device = device)
|
||||
|
||||
assert all([image.ndim == 3 and image.shape[0] == self.channels for image in images]), f'all images must have {self.channels} channels and number of dimensions of 3 (channels, height, width)'
|
||||
|
||||
all_patches = [self.to_patches(image) for image in images]
|
||||
|
||||
# prepare factorized positional embedding height width indices
|
||||
|
||||
positions = []
|
||||
|
||||
for patches in all_patches:
|
||||
patch_height, patch_width = patches.shape[:2]
|
||||
hw_indices = torch.stack(torch.meshgrid((arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
|
||||
hw_indices = rearrange(hw_indices, 'h w c -> (h w) c')
|
||||
positions.append(hw_indices)
|
||||
|
||||
# need the sizes to compute token dropout + positional embedding
|
||||
|
||||
tokens = [rearrange(patches, 'h w d -> (h w) d') for patches in all_patches]
|
||||
|
||||
# handle token dropout
|
||||
|
||||
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
|
||||
|
||||
if self.training and self.token_dropout_prob > 0:
|
||||
|
||||
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
|
||||
|
||||
kept_tokens = []
|
||||
kept_positions = []
|
||||
|
||||
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
one_image_kept_tokens = one_image_tokens[keep_indices]
|
||||
one_image_kept_positions = one_image_positions[keep_indices]
|
||||
|
||||
kept_tokens.append(one_image_kept_tokens)
|
||||
kept_positions.append(one_image_kept_positions)
|
||||
|
||||
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
|
||||
|
||||
# add all height and width factorized positions
|
||||
|
||||
height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
|
||||
height_embed, width_embed = self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
|
||||
|
||||
pos_embed = height_embed + width_embed
|
||||
|
||||
# use nested tensor for transformers and save on padding computation
|
||||
|
||||
tokens = torch.cat(tokens)
|
||||
|
||||
# linear projection to patch embeddings
|
||||
|
||||
tokens = self.to_patch_embedding(tokens)
|
||||
|
||||
# absolute positions
|
||||
|
||||
tokens = tokens + pos_embed
|
||||
|
||||
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
|
||||
|
||||
# embedding dropout
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# transformer
|
||||
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# attention pooling
|
||||
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
|
||||
|
||||
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
|
||||
|
||||
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
|
||||
|
||||
pooled = self.attn_pool(attn_pool_queries, tokens)
|
||||
|
||||
# back to unjagged
|
||||
|
||||
logits = torch.stack(pooled.unbind())
|
||||
|
||||
logits = rearrange(logits, 'b 1 d -> b d')
|
||||
|
||||
logits = self.to_latent(logits)
|
||||
|
||||
return self.mlp_head(logits)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
assert v(images).shape == (5, 1000)
|
||||
364
vit_pytorch/na_vit_nested_tensor_3d.py
Normal file
364
vit_pytorch/na_vit_nested_tensor_3d.py
Normal file
@@ -0,0 +1,364 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.4')
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nested import nested_tensor
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
dim_inner = heads * dim_head
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_values = nn.Linear(dim, dim_inner, bias = False)
|
||||
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# for attention pooling, one query pooling to entire sequence
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
def transpose_head_seq(t):
|
||||
return t.transpose(1, 2)
|
||||
|
||||
query, key, value = map(split_heads, (query, key, value))
|
||||
|
||||
# qk norm for attention stability
|
||||
|
||||
query = self.query_norm(query)
|
||||
key = self.key_norm(key)
|
||||
|
||||
query, key, value = map(transpose_head_seq, (query, key, value))
|
||||
|
||||
# attention
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = out.transpose(1, 2).flatten(-2)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
max_frames,
|
||||
patch_size,
|
||||
frame_patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_registers = 4,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
assert divisible_by(max_frames, frame_patch_size)
|
||||
|
||||
patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size)
|
||||
|
||||
patch_dim = channels * (patch_size ** 2) * frame_patch_size
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_frame = nn.Parameter(torch.zeros(patch_frame_dim, dim))
|
||||
self.pos_embed_height = nn.Parameter(torch.zeros(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.zeros(patch_width_dim, dim))
|
||||
|
||||
# register tokens
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.zeros(num_registers, dim))
|
||||
|
||||
nn.init.normal_(self.pos_embed_frame, std = 0.02)
|
||||
nn.init.normal_(self.pos_embed_height, std = 0.02)
|
||||
nn.init.normal_(self.pos_embed_width, std = 0.02)
|
||||
nn.init.normal_(self.register_tokens, std = 0.02)
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
volumes: List[Tensor], # different resolution images / CT scans
|
||||
):
|
||||
batch, device = len(volumes), self.device
|
||||
arange = partial(torch.arange, device = device)
|
||||
|
||||
assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'
|
||||
|
||||
all_patches = [self.to_patches(volume) for volume in volumes]
|
||||
|
||||
# prepare factorized positional embedding height width indices
|
||||
|
||||
positions = []
|
||||
|
||||
for patches in all_patches:
|
||||
patch_frame, patch_height, patch_width = patches.shape[:3]
|
||||
fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
|
||||
fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c')
|
||||
|
||||
positions.append(fhw_indices)
|
||||
|
||||
# need the sizes to compute token dropout + positional embedding
|
||||
|
||||
tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches]
|
||||
|
||||
# handle token dropout
|
||||
|
||||
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
|
||||
|
||||
if self.training and self.token_dropout_prob > 0:
|
||||
|
||||
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
|
||||
|
||||
kept_tokens = []
|
||||
kept_positions = []
|
||||
|
||||
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
one_image_kept_tokens = one_image_tokens[keep_indices]
|
||||
one_image_kept_positions = one_image_positions[keep_indices]
|
||||
|
||||
kept_tokens.append(one_image_kept_tokens)
|
||||
kept_positions.append(one_image_kept_positions)
|
||||
|
||||
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
|
||||
|
||||
# add all height and width factorized positions
|
||||
|
||||
|
||||
frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
|
||||
frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
|
||||
|
||||
pos_embed = frame_embed + height_embed + width_embed
|
||||
|
||||
tokens = torch.cat(tokens)
|
||||
|
||||
# linear projection to patch embeddings
|
||||
|
||||
tokens = self.to_patch_embedding(tokens)
|
||||
|
||||
# absolute positions
|
||||
|
||||
tokens = tokens + pos_embed
|
||||
|
||||
# add register tokens
|
||||
|
||||
tokens = tokens.split(seq_lens.tolist())
|
||||
|
||||
tokens = [torch.cat((self.register_tokens, one_tokens)) for one_tokens in tokens]
|
||||
|
||||
# use nested tensor for transformers and save on padding computation
|
||||
|
||||
tokens = nested_tensor(tokens, layout = torch.jagged, device = device)
|
||||
|
||||
# embedding dropout
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# transformer
|
||||
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# attention pooling
|
||||
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
|
||||
|
||||
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
|
||||
|
||||
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
|
||||
|
||||
pooled = self.attn_pool(attn_pool_queries, tokens)
|
||||
|
||||
# back to unjagged
|
||||
|
||||
logits = torch.stack(pooled.unbind())
|
||||
|
||||
logits = rearrange(logits, 'b 1 d -> b d')
|
||||
|
||||
logits = self.to_latent(logits)
|
||||
|
||||
return self.mlp_head(logits)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# works for torch 2.4
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
max_frames = 8,
|
||||
patch_size = 32,
|
||||
frame_patch_size = 2,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]
|
||||
|
||||
volumes = [
|
||||
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
|
||||
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
|
||||
torch.randn(3, 4, 64, 256)
|
||||
]
|
||||
|
||||
assert v(volumes).shape == (5, 1000)
|
||||
263
vit_pytorch/normalized_vit.py
Normal file
263
vit_pytorch/normalized_vit.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
from einops import rearrange, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
def l2norm(t, dim = -1):
|
||||
return F.normalize(t, dim = dim, p = 2)
|
||||
|
||||
# for use with parametrize
|
||||
|
||||
class L2Norm(Module):
|
||||
def __init__(self, dim = -1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t):
|
||||
return l2norm(t, dim = self.dim)
|
||||
|
||||
class NormLinear(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
norm_dim_in = True
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(dim, dim_out, bias = False)
|
||||
|
||||
parametrize.register_parametrization(
|
||||
self.linear,
|
||||
'weight',
|
||||
L2Norm(dim = -1 if norm_dim_in else 0)
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.linear.weight
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
# attention and feedforward
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.to_q = NormLinear(dim, dim_inner)
|
||||
self.to_k = NormLinear(dim, dim_inner)
|
||||
self.to_v = NormLinear(dim, dim_inner)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
|
||||
self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
|
||||
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||
|
||||
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
):
|
||||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
|
||||
|
||||
q = q * self.q_scale
|
||||
k = k * self.k_scale
|
||||
|
||||
q, k, v = map(self.split_heads, (q, k, v))
|
||||
|
||||
# query key rmsnorm
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
|
||||
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p = self.dropout if self.training else 0.,
|
||||
scale = 1.
|
||||
)
|
||||
|
||||
out = self.merge_heads(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_inner,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = int(dim_inner * 2 / 3)
|
||||
|
||||
self.dim = dim
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_hidden = NormLinear(dim, dim_inner)
|
||||
self.to_gate = NormLinear(dim, dim_inner)
|
||||
|
||||
self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
|
||||
self.gate_scale = nn.Parameter(torch.ones(dim_inner))
|
||||
|
||||
self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False)
|
||||
|
||||
def forward(self, x):
|
||||
hidden, gate = self.to_hidden(x), self.to_gate(x)
|
||||
|
||||
hidden = hidden * self.hidden_scale
|
||||
gate = gate * self.gate_scale * (self.dim ** 0.5)
|
||||
|
||||
hidden = F.silu(gate) * hidden
|
||||
|
||||
hidden = self.dropout(hidden)
|
||||
return self.to_out(hidden)
|
||||
|
||||
# classes
|
||||
|
||||
class nViT(Module):
|
||||
""" https://arxiv.org/abs/2410.01131 """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dropout = 0.,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
residual_lerp_scale_init = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
num_patches = patch_height_dim * patch_width_dim
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.abs_pos_emb = nn.Embedding(num_patches, dim)
|
||||
|
||||
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
|
||||
|
||||
# layers
|
||||
|
||||
self.dim = dim
|
||||
self.layers = ModuleList([])
|
||||
self.residual_lerp_scales = nn.ParameterList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout),
|
||||
FeedForward(dim, dim_inner = mlp_dim, dropout = dropout),
|
||||
]))
|
||||
|
||||
self.residual_lerp_scales.append(nn.ParameterList([
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
|
||||
]))
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.ones(num_classes))
|
||||
|
||||
self.to_pred = NormLinear(dim, num_classes)
|
||||
|
||||
@torch.no_grad()
|
||||
def norm_weights_(self):
|
||||
for module in self.modules():
|
||||
if not isinstance(module, NormLinear):
|
||||
continue
|
||||
|
||||
normed = module.weight
|
||||
original = module.linear.parametrizations.weight.original
|
||||
|
||||
original.copy_(normed)
|
||||
|
||||
def forward(self, images):
|
||||
device = images.device
|
||||
|
||||
tokens = self.to_patch_embedding(images)
|
||||
|
||||
pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device))
|
||||
|
||||
tokens = l2norm(tokens + pos_emb)
|
||||
|
||||
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
|
||||
|
||||
attn_out = l2norm(attn(tokens))
|
||||
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
|
||||
|
||||
ff_out = l2norm(ff(tokens))
|
||||
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
|
||||
|
||||
pooled = reduce(tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
logits = self.to_pred(pooled)
|
||||
logits = logits * self.logit_scale * (self.dim ** 0.5)
|
||||
|
||||
return logits
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = nViT(
|
||||
image_size = 256,
|
||||
patch_size = 16,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
img = torch.randn(4, 3, 256, 256)
|
||||
logits = v(img) # (4, 1000)
|
||||
assert logits.shape == (4, 1000)
|
||||
@@ -20,6 +20,18 @@ def divisible_by(val, d):
|
||||
|
||||
# helper classes
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
@@ -212,10 +224,10 @@ class RegionViT(nn.Module):
|
||||
if tokenize_local_3_conv:
|
||||
self.local_encoder = nn.Sequential(
|
||||
nn.Conv2d(3, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
ChanLayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
|
||||
nn.LayerNorm(init_dim),
|
||||
ChanLayerNorm(init_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
|
||||
)
|
||||
|
||||
@@ -3,12 +3,14 @@ from math import sqrt, pi, log
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.amp import autocast
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
@@ -22,6 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
@autocast('cuda', enabled = False)
|
||||
def forward(self, x):
|
||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||
|
||||
|
||||
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from packaging import version
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
|
||||
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(f, device = device),
|
||||
torch.arange(h, device = device),
|
||||
torch.arange(w, device = device),
|
||||
indexing = 'ij')
|
||||
|
||||
fourier_dim = dim // 6
|
||||
|
||||
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
|
||||
return pe.type(dtype)
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(Module):
|
||||
def __init__(self, use_flash = False, config: Config = Config(True, True, True)):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_flash = use_flash
|
||||
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v):
|
||||
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
|
||||
|
||||
if self.use_flash:
|
||||
return self.flash_attn(q, k, v)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = Attend(use_flash = use_flash)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash_attn)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, video):
|
||||
*_, h, w, dtype = *video.shape, video.dtype
|
||||
|
||||
x = self.to_patch_embedding(video)
|
||||
pe = posemb_sincos_3d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
176
vit_pytorch/simple_uvit.py
Normal file
176
vit_pytorch/simple_uvit.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, hidden_dim):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for layer in range(1, depth + 1):
|
||||
latter_half = layer >= (depth / 2 + 1)
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
nn.Linear(dim * 2, dim) if latter_half else None,
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
skips = []
|
||||
|
||||
for ind, (combine_skip, attn, ff) in enumerate(self.layers):
|
||||
layer = ind + 1
|
||||
first_half = layer <= (self.depth / 2)
|
||||
|
||||
if first_half:
|
||||
skips.append(x)
|
||||
|
||||
if exists(combine_skip):
|
||||
skip = skips.pop()
|
||||
skip_and_x = torch.cat((skip, x), dim = -1)
|
||||
x = combine_skip(skip_and_x)
|
||||
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
assert len(skips) == 0
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleUViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim
|
||||
)
|
||||
|
||||
self.register_buffer('pos_embedding', pos_embedding, persistent = False)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x = x + self.pos_embedding.type(x.dtype)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# quick test on odd number of layers
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = SimpleUViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 7,
|
||||
heads = 16,
|
||||
mlp_dim = 2048
|
||||
).cuda()
|
||||
|
||||
img = torch.randn(2, 3, 256, 256).cuda()
|
||||
|
||||
preds = v(img)
|
||||
assert preds.shape == (2, 1000)
|
||||
@@ -61,10 +61,7 @@ class T2TViT(nn.Module):
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
@@ -78,6 +78,30 @@ class Transformer(nn.Module):
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class FactorizedTransformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, f, n, _ = x.shape
|
||||
for spatial_attn, temporal_attn, ff in self.layers:
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
x = spatial_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
|
||||
x = temporal_attn(x) + x
|
||||
x = ff(x) + x
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -96,7 +120,8 @@ class ViT(nn.Module):
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
emb_dropout = 0.,
|
||||
variant = 'factorized_encoder',
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -104,6 +129,7 @@ class ViT(nn.Module):
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
num_frame_patches = (frames // frame_patch_size)
|
||||
@@ -125,15 +151,20 @@ class ViT(nn.Module):
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
if variant == 'factorized_encoder':
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
elif variant == 'factorized_self_attention':
|
||||
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.variant = variant
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
@@ -147,32 +178,37 @@ class ViT(nn.Module):
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
if self.variant == 'factorized_encoder':
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
x = self.spatial_transformer(x)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
# append temporal CLS tokens
|
||||
|
||||
# append temporal CLS tokens
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
# attend across time
|
||||
|
||||
# attend across time
|
||||
x = self.temporal_transformer(x)
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
elif self.variant == 'factorized_self_attention':
|
||||
x = self.factorized_transformer(x)
|
||||
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user