diff --git a/setup.py b/setup.py index b2b7fdb..83f4dc4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.16.7', + version = '0.16.8', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 9c70e16..42dcd09 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -94,7 +94,7 @@ class FeedForward(nn.Module): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, conv_query_kernel = 5): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = True, use_ds_conv = True, conv_query_kernel = 5): super().__init__() inner_dim = dim_head * heads self.use_rotary = use_rotary @@ -103,7 +103,8 @@ class Attention(nn.Module): self.attend = nn.Softmax(dim = -1) - self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) + self.use_ds_conv = use_ds_conv + self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False) if use_ds_conv else nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -115,7 +116,8 @@ class Attention(nn.Module): def forward(self, x, pos_emb, fmap_dims): b, n, _, h = *x.shape, self.heads - q = self.to_q(x, fmap_dims = fmap_dims) + q = self.to_q(x, fmap_dims = fmap_dims) if self.use_ds_conv else self.to_q(x) + qkv = (q, *self.to_kv(x).chunk(2, dim = -1)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) @@ -141,13 +143,13 @@ class Attention(nn.Module): return self.to_out(out) class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True): super().__init__() self.layers = nn.ModuleList([]) self.pos_emb = AxialRotaryEmbedding(dim_head) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary)), + PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x, fmap_dims): @@ -161,7 +163,7 @@ class Transformer(nn.Module): # Rotary Vision Transformer class RvT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 @@ -174,7 +176,7 @@ class RvT(nn.Module): ) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv) self.mlp_head = nn.Sequential( nn.LayerNorm(dim),