mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
offer easy way to turn DistillableViT to ViT at the end of training
This commit is contained in:
@@ -104,6 +104,13 @@ loss.backward()
|
||||
|
||||
The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training.
|
||||
|
||||
You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance.
|
||||
|
||||
```python
|
||||
v = v.to_vit()
|
||||
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Self Supervised Training
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.6.3',
|
||||
version = '0.6.4',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -16,9 +16,16 @@ def exists(val):
|
||||
class DistillableViT(ViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = ViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def forward(self, img, distill_token, mask = None):
|
||||
p = self.patch_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user