mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
cleanup rvt
This commit is contained in:
@@ -119,7 +119,8 @@ class Attention(nn.Module):
|
||||
def forward(self, x, pos_emb, fmap_dims):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
|
||||
q = self.to_q(x, fmap_dims = fmap_dims) if self.use_ds_conv else self.to_q(x)
|
||||
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
|
||||
q = self.to_q(x, **to_q_kwargs)
|
||||
|
||||
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user