Compare commits

...

8 Commits

Author SHA1 Message Date
Phil Wang
6ec8fdaa6d make sure global average pool can be used for vivit in place of cls token 2022-10-24 19:59:48 -07:00
Phil Wang
13fabf901e add vivit 2022-10-24 09:34:04 -07:00
Ryan Russell
c0eb4c0150 Improving Readability (#220)
Signed-off-by: Ryan Russell <git@ryanrussell.org>

Signed-off-by: Ryan Russell <git@ryanrussell.org>
2022-10-17 10:42:45 -07:00
Phil Wang
5f1a6a05e9 release updated mae where one can more easily visualize reconstructions, thanks to @Vishu26 2022-10-17 10:41:46 -07:00
Srikumar Sastry
9a95e7904e Update mae.py (#242)
update mae so decoded tokens can be easily reshaped back to visualize the reconstruction
2022-10-17 10:41:10 -07:00
Phil Wang
b4853d39c2 add the 3d simple vit 2022-10-16 20:45:30 -07:00
Phil Wang
29fbf0aff4 begin extending some of the architectures over to 3d, starting with basic ViT 2022-10-16 15:31:59 -07:00
Phil Wang
4b8f5bc900 add link to Flax translation by @conceptofmind 2022-07-27 08:58:18 -07:00
10 changed files with 553 additions and 15 deletions

100
README.md
View File

@@ -30,6 +30,8 @@
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
- [3D Vit](#3d-vit)
- [ViVit](#vivit)
- [Parallel ViT](#parallel-vit)
- [Learnable Memory ViT](#learnable-memory-vit)
- [Dino](#dino)
@@ -52,6 +54,8 @@ The official Jax repository is <a href="https://github.com/google-research/visio
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
<a href="https://github.com/conceptofmind/vit-flax">Flax translation</a> by <a href="https://github.com/conceptofmind">Enrico Shippole</a>!
## Install
```bash
@@ -661,7 +665,7 @@ preds = v(img) # (2, 1000)
<img src="./images/nest.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the hierarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
You can use it with the following code (ex. NesT-T)
@@ -675,7 +679,7 @@ nest = NesT(
dim = 96,
heads = 3,
num_hierarchies = 3, # number of hierarchies
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
block_repeats = (2, 2, 8), # the number of transformer blocks at each hierarchy, starting from the bottom
num_classes = 1000
)
@@ -965,6 +969,88 @@ img = torch.randn(4, 3, 256, 256)
tokens = spt(img) # (4, 256, 1024)
```
## 3D ViT
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
For starters, 3D ViT
```python
import torch
from vit_pytorch.vit_3d import ViT
v = ViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
```
3D Simple ViT
```python
import torch
from vit_pytorch.simple_vit_3d import SimpleViT
v = SimpleViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
```
## ViViT
<img src="./images/vivit.png" width="350px"></img>
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
```python
import torch
from vit_pytorch.vivit import ViT
v = ViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
spatial_depth = 6, # depth of the spatial transformer
temporal_depth = 6, # depth of the temporal transformer
heads = 8,
mlp_dim = 2048
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
```
## Parallel ViT
<img src="./images/parallel-vit.png" width="350px"></img>
@@ -1748,6 +1834,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
```
```bibtex
@article{Arnab2021ViViTAV,
title = {ViViT: A Video Vision Transformer},
author = {Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lucic and Cordelia Schmid},
journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2021},
pages = {6816-6826}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

View File

@@ -16,7 +16,7 @@
"\n",
"* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition\n",
"* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/\n",
"* Effecient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
"* Efficient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention"
]
},
{
@@ -342,7 +342,7 @@
"id": "ZhYDJXk2SRDu"
},
"source": [
"## Image Augumentation"
"## Image Augmentation"
]
},
{
@@ -497,7 +497,7 @@
"id": "TF9yMaRrSvmv"
},
"source": [
"## Effecient Attention"
"## Efficient Attention"
]
},
{
@@ -1307,7 +1307,7 @@
"celltoolbar": "Edit Metadata",
"colab": {
"collapsed_sections": [],
"name": "Effecient Attention | Cats & Dogs",
"name": "Efficient Attention | Cats & Dogs",
"provenance": [],
"toc_visible": true
},

BIN
images/vivit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.8',
version = '0.37.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -28,7 +28,7 @@ class MAE(nn.Module):
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
# decoder parameters
self.decoder_dim = decoder_dim
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
@@ -73,7 +73,7 @@ class MAE(nn.Module):
# reapply decoder position embedding to unmasked tokens
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
@@ -81,13 +81,15 @@ class MAE(nn.Module):
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[:, :num_masked]
mask_tokens = decoded_tokens[batch_range, masked_indices]
pred_pixel_values = self.to_pixels(mask_tokens)
# calculate reconstruction loss

View File

@@ -13,9 +13,9 @@ def conv_1x1_bn(inp, oup):
nn.SiLU()
)
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
return nn.Sequential(
nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.SiLU()
)

View File

@@ -131,7 +131,7 @@ class NesT(nn.Module):
fmap_size = image_size // patch_size
blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across hierarchy
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in reversed(hierarchies)]

View File

@@ -0,0 +1,128 @@
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
z, y, x = torch.meshgrid(
torch.arange(f, device = device),
torch.arange(h, device = device),
torch.arange(w, device = device),
indexing = 'ij')
fourier_dim = dim // 6
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
omega = 1. / (temperature ** omega)
z = z.flatten()[:, None] * omega[None, :]
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
return pe.type(dtype)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype
x = self.to_patch_embedding(video)
pe = posemb_sincos_3d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

129
vit_pytorch/vit_3d.py Normal file
View File

@@ -0,0 +1,129 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
x = self.to_patch_embedding(video)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

183
vit_pytorch/vivit.py Normal file
View File

@@ -0,0 +1,183 @@
import torch
from torch import nn
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(
self,
*,
image_size,
image_patch_size,
frames,
frame_patch_size,
num_classes,
dim,
spatial_depth,
temporal_depth,
heads,
mlp_dim,
pool = 'cls',
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(image_patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
self.dropout = nn.Dropout(emb_dropout)
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
x = x + self.pos_embedding
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
x = self.dropout(x)
x = rearrange(x, 'b f n d -> (b f) n d')
# attend across space
x = self.spatial_transformer(x)
x = rearrange(x, '(b f) n d -> b f n d', b = b)
# excise out the spatial cls tokens or average pool for temporal attention
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
# append temporal CLS tokens
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
# attend across time
x = self.temporal_transformer(x)
# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)