Compare commits

...

9 Commits
1.6.9 ... 1.7.6

Author SHA1 Message Date
Phil Wang
771fb6daaf Nested navit (#325)
add a variant of NaViT using nested tensors
2024-08-20 15:07:20 -07:00
Phil Wang
4f22eae631 1.7.5 2024-08-07 08:46:18 -07:00
Phil Wang
dfc8df6713 add the u-vit implementation with simple vit + register tokens 2024-08-07 08:45:57 -07:00
lucidrains
9992a615d1 attention re-use in lookup vit should use pre-softmax attention matrix 2024-07-19 19:23:38 -07:00
Phil Wang
4b2c00cb63 when cross attending in look vit, make sure context tokens are normalized 2024-07-19 10:23:12 -07:00
Phil Wang
ec6c48b8ff norm not needed when reusing attention in lookvit 2024-07-19 10:00:03 -07:00
Phil Wang
547bf94d07 1.7.1 2024-07-19 09:49:44 -07:00
Phil Wang
bd72b58355 add lookup vit, cite, document later 2024-07-19 09:48:58 -07:00
lucidrains
e3256d77cd fix t2t vit having two layernorms, and make final layernorm in distillation wrapper configurable, default to False for vit 2024-06-11 15:12:53 -07:00
8 changed files with 862 additions and 16 deletions

View File

@@ -198,6 +198,38 @@ preds = v(
) # (5, 1000)
```
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
```python
import torch
from vit_pytorch.na_vit_nested_tensor import NaViT
v = NaViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob = 0.1
)
# 5 images of different resolutions - List[Tensor]
images = [
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]
preds = v(images)
assert preds.shape == (5, 1000)
```
## Distillation
<img src="./images/distill.png" width="300px"></img>
@@ -2072,4 +2104,31 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Koner2024LookupViTCV,
title = {LookupViT: Compressing visual information to a limited number of tokens},
author = {Rajat Koner and Gagan Jain and Prateek Jain and Volker Tresp and Sujoy Paul},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:271244592}
}
```
```bibtex
@article{Bao2022AllAW,
title = {All are Worth Words: A ViT Backbone for Diffusion Models},
author = {Fan Bao and Shen Nie and Kaiwen Xue and Yue Cao and Chongxuan Li and Hang Su and Jun Zhu},
journal = {2023 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022},
pages = {22669-22679},
url = {https://api.semanticscholar.org/CorpusID:253581703}
}
```
```bibtex
@misc{Rubin2024,
author = {Ohad Rubin},
url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.6.9',
version = '1.7.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,

View File

@@ -1,6 +1,8 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from vit_pytorch.vit import ViT
from vit_pytorch.t2t import T2TViT
from vit_pytorch.efficient import ViT as EfficientViT
@@ -12,6 +14,9 @@ from einops import rearrange, repeat
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# classes
class DistillMixin:
@@ -20,12 +25,12 @@ class DistillMixin:
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]
if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x)
@@ -97,7 +102,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
# knowledge distillation wrapper
class DistillWrapper(nn.Module):
class DistillWrapper(Module):
def __init__(
self,
*,
@@ -105,7 +110,8 @@ class DistillWrapper(nn.Module):
student,
temperature = 1.,
alpha = 0.5,
hard = False
hard = False,
mlp_layernorm = False
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
@@ -122,14 +128,14 @@ class DistillWrapper(nn.Module):
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim),
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
nn.Linear(dim, num_classes)
)
def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha
T = temperature if exists(temperature) else self.temperature
alpha = default(alpha, self.alpha)
T = default(temperature, self.temperature)
with torch.no_grad():
teacher_logits = self.teacher(img)

278
vit_pytorch/look_vit.py Normal file
View File

@@ -0,0 +1,278 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from einops import einsum, rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def divisible_by(num, den):
return (num % den) == 0
# simple vit sinusoidal pos emb
def posemb_sincos_2d(t, temperature = 10000):
h, w, d, device = *t.shape[1:], t.device
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (d % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(d // 4, device = device) / (d // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pos = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pos.float()
# bias-less layernorm with unit offset trick (discovered by Ohad Rubin)
class LayerNorm(Module):
def __init__(self, dim):
super().__init__()
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
self.gamma = nn.Parameter(torch.zeros(dim))
def forward(self, x):
normed = self.ln(x)
return normed * (self.gamma + 1)
# mlp
def MLP(dim, factor = 4, dropout = 0.):
hidden_dim = int(dim * factor)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
# attention
class Attention(Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
cross_attend = False,
reuse_attention = False
):
super().__init__()
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.heads = heads
self.reuse_attention = reuse_attention
self.cross_attend = cross_attend
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity()
self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity()
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
self.to_k = nn.Linear(dim, inner_dim, bias = False) if not reuse_attention else None
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
Rearrange('b h n d -> b n (h d)'),
nn.Linear(inner_dim, dim, bias = False),
nn.Dropout(dropout)
)
def forward(
self,
x,
context = None,
return_qk_sim = False,
qk_sim = None
):
x = self.norm(x)
assert not (exists(context) ^ self.cross_attend)
if self.cross_attend:
context = self.norm_context(context)
else:
context = x
v = self.to_v(context)
v = self.split_heads(v)
if not self.reuse_attention:
qk = (self.to_q(x), self.to_k(context))
q, k = tuple(self.split_heads(t) for t in qk)
q = q * self.scale
qk_sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
else:
assert exists(qk_sim), 'qk sim matrix must be passed in for reusing previous attention'
attn = self.attend(qk_sim)
attn = self.dropout(attn)
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
out = self.to_out(out)
if not return_qk_sim:
return out
return out, qk_sim
# LookViT
class LookViT(Module):
def __init__(
self,
*,
dim,
image_size,
num_classes,
depth = 3,
patch_size = 16,
heads = 8,
mlp_factor = 4,
dim_head = 64,
highres_patch_size = 12,
highres_mlp_factor = 4,
cross_attn_heads = 8,
cross_attn_dim_head = 64,
patch_conv_kernel_size = 7,
dropout = 0.1,
channels = 3
):
super().__init__()
assert divisible_by(image_size, highres_patch_size)
assert divisible_by(image_size, patch_size)
assert patch_size > highres_patch_size, 'patch size of the main vision transformer should be smaller than the highres patch sizes (that does the `lookup`)'
assert not divisible_by(patch_conv_kernel_size, 2)
self.dim = dim
self.image_size = image_size
self.patch_size = patch_size
kernel_size = patch_conv_kernel_size
patch_dim = (highres_patch_size * highres_patch_size) * channels
self.to_patches = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = highres_patch_size, p2 = highres_patch_size),
nn.Conv2d(patch_dim, dim, kernel_size, padding = kernel_size // 2),
Rearrange('b c h w -> b h w c'),
LayerNorm(dim),
)
# absolute positions
num_patches = (image_size // highres_patch_size) ** 2
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
# lookvit blocks
layers = ModuleList([])
for _ in range(depth):
layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
MLP(dim = dim, factor = mlp_factor, dropout = dropout),
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True),
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True),
LayerNorm(dim),
MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout)
]))
self.layers = layers
self.norm = LayerNorm(dim)
self.highres_norm = LayerNorm(dim)
self.to_logits = nn.Linear(dim, num_classes, bias = False)
def forward(self, img):
assert img.shape[-2:] == (self.image_size, self.image_size)
# to patch tokens and positions
highres_tokens = self.to_patches(img)
size = highres_tokens.shape[-2]
pos_emb = posemb_sincos_2d(highres_tokens)
highres_tokens = highres_tokens + rearrange(pos_emb, '(h w) d -> h w d', h = size)
tokens = F.interpolate(
rearrange(highres_tokens, 'b h w d -> b d h w'),
img.shape[-1] // self.patch_size,
mode = 'bilinear'
)
tokens = rearrange(tokens, 'b c h w -> b (h w) c')
highres_tokens = rearrange(highres_tokens, 'b h w c -> b (h w) c')
# attention and feedforwards
for attn, mlp, lookup_cross_attn, highres_attn, highres_norm, highres_mlp in self.layers:
# main tokens cross attends (lookup) on the high res tokens
lookup_out, qk_sim = lookup_cross_attn(tokens, highres_tokens, return_qk_sim = True) # return attention as they reuse the attention matrix
tokens = lookup_out + tokens
tokens = attn(tokens) + tokens
tokens = mlp(tokens) + tokens
# attention-reuse
qk_sim = rearrange(qk_sim, 'b h i j -> b h j i') # transpose for reverse cross attention
highres_tokens = highres_attn(highres_tokens, tokens, qk_sim = qk_sim) + highres_tokens
highres_tokens = highres_norm(highres_tokens)
highres_tokens = highres_mlp(highres_tokens) + highres_tokens
# to logits
tokens = self.norm(tokens)
highres_tokens = self.highres_norm(highres_tokens)
tokens = reduce(tokens, 'b n d -> b d', 'mean')
highres_tokens = reduce(highres_tokens, 'b n d -> b d', 'mean')
return self.to_logits(tokens + highres_tokens)
# main
if __name__ == '__main__':
v = LookViT(
image_size = 256,
num_classes = 1000,
dim = 512,
depth = 2,
heads = 8,
dim_head = 64,
patch_size = 32,
highres_patch_size = 8,
highres_mlp_factor = 2,
cross_attn_heads = 8,
cross_attn_dim_head = 64,
dropout = 0.1
).cuda()
img = torch.randn(2, 3, 256, 256).cuda()
pred = v(img)
assert pred.shape == (2, 1000)

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
from functools import partial
from typing import List, Union
from typing import List
import torch
import torch.nn.functional as F
@@ -245,7 +247,7 @@ class NaViT(nn.Module):
def forward(
self,
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
group_images = False,
group_max_seq_len = 2048
):
@@ -264,6 +266,11 @@ class NaViT(nn.Module):
max_seq_len = group_max_seq_len
)
# if List[Tensor] is not grouped -> List[List[Tensor]]
if torch.is_tensor(batched_images[0]):
batched_images = [batched_images]
# process images into variable lengthed sequences with attention mask
num_images = []

View File

@@ -0,0 +1,323 @@
from __future__ import annotations
from typing import List
from functools import partial
import torch
import packaging.version as pkg_version
assert pkg_version.parse(torch.__version__) >= pkg_version.parse('2.4'), 'install pytorch 2.4 or greater to use this flavor of NaViT'
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.nested import nested_tensor
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
# feedforward
def FeedForward(dim, hidden_dim, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)
dim_inner = heads * dim_head
self.heads = heads
self.dim_head = dim_head
self.to_queries = nn.Linear(dim, dim_inner, bias = False)
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
self.to_values = nn.Linear(dim, dim_inner, bias = False)
# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors
self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)
self.dropout = dropout
self.to_out = nn.Linear(dim_inner, dim, bias = False)
def forward(
self,
x,
context: Tensor | None = None
):
x = self.norm(x)
# for attention pooling, one query pooling to entire sequence
context = default(context, x)
# queries, keys, values
query = self.to_queries(x)
key = self.to_keys(context)
value = self.to_values(context)
# split heads
def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head))
def transpose_head_seq(t):
return t.transpose(1, 2)
query, key, value = map(split_heads, (query, key, value))
# qk norm for attention stability
query = self.query_norm(query)
key = self.key_norm(key)
query, key, value = map(transpose_head_seq, (query, key, value))
# attention
out = F.scaled_dot_product_attention(
query, key, value,
dropout_p = self.dropout if self.training else 0.
)
# merge heads
out = out.transpose(1, 2).flatten(-2)
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
self.norm = nn.LayerNorm(dim, bias = False)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class NaViT(Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob: float | None = None
):
super().__init__()
image_height, image_width = pair(image_size)
# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width
self.token_dropout_prob = token_dropout_prob
# calculate patching related stuff
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2)
self.channels = channels
self.patch_size = patch_size
self.to_patches = Rearrange('c (h p1) (w p2) -> h w (c p1 p2)', p1 = patch_size, p2 = patch_size)
self.to_patch_embedding = nn.Sequential(
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
# final attention pooling queries
self.attn_pool_queries = nn.Parameter(torch.randn(dim))
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)
# output to logits
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, num_classes, bias = False)
)
@property
def device(self):
return next(self.parameters()).device
def forward(
self,
images: List[Tensor], # different resolution images
):
batch, device = len(images), self.device
arange = partial(torch.arange, device = device)
assert all([image.ndim == 3 and image.shape[0] == self.channels for image in images]), f'all images must have {self.channels} channels and number of dimensions of 3 (channels, height, width)'
all_patches = [self.to_patches(image) for image in images]
# prepare factorized positional embedding height width indices
positions = []
for patches in all_patches:
patch_height, patch_width = patches.shape[:2]
hw_indices = torch.stack(torch.meshgrid((arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
hw_indices = rearrange(hw_indices, 'h w c -> (h w) c')
positions.append(hw_indices)
# need the sizes to compute token dropout + positional embedding
tokens = [rearrange(patches, 'h w d -> (h w) d') for patches in all_patches]
# handle token dropout
seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)
if self.training and self.token_dropout_prob > 0:
keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)
kept_tokens = []
kept_positions = []
for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
one_image_kept_tokens = one_image_tokens[keep_indices]
one_image_kept_positions = one_image_positions[keep_indices]
kept_tokens.append(one_image_kept_tokens)
kept_positions.append(one_image_kept_positions)
tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens
# add all height and width factorized positions
height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
height_embed, width_embed = self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]
pos_embed = height_embed + width_embed
# use nested tensor for transformers and save on padding computation
tokens = torch.cat(tokens)
# linear projection to patch embeddings
tokens = self.to_patch_embedding(tokens)
# absolute positions
tokens = tokens + pos_embed
tokens = nested_tensor(tokens.split(seq_lens.tolist()), layout = torch.jagged, device = device)
# embedding dropout
tokens = self.dropout(tokens)
# transformer
tokens = self.transformer(tokens)
# attention pooling
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not
attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch
attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)
pooled = self.attn_pool(attn_pool_queries, tokens)
# back to unjagged
logits = torch.stack(pooled.unbind())
logits = rearrange(logits, 'b 1 d -> b d')
logits = self.to_latent(logits)
return self.mlp_head(logits)
# quick test
if __name__ == '__main__':
v = NaViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob = 0.1
)
# 5 images of different resolutions - List[Tensor]
images = [
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]
assert v(images).shape == (5, 1000)

176
vit_pytorch/simple_uvit.py Normal file
View File

@@ -0,0 +1,176 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def exists(v):
return v is not None
def divisible_by(num, den):
return (num % den) == 0
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = temperature ** -omega
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# classes
def FeedForward(dim, hidden_dim):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.depth = depth
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for layer in range(1, depth + 1):
latter_half = layer >= (depth / 2 + 1)
self.layers.append(nn.ModuleList([
nn.Linear(dim * 2, dim) if latter_half else None,
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
skips = []
for ind, (combine_skip, attn, ff) in enumerate(self.layers):
layer = ind + 1
first_half = layer <= (self.depth / 2)
if first_half:
skips.append(x)
if exists(combine_skip):
skip = skips.pop()
skip_and_x = torch.cat((skip, x), dim = -1)
x = combine_skip(skip_and_x)
x = attn(x) + x
x = ff(x) + x
assert len(skips) == 0
return self.norm(x)
class SimpleUViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim
)
self.register_buffer('pos_embedding', pos_embedding, persistent = False)
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch, device = img.shape[0], img.device
x = self.to_patch_embedding(img)
x = x + self.pos_embedding.type(x.dtype)
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
x, ps = pack([x, r], 'b * d')
x = self.transformer(x)
x, _ = unpack(x, ps, 'b * d')
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# quick test on odd number of layers
if __name__ == '__main__':
v = SimpleUViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 7,
heads = 16,
mlp_dim = 2048
).cuda()
img = torch.randn(2, 3, 256, 256).cuda()
preds = v(img)
assert preds.shape == (2, 1000)

View File

@@ -61,10 +61,7 @@ class T2TViT(nn.Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
x = self.to_patch_embedding(img)