Compare commits

...

3 Commits
1.8.1 ... 1.8.3

Author SHA1 Message Date
Phil Wang
5f85d7b987 go all the way with the normalized vit, fix some scales 2024-10-10 10:40:32 -07:00
Phil Wang
1d1a63fc5c cite for hypersphere vit adapted from ngpt 2024-10-10 10:15:04 -07:00
Phil Wang
74b62009f8 go for multi-headed rmsnorm for the qknorm on hypersphere vit 2024-10-10 08:09:58 -07:00
3 changed files with 27 additions and 13 deletions

View File

@@ -2142,4 +2142,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Liu2017DeepHL,
title = {Deep Hyperspherical Learning},
author = {Weiyang Liu and Yanming Zhang and Xingguo Li and Zhen Liu and Bo Dai and Tuo Zhao and Le Song},
booktitle = {Neural Information Processing Systems},
year = {2017},
url = {https://api.semanticscholar.org/CorpusID:5104558}
}
```
*I visualise a time when we will be to robots what dogs are to humans, and Im rooting for the machines.* — Claude Shannon

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.1',
version = '1.8.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,

View File

@@ -76,7 +76,8 @@ class Attention(Module):
self.dropout = dropout
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25))
self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -89,12 +90,14 @@ class Attention(Module):
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
q = q * self.q_scale
k = k * self.k_scale
q, k, v = map(self.split_heads, (q, k, v))
# query key rmsnorm
q, k = map(l2norm, (q, k))
q, k = (q * self.qk_scale), (k * self.qk_scale)
# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
@@ -176,18 +179,18 @@ class nViT(Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
NormLinear(patch_dim, dim),
)
self.abs_pos_emb = nn.Embedding(num_patches, dim)
self.abs_pos_emb = NormLinear(dim, num_patches)
residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)
# layers
self.dim = dim
self.scale = dim ** 0.5
self.layers = ModuleList([])
self.residual_lerp_scales = nn.ParameterList([])
@@ -198,8 +201,8 @@ class nViT(Module):
]))
self.residual_lerp_scales.append(nn.ParameterList([
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
]))
self.logit_scale = nn.Parameter(torch.ones(num_classes))
@@ -222,22 +225,23 @@ class nViT(Module):
tokens = self.to_patch_embedding(images)
pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device))
seq_len = tokens.shape[-2]
pos_emb = self.abs_pos_emb.weight[torch.arange(seq_len, device = device)]
tokens = l2norm(tokens + pos_emb)
for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales):
attn_out = l2norm(attn(tokens))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha))
tokens = l2norm(tokens.lerp(attn_out, attn_alpha * self.scale))
ff_out = l2norm(ff(tokens))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha))
tokens = l2norm(tokens.lerp(ff_out, ff_alpha * self.scale))
pooled = reduce(tokens, 'b n d -> b d', 'mean')
logits = self.to_pred(pooled)
logits = logits * self.logit_scale * (self.dim ** 0.5)
logits = logits * self.logit_scale * self.scale
return logits