mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-15 12:54:12 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc7ee3e007 | ||
|
|
3744ac691a |
@@ -130,7 +130,6 @@ from vit_pytorch.t2t import T2TViT
|
||||
v = T2TViT(
|
||||
dim = 512,
|
||||
image_size = 224,
|
||||
patch_size = 16,
|
||||
depth = 5,
|
||||
heads = 8,
|
||||
mlp_dim = 512,
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.7.1',
|
||||
version = '0.7.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from vit_pytorch.vit_pytorch import ViT
|
||||
from vit_pytorch.t2t import T2TViT
|
||||
from vit_pytorch.efficient import ViT as EfficientViT
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@@ -60,6 +61,24 @@ class DistillableViT(DistillMixin, ViT):
|
||||
x = self.transformer(x, mask)
|
||||
return x
|
||||
|
||||
class DistillableT2TViT(DistillMixin, T2TViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableT2TViT, self).__init__(*args, **kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.dim = kwargs['dim']
|
||||
self.num_classes = kwargs['num_classes']
|
||||
|
||||
def to_vit(self):
|
||||
v = T2TViT(*self.args, **self.kwargs)
|
||||
v.load_state_dict(self.state_dict())
|
||||
return v
|
||||
|
||||
def _attend(self, x, mask):
|
||||
x = self.dropout(x)
|
||||
x = self.transformer(x, mask)
|
||||
return x
|
||||
|
||||
class DistillableEfficientViT(DistillMixin, EfficientViT):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
|
||||
@@ -88,7 +107,7 @@ class DistillWrapper(nn.Module):
|
||||
alpha = 0.5
|
||||
):
|
||||
super().__init__()
|
||||
assert (isinstance(student, (DistillableViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
|
||||
|
||||
self.teacher = teacher
|
||||
self.student = student
|
||||
|
||||
@@ -9,6 +9,9 @@ from einops.layers.torch import Rearrange
|
||||
|
||||
# classes
|
||||
|
||||
def conv_output_size(image_size, kernel_size, stride, padding):
|
||||
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
|
||||
@@ -17,19 +20,18 @@ class RearrangeImage(nn.Module):
|
||||
|
||||
class T2TViT(nn.Module):
|
||||
def __init__(
|
||||
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
layers = []
|
||||
layer_dim = channels
|
||||
output_image_size = image_size
|
||||
|
||||
for i, (kernel_size, stride) in enumerate(t2t_layers):
|
||||
layer_dim *= kernel_size ** 2
|
||||
is_first = i == 0
|
||||
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
|
||||
|
||||
layers.extend([
|
||||
RearrangeImage() if not is_first else nn.Identity(),
|
||||
@@ -41,7 +43,7 @@ class T2TViT(nn.Module):
|
||||
layers.append(nn.Linear(layer_dim, dim))
|
||||
self.to_patch_embedding = nn.Sequential(*layers)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
@@ -61,7 +63,7 @@ class T2TViT(nn.Module):
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
Reference in New Issue
Block a user