Compare commits

...

10 Commits

Author SHA1 Message Date
Phil Wang
61450ae1cf add a 3d version of cct, addressing https://github.com/lucidrains/vit-pytorch/issues/238 2022-10-29 11:33:17 -07:00
Phil Wang
6ec8fdaa6d make sure global average pool can be used for vivit in place of cls token 2022-10-24 19:59:48 -07:00
Phil Wang
13fabf901e add vivit 2022-10-24 09:34:04 -07:00
Ryan Russell
c0eb4c0150 Improving Readability (#220)
Signed-off-by: Ryan Russell <git@ryanrussell.org>

Signed-off-by: Ryan Russell <git@ryanrussell.org>
2022-10-17 10:42:45 -07:00
Phil Wang
5f1a6a05e9 release updated mae where one can more easily visualize reconstructions, thanks to @Vishu26 2022-10-17 10:41:46 -07:00
Srikumar Sastry
9a95e7904e Update mae.py (#242)
update mae so decoded tokens can be easily reshaped back to visualize the reconstruction
2022-10-17 10:41:10 -07:00
Phil Wang
b4853d39c2 add the 3d simple vit 2022-10-16 20:45:30 -07:00
Phil Wang
29fbf0aff4 begin extending some of the architectures over to 3d, starting with basic ViT 2022-10-16 15:31:59 -07:00
Phil Wang
4b8f5bc900 add link to Flax translation by @conceptofmind 2022-07-27 08:58:18 -07:00
Phil Wang
f86e052c05 offer way for extractor to return latents without detaching them 2022-07-16 16:22:40 -07:00
12 changed files with 693 additions and 111 deletions

130
README.md
View File

@@ -30,6 +30,8 @@
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
- [3D Vit](#3d-vit)
- [ViVit](#vivit)
- [Parallel ViT](#parallel-vit)
- [Learnable Memory ViT](#learnable-memory-vit)
- [Dino](#dino)
@@ -52,6 +54,8 @@ The official Jax repository is <a href="https://github.com/google-research/visio
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
<a href="https://github.com/conceptofmind/vit-flax">Flax translation</a> by <a href="https://github.com/conceptofmind">Enrico Shippole</a>!
## Install
```bash
@@ -661,7 +665,7 @@ preds = v(img) # (2, 1000)
<img src="./images/nest.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the hierarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
You can use it with the following code (ex. NesT-T)
@@ -675,7 +679,7 @@ nest = NesT(
dim = 96,
heads = 3,
num_hierarchies = 3, # number of hierarchies
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
block_repeats = (2, 2, 8), # the number of transformer blocks at each hierarchy, starting from the bottom
num_classes = 1000
)
@@ -965,6 +969,118 @@ img = torch.randn(4, 3, 256, 256)
tokens = spt(img) # (4, 256, 1024)
```
## 3D ViT
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
For starters, 3D ViT
```python
import torch
from vit_pytorch.vit_3d import ViT
v = ViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
```
3D Simple ViT
```python
import torch
from vit_pytorch.simple_vit_3d import SimpleViT
v = SimpleViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
```
3D version of <a href="https://github.com/lucidrains/vit-pytorch#cct">CCT</a>
```python
import torch
from vit_pytorch.cct_3d import CCT
cct = CCT(
img_size = 224,
num_frames = 8,
embedding_dim = 384,
n_conv_layers = 2,
frame_kernel_size = 3,
kernel_size = 7,
stride = 2,
padding = 3,
pooling_kernel_size = 3,
pooling_stride = 2,
pooling_padding = 1,
num_layers = 14,
num_heads = 6,
mlp_radio = 3.,
num_classes = 1000,
positional_embedding = 'learnable'
)
video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
pred = cct(video)
print(pred.shape)
```
## ViViT
<img src="./images/vivit.png" width="350px"></img>
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
```python
import torch
from vit_pytorch.vivit import ViT
v = ViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
spatial_depth = 6, # depth of the spatial transformer
temporal_depth = 6, # depth of the temporal transformer
heads = 8,
mlp_dim = 2048
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
```
## Parallel ViT
<img src="./images/parallel-vit.png" width="350px"></img>
@@ -1748,6 +1864,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
```
```bibtex
@article{Arnab2021ViViTAV,
title = {ViViT: A Video Vision Transformer},
author = {Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lucic and Cordelia Schmid},
journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2021},
pages = {6816-6826}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

View File

@@ -16,7 +16,7 @@
"\n",
"* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition\n",
"* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/\n",
"* Effecient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
"* Efficient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
]
},
{
@@ -342,7 +342,7 @@
"id": "ZhYDJXk2SRDu"
},
"source": [
"## Image Augumentation"
"## Image Augmentation"
]
},
{
@@ -497,7 +497,7 @@
"id": "TF9yMaRrSvmv"
},
"source": [
"## Effecient Attention"
"## Efficient Attention"
]
},
{
@@ -1307,7 +1307,7 @@
"celltoolbar": "Edit Metadata",
"colab": {
"collapsed_sections": [],
"name": "Effecient Attention | Cats & Dogs",
"name": "Efficient Attention | Cats & Dogs",
"provenance": [],
"toc_visible": true
},

BIN
images/vivit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.7',
version = '0.38.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -1,9 +1,17 @@
import torch
import torch.nn as nn
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
@@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
padding = padding if padding is not None else max(1, (kernel_size // 2))
stride = default(stride, max(1, (kernel_size // 2) - 1))
padding = default(padding, max(1, (kernel_size // 2)))
return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
@@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
padding=padding,
*args, **kwargs)
# positional
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return rearrange(pe, '... -> 1 ...')
# modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.heads = num_heads
head_dim = dim // self.heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
@@ -77,17 +95,20 @@ class Attention(nn.Module):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
qkv = self.qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
q = q * self.scale
attn = einsum('b h i d, b h j d -> b h i j', q, k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
x = einsum('b h i j, b h j d -> b h i d', attn, v)
x = rearrange(x, 'b h n d -> b n (h d)')
return self.proj_drop(self.proj(x))
class TransformerEncoderLayer(nn.Module):
@@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
super().__init__()
self.pre_norm = nn.LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
@@ -108,50 +130,34 @@ class TransformerEncoderLayer(nn.Module):
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.drop_path = DropPath(drop_path_rate)
self.activation = F.gelu
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def forward(self, src, *args, **kwargs):
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
super().__init__()
self.drop_prob = float(drop_prob)
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
if drop_prob <= 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (batch, *((1,) * (x.ndim - 1)))
keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
output = x.div(keep_prob) * keep_mask.float()
return output
class Tokenizer(nn.Module):
def __init__(self,
@@ -164,34 +170,35 @@ class Tokenizer(nn.Module):
activation=None,
max_pool=True,
conv_bias=False):
super(Tokenizer, self).__init__()
super().__init__()
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
nn.Conv2d(chan_in, chan_out,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if activation is None else activation(),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for i in range(n_conv_layers)
for chan_in, chan_out in n_filter_list_pairs
])
self.flattener = nn.Flatten(2, 3)
self.apply(self.init_weight)
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
def forward(self, x):
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')
@staticmethod
def init_weight(m):
@@ -214,106 +221,104 @@ class TransformerClassifier(nn.Module):
sequence_length=None,
*args, **kwargs):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
assert positional_embedding in {'sine', 'learnable', 'none'}
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
assert sequence_length is not None or positional_embedding == 'none', \
assert exists(sequence_length) or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
if not seq_pool:
sequence_length += 1
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)
if positional_embedding != 'none':
if positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
else:
if positional_embedding == 'none':
self.positional_emb = None
elif positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
self.dropout = nn.Dropout(p=dropout_rate)
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
self.blocks = nn.ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
for layer_dpr in dpr])
self.norm = nn.LayerNorm(embedding_dim)
self.fc = nn.Linear(embedding_dim, num_classes)
self.apply(self.init_weight)
def forward(self, x):
if self.positional_emb is None and x.size(1) < self.sequence_length:
b = x.shape[0]
if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_token, x), dim=1)
if self.positional_emb is not None:
if exists(self.positional_emb):
x += self.positional_emb
x = self.dropout(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.seq_pool:
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
else:
x = x[:, 0]
x = self.fc(x)
return x
return self.fc(x)
@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if isinstance(m, nn.Linear) and exists(m.bias):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@staticmethod
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)
# CCT Main model
class CCT(nn.Module):
def __init__(self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs):
super(CCT, self).__init__()
def __init__(
self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs
):
super().__init__()
img_height, img_width = pair(img_size)
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,

View File

@@ -4,6 +4,12 @@ from torch import nn
def exists(val):
return val is not None
def identity(t):
return t
def clone_and_detach(t):
return t.clone().detach()
def apply_tuple_or_single(fn, val):
if isinstance(val, tuple):
return tuple(map(fn, val))
@@ -17,7 +23,8 @@ class Extractor(nn.Module):
layer = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False
return_embeddings_only = False,
detach = True
):
super().__init__()
self.vit = vit
@@ -34,9 +41,11 @@ class Extractor(nn.Module):
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only
self.detach_fn = clone_and_detach if detach else identity
def _hook(self, _, inputs, output):
layer_output = inputs if self.layer_save_input else output
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
def _register_hook(self):
if not exists(self.layer):

View File

@@ -28,7 +28,7 @@ class MAE(nn.Module):
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# decoder parameters
self.decoder_dim = decoder_dim
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
@@ -73,7 +73,7 @@ class MAE(nn.Module):
# reapply decoder position embedding to unmasked tokens
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
@@ -81,13 +81,15 @@ class MAE(nn.Module):
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[:, :num_masked]
mask_tokens = decoded_tokens[batch_range, masked_indices]
pred_pixel_values = self.to_pixels(mask_tokens)
# calculate reconstruction loss

View File

@@ -13,9 +13,9 @@ def conv_1x1_bn(inp, oup):
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)

View File

@@ -131,7 +131,7 @@ class NesT(nn.Module):
fmap_size = image_size // patch_size
blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across hierarchy
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in reversed(hierarchies)]

View File

@@ -0,0 +1,128 @@
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
z, y, x = torch.meshgrid(
torch.arange(f, device = device),
torch.arange(h, device = device),
torch.arange(w, device = device),
indexing = 'ij')
fourier_dim = dim // 6
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
omega = 1. / (temperature ** omega)
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
return pe.type(dtype)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype
x = self.to_patch_embedding(video)
pe = posemb_sincos_3d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

129
vit_pytorch/vit_3d.py Normal file
View File

@@ -0,0 +1,129 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
x = self.to_patch_embedding(video)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

183
vit_pytorch/vivit.py Normal file
View File

@@ -0,0 +1,183 @@
import torch
from torch import nn
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(
self,
*,
image_size,
image_patch_size,
frames,
frame_patch_size,
num_classes,
dim,
spatial_depth,
temporal_depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
self.dropout = nn.Dropout(emb_dropout)
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
x = x + self.pos_embedding
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
x = self.dropout(x)
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
# excise out the spatial cls tokens or average pool for temporal attention
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
# append temporal CLS tokens
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
# attend across time
x = self.temporal_transformer(x)
# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)