Compare commits

..

1 Commits
0.2.6 ... 0.2.2

4 changed files with 9 additions and 13 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.2.6',
version = '0.2.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -30,11 +30,10 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(b, -1, -1)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding
x = self.transformer(x)
x = self.to_cls_token(x[:, 0])

View File

@@ -3,8 +3,6 @@ import torch.nn.functional as F
from einops import rearrange
from torch import nn
MIN_NUM_PATCHES = 16
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
@@ -47,8 +45,8 @@ class Attention(nn.Module):
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
@@ -87,7 +85,6 @@ class ViT(nn.Module):
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
self.patch_size = patch_size
@@ -105,7 +102,8 @@ class ViT(nn.Module):
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
nn.Linear(mlp_dim, num_classes),
nn.Dropout(dropout)
)
def forward(self, img, mask = None):
@@ -113,11 +111,10 @@ class ViT(nn.Module):
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(b, -1, -1)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x, mask)