diff --git a/vit_pytorch/t2t.py b/vit_pytorch/t2t.py index c5582df..0ccc7d7 100644 --- a/vit_pytorch/t2t.py +++ b/vit_pytorch/t2t.py @@ -72,7 +72,7 @@ class T2TViT(nn.Module): cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) - x += self.pos_embedding + x += self.pos_embedding[:, :n+1] x = self.dropout(x) x = self.transformer(x)