Compare commits

...

2 Commits
0.7.1 ... 0.7.3

Author SHA1 Message Date
Phil Wang
dc7ee3e007 add DistillableT2TViT 2021-02-21 19:18:29 -08:00
Phil Wang
3744ac691a remove patch size from T2TViT 2021-02-21 19:15:19 -08:00
4 changed files with 29 additions and 9 deletions

View File

@@ -130,7 +130,6 @@ from vit_pytorch.t2t import T2TViT
v = T2TViT(
dim = 512,
image_size = 224,
patch_size = 16,
depth = 5,
heads = 8,
mlp_dim = 512,

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.1',
version = '0.7.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.t2t import T2TViT
from vit_pytorch.efficient import ViT as EfficientViT
from einops import rearrange, repeat
@@ -60,6 +61,24 @@ class DistillableViT(DistillMixin, ViT):
x = self.transformer(x, mask)
return x
class DistillableT2TViT(DistillMixin, T2TViT):
def __init__(self, *args, **kwargs):
super(DistillableT2TViT, self).__init__(*args, **kwargs)
self.args = args
self.kwargs = kwargs
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']
def to_vit(self):
v = T2TViT(*self.args, **self.kwargs)
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
x = self.dropout(x)
x = self.transformer(x, mask)
return x
class DistillableEfficientViT(DistillMixin, EfficientViT):
def __init__(self, *args, **kwargs):
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
@@ -88,7 +107,7 @@ class DistillWrapper(nn.Module):
alpha = 0.5
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
self.teacher = teacher
self.student = student

View File

@@ -9,6 +9,9 @@ from einops.layers.torch import Rearrange
# classes
def conv_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
class RearrangeImage(nn.Module):
def forward(self, x):
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
@@ -17,19 +20,18 @@ class RearrangeImage(nn.Module):
class T2TViT(nn.Module):
def __init__(
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
layers = []
layer_dim = channels
output_image_size = image_size
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
@@ -41,7 +43,7 @@ class T2TViT(nn.Module):
layers.append(nn.Linear(layer_dim, dim))
self.to_patch_embedding = nn.Sequential(*layers)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
@@ -61,7 +63,7 @@ class T2TViT(nn.Module):
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x)