mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e05cd6d8b8 | ||
|
|
b46233c3d6 | ||
|
|
68e13a3c7d | ||
|
|
b22dc0ecd2 | ||
|
|
db05a141a6 | ||
|
|
9f49a31977 | ||
|
|
ab63fc9cc8 | ||
|
|
c3018d1433 | ||
|
|
b7ed6bad28 | ||
|
|
e7cba9ba6d | ||
|
|
56373c0cbd | ||
|
|
24196a3e8a | ||
|
|
f6d7287b6b | ||
|
|
d47c57e32f | ||
|
|
0449865786 | ||
|
|
6693d47d0b | ||
|
|
141239ca86 | ||
|
|
0b5c9b4559 | ||
|
|
e300cdd7dc | ||
|
|
36ddc7a6ba | ||
|
|
1d1a63fc5c | ||
|
|
74b62009f8 |
41
README.md
41
README.md
@@ -198,7 +198,7 @@ preds = v(
|
||||
) # (5, 1000)
|
||||
```
|
||||
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
|
||||
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.5` and import as follows
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -2142,4 +2142,43 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2017DeepHL,
|
||||
title = {Deep Hyperspherical Learning},
|
||||
author = {Weiyang Liu and Yanming Zhang and Xingguo Li and Zhen Liu and Bo Dai and Tuo Zhao and Le Song},
|
||||
booktitle = {Neural Information Processing Systems},
|
||||
year = {2017},
|
||||
url = {https://api.semanticscholar.org/CorpusID:5104558}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Zhou2024ValueRL,
|
||||
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
|
||||
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
|
||||
year = {2024},
|
||||
url = {https://api.semanticscholar.org/CorpusID:273532030}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Zhu2024HyperConnections,
|
||||
title = {Hyper-Connections},
|
||||
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
||||
journal = {ArXiv},
|
||||
year = {2024},
|
||||
volume = {abs/2409.19606},
|
||||
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Fuller2025SimplerFV,
|
||||
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
|
||||
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
|
||||
year = {2025},
|
||||
url = {https://api.semanticscholar.org/CorpusID:276557720}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
4
setup.py
4
setup.py
@@ -6,10 +6,10 @@ with open('README.md') as f:
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.8.1',
|
||||
version = '1.11.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description=long_description,
|
||||
long_description = long_description,
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
|
||||
129
vit_pytorch/accept_video_wrapper.py
Normal file
129
vit_pytorch/accept_video_wrapper.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import is_tensor, randn
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
# classes
|
||||
|
||||
class AcceptVideoWrapper(Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_net: Module,
|
||||
forward_function = 'forward',
|
||||
add_time_pos_emb = False,
|
||||
dim_emb = None,
|
||||
time_seq_len = None,
|
||||
output_pos_add_pos_emb = 0 # defaults to first output position to add embedding
|
||||
):
|
||||
super().__init__()
|
||||
self.image_net = image_net
|
||||
self.forward_function = forward_function # for openclip, used in TRI-LBM
|
||||
|
||||
self.add_time_pos_emb = add_time_pos_emb
|
||||
self.output_pos_add_pos_emb = output_pos_add_pos_emb
|
||||
|
||||
if add_time_pos_emb:
|
||||
assert exists(dim_emb) and exists(time_seq_len), '`dim_emb` and `time_seq_len` must be set if adding positional embeddings to the output'
|
||||
self.time_seq_len = time_seq_len
|
||||
|
||||
self.pos_emb = Parameter(randn(time_seq_len, dim_emb) * 1e-2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video, # (b c t h w)
|
||||
eval_with_no_grad = False,
|
||||
forward_kwargs = dict()
|
||||
):
|
||||
add_time_pos_emb = self.add_time_pos_emb
|
||||
time = video.shape[2]
|
||||
|
||||
# maybe validate time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
assert time <= self.time_seq_len, f'received video with {time} frames but `time_seq_len` ({self.time_seq_len}) is too low'
|
||||
|
||||
video = rearrange(video, 'b c t h w -> b t c h w')
|
||||
|
||||
video = rearrange(video, 'b t ... -> (b t) ...')
|
||||
|
||||
# forward through image net for outputs
|
||||
|
||||
func = getattr(self.image_net, self.forward_function)
|
||||
|
||||
if eval_with_no_grad:
|
||||
self.image_net.eval()
|
||||
|
||||
context = torch.no_grad if eval_with_no_grad else nullcontext
|
||||
|
||||
with context():
|
||||
outputs = func(video, **forward_kwargs)
|
||||
|
||||
# handle multiple outputs, say logits and embeddings returned from extractor - also handle some reduce aux loss being returned
|
||||
|
||||
outputs, tree_spec = tree_flatten(outputs)
|
||||
|
||||
outputs = tuple(rearrange(t, '(b t) ... -> b t ...', t = time) if is_tensor(t) and t.numel() > 1 else t for t in outputs)
|
||||
|
||||
# maybe add time positional embedding
|
||||
|
||||
if add_time_pos_emb:
|
||||
|
||||
outputs = list(outputs)
|
||||
embed = outputs[self.output_pos_add_pos_emb]
|
||||
|
||||
pos_emb = rearrange(self.pos_emb, 't d -> 1 t d')
|
||||
|
||||
# handle the network outputting embeddings with spatial dimensions intact - assume embedded dimension is last
|
||||
|
||||
dims_to_unsqueeze = embed.ndim - pos_emb.ndim
|
||||
|
||||
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
|
||||
|
||||
embed = embed + pos_emb
|
||||
|
||||
outputs[self.output_pos_add_pos_emb] = embed
|
||||
|
||||
return tree_unflatten(outputs, tree_spec)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
from vit_pytorch import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
videos = torch.randn(1, 3, 10, 256, 256)
|
||||
|
||||
# step up the difficulty and return embeddings for robotics
|
||||
|
||||
from vit_pytorch.extractor import Extractor
|
||||
v = Extractor(v)
|
||||
|
||||
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
|
||||
|
||||
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
|
||||
|
||||
assert logits.shape == (1, 10, 1000)
|
||||
assert embeddings.shape == (1, 10, 65, 1024)
|
||||
@@ -167,8 +167,10 @@ class Tokenizer(nn.Module):
|
||||
stride,
|
||||
padding,
|
||||
frame_stride=1,
|
||||
frame_padding=None,
|
||||
frame_pooling_stride=1,
|
||||
frame_pooling_kernel_size=1,
|
||||
frame_pooling_padding=None,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
@@ -188,16 +190,22 @@ class Tokenizer(nn.Module):
|
||||
|
||||
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
||||
|
||||
if frame_padding is None:
|
||||
frame_padding = frame_kernel_size // 2
|
||||
|
||||
if frame_pooling_padding is None:
|
||||
frame_pooling_padding = frame_pooling_kernel_size // 2
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv3d(chan_in, chan_out,
|
||||
kernel_size=(frame_kernel_size, kernel_size, kernel_size),
|
||||
stride=(frame_stride, stride, stride),
|
||||
padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
|
||||
padding=(frame_padding, padding, padding), bias=conv_bias),
|
||||
nn.Identity() if not exists(activation) else activation(),
|
||||
nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
|
||||
stride=(frame_pooling_stride, pooling_stride, pooling_stride),
|
||||
padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||
padding=(frame_pooling_padding, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
|
||||
)
|
||||
for chan_in, chan_out in n_filter_list_pairs
|
||||
])
|
||||
@@ -324,8 +332,10 @@ class CCT(nn.Module):
|
||||
n_conv_layers=1,
|
||||
frame_stride=1,
|
||||
frame_kernel_size=3,
|
||||
frame_padding=None,
|
||||
frame_pooling_kernel_size=1,
|
||||
frame_pooling_stride=1,
|
||||
frame_pooling_padding=None,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
@@ -342,8 +352,10 @@ class CCT(nn.Module):
|
||||
n_output_channels=embedding_dim,
|
||||
frame_stride=frame_stride,
|
||||
frame_kernel_size=frame_kernel_size,
|
||||
frame_padding=frame_padding,
|
||||
frame_pooling_stride=frame_pooling_stride,
|
||||
frame_pooling_kernel_size=frame_pooling_kernel_size,
|
||||
frame_pooling_padding=frame_pooling_padding,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
|
||||
204
vit_pytorch/jumbo_vit.py
Normal file
204
vit_pytorch/jumbo_vit.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Simpler Fast Vision Transformers with a Jumbo CLS Token
|
||||
# https://arxiv.org/abs/2502.15021
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
|
||||
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = temperature ** -omega
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
|
||||
return pos_emb.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, mult = 4.):
|
||||
hidden_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class JumboViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
|
||||
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
|
||||
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
|
||||
channels = 3,
|
||||
dim_head = 64
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
jumbo_cls_dim = dim * jumbo_cls_k
|
||||
|
||||
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
|
||||
|
||||
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
|
||||
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
self.jumbo_ff = nn.Sequential(
|
||||
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
|
||||
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
|
||||
jumbo_cls_to_tokens
|
||||
)
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim),
|
||||
]))
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
|
||||
# pos embedding
|
||||
|
||||
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
|
||||
|
||||
x = x + pos_emb
|
||||
|
||||
# add cls tokens
|
||||
|
||||
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
|
||||
|
||||
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
|
||||
|
||||
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
|
||||
|
||||
# attention and feedforwards
|
||||
|
||||
for layer, (attn, ff) in enumerate(self.layers, start = 1):
|
||||
is_last = layer == len(self.layers)
|
||||
|
||||
x = attn(x) + x
|
||||
|
||||
# jumbo feedforward
|
||||
|
||||
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
|
||||
|
||||
x = ff(x) + x
|
||||
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
|
||||
|
||||
if is_last:
|
||||
continue
|
||||
|
||||
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
|
||||
|
||||
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
# normalization and project to logits
|
||||
|
||||
embed = self.norm(pooled)
|
||||
|
||||
embed = self.to_latent(embed)
|
||||
logits = self.linear_head(embed)
|
||||
return logits
|
||||
|
||||
# copy pasteable file
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
v = JumboViT(
|
||||
num_classes = 1000,
|
||||
image_size = 64,
|
||||
patch_size = 8,
|
||||
dim = 16,
|
||||
depth = 2,
|
||||
heads = 2,
|
||||
mlp_dim = 32,
|
||||
jumbo_cls_k = 3,
|
||||
jumbo_ff_mult = 2,
|
||||
)
|
||||
|
||||
images = torch.randn(1, 3, 64, 64)
|
||||
|
||||
logits = v(images)
|
||||
assert logits.shape == (1, 1000)
|
||||
@@ -6,9 +6,6 @@ from functools import partial
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.4')
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
@@ -44,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
@@ -59,8 +56,8 @@ class Attention(Module):
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
@@ -114,13 +111,13 @@ class Attention(Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
@@ -149,9 +146,15 @@ class NaViT(Module):
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
# what percent of tokens to dropout
|
||||
@@ -182,7 +185,7 @@ class NaViT(Module):
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
@@ -323,3 +326,5 @@ if __name__ == '__main__':
|
||||
]
|
||||
|
||||
assert v(images).shape == (5, 1000)
|
||||
|
||||
v(images).sum().backward()
|
||||
|
||||
@@ -6,9 +6,6 @@ from functools import partial
|
||||
import torch
|
||||
import packaging.version as pkg_version
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.4')
|
||||
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
@@ -44,7 +41,7 @@ def FeedForward(dim, hidden_dim, dropout = 0.):
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
@@ -59,8 +56,8 @@ class Attention(Module):
|
||||
# in the paper, they employ qk rmsnorm, a way to stabilize attention
|
||||
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
|
||||
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False)
|
||||
self.query_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
self.key_norm = nn.LayerNorm(dim_head, bias = False) if qk_norm else nn.Identity()
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
@@ -86,17 +83,6 @@ class Attention(Module):
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
query = self.to_queries(x)
|
||||
key = self.to_keys(context)
|
||||
value = self.to_values(context)
|
||||
|
||||
# split heads
|
||||
|
||||
def split_heads(t):
|
||||
return t.unflatten(-1, (self.heads, self.dim_head))
|
||||
|
||||
@@ -126,13 +112,13 @@ class Attention(Module):
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., qk_norm = True):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, qk_norm = qk_norm),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
@@ -164,11 +150,15 @@ class NaViT(Module):
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.,
|
||||
num_registers = 4,
|
||||
qk_rmsnorm = True,
|
||||
token_dropout_prob: float | None = None
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
|
||||
if pkg_version.parse(torch.__version__) < pkg_version.parse('2.5'):
|
||||
print('nested tensor NaViT was tested on pytorch 2.5')
|
||||
|
||||
# what percent of tokens to dropout
|
||||
# if int or float given, then assume constant dropout prob
|
||||
# otherwise accept a callback that in turn calculates dropout prob from height and width
|
||||
@@ -209,7 +199,7 @@ class NaViT(Module):
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, qk_rmsnorm)
|
||||
|
||||
# final attention pooling queries
|
||||
|
||||
@@ -336,7 +326,7 @@ class NaViT(Module):
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# works for torch 2.4
|
||||
# works for torch 2.5
|
||||
|
||||
v = NaViT(
|
||||
image_size = 256,
|
||||
@@ -362,3 +352,5 @@ if __name__ == '__main__':
|
||||
]
|
||||
|
||||
assert v(volumes).shape == (5, 1000)
|
||||
|
||||
v(volumes).sum().backward()
|
||||
|
||||
@@ -76,7 +76,8 @@ class Attention(Module):
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25))
|
||||
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
|
||||
|
||||
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
||||
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
||||
@@ -94,7 +95,9 @@ class Attention(Module):
|
||||
# query key rmsnorm
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
q, k = (q * self.qk_scale), (k * self.qk_scale)
|
||||
|
||||
q = q * self.q_scale
|
||||
k = k * self.k_scale
|
||||
|
||||
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
|
||||
|
||||
@@ -176,18 +179,18 @@ class nViT(Module):
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
NormLinear(patch_dim, dim, norm_dim_in = False),
|
||||
)
|
||||
|
||||
self.abs_pos_emb = nn.Embedding(num_patches, dim)
|
||||
self.abs_pos_emb = NormLinear(dim, num_patches)
|
||||
|
||||
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
|
||||
|
||||
# layers
|
||||
|
||||
self.dim = dim
|
||||
self.scale = dim ** 0.5
|
||||
|
||||
self.layers = ModuleList([])
|
||||
self.residual_lerp_scales = nn.ParameterList([])
|
||||
|
||||
@@ -198,8 +201,8 @@ class nViT(Module):
|
||||
]))
|
||||
|
||||
self.residual_lerp_scales.append(nn.ParameterList([
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
|
||||
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
|
||||
]))
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.ones(num_classes))
|
||||
@@ -222,22 +225,23 @@ class nViT(Module):
|
||||
|
||||
tokens = self.to_patch_embedding(images)
|
||||
|
||||
pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device))
|
||||
seq_len = tokens.shape[-2]
|
||||
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
|
||||
|
||||
tokens = l2norm(tokens + pos_emb)
|
||||
|
||||
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
|
||||
|
||||
attn_out = l2norm(attn(tokens))
|
||||
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
|
||||
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
|
||||
|
||||
ff_out = l2norm(ff(tokens))
|
||||
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
|
||||
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
|
||||
|
||||
pooled = reduce(tokens, 'b n d -> b d', 'mean')
|
||||
|
||||
logits = self.to_pred(pooled)
|
||||
logits = logits * self.logit_scale * (self.dim ** 0.5)
|
||||
logits = logits * self.logit_scale * self.scale
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
233
vit_pytorch/simple_vit_with_hyper_connections.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
ViT + Hyper-Connections + Register Tokens
|
||||
https://arxiv.org/abs/2409.19606
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn, tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat, reduce, einsum, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# b - batch, h - heads, n - sequence, e - expansion rate / residual streams, d - feature dimension
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# hyper connections
|
||||
|
||||
class HyperConnection(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_residual_streams,
|
||||
layer_index
|
||||
):
|
||||
""" Appendix J - Algorithm 2, Dynamic only """
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, bias = False)
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
self.layer_index = layer_index
|
||||
|
||||
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
||||
|
||||
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
||||
init_alpha0[layer_index % num_residual_streams, 0] = 1.
|
||||
|
||||
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
||||
|
||||
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
||||
self.dynamic_alpha_scale = nn.Parameter(tensor(1e-2))
|
||||
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
||||
self.dynamic_beta_scale = nn.Parameter(tensor(1e-2))
|
||||
|
||||
def width_connection(self, residuals):
|
||||
normed = self.norm(residuals)
|
||||
|
||||
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
|
||||
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
||||
alpha = dynamic_alpha + self.static_alpha
|
||||
|
||||
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
|
||||
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
||||
beta = dynamic_beta + self.static_beta
|
||||
|
||||
# width connection
|
||||
mix_h = einsum(alpha, residuals, '... e1 e2, ... e1 d -> ... e2 d')
|
||||
|
||||
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
||||
|
||||
return branch_input, residuals, beta
|
||||
|
||||
def depth_connection(
|
||||
self,
|
||||
branch_output,
|
||||
residuals,
|
||||
beta
|
||||
):
|
||||
return einsum(branch_output, beta, "b n d, b n e -> b n e d") + residuals
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_residual_streams):
|
||||
super().__init__()
|
||||
|
||||
self.num_residual_streams = num_residual_streams
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for layer_index in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
HyperConnection(dim, num_residual_streams, layer_index),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = repeat(x, 'b n d -> b n e d', e = self.num_residual_streams)
|
||||
|
||||
for attn_hyper_conn, attn, ff_hyper_conn, ff in self.layers:
|
||||
|
||||
x, attn_res, beta = attn_hyper_conn.width_connection(x)
|
||||
|
||||
x = attn(x)
|
||||
|
||||
x = attn_hyper_conn.depth_connection(x, attn_res, beta)
|
||||
|
||||
x, ff_res, beta = ff_hyper_conn.width_connection(x)
|
||||
|
||||
x = ff(x)
|
||||
|
||||
x = ff_hyper_conn.depth_connection(x, ff_res, beta)
|
||||
|
||||
x = reduce(x, 'b n e d -> b n d', 'sum')
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_residual_streams, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_residual_streams)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(x)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
vit = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
num_residual_streams = 8
|
||||
)
|
||||
|
||||
images = torch.randn(3, 3, 256, 256)
|
||||
|
||||
logits = vit(images)
|
||||
159
vit_pytorch/simple_vit_with_value_residual.py
Normal file
159
vit_pytorch/simple_vit_with_value_residual.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
def FeedForward(dim, hidden_dim):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_residual_mix = nn.Sequential(
|
||||
nn.Linear(dim, heads),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b n h -> b h n 1')
|
||||
) if learned_value_residual_mix else (lambda _: 0.5)
|
||||
|
||||
def forward(self, x, value_residual = None):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
if exists(value_residual):
|
||||
mix = self.to_residual_mix(x)
|
||||
v = v * mix + value_residual * (1. - mix)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.to_out(out), v
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = ModuleList([])
|
||||
for i in range(depth):
|
||||
is_first = i == 0
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
value_residual = None
|
||||
|
||||
for attn, ff in self.layers:
|
||||
|
||||
attn_out, values = attn(x, value_residual = value_residual)
|
||||
value_residual = default(value_residual, values)
|
||||
|
||||
x = attn_out + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
device = img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
|
||||
# quick test
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = SimpleViT(
|
||||
num_classes = 1000,
|
||||
image_size = 256,
|
||||
patch_size = 8,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
)
|
||||
|
||||
images = torch.randn(2, 3, 256, 256)
|
||||
|
||||
logits = v(images)
|
||||
Reference in New Issue
Block a user