cleanup rvt

This commit is contained in:
Phil Wang
2021-04-27 11:45:46 -07:00
parent 0df1505662
commit 60b5687a79

View File

@@ -119,7 +119,8 @@ class Attention(nn.Module):
def forward(self, x, pos_emb, fmap_dims):
b, n, _, h = *x.shape, self.heads
q = self.to_q(x, fmap_dims = fmap_dims) if self.use_ds_conv else self.to_q(x)
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
q = self.to_q(x, **to_q_kwargs)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))