mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 16:12:29 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
014df1e6e4 | ||
|
|
680d446e46 | ||
|
|
3fdb8dd352 | ||
|
|
a36546df23 | ||
|
|
d830b05f06 |
25
.github/workflows/python-publish.yml
vendored
25
.github/workflows/python-publish.yml
vendored
@@ -1,11 +1,16 @@
|
||||
# This workflows will upload a Python Package using Twine when a release is created
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [created]
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
@@ -21,11 +26,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
run: |
|
||||
python setup.py sdist bdist_wheel
|
||||
twine upload dist/*
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
||||
@@ -2020,4 +2020,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Darcet2023VisionTN,
|
||||
title = {Vision Transformers Need Registers},
|
||||
author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
|
||||
year = {2023},
|
||||
url = {https://api.semanticscholar.org/CorpusID:263134283}
|
||||
}
|
||||
```
|
||||
|
||||
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon
|
||||
|
||||
4
setup.py
4
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '1.4.4',
|
||||
version = '1.5.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
@@ -16,7 +16,7 @@ setup(
|
||||
'image recognition'
|
||||
],
|
||||
install_requires=[
|
||||
'einops>=0.6.1',
|
||||
'einops>=0.7.0',
|
||||
'torch>=1.10',
|
||||
'torchvision'
|
||||
],
|
||||
|
||||
@@ -115,7 +115,7 @@ class CrossTransformer(nn.Module):
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
ProjectInOut(lg_dim, sm_dim, ttention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
|
||||
ProjectInOut(lg_dim, sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
|
||||
]))
|
||||
|
||||
def forward(self, sm_tokens, lg_tokens):
|
||||
|
||||
@@ -173,7 +173,7 @@ class Attention(nn.Module):
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
|
||||
339
vit_pytorch/max_vit_with_registers.py
Normal file
339
vit_pytorch/max_vit_with_registers.py
Normal file
@@ -0,0 +1,339 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList, Sequential
|
||||
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
def pack_one(x, pattern):
|
||||
return pack([x], pattern)
|
||||
|
||||
def unpack_one(x, ps, pattern):
|
||||
return unpack(x, ps, pattern)[0]
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
# helper classes
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
inner_dim = int(dim * mult)
|
||||
return Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# MBConv
|
||||
|
||||
class SqueezeExcitation(Module):
|
||||
def __init__(self, dim, shrinkage_rate = 0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = Sequential(
|
||||
Reduce('b c h w -> b c', 'mean'),
|
||||
nn.Linear(dim, hidden_dim, bias = False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias = False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange('b c -> b c 1 1')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
class MBConvResidual(Module):
|
||||
def __init__(self, fn, dropout = 0.):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
class Dropsample(Module):
|
||||
def __init__(self, prob = 0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0. or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
def MBConv(
|
||||
dim_in,
|
||||
dim_out,
|
||||
*,
|
||||
downsample,
|
||||
expansion_rate = 4,
|
||||
shrinkage_rate = 0.25,
|
||||
dropout = 0.
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout = dropout)
|
||||
|
||||
return net
|
||||
|
||||
# attention related classes
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head = 32,
|
||||
dropout = 0.,
|
||||
window_size = 7
|
||||
):
|
||||
super().__init__()
|
||||
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
|
||||
|
||||
self.attend = nn.Sequential(
|
||||
nn.Softmax(dim = -1),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias = False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
|
||||
|
||||
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
device, h = x.device, self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add positional bias
|
||||
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
bias = rearrange(bias, 'i j h -> h i j')
|
||||
|
||||
num_registers = sim.shape[-1] - bias.shape[-1]
|
||||
bias = F.pad(bias, (num_registers, 0, num_registers, 0), value = 0.)
|
||||
|
||||
sim = sim + bias
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class MaxViT(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_classes,
|
||||
dim,
|
||||
depth,
|
||||
dim_head = 32,
|
||||
dim_conv_stem = None,
|
||||
window_size = 7,
|
||||
mbconv_expansion_rate = 4,
|
||||
mbconv_shrinkage_rate = 0.25,
|
||||
dropout = 0.1,
|
||||
channels = 3,
|
||||
num_register_tokens = 4
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
|
||||
|
||||
# convolutional stem
|
||||
|
||||
dim_conv_stem = default(dim_conv_stem, dim)
|
||||
|
||||
self.conv_stem = Sequential(
|
||||
nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
|
||||
nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
|
||||
)
|
||||
|
||||
# variables
|
||||
|
||||
num_stages = len(depth)
|
||||
|
||||
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
|
||||
dims = (dim_conv_stem, *dims)
|
||||
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
# window size
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
self.register_tokens = nn.ParameterList([])
|
||||
|
||||
# iterate through stages
|
||||
|
||||
for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
|
||||
for stage_ind in range(layer_depth):
|
||||
is_first = stage_ind == 0
|
||||
stage_dim_in = layer_dim_in if is_first else layer_dim
|
||||
|
||||
conv = MBConv(
|
||||
stage_dim_in,
|
||||
layer_dim,
|
||||
downsample = is_first,
|
||||
expansion_rate = mbconv_expansion_rate,
|
||||
shrinkage_rate = mbconv_shrinkage_rate
|
||||
)
|
||||
|
||||
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
|
||||
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
|
||||
|
||||
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
|
||||
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
|
||||
|
||||
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
conv,
|
||||
ModuleList([block_attn, block_ff]),
|
||||
ModuleList([grid_attn, grid_ff])
|
||||
]))
|
||||
|
||||
self.register_tokens.append(register_tokens)
|
||||
|
||||
# mlp head out
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
Reduce('b d h w -> b d', 'mean'),
|
||||
nn.LayerNorm(dims[-1]),
|
||||
nn.Linear(dims[-1], num_classes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, w = x.shape[0], self.window_size
|
||||
|
||||
x = self.conv_stem(x)
|
||||
|
||||
for (conv, (block_attn, block_ff), (grid_attn, grid_ff)), register_tokens in zip(self.layers, self.register_tokens):
|
||||
x = conv(x)
|
||||
|
||||
# block-like attention
|
||||
|
||||
x = rearrange(x, 'b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w)
|
||||
|
||||
# prepare register tokens
|
||||
|
||||
r = repeat(register_tokens, 'n d -> b x y n d', b = b, x = x.shape[1],y = x.shape[2])
|
||||
r, register_batch_ps = pack_one(r, '* n d')
|
||||
|
||||
x, window_ps = pack_one(x, 'b x y * d')
|
||||
x, batch_ps = pack_one(x, '* n d')
|
||||
x, register_ps = pack([r, x], 'b * d')
|
||||
|
||||
x = block_attn(x) + x
|
||||
x = block_ff(x) + x
|
||||
|
||||
r, x = unpack(x, register_ps, 'b * d')
|
||||
|
||||
x = unpack_one(x, batch_ps, '* n d')
|
||||
x = unpack_one(x, window_ps, 'b x y * d')
|
||||
x = rearrange(x, 'b x y w1 w2 d -> b d (x w1) (y w2)')
|
||||
|
||||
r = unpack_one(r, register_batch_ps, '* n d')
|
||||
|
||||
# grid-like attention
|
||||
|
||||
x = rearrange(x, 'b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w)
|
||||
|
||||
# prepare register tokens
|
||||
|
||||
r = reduce(r, 'b x y n d -> b n d', 'mean')
|
||||
r = repeat(r, 'b n d -> b x y n d', x = x.shape[1], y = x.shape[2])
|
||||
r, register_batch_ps = pack_one(r, '* n d')
|
||||
|
||||
x, window_ps = pack_one(x, 'b x y * d')
|
||||
x, batch_ps = pack_one(x, '* n d')
|
||||
x, register_ps = pack([r, x], 'b * d')
|
||||
|
||||
x = grid_attn(x) + x
|
||||
|
||||
r, x = unpack(x, register_ps, 'b * d')
|
||||
|
||||
x = grid_ff(x) + x
|
||||
|
||||
x = unpack_one(x, batch_ps, '* n d')
|
||||
x = unpack_one(x, window_ps, 'b x y * d')
|
||||
x = rearrange(x, 'b x y w1 w2 d -> b d (w1 x) (w2 y)')
|
||||
|
||||
return self.mlp_head(x)
|
||||
134
vit_pytorch/simple_vit_with_register_tokens.py
Normal file
134
vit_pytorch/simple_vit_with_register_tokens.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Vision Transformers Need Registers
|
||||
https://arxiv.org/abs/2309.16588
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
|
||||
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
|
||||
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
|
||||
omega = torch.arange(dim // 4) / (dim // 4 - 1)
|
||||
omega = 1.0 / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return self.norm(x)
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, num_register_tokens = 4, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
|
||||
|
||||
self.pos_embedding = posemb_sincos_2d(
|
||||
h = image_height // patch_height,
|
||||
w = image_width // patch_width,
|
||||
dim = dim,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.pool = "mean"
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.linear_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch, device = img.shape[0], img.device
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
x += self.pos_embedding.to(device, dtype=x.dtype)
|
||||
|
||||
r = repeat(self.register_tokens, 'n d -> b n d', b = batch)
|
||||
|
||||
x, ps = pack([x, r], 'b * d')
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x, _ = unpack(x, ps, 'b * d')
|
||||
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
Reference in New Issue
Block a user