mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
771fb6daaf | ||
|
|
4f22eae631 | ||
|
|
dfc8df6713 | ||
|
|
9992a615d1 | ||
|
|
4b2c00cb63 | ||
|
|
ec6c48b8ff | ||
|
|
547bf94d07 | ||
|
|
bd72b58355 | ||
|
|
e3256d77cd | ||
|
|
90be7233a3 | ||
|
|
bca88e9039 | ||
|
|
96f66d2754 | ||
|
|
12249dcc5f | ||
|
|
8b8da8dede | ||
|
|
5578ac472f | ||
|
|
d446a41243 | ||
|
|
0ad09c4cbc | ||
|
|
92b69321f4 | ||
|
|
fb4ac25174 | ||
|
|
53fe345e85 | ||
|
|
efb94608ea | ||
|
|
51310d1d07 | ||
|
|
1616288e30 | ||
|
|
9e1e824385 | ||
|
|
bbb24e34d4 | ||
|
|
df8733d86e | ||
|
|
680d446e46 | ||
|
|
3fdb8dd352 |
25
.github/workflows/python-publish.yml
vendored
25
.github/workflows/python-publish.yml
vendored
@@ -1,11 +1,16 @@
|
||||
# This workflows will upload a Python Package using Twine when a release is created
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
@@ -21,11 +26,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
run: |
|
||||
python setup.py sdist bdist_wheel
|
||||
twine upload dist/*
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
||||
2
.github/workflows/python-test.yml
vendored
2
.github/workflows/python-test.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
python-version: [3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
104
README.md
104
README.md
@@ -25,6 +25,7 @@
|
||||
- [MaxViT](#maxvit)
|
||||
- [NesT](#nest)
|
||||
- [MobileViT](#mobilevit)
|
||||
- [XCiT](#xcit)
|
||||
- [Masked Autoencoder](#masked-autoencoder)
|
||||
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
|
||||
- [Masked Patch Prediction](#masked-patch-prediction)
|
||||
@@ -92,7 +93,7 @@ preds = v(img) # (1, 1000)
|
||||
- `image_size`: int.
|
||||
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
|
||||
- `patch_size`: int.
|
||||
Number of patches. `image_size` must be divisible by `patch_size`.
|
||||
Size of patches. `image_size` must be divisible by `patch_size`.
|
||||
The number of patches is: ` n = (image_size // patch_size) ** 2` and `n` **must be greater than 16**.
|
||||
- `num_classes`: int.
|
||||
Number of classes to classify.
|
||||
@@ -197,6 +198,38 @@ preds = v(
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.na_vit_nested_tensor import NaViT
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
preds = v(images)
|
||||
|
||||
assert preds.shape == (5, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -772,6 +805,38 @@ img = torch.randn(1, 3, 256, 256)
|
||||
pred = mbvit_xs(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## XCiT
|
||||
|
||||
<img src="./images/xcit.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2106.09681">paper</a> introduces the cross covariance attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).
|
||||
|
||||
Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.xcit import XCiT
|
||||
|
||||
v = XCiT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 12, # depth of xcit transformer
|
||||
cls_depth = 2, # depth of cross attention of CLS tokens to patch, attention pool at end
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1,
|
||||
layer_dropout = 0.05, # randomly dropout 5% of the layers
|
||||
local_patch_kernel_size = 3 # kernel size of the local patch interaction module (depthwise convs)
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Simple Masked Image Modeling
|
||||
|
||||
<img src="./images/simmim.png" width="400px"/>
|
||||
@@ -2029,4 +2094,41 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{ElNouby2021XCiTCI,
|
||||
title = {XCiT: Cross-Covariance Image Transformers},
|
||||
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
|
||||
booktitle = {Neural Information Processing Systems},
|
||||
year = {2021},
|
||||
url = {https://api.semanticscholar.org/CorpusID:235458262}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Koner2024LookupViTCV,
|
||||
title = {LookupViT: Compressing visual information to a limited number of tokens},
|
||||
author = {Rajat Koner and Gagan Jain and Prateek Jain and Volker Tresp and Sujoy Paul},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:271244592}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Bao2022AllAW,
|
||||
title = {All are Worth Words: A ViT Backbone for Diffusion Models},
|
||||
author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu},
|
||||
journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
year = {2022},
|
||||
pages = {22669-22679},
|
||||
url = {https://api.semanticscholar.org/CorpusID:253581703}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Rubin2024,
|
||||
author = {Ohad Rubin},
|
||||
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
BIN
images/xcit.png
Normal file
BIN
images/xcit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 814 KiB |
8
setup.py
8
setup.py
@@ -1,11 +1,15 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
with open('README.md') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.5.0 ',
|
||||
version = '1.7.6',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
@@ -16,7 +20,7 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.6.1',
|
||||
'einops>=0.7.0',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
||||
from einops._torch_specific import allow_ops_in_compiled_graph
|
||||
allow_ops_in_compiled_graph()
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
|
||||
@@ -170,12 +170,13 @@ class ImageEmbedder(nn.Module):
|
||||
dim,
|
||||
image_size,
|
||||
patch_size,
|
||||
dropout = 0.
|
||||
dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
@@ -223,11 +224,12 @@ class CrossViT(nn.Module):
|
||||
cross_attn_dim_head = 64,
|
||||
depth = 3,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
emb_dropout = 0.1,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
self.sm_image_embedder = ImageEmbedder(dim = sm_dim, channels= channels, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout)
|
||||
self.lg_image_embedder = ImageEmbedder(dim = lg_dim, channels = channels, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout)
|
||||
|
||||
self.multi_scale_encoder = MultiScaleEncoder(
|
||||
depth = depth,
|
||||
|
||||
@@ -140,12 +140,13 @@ class CvT(nn.Module):
|
||||
s3_heads = 6,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
dropout = 0.,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = dict(locals())
|
||||
|
||||
dim = 3
|
||||
dim = channels
|
||||
layers = []
|
||||
|
||||
for prefix in ('s1', 's2', 's3'):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
@@ -12,6 +14,9 @@ from einops import rearrange, repeat
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# classes
|
||||
|
||||
class DistillMixin:
|
||||
@@ -20,12 +25,12 @@ class DistillMixin:
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x)
|
||||
@@ -97,7 +102,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
|
||||
# knowledge distillation wrapper
|
||||
|
||||
class DistillWrapper(nn.Module):
|
||||
class DistillWrapper(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -105,7 +110,8 @@ class DistillWrapper(nn.Module):
|
||||
student,
|
||||
temperature = 1.,
|
||||
alpha = 0.5,
|
||||
hard = False
|
||||
hard = False,
|
||||
mlp_layernorm = False
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
@@ -122,14 +128,14 @@ class DistillWrapper(nn.Module):
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
|
||||
b, *_ = img.shape
|
||||
alpha = alpha if exists(alpha) else self.alpha
|
||||
T = temperature if exists(temperature) else self.temperature
|
||||
|
||||
alpha = default(alpha, self.alpha)
|
||||
T = default(temperature, self.temperature)
|
||||
|
||||
with torch.no_grad():
|
||||
teacher_logits = self.teacher(img)
|
||||
|
||||
278
vit_pytorch/look_vit.py
Normal file
278
vit_pytorch/look_vit.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import einsum, rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
# simple vit sinusoidal pos emb
|
||||
|
||||
def posemb_sincos_2d(t, temperature = 10000):
|
||||
h, w, d, device = *t.shape[1:], t.device
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (d % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(d // 4, device = device) / (d // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pos = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
|
||||
return pos.float()
|
||||
|
||||
# bias-less layernorm with unit offset trick (discovered by Ohad Rubin)
|
||||
|
||||
class LayerNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
||||
self.gamma = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
normed = self.ln(x)
|
||||
return normed * (self.gamma + 1)
|
||||
|
||||
# mlp
|
||||
|
||||
def MLP(dim, factor = 4, dropout = 0.):
|
||||
hidden_dim = int(dim * factor)
|
||||
return nn.Sequential(
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# attention
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
cross_attend = False,
|
||||
reuse_attention = False
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.reuse_attention = reuse_attention
|
||||
self.cross_attend = cross_attend
|
||||
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
|
||||
self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity()
|
||||
self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity()
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
Rearrange('b h n d -> b n (h d)'),
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
return_qk_sim = False,
|
||||
qk_sim = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
assert not (exists(context) ^ self.cross_attend)
|
||||
|
||||
if self.cross_attend:
|
||||
context = self.norm_context(context)
|
||||
else:
|
||||
context = x
|
||||
|
||||
v = self.to_v(context)
|
||||
v = self.split_heads(v)
|
||||
|
||||
if not self.reuse_attention:
|
||||
qk = (self.to_q(x), self.to_k(context))
|
||||
q, k = tuple(self.split_heads(t) for t in qk)
|
||||
|
||||
q = q * self.scale
|
||||
qk_sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
|
||||
|
||||
else:
|
||||
assert exists(qk_sim), 'qk sim matrix must be passed in for reusing previous attention'
|
||||
|
||||
attn = self.attend(qk_sim)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
||||
out = self.to_out(out)
|
||||
|
||||
if not return_qk_sim:
|
||||
return out
|
||||
|
||||
return out, qk_sim
|
||||
|
||||
# LookViT
|
||||
|
||||
class LookViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
num_classes,
|
||||
depth = 3,
|
||||
patch_size = 16,
|
||||
heads = 8,
|
||||
mlp_factor = 4,
|
||||
dim_head = 64,
|
||||
highres_patch_size = 12,
|
||||
highres_mlp_factor = 4,
|
||||
cross_attn_heads = 8,
|
||||
cross_attn_dim_head = 64,
|
||||
patch_conv_kernel_size = 7,
|
||||
dropout = 0.1,
|
||||
channels = 3
|
||||
):
|
||||
super().__init__()
|
||||
assert divisible_by(image_size, highres_patch_size)
|
||||
assert divisible_by(image_size, patch_size)
|
||||
assert patch_size > highres_patch_size, 'patch size of the main vision transformer should be smaller than the highres patch sizes (that does the `lookup`)'
|
||||
assert not divisible_by(patch_conv_kernel_size, 2)
|
||||
|
||||
self.dim = dim
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
kernel_size = patch_conv_kernel_size
|
||||
patch_dim = (highres_patch_size * highres_patch_size) * channels
|
||||
|
||||
self.to_patches = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = highres_patch_size, p2 = highres_patch_size),
|
||||
nn.Conv2d(patch_dim, dim, kernel_size, padding = kernel_size // 2),
|
||||
Rearrange('b c h w -> b h w c'),
|
||||
LayerNorm(dim),
|
||||
)
|
||||
|
||||
# absolute positions
|
||||
|
||||
num_patches = (image_size // highres_patch_size) ** 2
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
|
||||
|
||||
# lookvit blocks
|
||||
|
||||
layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
layers.append(ModuleList([
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
|
||||
MLP(dim = dim, factor = mlp_factor, dropout = dropout),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True),
|
||||
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True),
|
||||
LayerNorm(dim),
|
||||
MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.highres_norm = LayerNorm(dim)
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_classes, bias = False)
|
||||
|
||||
def forward(self, img):
|
||||
assert img.shape[-2:] == (self.image_size, self.image_size)
|
||||
|
||||
# to patch tokens and positions
|
||||
|
||||
highres_tokens = self.to_patches(img)
|
||||
size = highres_tokens.shape[-2]
|
||||
|
||||
pos_emb = posemb_sincos_2d(highres_tokens)
|
||||
highres_tokens = highres_tokens + rearrange(pos_emb, '(h w) d -> h w d', h = size)
|
||||
|
||||
tokens = F.interpolate(
|
||||
rearrange(highres_tokens, 'b h w d -> b d h w'),
|
||||
img.shape[-1] // self.patch_size,
|
||||
mode = 'bilinear'
|
||||
)
|
||||
|
||||
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
|
||||
highres_tokens = rearrange(highres_tokens, 'b h w c -> b (h w) c')
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
for attn, mlp, lookup_cross_attn, highres_attn, highres_norm, highres_mlp in self.layers:
|
||||
|
||||
# main tokens cross attends (lookup) on the high res tokens
|
||||
|
||||
lookup_out, qk_sim = lookup_cross_attn(tokens, highres_tokens, return_qk_sim = True) # return attention as they reuse the attention matrix
|
||||
tokens = lookup_out + tokens
|
||||
|
||||
tokens = attn(tokens) + tokens
|
||||
tokens = mlp(tokens) + tokens
|
||||
|
||||
# attention-reuse
|
||||
|
||||
qk_sim = rearrange(qk_sim, 'b h i j -> b h j i') # transpose for reverse cross attention
|
||||
|
||||
highres_tokens = highres_attn(highres_tokens, tokens, qk_sim = qk_sim) + highres_tokens
|
||||
highres_tokens = highres_norm(highres_tokens)
|
||||
|
||||
highres_tokens = highres_mlp(highres_tokens) + highres_tokens
|
||||
|
||||
# to logits
|
||||
|
||||
tokens = self.norm(tokens)
|
||||
highres_tokens = self.highres_norm(highres_tokens)
|
||||
|
||||
tokens = reduce(tokens, 'b n d -> b d', 'mean')
|
||||
highres_tokens = reduce(highres_tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
return self.to_logits(tokens + highres_tokens)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = LookViT(
|
||||
image_size = 256,
|
||||
num_classes = 1000,
|
||||
dim = 512,
|
||||
depth = 2,
|
||||
heads = 8,
|
||||
dim_head = 64,
|
||||
patch_size = 32,
|
||||
highres_patch_size = 8,
|
||||
highres_mlp_factor = 2,
|
||||
cross_attn_heads = 8,
|
||||
cross_attn_dim_head = 64,
|
||||
dropout = 0.1
|
||||
).cuda()
|
||||
|
||||
img = torch.randn(2, 3, 256, 256).cuda()
|
||||
pred = v(img)
|
||||
|
||||
assert pred.shape == (2, 1000)
|
||||
@@ -173,7 +173,7 @@ class Attention(nn.Module):
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
|
||||
340
vit_pytorch/max_vit_with_registers.py
Normal file
340
vit_pytorch/max_vit_with_registers.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Sequential
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pack_one(x, pattern):
|
||||
return pack([x], pattern)
|
||||
|
||||
def unpack_one(x, ps, pattern):
|
||||
return unpack(x, ps, pattern)[0]
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(dim * mult)
|
||||
return Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# MBConv
|
||||
|
||||
class SqueezeExcitation(Module):
|
||||
def __init__(self, dim, shrinkage_rate = 0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, hidden_dim, bias = False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias = False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b c -> b c 1 1')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
class MBConvResidual(Module):
|
||||
def __init__(self, fn, dropout = 0.):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
class Dropsample(Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0. or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
def MBConv(
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
downsample,
|
||||
expansion_rate = 4,
|
||||
shrinkage_rate = 0.25,
|
||||
dropout = 0.
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout = dropout)
|
||||
|
||||
return net
|
||||
|
||||
# attention related classes
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head = 32,
|
||||
dropout = 0.,
|
||||
window_size = 7,
|
||||
num_registers = 1
|
||||
):
|
||||
super().__init__()
|
||||
assert num_registers > 0
|
||||
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias = False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
|
||||
num_rel_pos_bias = (2 * window_size - 1) ** 2
|
||||
|
||||
self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
||||
|
||||
rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
|
||||
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add positional bias
|
||||
|
||||
bias = self.rel_pos_bias(bias_indices)
|
||||
sim = sim + rearrange(bias, 'i j h -> h i j')
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class MaxViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
dim_head = 32,
|
||||
dim_conv_stem = None,
|
||||
window_size = 7,
|
||||
mbconv_expansion_rate = 4,
|
||||
mbconv_shrinkage_rate = 0.25,
|
||||
dropout = 0.1,
|
||||
channels = 3,
|
||||
num_register_tokens = 4
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
assert num_register_tokens > 0
|
||||
|
||||
# convolutional stem
|
||||
|
||||
dim_conv_stem = default(dim_conv_stem, dim)
|
||||
|
||||
self.conv_stem = Sequential(
|
||||
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
|
||||
)
|
||||
|
||||
# variables
|
||||
|
||||
num_stages = len(depth)
|
||||
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
dims = (dim_conv_stem, *dims)
|
||||
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
# window size
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
self.register_tokens = nn.ParameterList([])
|
||||
|
||||
# iterate through stages
|
||||
|
||||
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
|
||||
for stage_ind in range(layer_depth):
|
||||
is_first = stage_ind == 0
|
||||
stage_dim_in = layer_dim_in if is_first else layer_dim
|
||||
|
||||
conv = MBConv(
|
||||
stage_dim_in,
|
||||
layer_dim,
|
||||
downsample = is_first,
|
||||
expansion_rate = mbconv_expansion_rate,
|
||||
shrinkage_rate = mbconv_shrinkage_rate
|
||||
)
|
||||
|
||||
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
|
||||
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
|
||||
|
||||
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
|
||||
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
|
||||
|
||||
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
conv,
|
||||
ModuleList([block_attn, block_ff]),
|
||||
ModuleList([grid_attn, grid_ff])
|
||||
]))
|
||||
|
||||
self.register_tokens.append(register_tokens)
|
||||
|
||||
# mlp head out
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, w = x.shape[0], self.window_size
|
||||
|
||||
x = self.conv_stem(x)
|
||||
|
||||
for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
|
||||
x = conv(x)
|
||||
|
||||
# block-like attention
|
||||
|
||||
x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)
|
||||
|
||||
# prepare register tokens
|
||||
|
||||
r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
|
||||
r, register_batch_ps = pack_one(r, '* n d')
|
||||
|
||||
x, window_ps = pack_one(x, 'b x y * d')
|
||||
x, batch_ps = pack_one(x, '* n d')
|
||||
x, register_ps = pack([r, x], 'b * d')
|
||||
|
||||
x = block_attn(x) + x
|
||||
x = block_ff(x) + x
|
||||
|
||||
r, x = unpack(x, register_ps, 'b * d')
|
||||
|
||||
x = unpack_one(x, batch_ps, '* n d')
|
||||
x = unpack_one(x, window_ps, 'b x y * d')
|
||||
x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')
|
||||
|
||||
r = unpack_one(r, register_batch_ps, '* n d')
|
||||
|
||||
# grid-like attention
|
||||
|
||||
x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)
|
||||
|
||||
# prepare register tokens
|
||||
|
||||
r = reduce(r, 'b x y n d -> b n d', 'mean')
|
||||
r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
|
||||
r, register_batch_ps = pack_one(r, '* n d')
|
||||
|
||||
x, window_ps = pack_one(x, 'b x y * d')
|
||||
x, batch_ps = pack_one(x, '* n d')
|
||||
x, register_ps = pack([r, x], 'b * d')
|
||||
|
||||
x = grid_attn(x) + x
|
||||
|
||||
r, x = unpack(x, register_ps, 'b * d')
|
||||
|
||||
x = grid_ff(x) + x
|
||||
|
||||
x = unpack_one(x, batch_ps, '* n d')
|
||||
x = unpack_one(x, window_ps, 'b x y * d')
|
||||
x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')
|
||||
|
||||
return self.mlp_head(x)
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -198,7 +200,7 @@ class NaViT(nn.Module):
|
||||
self.calc_token_dropout = token_dropout_prob
|
||||
|
||||
elif isinstance(token_dropout_prob, (float, int)):
|
||||
assert 0. < token_dropout_prob < 1.
|
||||
assert 0. <= token_dropout_prob < 1.
|
||||
token_dropout_prob = float(token_dropout_prob)
|
||||
self.calc_token_dropout = lambda height, width: token_dropout_prob
|
||||
|
||||
@@ -245,11 +247,11 @@ class NaViT(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
|
||||
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
|
||||
group_images = False,
|
||||
group_max_seq_len = 2048
|
||||
):
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
|
||||
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) and self.training
|
||||
|
||||
arange = partial(torch.arange, device = device)
|
||||
pad_sequence = partial(orig_pad_sequence, batch_first = True)
|
||||
@@ -260,10 +262,15 @@ class NaViT(nn.Module):
|
||||
batched_images = group_images_by_max_seq_len(
|
||||
batched_images,
|
||||
patch_size = self.patch_size,
|
||||
calc_token_dropout = self.calc_token_dropout,
|
||||
calc_token_dropout = self.calc_token_dropout if self.training else None,
|
||||
max_seq_len = group_max_seq_len
|
||||
)
|
||||
|
||||
# if List[Tensor] is not grouped -> List[List[Tensor]]
|
||||
|
||||
if torch.is_tensor(batched_images[0]):
|
||||
batched_images = [batched_images]
|
||||
|
||||
# process images into variable lengthed sequences with attention mask
|
||||
|
||||
num_images = []
|
||||
@@ -314,8 +321,8 @@ class NaViT(nn.Module):
|
||||
# derive key padding mask
|
||||
|
||||
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
|
||||
max_length = arange(lengths.amax().item())
|
||||
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
|
||||
seq_arange = arange(lengths.amax().item())
|
||||
key_pad_mask = rearrange(seq_arange, 'n -> 1 n') < rearrange(lengths, 'b -> b 1')
|
||||
|
||||
# derive attention mask, and combine with key padding mask from above
|
||||
|
||||
|
||||
323
vit_pytorch/na_vit_nested_tensor.py
Normal file
323
vit_pytorch/na_vit_nested_tensor.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
assert pkg_version.parse(torch.__version__) >= pkg_version.parse('2.4'), 'install pytorch 2.4 or greater to use this flavor of NaViT'
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nested import nested_tensor
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
dim_inner = heads * dim_head
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
|
||||
self.to_values = nn.Linear(dim, dim_inner, bias = False)
|
||||
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context: Tensor | None = None
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
# for attention pooling, one query pooling to entire sequence
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
def transpose_head_seq(t):
|
||||
return t.transpose(1, 2)
|
||||
|
||||
query, key, value = map(split_heads, (query, key, value))
|
||||
|
||||
# qk norm for attention stability
|
||||
|
||||
query = self.query_norm(query)
|
||||
key = self.key_norm(key)
|
||||
|
||||
query, key, value = map(transpose_head_seq, (query, key, value))
|
||||
|
||||
# attention
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = out.transpose(1, 2).flatten(-2)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class NaViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
|
||||
self.token_dropout_prob = token_dropout_prob
|
||||
|
||||
# calculate patching related stuff
|
||||
|
||||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
|
||||
patch_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (h p1) (w p2) -> h w (c p1 p2)', p1 = patch_size, p2 = patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
|
||||
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
|
||||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
|
||||
|
||||
# output to logits
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim, bias = False),
|
||||
nn.Linear(dim, num_classes, bias = False)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: List[Tensor], # different resolution images
|
||||
):
|
||||
batch, device = len(images), self.device
|
||||
arange = partial(torch.arange, device = device)
|
||||
|
||||
assert all([image.ndim == 3 and image.shape[0] == self.channels for image in images]), f'all images must have {self.channels} channels and number of dimensions of 3 (channels, height, width)'
|
||||
|
||||
all_patches = [self.to_patches(image) for image in images]
|
||||
|
||||
# prepare factorized positional embedding height width indices
|
||||
|
||||
positions = []
|
||||
|
||||
for patches in all_patches:
|
||||
patch_height, patch_width = patches.shape[:2]
|
||||
hw_indices = torch.stack(torch.meshgrid((arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
|
||||
hw_indices = rearrange(hw_indices, 'h w c -> (h w) c')
|
||||
positions.append(hw_indices)
|
||||
|
||||
# need the sizes to compute token dropout + positional embedding
|
||||
|
||||
tokens = [rearrange(patches, 'h w d -> (h w) d') for patches in all_patches]
|
||||
|
||||
# handle token dropout
|
||||
|
||||
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
|
||||
|
||||
if self.training and self.token_dropout_prob > 0:
|
||||
|
||||
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
|
||||
|
||||
kept_tokens = []
|
||||
kept_positions = []
|
||||
|
||||
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
|
||||
one_image_kept_tokens = one_image_tokens[keep_indices]
|
||||
one_image_kept_positions = one_image_positions[keep_indices]
|
||||
|
||||
kept_tokens.append(one_image_kept_tokens)
|
||||
kept_positions.append(one_image_kept_positions)
|
||||
|
||||
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
|
||||
|
||||
# add all height and width factorized positions
|
||||
|
||||
height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
|
||||
height_embed, width_embed = self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
|
||||
|
||||
pos_embed = height_embed + width_embed
|
||||
|
||||
# use nested tensor for transformers and save on padding computation
|
||||
|
||||
tokens = torch.cat(tokens)
|
||||
|
||||
# linear projection to patch embeddings
|
||||
|
||||
tokens = self.to_patch_embedding(tokens)
|
||||
|
||||
# absolute positions
|
||||
|
||||
tokens = tokens + pos_embed
|
||||
|
||||
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
|
||||
|
||||
# embedding dropout
|
||||
|
||||
tokens = self.dropout(tokens)
|
||||
|
||||
# transformer
|
||||
|
||||
tokens = self.transformer(tokens)
|
||||
|
||||
# attention pooling
|
||||
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
|
||||
|
||||
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
|
||||
|
||||
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
|
||||
|
||||
pooled = self.attn_pool(attn_pool_queries, tokens)
|
||||
|
||||
# back to unjagged
|
||||
|
||||
logits = torch.stack(pooled.unbind())
|
||||
|
||||
logits = rearrange(logits, 'b 1 d -> b d')
|
||||
|
||||
logits = self.to_latent(logits)
|
||||
|
||||
return self.mlp_head(logits)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
token_dropout_prob = 0.1
|
||||
)
|
||||
|
||||
# 5 images of different resolutions - List[Tensor]
|
||||
|
||||
images = [
|
||||
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
|
||||
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
|
||||
torch.randn(3, 64, 256)
|
||||
]
|
||||
|
||||
assert v(images).shape == (5, 1000)
|
||||
@@ -3,12 +3,14 @@ from math import sqrt, pi, log
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
@autocast(enabled = False)
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
@@ -22,6 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
|
||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||
self.register_buffer('scales', scales)
|
||||
|
||||
@autocast(enabled = False)
|
||||
def forward(self, x):
|
||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||
|
||||
|
||||
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
171
vit_pytorch/simple_flash_attn_vit_3d.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from packaging import version
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# constants
|
||||
|
||||
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(f, device = device),
|
||||
torch.arange(h, device = device),
|
||||
torch.arange(w, device = device),
|
||||
indexing = 'ij')
|
||||
|
||||
fourier_dim = dim // 6
|
||||
|
||||
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
|
||||
return pe.type(dtype)
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(Module):
|
||||
def __init__(self, use_flash = False, config: Config = Config(True, True, True)):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.use_flash = use_flash
|
||||
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
# flash attention - https://arxiv.org/abs/2205.14135
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**self.config._asdict()):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v):
|
||||
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
|
||||
|
||||
if self.use_flash:
|
||||
return self.flash_attn(q, k, v)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = Attend(use_flash = use_flash)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash_attn = True):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash_attn)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, video):
|
||||
*_, h, w, dtype = *video.shape, video.dtype
|
||||
|
||||
x = self.to_patch_embedding(video)
|
||||
pe = posemb_sincos_3d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
176
vit_pytorch/simple_uvit.py
Normal file
176
vit_pytorch/simple_uvit.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, hidden_dim):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for layer in range(1, depth + 1):
|
||||
latter_half = layer >= (depth / 2 + 1)
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
nn.Linear(dim * 2, dim) if latter_half else None,
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
skips = []
|
||||
|
||||
for ind, (combine_skip, attn, ff) in enumerate(self.layers):
|
||||
layer = ind + 1
|
||||
first_half = layer <= (self.depth / 2)
|
||||
|
||||
if first_half:
|
||||
skips.append(x)
|
||||
|
||||
if exists(combine_skip):
|
||||
skip = skips.pop()
|
||||
skip_and_x = torch.cat((skip, x), dim = -1)
|
||||
x = combine_skip(skip_and_x)
|
||||
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
assert len(skips) == 0
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleUViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim
|
||||
)
|
||||
|
||||
self.register_buffer('pos_embedding', pos_embedding, persistent = False)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x = x + self.pos_embedding.type(x.dtype)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# quick test on odd number of layers
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = SimpleUViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 7,
|
||||
heads = 16,
|
||||
mlp_dim = 2048
|
||||
).cuda()
|
||||
|
||||
img = torch.randn(2, 3, 256, 256).cuda()
|
||||
|
||||
preds = v(img)
|
||||
assert preds.shape == (2, 1000)
|
||||
162
vit_pytorch/simple_vit_with_fft.py
Normal file
162
vit_pytorch/simple_vit_with_fft.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import torch
|
||||
from torch.fft import fft2
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, freq_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
freq_patch_height, freq_patch_width = pair(freq_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert image_height % freq_patch_height == 0 and image_width % freq_patch_width == 0, 'Image dimensions must be divisible by the freq patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
freq_patch_dim = channels * 2 * freq_patch_height * freq_patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.to_freq_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) ri -> b (h w) (p1 p2 ri c)", p1 = freq_patch_height, p2 = freq_patch_width),
|
||||
nn.LayerNorm(freq_patch_dim),
|
||||
nn.Linear(freq_patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.freq_pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // freq_patch_height,
|
||||
w = image_width // freq_patch_width,
|
||||
dim = dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device, dtype = img.device, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
freqs = torch.view_as_real(fft2(img))
|
||||
|
||||
f = self.to_freq_embedding(freqs)
|
||||
|
||||
x += self.pos_embedding.to(device, dtype = dtype)
|
||||
f += self.freq_pos_embedding.to(device, dtype = dtype)
|
||||
|
||||
x, ps = pack((f, x), 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
_, x = unpack(x, ps, 'b * d')
|
||||
x = reduce(x, 'b n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
freq_patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 1,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
images = torch.randn(8, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
@@ -1,3 +1,8 @@
|
||||
"""
|
||||
Vision Transformers Need Registers
|
||||
https://arxiv.org/abs/2309.16588
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -61,10 +61,7 @@ class T2TViT(nn.Module):
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
@@ -10,7 +10,7 @@ class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Layernorm(dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
|
||||
283
vit_pytorch/xcit.py
Normal file
283
vit_pytorch/xcit.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from random import randrange
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1, p = 2)
|
||||
|
||||
def dropout_layers(layers, dropout):
|
||||
if dropout == 0:
|
||||
return layers
|
||||
|
||||
num_layers = len(layers)
|
||||
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
|
||||
|
||||
# make sure at least one layer makes it
|
||||
if all(to_drop):
|
||||
rand_index = randrange(num_layers)
|
||||
to_drop[rand_index] = False
|
||||
|
||||
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
|
||||
return layers
|
||||
|
||||
# classes
|
||||
|
||||
class LayerScale(Module):
|
||||
def __init__(self, dim, fn, depth):
|
||||
super().__init__()
|
||||
if depth <= 18:
|
||||
init_eps = 0.1
|
||||
elif 18 > depth <= 24:
|
||||
init_eps = 1e-5
|
||||
else:
|
||||
init_eps = 1e-6
|
||||
|
||||
self.fn = fn
|
||||
self.scale = nn.Parameter(torch.full((dim,), init_eps))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(x, **kwargs) * self.scale
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context = None):
|
||||
h = self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
context = x if not exists(context) else torch.cat((x, context), dim = 1)
|
||||
|
||||
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
|
||||
attn = self.attend(sim)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class XCAttention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
x, ps = pack_one(x, 'b * d')
|
||||
|
||||
x = self.norm(x)
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h = h), (q, k, v))
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
|
||||
sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()
|
||||
|
||||
attn = self.attend(sim)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h j n -> b h i n', attn, v)
|
||||
out = rearrange(out, 'b h d n -> b n (h d)')
|
||||
|
||||
out = unpack_one(out, ps, 'b * d')
|
||||
return self.to_out(out)
|
||||
|
||||
class LocalPatchInteraction(Module):
|
||||
def __init__(self, dim, kernel_size = 3):
|
||||
super().__init__()
|
||||
assert (kernel_size % 2) == 1
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
Rearrange('b h w c -> b c h w'),
|
||||
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
|
||||
nn.BatchNorm2d(dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
|
||||
Rearrange('b c h w -> b h w c'),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
for ind in range(depth):
|
||||
layer = ind + 1
|
||||
self.layers.append(ModuleList([
|
||||
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
|
||||
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
|
||||
]))
|
||||
|
||||
def forward(self, x, context = None):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
|
||||
for attn, ff in layers:
|
||||
x = attn(x, context = context) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class XCATransformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
self.layer_dropout = layer_dropout
|
||||
|
||||
for ind in range(depth):
|
||||
layer = ind + 1
|
||||
self.layers.append(ModuleList([
|
||||
LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
|
||||
LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
|
||||
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
|
||||
|
||||
for cross_covariance_attn, local_patch_interaction, ff in layers:
|
||||
x = cross_covariance_attn(x) + x
|
||||
x = local_patch_interaction(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return x
|
||||
|
||||
class XCiT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
cls_depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
local_patch_kernel_size = 3,
|
||||
layer_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = 3 * patch_size ** 2
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
x, ps = pack_one(x, 'b * d')
|
||||
|
||||
b, n, _ = x.shape
|
||||
x += self.pos_embedding[:, :n]
|
||||
|
||||
x = unpack_one(x, ps, 'b * d')
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.xcit_transformer(x)
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
|
||||
|
||||
x = rearrange(x, 'b ... d -> b (...) d')
|
||||
cls_tokens = self.cls_transformer(cls_tokens, context = x)
|
||||
|
||||
return self.mlp_head(cls_tokens[:, 0])
|
||||
Reference in New Issue
Block a user