make sure global average pool can be used for vivit in place of cls token

This commit is contained in:
Phil Wang
2022-10-24 19:59:48 -07:00
parent 13fabf901e
commit 6ec8fdaa6d
4 changed files with 35 additions and 21 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.37.0',
version = '0.37.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -114,10 +114,10 @@ class SimpleViT(nn.Module):
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype
x = self.to_patch_embedding(img)
x = self.to_patch_embedding(video)
pe = posemb_sincos_3d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe

View File

@@ -112,8 +112,8 @@ class ViT(nn.Module):
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
def forward(self, video):
x = self.to_patch_embedding(video)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)

View File

@@ -1,11 +1,14 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
@@ -106,20 +109,25 @@ class ViT(nn.Module):
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)
patch_dim = channels * patch_height * patch_width * frame_patch_size
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.global_average_pool = pool == 'mean'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
self.dropout = nn.Dropout(emb_dropout)
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
@@ -132,13 +140,16 @@ class ViT(nn.Module):
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
x += self.pos_embedding[:, :(n + 1)]
x = x + self.pos_embedding
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
x = self.dropout(x)
x = rearrange(x, 'b f n d -> (b f) n d')
@@ -149,21 +160,24 @@ class ViT(nn.Module):
x = rearrange(x, '(b f) n d -> b f n d', b = b)
# excise out the spatial cls tokens for temporal attention
# excise out the spatial cls tokens or average pool for temporal attention
x = x[:, :, 0]
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')
# append temporal CLS tokens
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
x = torch.cat((temporal_cls_tokens, x), dim = 1)
# attend across time
x = self.temporal_transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
# excise out temporal cls token or average pool
x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')
x = self.to_latent(x)
return self.mlp_head(x)