offer easy way to turn DistillableViT to ViT at the end of training

This commit is contained in:
Phil Wang
2020-12-25 11:16:52 -08:00
parent 0c68688d61
commit 74074e2b6c
3 changed files with 15 additions and 1 deletions

View File

@@ -104,6 +104,13 @@ loss.backward()
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
```python
v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```
## Research Ideas
### Self Supervised Training

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.6.3',
version = '0.6.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -16,9 +16,16 @@ def exists(val):
class DistillableViT(ViT):
def __init__(self, *args, **kwargs):
super(DistillableViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = ViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def forward(self, img, distill_token, mask = None):
p = self.patch_size