mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0fa41070f | ||
|
|
b298031c17 |
41
README.md
41
README.md
@@ -32,7 +32,9 @@ mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to at
|
||||
preds = v(img, mask = mask) # (1, 1000)
|
||||
```
|
||||
|
||||
## Suggestion
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
You can train this with a near SOTA self-supervised learning technique, <a href="https://github.com/lucidrains/byol-pytorch">BYOL</a>, with the following code.
|
||||
|
||||
@@ -82,6 +84,43 @@ torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
|
||||
A pytorch-lightning script is ready for you to use at the repository link above.
|
||||
|
||||
### Efficient Attention
|
||||
|
||||
There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.
|
||||
|
||||
An example with <a href="https://arxiv.org/abs/2006.04768">Linformer</a>
|
||||
|
||||
```bash
|
||||
$ pip install linformer
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.efficient import ViT
|
||||
from linformer import Linformer
|
||||
|
||||
efficient_transformer = Linformer(
|
||||
dim = 512,
|
||||
seq_len = 4096 + 1, # 64 x 64 patches + 1 cls token
|
||||
depth = 12,
|
||||
heads = 8,
|
||||
k = 256
|
||||
)
|
||||
|
||||
v = ViT(
|
||||
dim = 512,
|
||||
image_size = 2048,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
transformer = efficient_transformer
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
|
||||
v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
Other sparse attention frameworks I would highly recommend is <a href="https://github.com/lucidrains/routing-transformer">Routing Transformer</a> or <a href="https://github.com/lucidrains/sinkhorn-transformer">Sinkhorn Transformer</a>
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.0.5',
|
||||
version = '0.2.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
40
vit_pytorch/efficient.py
Normal file
40
vit_pytorch/efficient.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = transformer
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
p = self.patch_size
|
||||
|
||||
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
x = self.patch_to_embedding(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.transformer(x)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
return self.mlp_head(x)
|
||||
@@ -90,6 +90,7 @@ class ViT(nn.Module):
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
|
||||
Reference in New Issue
Block a user