mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
give cross correlation transformer a final norm at end
This commit is contained in:
@@ -249,6 +249,9 @@ class XCiT(Module):
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
|
||||
|
||||
self.final_norm = nn.LayerNorm(dim)
|
||||
|
||||
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
@@ -270,6 +273,8 @@ class XCiT(Module):
|
||||
|
||||
x = self.xcit_transformer(x)
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
|
||||
|
||||
x = rearrange(x, 'b ... d -> b (...) d')
|
||||
|
||||
Reference in New Issue
Block a user