From b298031c1768b8566f73999d3635716997dc134d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 7 Oct 2020 19:15:21 -0700 Subject: [PATCH] write up example for using efficient transformers --- README.md | 41 +++++++++++++++++++++++++++++++++++++++- setup.py | 2 +- vit_pytorch/efficient.py | 39 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 vit_pytorch/efficient.py diff --git a/README.md b/README.md index cef6ae7..e097d56 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,9 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at preds = v(img, mask = mask) # (1, 1000) ``` -## Suggestion +## Research Ideas + +### Self Supervised Training You can train this with a near SOTA self-supervised learning technique, BYOL, with the following code. @@ -82,6 +84,43 @@ torch.save(model.state_dict(), './pretrained-net.pt') A pytorch-lightning script is ready for you to use at the repository link above. +### Efficient Attention + +There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer. + +An example with Linformer + +```bash +$ pip install linformer +``` + +```python +import torch +from vit_pytorch.efficient import ViT +from linformer import Linformer + +efficient_transformer = Linformer( + dim = 512, + seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token + depth = 12, + heads = 8, + k = 256 +) + +v = ViT( + dim = 512, + image_size = 2048, + patch_size = 32, + num_classes = 1000, + transformer = efficient_transformer +) + +img = torch.randn(1, 3, 2048, 2048) # your high resolution picture +v(img) # (1, 1000) +``` + +Other sparse attention frameworks I would highly recommend is Routing Transformer or Sinkhorn Transformer + ## Citations ```bibtex diff --git a/setup.py b/setup.py index 127b161..97a75df 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(), - version = '0.0.5', + version = '0.1.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/efficient.py b/vit_pytorch/efficient.py new file mode 100644 index 0000000..706fb09 --- /dev/null +++ b/vit_pytorch/efficient.py @@ -0,0 +1,39 @@ +import torch +from einops import rearrange +from torch import nn + +class ViT(nn.Module): + def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3): + super().__init__() + assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' + num_patches = (image_size // patch_size) ** 2 + patch_dim = channels * patch_size ** 2 + + self.patch_size = patch_size + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.transformer = transformer + + self.to_cls_token = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, num_classes) + ) + + def forward(self, img): + p = self.patch_size + + x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) + x = self.patch_to_embedding(x) + + cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x += self.pos_embedding + x = self.transformer(x) + + x = self.to_cls_token(x[:, 0]) + return self.mlp_head(x)