mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Update README.md
This commit is contained in:
112
README.md
112
README.md
@@ -62,61 +62,6 @@ Dropout rate.
|
|||||||
Embedding dropout rate.
|
Embedding dropout rate.
|
||||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||||
|
|
||||||
## Distillation
|
|
||||||
|
|
||||||
<img src="./images/distill.png" width="300px"></img>
|
|
||||||
|
|
||||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
|
||||||
|
|
||||||
ex. distilling from Resnet50 (or any teacher) to a vision transformer
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from torchvision.models import resnet50
|
|
||||||
|
|
||||||
from vit_pytorch.distill import DistillableViT, DistillWrapper
|
|
||||||
|
|
||||||
teacher = resnet50(pretrained = True)
|
|
||||||
|
|
||||||
v = DistillableViT(
|
|
||||||
image_size = 256,
|
|
||||||
patch_size = 32,
|
|
||||||
num_classes = 1000,
|
|
||||||
dim = 1024,
|
|
||||||
depth = 6,
|
|
||||||
heads = 8,
|
|
||||||
mlp_dim = 2048,
|
|
||||||
dropout = 0.1,
|
|
||||||
emb_dropout = 0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
distiller = DistillWrapper(
|
|
||||||
student = v,
|
|
||||||
teacher = teacher,
|
|
||||||
temperature = 3, # temperature of distillation
|
|
||||||
alpha = 0.5, # trade between main loss and distillation loss
|
|
||||||
hard = False # whether to use soft or hard distillation
|
|
||||||
)
|
|
||||||
|
|
||||||
img = torch.randn(2, 3, 256, 256)
|
|
||||||
labels = torch.randint(0, 1000, (2,))
|
|
||||||
|
|
||||||
loss = distiller(img, labels)
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# after lots of training above ...
|
|
||||||
|
|
||||||
pred = v(img) # (2, 1000)
|
|
||||||
```
|
|
||||||
|
|
||||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
|
||||||
|
|
||||||
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
|
|
||||||
|
|
||||||
```python
|
|
||||||
v = v.to_vit()
|
|
||||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
|
||||||
```
|
|
||||||
## CCT
|
## CCT
|
||||||
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
||||||
|
|
||||||
@@ -182,6 +127,63 @@ model = cct_2(
|
|||||||
Repository</a>
|
Repository</a>
|
||||||
|
|
||||||
|
|
||||||
|
## Distillation
|
||||||
|
|
||||||
|
<img src="./images/distill.png" width="300px"></img>
|
||||||
|
|
||||||
|
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||||
|
|
||||||
|
ex. distilling from Resnet50 (or any teacher) to a vision transformer
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torchvision.models import resnet50
|
||||||
|
|
||||||
|
from vit_pytorch.distill import DistillableViT, DistillWrapper
|
||||||
|
|
||||||
|
teacher = resnet50(pretrained = True)
|
||||||
|
|
||||||
|
v = DistillableViT(
|
||||||
|
image_size = 256,
|
||||||
|
patch_size = 32,
|
||||||
|
num_classes = 1000,
|
||||||
|
dim = 1024,
|
||||||
|
depth = 6,
|
||||||
|
heads = 8,
|
||||||
|
mlp_dim = 2048,
|
||||||
|
dropout = 0.1,
|
||||||
|
emb_dropout = 0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
distiller = DistillWrapper(
|
||||||
|
student = v,
|
||||||
|
teacher = teacher,
|
||||||
|
temperature = 3, # temperature of distillation
|
||||||
|
alpha = 0.5, # trade between main loss and distillation loss
|
||||||
|
hard = False # whether to use soft or hard distillation
|
||||||
|
)
|
||||||
|
|
||||||
|
img = torch.randn(2, 3, 256, 256)
|
||||||
|
labels = torch.randint(0, 1000, (2,))
|
||||||
|
|
||||||
|
loss = distiller(img, labels)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# after lots of training above ...
|
||||||
|
|
||||||
|
pred = v(img) # (2, 1000)
|
||||||
|
```
|
||||||
|
|
||||||
|
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||||
|
|
||||||
|
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
|
||||||
|
|
||||||
|
```python
|
||||||
|
v = v.to_vit()
|
||||||
|
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Deep ViT
|
## Deep ViT
|
||||||
|
|
||||||
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
|
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
|
||||||
|
|||||||
Reference in New Issue
Block a user