From 60b5687a7997f41c855ebc78ff77040ac5da5b61 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 27 Apr 2021 11:45:46 -0700 Subject: [PATCH] cleanup rvt --- vit_pytorch/rvt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index eeb842b..cbc4599 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -119,7 +119,8 @@ class Attention(nn.Module): def forward(self, x, pos_emb, fmap_dims): b, n, _, h = *x.shape, self.heads - q = self.to_q(x, fmap_dims = fmap_dims) if self.use_ds_conv else self.to_q(x) + to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {} + q = self.to_q(x, **to_q_kwargs) qkv = (q, *self.to_kv(x).chunk(2, dim = -1))