diff --git a/README.md b/README.md index 1de5f85..65dfcd0 100644 --- a/README.md +++ b/README.md @@ -1023,6 +1023,35 @@ video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, widt preds = v(video) # (4, 1000) ``` +3D version of CCT + +```python +import torch +from vit_pytorch.cct_3d import CCT + +cct = CCT( + img_size = 224, + num_frames = 8, + embedding_dim = 384, + n_conv_layers = 2, + frame_kernel_size = 3, + kernel_size = 7, + stride = 2, + padding = 3, + pooling_kernel_size = 3, + pooling_stride = 2, + pooling_padding = 1, + num_layers = 14, + num_heads = 6, + mlp_radio = 3., + num_classes = 1000, + positional_embedding = 'learnable' +) + +video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width) +pred = cct(video) +``` + ## ViViT diff --git a/setup.py b/setup.py index f3dc68f..0877f90 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.37.1', + version = '0.38.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/cct.py b/vit_pytorch/cct.py index 643b180..4b37699 100644 --- a/vit_pytorch/cct.py +++ b/vit_pytorch/cct.py @@ -1,9 +1,17 @@ import torch -import torch.nn as nn +from torch import nn, einsum import torch.nn.functional as F +from einops import rearrange, repeat + # helpers +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + def pair(t): return t if isinstance(t, tuple) else (t, t) @@ -50,8 +58,9 @@ def cct_16(*args, **kwargs): def _cct(num_layers, num_heads, mlp_ratio, embedding_dim, kernel_size=3, stride=None, padding=None, *args, **kwargs): - stride = stride if stride is not None else max(1, (kernel_size // 2) - 1) - padding = padding if padding is not None else max(1, (kernel_size // 2)) + stride = default(stride, max(1, (kernel_size // 2) - 1)) + padding = default(padding, max(1, (kernel_size // 2))) + return CCT(num_layers=num_layers, num_heads=num_heads, mlp_ratio=mlp_ratio, @@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim, padding=padding, *args, **kwargs) +# positional + +def sinusoidal_embedding(n_channels, dim): + pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] + for p in range(n_channels)]) + pe[:, 0::2] = torch.sin(pe[:, 0::2]) + pe[:, 1::2] = torch.cos(pe[:, 1::2]) + return rearrange(pe, '... -> 1 ...') + # modules class Attention(nn.Module): def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): super().__init__() - self.num_heads = num_heads - head_dim = dim // self.num_heads + self.heads = num_heads + head_dim = dim // self.heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=False) @@ -77,17 +95,20 @@ class Attention(nn.Module): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale + qkv = self.qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + q = q * self.scale + + attn = einsum('b h i d, b h j d -> b h i j', q, k) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x + x = einsum('b h i j, b h j d -> b h i d', attn, v) + x = rearrange(x, 'b h n d -> b n (h d)') + + return self.proj_drop(self.proj(x)) class TransformerEncoderLayer(nn.Module): @@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module): """ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, attention_dropout=0.1, drop_path_rate=0.1): - super(TransformerEncoderLayer, self).__init__() + super().__init__() + self.pre_norm = nn.LayerNorm(d_model) self.self_attn = Attention(dim=d_model, num_heads=nhead, attention_dropout=attention_dropout, projection_dropout=dropout) @@ -108,50 +130,34 @@ class TransformerEncoderLayer(nn.Module): self.linear2 = nn.Linear(dim_feedforward, d_model) self.dropout2 = nn.Dropout(dropout) - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.drop_path = DropPath(drop_path_rate) self.activation = F.gelu - def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def forward(self, src, *args, **kwargs): src = src + self.drop_path(self.self_attn(self.pre_norm(src))) src = self.norm1(src) src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) src = src + self.drop_path(self.dropout2(src2)) return src - -def drop_path(x, drop_prob: float = 0., training: bool = False): - """ - Obtained from: github.com:rwightman/pytorch-image-models - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - """ - if drop_prob == 0. or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - class DropPath(nn.Module): - """ - Obtained from: github.com:rwightman/pytorch-image-models - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob + super().__init__() + self.drop_prob = float(drop_prob) def forward(self, x): - return drop_path(x, self.drop_prob, self.training) + batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype + if drop_prob <= 0. or not self.training: + return x + + keep_prob = 1 - self.drop_prob + shape = (batch, *((1,) * (x.ndim - 1))) + + keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob + output = x.div(keep_prob) * keep_mask.float() + return output class Tokenizer(nn.Module): def __init__(self, @@ -164,34 +170,35 @@ class Tokenizer(nn.Module): activation=None, max_pool=True, conv_bias=False): - super(Tokenizer, self).__init__() + super().__init__() n_filter_list = [n_input_channels] + \ [in_planes for _ in range(n_conv_layers - 1)] + \ [n_output_channels] + n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:]) + self.conv_layers = nn.Sequential( *[nn.Sequential( - nn.Conv2d(n_filter_list[i], n_filter_list[i + 1], + nn.Conv2d(chan_in, chan_out, kernel_size=(kernel_size, kernel_size), stride=(stride, stride), padding=(padding, padding), bias=conv_bias), - nn.Identity() if activation is None else activation(), + nn.Identity() if not exists(activation) else activation(), nn.MaxPool2d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding) if max_pool else nn.Identity() ) - for i in range(n_conv_layers) + for chan_in, chan_out in n_filter_list_pairs ]) - self.flattener = nn.Flatten(2, 3) self.apply(self.init_weight) def sequence_length(self, n_channels=3, height=224, width=224): return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] def forward(self, x): - return self.flattener(self.conv_layers(x)).transpose(-2, -1) + return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c') @staticmethod def init_weight(m): @@ -214,106 +221,104 @@ class TransformerClassifier(nn.Module): sequence_length=None, *args, **kwargs): super().__init__() - positional_embedding = positional_embedding if \ - positional_embedding in ['sine', 'learnable', 'none'] else 'sine' + assert positional_embedding in {'sine', 'learnable', 'none'} + dim_feedforward = int(embedding_dim * mlp_ratio) self.embedding_dim = embedding_dim self.sequence_length = sequence_length self.seq_pool = seq_pool - assert sequence_length is not None or positional_embedding == 'none', \ + assert exists(sequence_length) or positional_embedding == 'none', \ f"Positional embedding is set to {positional_embedding} and" \ f" the sequence length was not specified." if not seq_pool: sequence_length += 1 - self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), - requires_grad=True) + self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True) else: self.attention_pool = nn.Linear(self.embedding_dim, 1) - if positional_embedding != 'none': - if positional_embedding == 'learnable': - self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim), - requires_grad=True) - nn.init.trunc_normal_(self.positional_emb, std=0.2) - else: - self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim), - requires_grad=False) - else: + if positional_embedding == 'none': self.positional_emb = None + elif positional_embedding == 'learnable': + self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim), + requires_grad=True) + nn.init.trunc_normal_(self.positional_emb, std=0.2) + else: + self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim), + requires_grad=False) self.dropout = nn.Dropout(p=dropout_rate) + dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)] + self.blocks = nn.ModuleList([ TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, dim_feedforward=dim_feedforward, dropout=dropout_rate, - attention_dropout=attention_dropout, drop_path_rate=dpr[i]) - for i in range(num_layers)]) + attention_dropout=attention_dropout, drop_path_rate=layer_dpr) + for layer_dpr in dpr]) + self.norm = nn.LayerNorm(embedding_dim) self.fc = nn.Linear(embedding_dim, num_classes) self.apply(self.init_weight) def forward(self, x): - if self.positional_emb is None and x.size(1) < self.sequence_length: + b = x.shape[0] + + if not exists(self.positional_emb) and x.size(1) < self.sequence_length: x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) if not self.seq_pool: - cls_token = self.class_emb.expand(x.shape[0], -1, -1) + cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b) x = torch.cat((cls_token, x), dim=1) - if self.positional_emb is not None: + if exists(self.positional_emb): x += self.positional_emb x = self.dropout(x) for blk in self.blocks: x = blk(x) + x = self.norm(x) if self.seq_pool: - x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2) + attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n') + x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x) else: x = x[:, 0] - x = self.fc(x) - return x + return self.fc(x) @staticmethod def init_weight(m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: + if isinstance(m, nn.Linear) and exists(m.bias): nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - @staticmethod - def sinusoidal_embedding(n_channels, dim): - pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] - for p in range(n_channels)]) - pe[:, 0::2] = torch.sin(pe[:, 0::2]) - pe[:, 1::2] = torch.cos(pe[:, 1::2]) - return pe.unsqueeze(0) - - # CCT Main model + class CCT(nn.Module): - def __init__(self, - img_size=224, - embedding_dim=768, - n_input_channels=3, - n_conv_layers=1, - kernel_size=7, - stride=2, - padding=3, - pooling_kernel_size=3, - pooling_stride=2, - pooling_padding=1, - *args, **kwargs): - super(CCT, self).__init__() + def __init__( + self, + img_size=224, + embedding_dim=768, + n_input_channels=3, + n_conv_layers=1, + kernel_size=7, + stride=2, + padding=3, + pooling_kernel_size=3, + pooling_stride=2, + pooling_padding=1, + *args, **kwargs + ): + super().__init__() img_height, img_width = pair(img_size) self.tokenizer = Tokenizer(n_input_channels=n_input_channels, diff --git a/vit_pytorch/cct_3d.py b/vit_pytorch/cct_3d.py new file mode 100644 index 0000000..e14fda1 --- /dev/null +++ b/vit_pytorch/cct_3d.py @@ -0,0 +1,376 @@ +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, repeat + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# CCT Models + +__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16'] + + +def cct_2(*args, **kwargs): + return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128, + *args, **kwargs) + + +def cct_4(*args, **kwargs): + return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128, + *args, **kwargs) + + +def cct_6(*args, **kwargs): + return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256, + *args, **kwargs) + + +def cct_7(*args, **kwargs): + return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256, + *args, **kwargs) + + +def cct_8(*args, **kwargs): + return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256, + *args, **kwargs) + + +def cct_14(*args, **kwargs): + return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384, + *args, **kwargs) + + +def cct_16(*args, **kwargs): + return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384, + *args, **kwargs) + + +def _cct(num_layers, num_heads, mlp_ratio, embedding_dim, + kernel_size=3, stride=None, padding=None, + *args, **kwargs): + stride = default(stride, max(1, (kernel_size // 2) - 1)) + padding = default(padding, max(1, (kernel_size // 2))) + + return CCT(num_layers=num_layers, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + embedding_dim=embedding_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding, + *args, **kwargs) + +# positional + +def sinusoidal_embedding(n_channels, dim): + pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] + for p in range(n_channels)]) + pe[:, 0::2] = torch.sin(pe[:, 0::2]) + pe[:, 1::2] = torch.cos(pe[:, 1::2]) + return rearrange(pe, '... -> 1 ...') + +# modules + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): + super().__init__() + self.heads = num_heads + head_dim = dim // self.heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + self.attn_drop = nn.Dropout(attention_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(projection_dropout) + + def forward(self, x): + B, N, C = x.shape + + qkv = self.qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + q = q * self.scale + + attn = einsum('b h i d, b h j d -> b h i j', q, k) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = einsum('b h i j, b h j d -> b h i d', attn, v) + x = rearrange(x, 'b h n d -> b n (h d)') + + return self.proj_drop(self.proj(x)) + + +class TransformerEncoderLayer(nn.Module): + """ + Inspired by torch.nn.TransformerEncoderLayer and + rwightman's timm package. + """ + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + attention_dropout=0.1, drop_path_rate=0.1): + super().__init__() + + self.pre_norm = nn.LayerNorm(d_model) + self.self_attn = Attention(dim=d_model, num_heads=nhead, + attention_dropout=attention_dropout, projection_dropout=dropout) + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout2 = nn.Dropout(dropout) + + self.drop_path = DropPath(drop_path_rate) + + self.activation = F.gelu + + def forward(self, src, *args, **kwargs): + src = src + self.drop_path(self.self_attn(self.pre_norm(src))) + src = self.norm1(src) + src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) + src = src + self.drop_path(self.dropout2(src2)) + return src + +class DropPath(nn.Module): + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = float(drop_prob) + + def forward(self, x): + batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype + + if drop_prob <= 0. or not self.training: + return x + + keep_prob = 1 - self.drop_prob + shape = (batch, *((1,) * (x.ndim - 1))) + + keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob + output = x.div(keep_prob) * keep_mask.float() + return output + +class Tokenizer(nn.Module): + def __init__( + self, + frame_kernel_size, + kernel_size, + stride, + padding, + frame_stride=1, + frame_pooling_stride=1, + frame_pooling_kernel_size=1, + pooling_kernel_size=3, + pooling_stride=2, + pooling_padding=1, + n_conv_layers=1, + n_input_channels=3, + n_output_channels=64, + in_planes=64, + activation=None, + max_pool=True, + conv_bias=False + ): + super().__init__() + + n_filter_list = [n_input_channels] + \ + [in_planes for _ in range(n_conv_layers - 1)] + \ + [n_output_channels] + + n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:]) + + self.conv_layers = nn.Sequential( + *[nn.Sequential( + nn.Conv3d(chan_in, chan_out, + kernel_size=(frame_kernel_size, kernel_size, kernel_size), + stride=(frame_stride, stride, stride), + padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias), + nn.Identity() if not exists(activation) else activation(), + nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size), + stride=(frame_pooling_stride, pooling_stride, pooling_stride), + padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity() + ) + for chan_in, chan_out in n_filter_list_pairs + ]) + + self.apply(self.init_weight) + + def sequence_length(self, n_channels=3, frames=8, height=224, width=224): + return self.forward(torch.zeros((1, n_channels, frames, height, width))).shape[1] + + def forward(self, x): + x = self.conv_layers(x) + return rearrange(x, 'b c f h w -> b (f h w) c') + + @staticmethod + def init_weight(m): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight) + + +class TransformerClassifier(nn.Module): + def __init__( + self, + seq_pool=True, + embedding_dim=768, + num_layers=12, + num_heads=12, + mlp_ratio=4.0, + num_classes=1000, + dropout_rate=0.1, + attention_dropout=0.1, + stochastic_depth_rate=0.1, + positional_embedding='sine', + sequence_length=None, + *args, **kwargs + ): + super().__init__() + assert positional_embedding in {'sine', 'learnable', 'none'} + + dim_feedforward = int(embedding_dim * mlp_ratio) + self.embedding_dim = embedding_dim + self.sequence_length = sequence_length + self.seq_pool = seq_pool + + assert exists(sequence_length) or positional_embedding == 'none', \ + f"Positional embedding is set to {positional_embedding} and" \ + f" the sequence length was not specified." + + if not seq_pool: + sequence_length += 1 + self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim)) + else: + self.attention_pool = nn.Linear(self.embedding_dim, 1) + + if positional_embedding == 'none': + self.positional_emb = None + elif positional_embedding == 'learnable': + self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim)) + nn.init.trunc_normal_(self.positional_emb, std = 0.2) + else: + self.register_buffer('positional_emb', sinusoidal_embedding(sequence_length, embedding_dim)) + + self.dropout = nn.Dropout(p=dropout_rate) + + dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)] + + self.blocks = nn.ModuleList([ + TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads, + dim_feedforward=dim_feedforward, dropout=dropout_rate, + attention_dropout=attention_dropout, drop_path_rate=layer_dpr) + for layer_dpr in dpr]) + + self.norm = nn.LayerNorm(embedding_dim) + + self.fc = nn.Linear(embedding_dim, num_classes) + self.apply(self.init_weight) + + @staticmethod + def init_weight(m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and exists(m.bias): + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + b = x.shape[0] + + if not exists(self.positional_emb) and x.size(1) < self.sequence_length: + x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0) + + if not self.seq_pool: + cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b) + x = torch.cat((cls_token, x), dim=1) + + if exists(self.positional_emb): + x += self.positional_emb + + x = self.dropout(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + if self.seq_pool: + attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n') + x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x) + else: + x = x[:, 0] + + return self.fc(x) + +# CCT Main model + +class CCT(nn.Module): + def __init__( + self, + img_size=224, + num_frames=8, + embedding_dim=768, + n_input_channels=3, + n_conv_layers=1, + frame_stride=1, + frame_kernel_size=3, + frame_pooling_kernel_size=1, + frame_pooling_stride=1, + kernel_size=7, + stride=2, + padding=3, + pooling_kernel_size=3, + pooling_stride=2, + pooling_padding=1, + *args, **kwargs + ): + super().__init__() + img_height, img_width = pair(img_size) + + self.tokenizer = Tokenizer( + n_input_channels=n_input_channels, + n_output_channels=embedding_dim, + frame_stride=frame_stride, + frame_kernel_size=frame_kernel_size, + frame_pooling_stride=frame_pooling_stride, + frame_pooling_kernel_size=frame_pooling_kernel_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + pooling_kernel_size=pooling_kernel_size, + pooling_stride=pooling_stride, + pooling_padding=pooling_padding, + max_pool=True, + activation=nn.ReLU, + n_conv_layers=n_conv_layers, + conv_bias=False + ) + + self.classifier = TransformerClassifier( + sequence_length=self.tokenizer.sequence_length( + n_channels=n_input_channels, + frames=num_frames, + height=img_height, + width=img_width + ), + embedding_dim=embedding_dim, + seq_pool=True, + dropout_rate=0., + attention_dropout=0.1, + stochastic_depth=0.1, + *args, **kwargs + ) + + def forward(self, x): + x = self.tokenizer(x) + return self.classifier(x)