diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 01085d8..c6b038e 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -94,7 +94,7 @@ class FeedForward(nn.Module): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 9): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 5): super().__init__() inner_dim = dim_head * heads