import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d # feedforward class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) # attention class Attention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** -0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) def forward(self, x, context = None, kv_include_self = False): b, n, _, h = *x.shape, self.heads x = self.norm(x) context = default(context, x) if kv_include_self: context = torch.cat((x, context), dim = 1) # cross attention requires CLS token includes itself as key / value qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # transformer encoder, for small and large patches class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) self.norm = nn.LayerNorm(dim) for _ in range(depth): self.layers.append(nn.ModuleList([ Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), FeedForward(dim, mlp_dim, dropout = dropout) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return self.norm(x) # projecting CLS tokens, in the case that small and large patch tokens have different dimensions class ProjectInOut(nn.Module): def __init__(self, dim_in, dim_out, fn): super().__init__() self.fn = fn need_projection = dim_in != dim_out self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() def forward(self, x, *args, **kwargs): x = self.project_in(x) x = self.fn(x, *args, **kwargs) x = self.project_out(x) return x # cross attention transformer class CrossTransformer(nn.Module): def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)), ProjectInOut(lg_dim, sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)) ])) def forward(self, sm_tokens, lg_tokens): (sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map(lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens)) for sm_attend_lg, lg_attend_sm in self.layers: sm_cls = sm_attend_lg(sm_cls, context = lg_patch_tokens, kv_include_self = True) + sm_cls lg_cls = lg_attend_sm(lg_cls, context = sm_patch_tokens, kv_include_self = True) + lg_cls sm_tokens = torch.cat((sm_cls, sm_patch_tokens), dim = 1) lg_tokens = torch.cat((lg_cls, lg_patch_tokens), dim = 1) return sm_tokens, lg_tokens # multi-scale encoder class MultiScaleEncoder(nn.Module): def __init__( self, *, depth, sm_dim, lg_dim, sm_enc_params, lg_enc_params, cross_attn_heads, cross_attn_depth, cross_attn_dim_head = 64, dropout = 0. ): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Transformer(dim = sm_dim, dropout = dropout, **sm_enc_params), Transformer(dim = lg_dim, dropout = dropout, **lg_enc_params), CrossTransformer(sm_dim = sm_dim, lg_dim = lg_dim, depth = cross_attn_depth, heads = cross_attn_heads, dim_head = cross_attn_dim_head, dropout = dropout) ])) def forward(self, sm_tokens, lg_tokens): for sm_enc, lg_enc, cross_attend in self.layers: sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens) sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens) return sm_tokens, lg_tokens # patch-based image to token embedder class ImageEmbedder(nn.Module): def __init__( self, *, dim, image_size, patch_size, dropout = 0., channels = 3 ): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 patch_dim = channels * patch_size ** 2 self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim) ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(dropout) def forward(self, img): x = self.to_patch_embedding(img) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] return self.dropout(x) # cross ViT class class CrossViT(nn.Module): def __init__( self, *, image_size, num_classes, sm_dim, lg_dim, sm_patch_size = 12, sm_enc_depth = 1, sm_enc_heads = 8, sm_enc_mlp_dim = 2048, sm_enc_dim_head = 64, lg_patch_size = 16, lg_enc_depth = 4, lg_enc_heads = 8, lg_enc_mlp_dim = 2048, lg_enc_dim_head = 64, cross_attn_depth = 2, cross_attn_heads = 8, cross_attn_dim_head = 64, depth = 3, dropout = 0.1, emb_dropout = 0.1, channels = 3 ): super().__init__() self.sm_image_embedder = ImageEmbedder(dim = sm_dim, channels= channels, image_size = image_size, patch_size = sm_patch_size, dropout = emb_dropout) self.lg_image_embedder = ImageEmbedder(dim = lg_dim, channels = channels, image_size = image_size, patch_size = lg_patch_size, dropout = emb_dropout) self.multi_scale_encoder = MultiScaleEncoder( depth = depth, sm_dim = sm_dim, lg_dim = lg_dim, cross_attn_heads = cross_attn_heads, cross_attn_dim_head = cross_attn_dim_head, cross_attn_depth = cross_attn_depth, sm_enc_params = dict( depth = sm_enc_depth, heads = sm_enc_heads, mlp_dim = sm_enc_mlp_dim, dim_head = sm_enc_dim_head ), lg_enc_params = dict( depth = lg_enc_depth, heads = lg_enc_heads, mlp_dim = lg_enc_mlp_dim, dim_head = lg_enc_dim_head ), dropout = dropout ) self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes)) self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes)) def forward(self, img): sm_tokens = self.sm_image_embedder(img) lg_tokens = self.lg_image_embedder(img) sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens) sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens)) sm_logits = self.sm_mlp_head(sm_cls) lg_logits = self.lg_mlp_head(lg_cls) return sm_logits + lg_logits