diff --git a/vit_pytorch/xcit.py b/vit_pytorch/xcit.py index bc5a85c..83b098f 100644 --- a/vit_pytorch/xcit.py +++ b/vit_pytorch/xcit.py @@ -249,6 +249,9 @@ class XCiT(Module): self.dropout = nn.Dropout(emb_dropout) self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout) + + self.final_norm = nn.LayerNorm(dim) + self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout) self.mlp_head = nn.Sequential( @@ -270,6 +273,8 @@ class XCiT(Module): x = self.xcit_transformer(x) + x = self.final_norm(x) + cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b) x = rearrange(x, 'b ... d -> b (...) d')