mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Update README.md
This commit is contained in:
112
README.md
112
README.md
@@ -62,61 +62,6 @@ Dropout rate.
|
||||
Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
|
||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||
|
||||
ex. distilling from Resnet50 (or any teacher) to a vision transformer
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from vit_pytorch.distill import DistillableViT, DistillWrapper
|
||||
|
||||
teacher = resnet50(pretrained = True)
|
||||
|
||||
v = DistillableViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5, # trade between main loss and distillation loss
|
||||
hard = False # whether to use soft or hard distillation
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
labels = torch.randint(0, 1000, (2,))
|
||||
|
||||
loss = distiller(img, labels)
|
||||
loss.backward()
|
||||
|
||||
# after lots of training above ...
|
||||
|
||||
pred = v(img) # (2, 1000)
|
||||
```
|
||||
|
||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||
|
||||
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
|
||||
|
||||
```python
|
||||
v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
## CCT
|
||||
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
||||
|
||||
@@ -182,6 +127,63 @@ model = cct_2(
|
||||
Repository</a>
|
||||
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
|
||||
A recent <a href="https://arxiv.org/abs/2012.12877">paper</a> has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.
|
||||
|
||||
ex. distilling from Resnet50 (or any teacher) to a vision transformer
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from vit_pytorch.distill import DistillableViT, DistillWrapper
|
||||
|
||||
teacher = resnet50(pretrained = True)
|
||||
|
||||
v = DistillableViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
distiller = DistillWrapper(
|
||||
student = v,
|
||||
teacher = teacher,
|
||||
temperature = 3, # temperature of distillation
|
||||
alpha = 0.5, # trade between main loss and distillation loss
|
||||
hard = False # whether to use soft or hard distillation
|
||||
)
|
||||
|
||||
img = torch.randn(2, 3, 256, 256)
|
||||
labels = torch.randint(0, 1000, (2,))
|
||||
|
||||
loss = distiller(img, labels)
|
||||
loss.backward()
|
||||
|
||||
# after lots of training above ...
|
||||
|
||||
pred = v(img) # (2, 1000)
|
||||
```
|
||||
|
||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||
|
||||
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
|
||||
|
||||
```python
|
||||
v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
|
||||
## Deep ViT
|
||||
|
||||
This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
|
||||
|
||||
Reference in New Issue
Block a user