mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61450ae1cf | ||
|
|
6ec8fdaa6d | ||
|
|
13fabf901e | ||
|
|
c0eb4c0150 | ||
|
|
5f1a6a05e9 | ||
|
|
9a95e7904e | ||
|
|
b4853d39c2 | ||
|
|
29fbf0aff4 | ||
|
|
4b8f5bc900 | ||
|
|
f86e052c05 |
130
README.md
130
README.md
@@ -30,6 +30,8 @@
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Patch Merger](#patch-merger)
|
||||
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
|
||||
- [3D Vit](#3d-vit)
|
||||
- [ViVit](#vivit)
|
||||
- [Parallel ViT](#parallel-vit)
|
||||
- [Learnable Memory ViT](#learnable-memory-vit)
|
||||
- [Dino](#dino)
|
||||
@@ -52,6 +54,8 @@ The official Jax repository is <a href="https://github.com/google-research/visio
|
||||
|
||||
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
|
||||
|
||||
<a href="https://github.com/conceptofmind/vit-flax">Flax translation</a> by <a href="https://github.com/conceptofmind">Enrico Shippole</a>!
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -661,7 +665,7 @@ preds = v(img) # (2, 1000)
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the hierarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
|
||||
You can use it with the following code (ex. NesT-T)
|
||||
|
||||
@@ -675,7 +679,7 @@ nest = NesT(
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
block_repeats = (2, 2, 8), # the number of transformer blocks at each hierarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
@@ -965,6 +969,118 @@ img = torch.randn(4, 3, 256, 256)
|
||||
tokens = spt(img) # (4, 256, 1024)
|
||||
```
|
||||
|
||||
## 3D ViT
|
||||
|
||||
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
|
||||
|
||||
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
|
||||
|
||||
For starters, 3D ViT
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_3d import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
3D Simple ViT
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.simple_vit_3d import SimpleViT
|
||||
|
||||
v = SimpleViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
3D version of <a href="https://github.com/lucidrains/vit-pytorch#cct">CCT</a>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cct_3d import CCT
|
||||
|
||||
cct = CCT(
|
||||
img_size = 224,
|
||||
num_frames = 8,
|
||||
embedding_dim = 384,
|
||||
n_conv_layers = 2,
|
||||
frame_kernel_size = 3,
|
||||
kernel_size = 7,
|
||||
stride = 2,
|
||||
padding = 3,
|
||||
pooling_kernel_size = 3,
|
||||
pooling_stride = 2,
|
||||
pooling_padding = 1,
|
||||
num_layers = 14,
|
||||
num_heads = 6,
|
||||
mlp_radio = 3.,
|
||||
num_classes = 1000,
|
||||
positional_embedding = 'learnable'
|
||||
)
|
||||
|
||||
video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
|
||||
pred = cct(video)
|
||||
print(pred.shape)
|
||||
```
|
||||
|
||||
## ViViT
|
||||
|
||||
<img src="./images/vivit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vivit import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
spatial_depth = 6, # depth of the spatial transformer
|
||||
temporal_depth = 6, # depth of the temporal transformer
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
## Parallel ViT
|
||||
|
||||
<img src="./images/parallel-vit.png" width="350px"></img>
|
||||
@@ -1748,6 +1864,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Arnab2021ViViTAV,
|
||||
title = {ViViT: A Video Vision Transformer},
|
||||
author = {Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lucic and Cordelia Schmid},
|
||||
journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
|
||||
year = {2021},
|
||||
pages = {6816-6826}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"\n",
|
||||
"* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition\n",
|
||||
"* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/\n",
|
||||
"* Effecient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
|
||||
"* Efficient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -342,7 +342,7 @@
|
||||
"id": "ZhYDJXk2SRDu"
|
||||
},
|
||||
"source": [
|
||||
"## Image Augumentation"
|
||||
"## Image Augmentation"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -497,7 +497,7 @@
|
||||
"id": "TF9yMaRrSvmv"
|
||||
},
|
||||
"source": [
|
||||
"## Effecient Attention"
|
||||
"## Efficient Attention"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1307,7 +1307,7 @@
|
||||
"celltoolbar": "Edit Metadata",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "Effecient Attention | Cats & Dogs",
|
||||
"name": "Efficient Attention | Cats & Dogs",
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
|
||||
BIN
images/vivit.png
Normal file
BIN
images/vivit.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 104 KiB |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.35.7',
|
||||
version = '0.38.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
@@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
|
||||
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
kernel_size=3, stride=None, padding=None,
|
||||
*args, **kwargs):
|
||||
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
|
||||
padding = padding if padding is not None else max(1, (kernel_size // 2))
|
||||
stride = default(stride, max(1, (kernel_size // 2) - 1))
|
||||
padding = default(padding, max(1, (kernel_size // 2)))
|
||||
|
||||
return CCT(num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
@@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||
padding=padding,
|
||||
*args, **kwargs)
|
||||
|
||||
# positional
|
||||
|
||||
def sinusoidal_embedding(n_channels, dim):
|
||||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
|
||||
for p in range(n_channels)])
|
||||
pe[:, 0::2] = torch.sin(pe[:, 0::2])
|
||||
pe[:, 1::2] = torch.cos(pe[:, 1::2])
|
||||
return rearrange(pe, '... -> 1 ...')
|
||||
|
||||
# modules
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // self.num_heads
|
||||
self.heads = num_heads
|
||||
head_dim = dim // self.heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
@@ -77,17 +95,20 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
qkv = self.qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
attn = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
x = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
x = rearrange(x, 'b h n d -> b n (h d)')
|
||||
|
||||
return self.proj_drop(self.proj(x))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
@@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
||||
attention_dropout=0.1, drop_path_rate=0.1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.pre_norm = nn.LayerNorm(d_model)
|
||||
self.self_attn = Attention(dim=d_model, num_heads=nhead,
|
||||
attention_dropout=attention_dropout, projection_dropout=dropout)
|
||||
@@ -108,50 +130,34 @@ class TransformerEncoderLayer(nn.Module):
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.drop_path = DropPath(drop_path_rate)
|
||||
|
||||
self.activation = F.gelu
|
||||
|
||||
def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
def forward(self, src, *args, **kwargs):
|
||||
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
|
||||
src = src + self.drop_path(self.dropout2(src2))
|
||||
return src
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""
|
||||
Obtained from: github.com:rwightman/pytorch-image-models
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
super().__init__()
|
||||
self.drop_prob = float(drop_prob)
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
|
||||
|
||||
if drop_prob <= 0. or not self.training:
|
||||
return x
|
||||
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (batch, *((1,) * (x.ndim - 1)))
|
||||
|
||||
keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
|
||||
output = x.div(keep_prob) * keep_mask.float()
|
||||
return output
|
||||
|
||||
class Tokenizer(nn.Module):
|
||||
def __init__(self,
|
||||
@@ -164,34 +170,35 @@ class Tokenizer(nn.Module):
|
||||
activation=None,
|
||||
max_pool=True,
|
||||
conv_bias=False):
|
||||
super(Tokenizer, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
n_filter_list = [n_input_channels] + \
|
||||
[in_planes for _ in range(n_conv_layers - 1)] + \
|
||||
[n_output_channels]
|
||||
|
||||
n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
|
||||
|
||||
self.conv_layers = nn.Sequential(
|
||||
*[nn.Sequential(
|
||||
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
|
||||
nn.Conv2d(chan_in, chan_out,
|
||||
kernel_size=(kernel_size, kernel_size),
|
||||
stride=(stride, stride),
|
||||
padding=(padding, padding), bias=conv_bias),
|
||||
nn.Identity() if activation is None else activation(),
|
||||
nn.Identity() if not exists(activation) else activation(),
|
||||
nn.MaxPool2d(kernel_size=pooling_kernel_size,
|
||||
stride=pooling_stride,
|
||||
padding=pooling_padding) if max_pool else nn.Identity()
|
||||
)
|
||||
for i in range(n_conv_layers)
|
||||
for chan_in, chan_out in n_filter_list_pairs
|
||||
])
|
||||
|
||||
self.flattener = nn.Flatten(2, 3)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def sequence_length(self, n_channels=3, height=224, width=224):
|
||||
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
|
||||
|
||||
def forward(self, x):
|
||||
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
|
||||
return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
@@ -214,106 +221,104 @@ class TransformerClassifier(nn.Module):
|
||||
sequence_length=None,
|
||||
*args, **kwargs):
|
||||
super().__init__()
|
||||
positional_embedding = positional_embedding if \
|
||||
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
|
||||
assert positional_embedding in {'sine', 'learnable', 'none'}
|
||||
|
||||
dim_feedforward = int(embedding_dim * mlp_ratio)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.sequence_length = sequence_length
|
||||
self.seq_pool = seq_pool
|
||||
|
||||
assert sequence_length is not None or positional_embedding == 'none', \
|
||||
assert exists(sequence_length) or positional_embedding == 'none', \
|
||||
f"Positional embedding is set to {positional_embedding} and" \
|
||||
f" the sequence length was not specified."
|
||||
|
||||
if not seq_pool:
|
||||
sequence_length += 1
|
||||
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
|
||||
requires_grad=True)
|
||||
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)
|
||||
else:
|
||||
self.attention_pool = nn.Linear(self.embedding_dim, 1)
|
||||
|
||||
if positional_embedding != 'none':
|
||||
if positional_embedding == 'learnable':
|
||||
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
|
||||
requires_grad=True)
|
||||
nn.init.trunc_normal_(self.positional_emb, std=0.2)
|
||||
else:
|
||||
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
|
||||
requires_grad=False)
|
||||
else:
|
||||
if positional_embedding == 'none':
|
||||
self.positional_emb = None
|
||||
elif positional_embedding == 'learnable':
|
||||
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
|
||||
requires_grad=True)
|
||||
nn.init.trunc_normal_(self.positional_emb, std=0.2)
|
||||
else:
|
||||
self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),
|
||||
requires_grad=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
|
||||
dim_feedforward=dim_feedforward, dropout=dropout_rate,
|
||||
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
|
||||
for i in range(num_layers)])
|
||||
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
|
||||
for layer_dpr in dpr])
|
||||
|
||||
self.norm = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.fc = nn.Linear(embedding_dim, num_classes)
|
||||
self.apply(self.init_weight)
|
||||
|
||||
def forward(self, x):
|
||||
if self.positional_emb is None and x.size(1) < self.sequence_length:
|
||||
b = x.shape[0]
|
||||
|
||||
if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
|
||||
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
|
||||
|
||||
if not self.seq_pool:
|
||||
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
|
||||
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
if self.positional_emb is not None:
|
||||
if exists(self.positional_emb):
|
||||
x += self.positional_emb
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
if self.seq_pool:
|
||||
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
|
||||
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
|
||||
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
|
||||
else:
|
||||
x = x[:, 0]
|
||||
|
||||
x = self.fc(x)
|
||||
return x
|
||||
return self.fc(x)
|
||||
|
||||
@staticmethod
|
||||
def init_weight(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
if isinstance(m, nn.Linear) and exists(m.bias):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def sinusoidal_embedding(n_channels, dim):
|
||||
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
|
||||
for p in range(n_channels)])
|
||||
pe[:, 0::2] = torch.sin(pe[:, 0::2])
|
||||
pe[:, 1::2] = torch.cos(pe[:, 1::2])
|
||||
return pe.unsqueeze(0)
|
||||
|
||||
|
||||
# CCT Main model
|
||||
|
||||
class CCT(nn.Module):
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
embedding_dim=768,
|
||||
n_input_channels=3,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
*args, **kwargs):
|
||||
super(CCT, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
embedding_dim=768,
|
||||
n_input_channels=3,
|
||||
n_conv_layers=1,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
pooling_kernel_size=3,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
img_height, img_width = pair(img_size)
|
||||
|
||||
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
|
||||
|
||||
@@ -4,6 +4,12 @@ from torch import nn
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def identity(t):
|
||||
return t
|
||||
|
||||
def clone_and_detach(t):
|
||||
return t.clone().detach()
|
||||
|
||||
def apply_tuple_or_single(fn, val):
|
||||
if isinstance(val, tuple):
|
||||
return tuple(map(fn, val))
|
||||
@@ -17,7 +23,8 @@ class Extractor(nn.Module):
|
||||
layer = None,
|
||||
layer_name = 'transformer',
|
||||
layer_save_input = False,
|
||||
return_embeddings_only = False
|
||||
return_embeddings_only = False,
|
||||
detach = True
|
||||
):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
@@ -34,9 +41,11 @@ class Extractor(nn.Module):
|
||||
self.layer_save_input = layer_save_input # whether to save input or output of layer
|
||||
self.return_embeddings_only = return_embeddings_only
|
||||
|
||||
self.detach_fn = clone_and_detach if detach else identity
|
||||
|
||||
def _hook(self, _, inputs, output):
|
||||
layer_output = inputs if self.layer_save_input else output
|
||||
self.latents = apply_tuple_or_single(lambda t: t.clone().detach(), layer_output)
|
||||
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
|
||||
|
||||
def _register_hook(self):
|
||||
if not exists(self.layer):
|
||||
|
||||
@@ -28,7 +28,7 @@ class MAE(nn.Module):
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# decoder parameters
|
||||
|
||||
self.decoder_dim = decoder_dim
|
||||
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
|
||||
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
|
||||
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
|
||||
@@ -73,7 +73,7 @@ class MAE(nn.Module):
|
||||
|
||||
# reapply decoder position embedding to unmasked tokens
|
||||
|
||||
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
@@ -81,13 +81,15 @@ class MAE(nn.Module):
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
|
||||
|
||||
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
|
||||
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
|
||||
decoder_tokens[batch_range, masked_indices] = mask_tokens
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
# splice out the mask tokens and project to pixel values
|
||||
|
||||
mask_tokens = decoded_tokens[:, :num_masked]
|
||||
mask_tokens = decoded_tokens[batch_range, masked_indices]
|
||||
pred_pixel_values = self.to_pixels(mask_tokens)
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
@@ -13,9 +13,9 @@ def conv_1x1_bn(inp, oup):
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
|
||||
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
|
||||
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
@@ -131,7 +131,7 @@ class NesT(nn.Module):
|
||||
fmap_size = image_size // patch_size
|
||||
blocks = 2 ** (num_hierarchies - 1)
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across hierarchy
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in reversed(hierarchies)]
|
||||
|
||||
|
||||
128
vit_pytorch/simple_vit_3d.py
Normal file
128
vit_pytorch/simple_vit_3d.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(f, device = device),
|
||||
torch.arange(h, device = device),
|
||||
torch.arange(w, device = device),
|
||||
indexing = 'ij')
|
||||
|
||||
fourier_dim = dim // 6
|
||||
|
||||
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, video):
|
||||
*_, h, w, dtype = *video.shape, video.dtype
|
||||
|
||||
x = self.to_patch_embedding(video)
|
||||
pe = posemb_sincos_3d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
129
vit_pytorch/vit_3d.py
Normal file
129
vit_pytorch/vit_3d.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
183
vit_pytorch/vivit.py
Normal file
183
vit_pytorch/vivit.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat, reduce
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
image_size,
|
||||
image_patch_size,
|
||||
frames,
|
||||
frame_patch_size,
|
||||
num_classes,
|
||||
dim,
|
||||
spatial_depth,
|
||||
temporal_depth,
|
||||
heads,
|
||||
mlp_dim,
|
||||
pool = 'cls',
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
num_frame_patches = (frames // frame_patch_size)
|
||||
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
b, f, n, _ = x.shape
|
||||
|
||||
x = x + self.pos_embedding
|
||||
|
||||
if exists(self.spatial_cls_token):
|
||||
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
|
||||
x = torch.cat((spatial_cls_tokens, x), dim = 2)
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
|
||||
# append temporal CLS tokens
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
# attend across time
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
Reference in New Issue
Block a user