diff --git a/README.md b/README.md index cd47750..390bfe4 100644 --- a/README.md +++ b/README.md @@ -1218,7 +1218,8 @@ pred = cct(video) -This paper offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one. +This paper offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository includes the factorized encoder and the factorized self-attention variant. +The factorized encoder variant is a spatial transformer followed by a temporal one. The factorized self-attention variant is a spatio-temporal transformer with alternating spatial and temporal self-attention layers. ```python import torch @@ -1234,7 +1235,8 @@ v = ViT( spatial_depth = 6, # depth of the spatial transformer temporal_depth = 6, # depth of the temporal transformer heads = 8, - mlp_dim = 2048 + mlp_dim = 2048, + variant = 'factorized_encoder', # or 'factorized_self_attention' ) video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width) diff --git a/vit_pytorch/vivit.py b/vit_pytorch/vivit.py index b95afdc..43ad0d3 100644 --- a/vit_pytorch/vivit.py +++ b/vit_pytorch/vivit.py @@ -78,6 +78,30 @@ class Transformer(nn.Module): x = ff(x) + x return self.norm(x) +class FactorizedTransformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) + ])) + + def forward(self, x): + b, f, n, _ = x.shape + for spatial_attn, temporal_attn, ff in self.layers: + x = rearrange(x, 'b f n d -> (b f) n d') + x = spatial_attn(x) + x + x = rearrange(x, '(b f) n d -> (b n) f d', b=b, f=f) + x = temporal_attn(x) + x + x = ff(x) + x + x = rearrange(x, '(b n) f d -> b f n d', b=b, n=n) + + return self.norm(x) + class ViT(nn.Module): def __init__( self, @@ -96,7 +120,8 @@ class ViT(nn.Module): channels = 3, dim_head = 64, dropout = 0., - emb_dropout = 0. + emb_dropout = 0., + variant = 'factorized_encoder', ): super().__init__() image_height, image_width = pair(image_size) @@ -104,6 +129,7 @@ class ViT(nn.Module): assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' + assert variant in ('factorized_encoder', 'factorized_self_attention'), f'variant = {variant} is not implemented' num_image_patches = (image_height // patch_height) * (image_width // patch_width) num_frame_patches = (frames // frame_patch_size) @@ -125,15 +151,20 @@ class ViT(nn.Module): self.dropout = nn.Dropout(emb_dropout) self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None - self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None - self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout) - self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout) + if variant == 'factorized_encoder': + self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None + self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout) + self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout) + elif variant == 'factorized_self_attention': + assert spatial_depth == temporal_depth, 'Spatial and temporal depth must be the same for factorized self-attention' + self.factorized_transformer = FactorizedTransformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout) self.pool = pool self.to_latent = nn.Identity() self.mlp_head = nn.Linear(dim, num_classes) + self.variant = variant def forward(self, video): x = self.to_patch_embedding(video) @@ -147,32 +178,37 @@ class ViT(nn.Module): x = self.dropout(x) - x = rearrange(x, 'b f n d -> (b f) n d') + if self.variant == 'factorized_encoder': + x = rearrange(x, 'b f n d -> (b f) n d') - # attend across space + # attend across space - x = self.spatial_transformer(x) + x = self.spatial_transformer(x) + x = rearrange(x, '(b f) n d -> b f n d', b = b) - x = rearrange(x, '(b f) n d -> b f n d', b = b) + # excise out the spatial cls tokens or average pool for temporal attention - # excise out the spatial cls tokens or average pool for temporal attention + x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean') - x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean') + # append temporal CLS tokens - # append temporal CLS tokens + if exists(self.temporal_cls_token): + temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b) - if exists(self.temporal_cls_token): - temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b) + x = torch.cat((temporal_cls_tokens, x), dim = 1) + - x = torch.cat((temporal_cls_tokens, x), dim = 1) + # attend across time - # attend across time + x = self.temporal_transformer(x) - x = self.temporal_transformer(x) + # excise out temporal cls token or average pool - # excise out temporal cls token or average pool + x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean') - x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean') + elif self.variant == 'factorized_self_attention': + x = self.factorized_transformer(x) + x = x[:, 0, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b d', 'mean') x = self.to_latent(x) return self.mlp_head(x)