mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-01-06 21:12:31 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d1df1a970 | ||
|
|
d65a8c17a5 | ||
|
|
f7c164d910 | ||
|
|
c7b74e0bc3 | ||
|
|
5b5d98a3a7 | ||
|
|
b0e4790c24 | ||
|
|
0b2b3fc20c | ||
|
|
ced464dcb4 | ||
|
|
5bf45a2d4d | ||
|
|
fa32e22855 |
@@ -23,7 +23,9 @@ v = ViT(
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
6253
examples/cats_and_dogs.ipynb
Normal file
6253
examples/cats_and_dogs.ipynb
Normal file
File diff suppressed because one or more lines are too long
6
setup.py
6
setup.py
@@ -2,8 +2,8 @@ from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(),
|
||||
version = '0.2.0',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.2.5',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
@@ -25,4 +25,4 @@ setup(
|
||||
'License :: OSI Approved :: MIT License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,8 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
@@ -19,28 +21,34 @@ class PreNorm(nn.Module):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim)
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8):
|
||||
def __init__(self, dim, heads = 8, dropout = 0.):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim ** -0.5
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(dim, dim)
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
b, n, _, h = *x.shape, self.heads
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
|
||||
|
||||
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
|
||||
|
||||
@@ -59,13 +67,13 @@ class Attention(nn.Module):
|
||||
return out
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, mlp_dim):
|
||||
def __init__(self, dim, depth, heads, mlp_dim, dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
|
||||
Residual(PreNorm(dim, Attention(dim, heads = heads, dropout = dropout))),
|
||||
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
|
||||
]))
|
||||
def forward(self, x, mask = None):
|
||||
for attn, ff in self.layers:
|
||||
@@ -74,18 +82,21 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim)
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
@@ -93,6 +104,7 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
)
|
||||
|
||||
@@ -105,6 +117,8 @@ class ViT(nn.Module):
|
||||
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x, mask)
|
||||
|
||||
x = self.to_cls_token(x[:, 0])
|
||||
|
||||
Reference in New Issue
Block a user