mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b347973fe |
75
README.md
75
README.md
@@ -32,6 +32,7 @@
|
||||
- [Parallel ViT](#parallel-vit)
|
||||
- [Learnable Memory ViT](#learnable-memory-vit)
|
||||
- [Dino](#dino)
|
||||
- [EsViT](#esvit)
|
||||
- [Accessing Attention](#accessing-attention)
|
||||
- [Research Ideas](#research-ideas)
|
||||
* [Efficient Attention](#efficient-attention)
|
||||
@@ -1076,6 +1077,80 @@ for _ in range(100):
|
||||
torch.save(model.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## EsViT
|
||||
|
||||
<img src="./images/esvit.png" width="350px"></img>
|
||||
|
||||
`EsViT` is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput.
|
||||
|
||||
Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cvt import CvT
|
||||
from vit_pytorch.es_vit import EsViTTrainer
|
||||
|
||||
cvt = CvT(
|
||||
num_classes = 1000,
|
||||
s1_emb_dim = 64,
|
||||
s1_emb_kernel = 7,
|
||||
s1_emb_stride = 4,
|
||||
s1_proj_kernel = 3,
|
||||
s1_kv_proj_stride = 2,
|
||||
s1_heads = 1,
|
||||
s1_depth = 1,
|
||||
s1_mlp_mult = 4,
|
||||
s2_emb_dim = 192,
|
||||
s2_emb_kernel = 3,
|
||||
s2_emb_stride = 2,
|
||||
s2_proj_kernel = 3,
|
||||
s2_kv_proj_stride = 2,
|
||||
s2_heads = 3,
|
||||
s2_depth = 2,
|
||||
s2_mlp_mult = 4,
|
||||
s3_emb_dim = 384,
|
||||
s3_emb_kernel = 3,
|
||||
s3_emb_stride = 2,
|
||||
s3_proj_kernel = 3,
|
||||
s3_kv_proj_stride = 2,
|
||||
s3_heads = 4,
|
||||
s3_depth = 10,
|
||||
s3_mlp_mult = 4,
|
||||
dropout = 0.
|
||||
)
|
||||
|
||||
learner = EsViTTrainer(
|
||||
cvt,
|
||||
image_size = 256,
|
||||
hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding
|
||||
projection_hidden_size = 256, # projector network hidden dimension
|
||||
projection_layers = 4, # number of layers in projection network
|
||||
num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
|
||||
student_temp = 0.9, # student temperature
|
||||
teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
|
||||
local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
|
||||
global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
|
||||
moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
|
||||
)
|
||||
|
||||
opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)
|
||||
|
||||
def sample_unlabelled_images():
|
||||
return torch.randn(8, 3, 256, 256)
|
||||
|
||||
for _ in range(1000):
|
||||
images = sample_unlabelled_images()
|
||||
loss = learner(images)
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
learner.update_moving_average() # update moving average of teacher encoder and teacher centers
|
||||
|
||||
# save your improved network
|
||||
torch.save(cvt.state_dict(), './pretrained-net.pt')
|
||||
```
|
||||
|
||||
## Accessing Attention
|
||||
|
||||
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.33.2',
|
||||
version = '0.34.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -164,12 +164,14 @@ class CvT(nn.Module):
|
||||
|
||||
dim = config['emb_dim']
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
*layers,
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.to_logits = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
Rearrange('... () () -> ...'),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
latents = self.layers(x)
|
||||
return self.to_logits(latents)
|
||||
|
||||
Reference in New Issue
Block a user