diff --git a/README.md b/README.md index 604e384..1bc258e 100644 --- a/README.md +++ b/README.md @@ -2020,4 +2020,13 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@inproceedings{Darcet2023VisionTN, + title = {Vision Transformers Need Registers}, + author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski}, + year = {2023}, + url = {https://api.semanticscholar.org/CorpusID:263134283} +} +``` + *I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon diff --git a/setup.py b/setup.py index c72e5a5..5fd08f3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.4.5 ', + version = '1.5.0 ', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/simple_vit_with_register_tokens.py b/vit_pytorch/simple_vit_with_register_tokens.py new file mode 100644 index 0000000..6b5ab44 --- /dev/null +++ b/vit_pytorch/simple_vit_with_register_tokens.py @@ -0,0 +1,129 @@ +import torch +from torch import nn + +from einops import rearrange, repeat, pack, unpack +from einops.layers.torch import Rearrange + +# helpers + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4) / (dim // 4 - 1) + omega = 1.0 / (temperature ** omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + +# classes + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, dim), + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + self.norm = nn.LayerNorm(dim) + + self.attend = nn.Softmax(dim = -1) + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Linear(inner_dim, dim, bias = False) + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head), + FeedForward(dim, mlp_dim) + ])) + def forward(self, x): + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + +class SimpleViT(nn.Module): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64): + super().__init__() + image_height, image_width = pair(image_size) + patch_height, patch_width = pair(patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + + patch_dim = channels * patch_height * patch_width + + self.to_patch_embedding = nn.Sequential( + Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + + self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) + + self.pos_embedding = posemb_sincos_2d( + h = image_height // patch_height, + w = image_width // patch_width, + dim = dim, + ) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + + self.pool = "mean" + self.to_latent = nn.Identity() + + self.linear_head = nn.Linear(dim, num_classes) + + def forward(self, img): + batch, device = img.shape[0], img.device + + x = self.to_patch_embedding(img) + x += self.pos_embedding.to(device, dtype=x.dtype) + + r = repeat(self.register_tokens, 'n d -> b n d', b = batch) + + x, ps = pack([x, r], 'b * d') + + x = self.transformer(x) + + x, _ = unpack(x, ps, 'b * d') + + x = x.mean(dim = 1) + + x = self.to_latent(x) + return self.linear_head(x)