mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Add ViViT variant with factorized self-attention (#327)
* Add FactorizedTransformer * Add variant param and check in fwd method * Check if variant is implemented * Describe new ViViT variant
This commit is contained in:
@@ -1218,7 +1218,8 @@ pred = cct(video)
|
||||
|
||||
<img src="./images/vivit.png" width="350px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.
|
||||
This <a href="https://arxiv.org/abs/2103.15691">paper</a> offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant.
|
||||
The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -1234,7 +1235,8 @@ v = ViT(
|
||||
spatial_depth = 6, # depth of the spatial transformer
|
||||
temporal_depth = 6, # depth of the temporal transformer
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
mlp_dim = 2048,
|
||||
variant = 'factorized_encoder', # or 'factorized_self_attention'
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
@@ -78,6 +78,30 @@ class Transformer(nn.Module):
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class FactorizedTransformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
b, f, n, _ = x.shape
|
||||
for spatial_attn, temporal_attn, ff in self.layers:
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
x = spatial_attn(x) + x
|
||||
x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f)
|
||||
x = temporal_attn(x) + x
|
||||
x = ff(x) + x
|
||||
x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -96,7 +120,8 @@ class ViT(nn.Module):
|
||||
channels = 3,
|
||||
dim_head = 64,
|
||||
dropout = 0.,
|
||||
emb_dropout = 0.
|
||||
emb_dropout = 0.,
|
||||
variant = 'factorized_encoder',
|
||||
):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -104,6 +129,7 @@ class ViT(nn.Module):
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented'
|
||||
|
||||
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
num_frame_patches = (frames // frame_patch_size)
|
||||
@@ -125,15 +151,20 @@ class ViT(nn.Module):
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
if variant == 'factorized_encoder':
|
||||
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
|
||||
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
|
||||
elif variant == 'factorized_self_attention':
|
||||
assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention'
|
||||
self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
self.variant = variant
|
||||
|
||||
def forward(self, video):
|
||||
x = self.to_patch_embedding(video)
|
||||
@@ -147,32 +178,37 @@ class ViT(nn.Module):
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
if self.variant == 'factorized_encoder':
|
||||
x = rearrange(x, 'b f n d -> (b f) n d')
|
||||
|
||||
# attend across space
|
||||
# attend across space
|
||||
|
||||
x = self.spatial_transformer(x)
|
||||
x = self.spatial_transformer(x)
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
|
||||
x = rearrange(x, '(b f) n d -> b f n d', b = b)
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
|
||||
# excise out the spatial cls tokens or average pool for temporal attention
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
|
||||
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
|
||||
# append temporal CLS tokens
|
||||
|
||||
# append temporal CLS tokens
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
|
||||
if exists(self.temporal_cls_token):
|
||||
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
|
||||
|
||||
x = torch.cat((temporal_cls_tokens, x), dim = 1)
|
||||
# attend across time
|
||||
|
||||
# attend across time
|
||||
x = self.temporal_transformer(x)
|
||||
|
||||
x = self.temporal_transformer(x)
|
||||
# excise out temporal cls token or average pool
|
||||
|
||||
# excise out temporal cls token or average pool
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
|
||||
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
|
||||
elif self.variant == 'factorized_self_attention':
|
||||
x = self.factorized_transformer(x)
|
||||
x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean')
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
|
||||
Reference in New Issue
Block a user