diff --git a/README.md b/README.md index 59fbb0c..6cc14dc 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,13 @@ loss.backward() The `DistillableViT` class is identical to `ViT` except for how the forward pass is handled, so you should be able to load the parameters back to `ViT` after you have completed distillation training. +You can also use the handy `.to_vit` method on the `DistillableViT` instance to get back a `ViT` instance. + +```python +v = v.to_vit() +type(v) # +``` + ## Research Ideas ### Self Supervised Training diff --git a/setup.py b/setup.py index 63f4d04..1181bab 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.6.3', + version = '0.6.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/distill.py b/vit_pytorch/distill.py index d41957f..2b4660b 100644 --- a/vit_pytorch/distill.py +++ b/vit_pytorch/distill.py @@ -16,9 +16,16 @@ def exists(val): class DistillableViT(ViT): def __init__(self, *args, **kwargs): super(DistillableViT, self).__init__(*args, **kwargs) + self.args = args + self.kwargs = kwargs self.dim = kwargs['dim'] self.num_classes = kwargs['num_classes'] + def to_vit(self): + v = ViT(*self.args, **self.kwargs) + v.load_state_dict(self.state_dict()) + return v + def forward(self, img, distill_token, mask = None): p = self.patch_size