diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 03ebf58..01085d8 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -43,6 +43,16 @@ class AxialRotaryEmbedding(nn.Module): sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos)) return sin, cos +class DepthWiseConv2d(nn.Module): + def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), + nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) + ) + def forward(self, x): + return self.net(x) + # helper classes class PreNorm(nn.Module): @@ -53,6 +63,18 @@ class PreNorm(nn.Module): def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) +class SpatialConv(nn.Module): + def __init__(self, dim_in, dim_out, kernel, bias = False): + super().__init__() + self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False) + + def forward(self, x, fmap_dims): + cls_token, x = x[:, :1], x[:, 1:] + x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims) + x = self.conv(x) + x = rearrange(x, 'b d h w -> b (h w) d') + return torch.cat((cls_token, x), dim = 1) + class GEGLU(nn.Module): def forward(self, x): x, gates = x.chunk(2, dim = -1) @@ -72,7 +94,7 @@ class FeedForward(nn.Module): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 9): super().__init__() inner_dim = dim_head * heads @@ -80,16 +102,22 @@ class Attention(nn.Module): self.scale = dim_head ** -0.5 self.attend = nn.Softmax(dim = -1) - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) + + self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) - def forward(self, x, pos_emb): + def forward(self, x, pos_emb, fmap_dims): b, n, _, h = *x.shape, self.heads - qkv = self.to_qkv(x).chunk(3, dim = -1) + + q = self.to_q(x, fmap_dims = fmap_dims) + qkv = (q, *self.to_kv(x).chunk(2, dim = -1)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) # apply 2d rotary embeddings to queries and keys, excluding CLS tokens @@ -121,11 +149,11 @@ class Transformer(nn.Module): PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) - def forward(self, x): + def forward(self, x, fmap_dims): pos_emb = self.pos_emb(x[:, 1:]) for attn, ff in self.layers: - x = attn(x, pos_emb = pos_emb) + x + x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x x = ff(x) + x return x @@ -138,6 +166,7 @@ class RvT(nn.Module): num_patches = (image_size // patch_size) ** 2 patch_dim = channels * patch_size ** 2 + self.patch_size = patch_size self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), nn.Linear(patch_dim, dim), @@ -152,12 +181,15 @@ class RvT(nn.Module): ) def forward(self, img): + b, _, h, w, p = *img.shape, self.patch_size + x = self.to_patch_embedding(img) - b, n, _ = x.shape + n = x.shape[1] cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) - x = self.transformer(x) + fmap_dims = {'h': h // p, 'w': w // p} + x = self.transformer(x, fmap_dims = fmap_dims) return self.mlp_head(x)