Compare commits

...

1 Commits
1.4.2 ... 0.7.5

Author SHA1 Message Date
Phil Wang
be67c142aa make it so one can plug performer into t2tvit 2021-02-25 20:38:13 -08:00
2 changed files with 9 additions and 4 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.4',
version = '0.7.5',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -7,11 +7,16 @@ from vit_pytorch.vit_pytorch import Transformer
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# classes
# helpers
def exists(val):
return val is not None
def conv_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)
# classes
class RearrangeImage(nn.Module):
def forward(self, x):
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
@@ -20,7 +25,7 @@ class RearrangeImage(nn.Module):
class T2TViT(nn.Module):
def __init__(
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., transformer = None, t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
@@ -47,7 +52,7 @@ class T2TViT(nn.Module):
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) if not exists(transformer) else transformer
self.pool = pool
self.to_latent = nn.Identity()