diff --git a/README.md b/README.md
index cef6ae7..e097d56 100644
--- a/README.md
+++ b/README.md
@@ -32,7 +32,9 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
preds = v(img, mask = mask) # (1, 1000)
```
-## Suggestion
+## Research Ideas
+
+### Self Supervised Training
You can train this with a near SOTA self-supervised learning technique, BYOL, with the following code.
@@ -82,6 +84,43 @@ torch.save(model.state_dict(), './pretrained-net.pt')
A pytorch-lightning script is ready for you to use at the repository link above.
+### Efficient Attention
+
+There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
+
+An example with Linformer
+
+```bash
+$ pip install linformer
+```
+
+```python
+import torch
+from vit_pytorch.efficient import ViT
+from linformer import Linformer
+
+efficient_transformer = Linformer(
+ dim = 512,
+ seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
+ depth = 12,
+ heads = 8,
+ k = 256
+)
+
+v = ViT(
+ dim = 512,
+ image_size = 2048,
+ patch_size = 32,
+ num_classes = 1000,
+ transformer = efficient_transformer
+)
+
+img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
+v(img) # (1, 1000)
+```
+
+Other sparse attention frameworks I would highly recommend is Routing Transformer or Sinkhorn Transformer
+
## Citations
```bibtex
diff --git a/setup.py b/setup.py
index 127b161..97a75df 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(),
- version = '0.0.5',
+ version = '0.1.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
diff --git a/vit_pytorch/efficient.py b/vit_pytorch/efficient.py
new file mode 100644
index 0000000..706fb09
--- /dev/null
+++ b/vit_pytorch/efficient.py
@@ -0,0 +1,39 @@
+import torch
+from einops import rearrange
+from torch import nn
+
+class ViT(nn.Module):
+ def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
+ super().__init__()
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
+ num_patches = (image_size // patch_size) ** 2
+ patch_dim = channels * patch_size ** 2
+
+ self.patch_size = patch_size
+
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
+ self.transformer = transformer
+
+ self.to_cls_token = nn.Identity()
+
+ self.mlp_head = nn.Sequential(
+ nn.Linear(dim, dim * 4),
+ nn.GELU(),
+ nn.Linear(dim * 4, num_classes)
+ )
+
+ def forward(self, img):
+ p = self.patch_size
+
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
+ x = self.patch_to_embedding(x)
+
+ cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x += self.pos_embedding
+ x = self.transformer(x)
+
+ x = self.to_cls_token(x[:, 0])
+ return self.mlp_head(x)