diff --git a/README.md b/README.md
index 1de5f85..65dfcd0 100644
--- a/README.md
+++ b/README.md
@@ -1023,6 +1023,35 @@ video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, widt
preds = v(video) # (4, 1000)
```
+3D version of CCT
+
+```python
+import torch
+from vit_pytorch.cct_3d import CCT
+
+cct = CCT(
+ img_size = 224,
+ num_frames = 8,
+ embedding_dim = 384,
+ n_conv_layers = 2,
+ frame_kernel_size = 3,
+ kernel_size = 7,
+ stride = 2,
+ padding = 3,
+ pooling_kernel_size = 3,
+ pooling_stride = 2,
+ pooling_padding = 1,
+ num_layers = 14,
+ num_heads = 6,
+ mlp_radio = 3.,
+ num_classes = 1000,
+ positional_embedding = 'learnable'
+)
+
+video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
+pred = cct(video)
+```
+
## ViViT
diff --git a/setup.py b/setup.py
index f3dc68f..0877f90 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.37.1',
+ version = '0.38.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
diff --git a/vit_pytorch/cct.py b/vit_pytorch/cct.py
index 643b180..4b37699 100644
--- a/vit_pytorch/cct.py
+++ b/vit_pytorch/cct.py
@@ -1,9 +1,17 @@
import torch
-import torch.nn as nn
+from torch import nn, einsum
import torch.nn.functional as F
+from einops import rearrange, repeat
+
# helpers
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
def pair(t):
return t if isinstance(t, tuple) else (t, t)
@@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
- stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
- padding = padding if padding is not None else max(1, (kernel_size // 2))
+ stride = default(stride, max(1, (kernel_size // 2) - 1))
+ padding = default(padding, max(1, (kernel_size // 2)))
+
return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
@@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
padding=padding,
*args, **kwargs)
+# positional
+
+def sinusoidal_embedding(n_channels, dim):
+ pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
+ for p in range(n_channels)])
+ pe[:, 0::2] = torch.sin(pe[:, 0::2])
+ pe[:, 1::2] = torch.cos(pe[:, 1::2])
+ return rearrange(pe, '... -> 1 ...')
+
# modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
- self.num_heads = num_heads
- head_dim = dim // self.num_heads
+ self.heads = num_heads
+ head_dim = dim // self.heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
@@ -77,17 +95,20 @@ class Attention(nn.Module):
def forward(self, x):
B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
- attn = (q @ k.transpose(-2, -1)) * self.scale
+ qkv = self.qkv(x).chunk(3, dim = -1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
+
+ q = q * self.scale
+
+ attn = einsum('b h i d, b h j d -> b h i j', q, k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
+ x = einsum('b h i j, b h j d -> b h i d', attn, v)
+ x = rearrange(x, 'b h n d -> b n (h d)')
+
+ return self.proj_drop(self.proj(x))
class TransformerEncoderLayer(nn.Module):
@@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
- super(TransformerEncoderLayer, self).__init__()
+ super().__init__()
+
self.pre_norm = nn.LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
@@ -108,50 +130,34 @@ class TransformerEncoderLayer(nn.Module):
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout)
- self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+ self.drop_path = DropPath(drop_path_rate)
self.activation = F.gelu
- def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ def forward(self, src, *args, **kwargs):
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src
-
-def drop_path(x, drop_prob: float = 0., training: bool = False):
- """
- Obtained from: github.com:rwightman/pytorch-image-models
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
- 'survival rate' as the argument.
- """
- if drop_prob == 0. or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
- random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
- return output
-
-
class DropPath(nn.Module):
- """
- Obtained from: github.com:rwightman/pytorch-image-models
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
+ super().__init__()
+ self.drop_prob = float(drop_prob)
def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
+ batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
+ if drop_prob <= 0. or not self.training:
+ return x
+
+ keep_prob = 1 - self.drop_prob
+ shape = (batch, *((1,) * (x.ndim - 1)))
+
+ keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
+ output = x.div(keep_prob) * keep_mask.float()
+ return output
class Tokenizer(nn.Module):
def __init__(self,
@@ -164,34 +170,35 @@ class Tokenizer(nn.Module):
activation=None,
max_pool=True,
conv_bias=False):
- super(Tokenizer, self).__init__()
+ super().__init__()
n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]
+ n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
+
self.conv_layers = nn.Sequential(
*[nn.Sequential(
- nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
+ nn.Conv2d(chan_in, chan_out,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
- nn.Identity() if activation is None else activation(),
+ nn.Identity() if not exists(activation) else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
- for i in range(n_conv_layers)
+ for chan_in, chan_out in n_filter_list_pairs
])
- self.flattener = nn.Flatten(2, 3)
self.apply(self.init_weight)
def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
def forward(self, x):
- return self.flattener(self.conv_layers(x)).transpose(-2, -1)
+ return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')
@staticmethod
def init_weight(m):
@@ -214,106 +221,104 @@ class TransformerClassifier(nn.Module):
sequence_length=None,
*args, **kwargs):
super().__init__()
- positional_embedding = positional_embedding if \
- positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
+ assert positional_embedding in {'sine', 'learnable', 'none'}
+
dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool
- assert sequence_length is not None or positional_embedding == 'none', \
+ assert exists(sequence_length) or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."
if not seq_pool:
sequence_length += 1
- self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
- requires_grad=True)
+ self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)
- if positional_embedding != 'none':
- if positional_embedding == 'learnable':
- self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
- requires_grad=True)
- nn.init.trunc_normal_(self.positional_emb, std=0.2)
- else:
- self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
- requires_grad=False)
- else:
+ if positional_embedding == 'none':
self.positional_emb = None
+ elif positional_embedding == 'learnable':
+ self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
+ requires_grad=True)
+ nn.init.trunc_normal_(self.positional_emb, std=0.2)
+ else:
+ self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),
+ requires_grad=False)
self.dropout = nn.Dropout(p=dropout_rate)
+
dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
+
self.blocks = nn.ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
- attention_dropout=attention_dropout, drop_path_rate=dpr[i])
- for i in range(num_layers)])
+ attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
+ for layer_dpr in dpr])
+
self.norm = nn.LayerNorm(embedding_dim)
self.fc = nn.Linear(embedding_dim, num_classes)
self.apply(self.init_weight)
def forward(self, x):
- if self.positional_emb is None and x.size(1) < self.sequence_length:
+ b = x.shape[0]
+
+ if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
if not self.seq_pool:
- cls_token = self.class_emb.expand(x.shape[0], -1, -1)
+ cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_token, x), dim=1)
- if self.positional_emb is not None:
+ if exists(self.positional_emb):
x += self.positional_emb
x = self.dropout(x)
for blk in self.blocks:
x = blk(x)
+
x = self.norm(x)
if self.seq_pool:
- x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
+ attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
+ x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
else:
x = x[:, 0]
- x = self.fc(x)
- return x
+ return self.fc(x)
@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
+ if isinstance(m, nn.Linear) and exists(m.bias):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
- @staticmethod
- def sinusoidal_embedding(n_channels, dim):
- pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
- for p in range(n_channels)])
- pe[:, 0::2] = torch.sin(pe[:, 0::2])
- pe[:, 1::2] = torch.cos(pe[:, 1::2])
- return pe.unsqueeze(0)
-
-
# CCT Main model
+
class CCT(nn.Module):
- def __init__(self,
- img_size=224,
- embedding_dim=768,
- n_input_channels=3,
- n_conv_layers=1,
- kernel_size=7,
- stride=2,
- padding=3,
- pooling_kernel_size=3,
- pooling_stride=2,
- pooling_padding=1,
- *args, **kwargs):
- super(CCT, self).__init__()
+ def __init__(
+ self,
+ img_size=224,
+ embedding_dim=768,
+ n_input_channels=3,
+ n_conv_layers=1,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ pooling_kernel_size=3,
+ pooling_stride=2,
+ pooling_padding=1,
+ *args, **kwargs
+ ):
+ super().__init__()
img_height, img_width = pair(img_size)
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
diff --git a/vit_pytorch/cct_3d.py b/vit_pytorch/cct_3d.py
new file mode 100644
index 0000000..e14fda1
--- /dev/null
+++ b/vit_pytorch/cct_3d.py
@@ -0,0 +1,376 @@
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+from einops import rearrange, repeat
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+def pair(t):
+ return t if isinstance(t, tuple) else (t, t)
+
+# CCT Models
+
+__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
+
+
+def cct_2(*args, **kwargs):
+ return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
+ *args, **kwargs)
+
+
+def cct_4(*args, **kwargs):
+ return _cct(num_layers=4, num_heads=2, mlp_ratio=1, embedding_dim=128,
+ *args, **kwargs)
+
+
+def cct_6(*args, **kwargs):
+ return _cct(num_layers=6, num_heads=4, mlp_ratio=2, embedding_dim=256,
+ *args, **kwargs)
+
+
+def cct_7(*args, **kwargs):
+ return _cct(num_layers=7, num_heads=4, mlp_ratio=2, embedding_dim=256,
+ *args, **kwargs)
+
+
+def cct_8(*args, **kwargs):
+ return _cct(num_layers=8, num_heads=4, mlp_ratio=2, embedding_dim=256,
+ *args, **kwargs)
+
+
+def cct_14(*args, **kwargs):
+ return _cct(num_layers=14, num_heads=6, mlp_ratio=3, embedding_dim=384,
+ *args, **kwargs)
+
+
+def cct_16(*args, **kwargs):
+ return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
+ *args, **kwargs)
+
+
+def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
+ kernel_size=3, stride=None, padding=None,
+ *args, **kwargs):
+ stride = default(stride, max(1, (kernel_size // 2) - 1))
+ padding = default(padding, max(1, (kernel_size // 2)))
+
+ return CCT(num_layers=num_layers,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ embedding_dim=embedding_dim,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ *args, **kwargs)
+
+# positional
+
+def sinusoidal_embedding(n_channels, dim):
+ pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
+ for p in range(n_channels)])
+ pe[:, 0::2] = torch.sin(pe[:, 0::2])
+ pe[:, 1::2] = torch.cos(pe[:, 1::2])
+ return rearrange(pe, '... -> 1 ...')
+
+# modules
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
+ super().__init__()
+ self.heads = num_heads
+ head_dim = dim // self.heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ self.attn_drop = nn.Dropout(attention_dropout)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(projection_dropout)
+
+ def forward(self, x):
+ B, N, C = x.shape
+
+ qkv = self.qkv(x).chunk(3, dim = -1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
+
+ q = q * self.scale
+
+ attn = einsum('b h i d, b h j d -> b h i j', q, k)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = einsum('b h i j, b h j d -> b h i d', attn, v)
+ x = rearrange(x, 'b h n d -> b n (h d)')
+
+ return self.proj_drop(self.proj(x))
+
+
+class TransformerEncoderLayer(nn.Module):
+ """
+ Inspired by torch.nn.TransformerEncoderLayer and
+ rwightman's timm package.
+ """
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
+ attention_dropout=0.1, drop_path_rate=0.1):
+ super().__init__()
+
+ self.pre_norm = nn.LayerNorm(d_model)
+ self.self_attn = Attention(dim=d_model, num_heads=nhead,
+ attention_dropout=attention_dropout, projection_dropout=dropout)
+
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.drop_path = DropPath(drop_path_rate)
+
+ self.activation = F.gelu
+
+ def forward(self, src, *args, **kwargs):
+ src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
+ src = self.norm1(src)
+ src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
+ src = src + self.drop_path(self.dropout2(src2))
+ return src
+
+class DropPath(nn.Module):
+ def __init__(self, drop_prob=None):
+ super().__init__()
+ self.drop_prob = float(drop_prob)
+
+ def forward(self, x):
+ batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype
+
+ if drop_prob <= 0. or not self.training:
+ return x
+
+ keep_prob = 1 - self.drop_prob
+ shape = (batch, *((1,) * (x.ndim - 1)))
+
+ keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
+ output = x.div(keep_prob) * keep_mask.float()
+ return output
+
+class Tokenizer(nn.Module):
+ def __init__(
+ self,
+ frame_kernel_size,
+ kernel_size,
+ stride,
+ padding,
+ frame_stride=1,
+ frame_pooling_stride=1,
+ frame_pooling_kernel_size=1,
+ pooling_kernel_size=3,
+ pooling_stride=2,
+ pooling_padding=1,
+ n_conv_layers=1,
+ n_input_channels=3,
+ n_output_channels=64,
+ in_planes=64,
+ activation=None,
+ max_pool=True,
+ conv_bias=False
+ ):
+ super().__init__()
+
+ n_filter_list = [n_input_channels] + \
+ [in_planes for _ in range(n_conv_layers - 1)] + \
+ [n_output_channels]
+
+ n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])
+
+ self.conv_layers = nn.Sequential(
+ *[nn.Sequential(
+ nn.Conv3d(chan_in, chan_out,
+ kernel_size=(frame_kernel_size, kernel_size, kernel_size),
+ stride=(frame_stride, stride, stride),
+ padding=(frame_kernel_size // 2, padding, padding), bias=conv_bias),
+ nn.Identity() if not exists(activation) else activation(),
+ nn.MaxPool3d(kernel_size=(frame_pooling_kernel_size, pooling_kernel_size, pooling_kernel_size),
+ stride=(frame_pooling_stride, pooling_stride, pooling_stride),
+ padding=(frame_pooling_kernel_size // 2, pooling_padding, pooling_padding)) if max_pool else nn.Identity()
+ )
+ for chan_in, chan_out in n_filter_list_pairs
+ ])
+
+ self.apply(self.init_weight)
+
+ def sequence_length(self, n_channels=3, frames=8, height=224, width=224):
+ return self.forward(torch.zeros((1, n_channels, frames, height, width))).shape[1]
+
+ def forward(self, x):
+ x = self.conv_layers(x)
+ return rearrange(x, 'b c f h w -> b (f h w) c')
+
+ @staticmethod
+ def init_weight(m):
+ if isinstance(m, nn.Conv3d):
+ nn.init.kaiming_normal_(m.weight)
+
+
+class TransformerClassifier(nn.Module):
+ def __init__(
+ self,
+ seq_pool=True,
+ embedding_dim=768,
+ num_layers=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ num_classes=1000,
+ dropout_rate=0.1,
+ attention_dropout=0.1,
+ stochastic_depth_rate=0.1,
+ positional_embedding='sine',
+ sequence_length=None,
+ *args, **kwargs
+ ):
+ super().__init__()
+ assert positional_embedding in {'sine', 'learnable', 'none'}
+
+ dim_feedforward = int(embedding_dim * mlp_ratio)
+ self.embedding_dim = embedding_dim
+ self.sequence_length = sequence_length
+ self.seq_pool = seq_pool
+
+ assert exists(sequence_length) or positional_embedding == 'none', \
+ f"Positional embedding is set to {positional_embedding} and" \
+ f" the sequence length was not specified."
+
+ if not seq_pool:
+ sequence_length += 1
+ self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim))
+ else:
+ self.attention_pool = nn.Linear(self.embedding_dim, 1)
+
+ if positional_embedding == 'none':
+ self.positional_emb = None
+ elif positional_embedding == 'learnable':
+ self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim))
+ nn.init.trunc_normal_(self.positional_emb, std = 0.2)
+ else:
+ self.register_buffer('positional_emb', sinusoidal_embedding(sequence_length, embedding_dim))
+
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]
+
+ self.blocks = nn.ModuleList([
+ TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
+ dim_feedforward=dim_feedforward, dropout=dropout_rate,
+ attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
+ for layer_dpr in dpr])
+
+ self.norm = nn.LayerNorm(embedding_dim)
+
+ self.fc = nn.Linear(embedding_dim, num_classes)
+ self.apply(self.init_weight)
+
+ @staticmethod
+ def init_weight(m):
+ if isinstance(m, nn.Linear):
+ nn.init.trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and exists(m.bias):
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x):
+ b = x.shape[0]
+
+ if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
+ x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)
+
+ if not self.seq_pool:
+ cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
+ x = torch.cat((cls_token, x), dim=1)
+
+ if exists(self.positional_emb):
+ x += self.positional_emb
+
+ x = self.dropout(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ if self.seq_pool:
+ attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
+ x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
+ else:
+ x = x[:, 0]
+
+ return self.fc(x)
+
+# CCT Main model
+
+class CCT(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ num_frames=8,
+ embedding_dim=768,
+ n_input_channels=3,
+ n_conv_layers=1,
+ frame_stride=1,
+ frame_kernel_size=3,
+ frame_pooling_kernel_size=1,
+ frame_pooling_stride=1,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ pooling_kernel_size=3,
+ pooling_stride=2,
+ pooling_padding=1,
+ *args, **kwargs
+ ):
+ super().__init__()
+ img_height, img_width = pair(img_size)
+
+ self.tokenizer = Tokenizer(
+ n_input_channels=n_input_channels,
+ n_output_channels=embedding_dim,
+ frame_stride=frame_stride,
+ frame_kernel_size=frame_kernel_size,
+ frame_pooling_stride=frame_pooling_stride,
+ frame_pooling_kernel_size=frame_pooling_kernel_size,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ pooling_kernel_size=pooling_kernel_size,
+ pooling_stride=pooling_stride,
+ pooling_padding=pooling_padding,
+ max_pool=True,
+ activation=nn.ReLU,
+ n_conv_layers=n_conv_layers,
+ conv_bias=False
+ )
+
+ self.classifier = TransformerClassifier(
+ sequence_length=self.tokenizer.sequence_length(
+ n_channels=n_input_channels,
+ frames=num_frames,
+ height=img_height,
+ width=img_width
+ ),
+ embedding_dim=embedding_dim,
+ seq_pool=True,
+ dropout_rate=0.,
+ attention_dropout=0.1,
+ stochastic_depth=0.1,
+ *args, **kwargs
+ )
+
+ def forward(self, x):
+ x = self.tokenizer(x)
+ return self.classifier(x)