fix t2t vit having two layernorms, and make final layernorm in distillation wrapper configurable, default to False for vit

This commit is contained in:
lucidrains
2024-06-11 15:12:53 -07:00
parent 90be7233a3
commit e3256d77cd
3 changed files with 17 additions and 14 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.6.9', version = '1.7.0',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description, long_description=long_description,

View File

@@ -1,6 +1,8 @@
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from vit_pytorch.vit import ViT from vit_pytorch.vit import ViT
from vit_pytorch.t2t import T2TViT from vit_pytorch.t2t import T2TViT
from vit_pytorch.efficient import ViT as EfficientViT from vit_pytorch.efficient import ViT as EfficientViT
@@ -12,6 +14,9 @@ from einops import rearrange, repeat
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d):
return val if exists(val) else d
# classes # classes
class DistillMixin: class DistillMixin:
@@ -20,12 +25,12 @@ class DistillMixin:
x = self.to_patch_embedding(img) x = self.to_patch_embedding(img)
b, n, _ = x.shape b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1) x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)] x += self.pos_embedding[:, :(n + 1)]
if distilling: if distilling:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b) distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1) x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x) x = self._attend(x)
@@ -97,7 +102,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
# knowledge distillation wrapper # knowledge distillation wrapper
class DistillWrapper(nn.Module): class DistillWrapper(Module):
def __init__( def __init__(
self, self,
*, *,
@@ -105,7 +110,8 @@ class DistillWrapper(nn.Module):
student, student,
temperature = 1., temperature = 1.,
alpha = 0.5, alpha = 0.5,
hard = False hard = False,
mlp_layernorm = False
): ):
super().__init__() super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
@@ -122,14 +128,14 @@ class DistillWrapper(nn.Module):
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
self.distill_mlp = nn.Sequential( self.distill_mlp = nn.Sequential(
nn.LayerNorm(dim), nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
nn.Linear(dim, num_classes) nn.Linear(dim, num_classes)
) )
def forward(self, img, labels, temperature = None, alpha = None, **kwargs): def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
b, *_ = img.shape
alpha = alpha if exists(alpha) else self.alpha alpha = default(alpha, self.alpha)
T = temperature if exists(temperature) else self.temperature T = default(temperature, self.temperature)
with torch.no_grad(): with torch.no_grad():
teacher_logits = self.teacher(img) teacher_logits = self.teacher(img)

View File

@@ -61,10 +61,7 @@ class T2TViT(nn.Module):
self.pool = pool self.pool = pool
self.to_latent = nn.Identity() self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential( self.mlp_head = nn.Linear(dim, num_classes)
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img): def forward(self, img):
x = self.to_patch_embedding(img) x = self.to_patch_embedding(img)