go all the way with the normalized vit, fix some scales

This commit is contained in:
Phil Wang
2024-10-10 10:42:37 -07:00
parent 1d1a63fc5c
commit 36ddc7a6ba
2 changed files with 12 additions and 11 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.8.2', version = '1.8.4',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description, long_description=long_description,

View File

@@ -179,18 +179,18 @@ class nViT(Module):
self.to_patch_embedding = nn.Sequential( self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size), Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim), NormLinear(patch_dim, dim, norm_dim_in = False),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
) )
self.abs_pos_emb = nn.Embedding(num_patches, dim) self.abs_pos_emb = NormLinear(dim, num_patches)
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth) residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
# layers # layers
self.dim = dim self.dim = dim
self.scale = dim ** 0.5
self.layers = ModuleList([]) self.layers = ModuleList([])
self.residual_lerp_scales = nn.ParameterList([]) self.residual_lerp_scales = nn.ParameterList([])
@@ -201,8 +201,8 @@ class nViT(Module):
])) ]))
self.residual_lerp_scales.append(nn.ParameterList([ self.residual_lerp_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
])) ]))
self.logit_scale = nn.Parameter(torch.ones(num_classes)) self.logit_scale = nn.Parameter(torch.ones(num_classes))
@@ -225,22 +225,23 @@ class nViT(Module):
tokens = self.to_patch_embedding(images) tokens = self.to_patch_embedding(images)
pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device)) seq_len = tokens.shape[-2]
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
tokens = l2norm(tokens + pos_emb) tokens = l2norm(tokens + pos_emb)
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales): for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
attn_out = l2norm(attn(tokens)) attn_out = l2norm(attn(tokens))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha)) tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
ff_out = l2norm(ff(tokens)) ff_out = l2norm(ff(tokens))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha)) tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
pooled = reduce(tokens, 'b n d -> b d', 'mean') pooled = reduce(tokens, 'b n d -> b d', 'mean')
logits = self.to_pred(pooled) logits = self.to_pred(pooled)
logits = logits * self.logit_scale * (self.dim ** 0.5) logits = logits * self.logit_scale * self.scale
return logits return logits