give cross correlation transformer a final norm at end

This commit is contained in:
lucidrains
2023-10-12 19:51:07 -07:00
parent bcfb0f054a
commit d9679d3e26

View File

@@ -249,6 +249,9 @@ class XCiT(Module):
self.dropout = nn.Dropout(emb_dropout)
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
self.final_norm = nn.LayerNorm(dim)
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
self.mlp_head = nn.Sequential(
@@ -270,6 +273,8 @@ class XCiT(Module):
x = self.xcit_transformer(x)
x = self.final_norm(x)
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
x = rearrange(x, 'b ... d -> b (...) d')