Compare commits

...

1 Commits
0.2.0 ... 0.1.0

Author SHA1 Message Date
Phil Wang
ee1cbbadec write up example for using efficient transformers 2020-10-07 19:13:25 -07:00
3 changed files with 78 additions and 2 deletions

View File

@@ -32,7 +32,7 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
preds = v(img, mask = mask) # (1, 1000)
```
## Suggestion
## Suggestion - Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
@@ -82,6 +82,43 @@ torch.save(model.state_dict(), './pretrained-net.pt')
A pytorch-lightning script is ready for you to use at the repository link above.
## Suggestion 2 - Efficient Transformers
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
```bash
$ pip install linformer
```
```python
import torch
from vit_pytorch.efficient import ViT
from linformer import Linformer
efficient_transformer = Linformer(
dim = 512,
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
depth = 12,
heads = 8,
k = 256
)
v = ViT(
dim = 512,
image_size = 2048,
patch_size = 32,
num_classes = 1000,
transformer = efficient_transformer
)
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)
```
Other sparse attention framework I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
## Citations
```bibtex

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.0.5',
version = '0.1.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

39
vit_pytorch/efficient.py Normal file
View File

@@ -0,0 +1,39 @@
import torch
from einops import rearrange
from torch import nn
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, num_classes)
)
def forward(self, img):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)