Compare commits

...

1 Commits
1.4.5 ... 0.7.2

Author SHA1 Message Date
Phil Wang
c2e3d9303b remove patch size from T2TViT 2021-02-21 19:13:30 -08:00
3 changed files with 8 additions and 7 deletions

View File

@@ -130,7 +130,6 @@ from vit_pytorch.t2t import T2TViT
v = T2TViT(
dim = 512,
image_size = 224,
patch_size = 16,
depth = 5,
heads = 8,
mlp_dim = 512,

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.1',
version = '0.7.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -9,6 +9,9 @@ from einops.layers.torch import Rearrange
# classes
def conv_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
class RearrangeImage(nn.Module):
def forward(self, x):
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
@@ -17,19 +20,18 @@ class RearrangeImage(nn.Module):
class T2TViT(nn.Module):
def __init__(
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
layers = []
layer_dim = channels
output_image_size = image_size
for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
@@ -41,7 +43,7 @@ class T2TViT(nn.Module):
layers.append(nn.Linear(layer_dim, dim))
self.to_patch_embedding = nn.Sequential(*layers)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)