Compare commits

..

14 Commits

Author SHA1 Message Date
Phil Wang
9e3fec2398 fix mpp 2023-06-28 08:02:43 -07:00
Phil Wang
ce4bcd08fb address https://github.com/lucidrains/vit-pytorch/issues/266 2023-05-20 08:24:49 -07:00
Phil Wang
ad4ca19775 enforce latest einops 2023-05-08 09:34:14 -07:00
Phil Wang
e1b08c15b9 fix tests 2023-03-19 10:52:47 -07:00
Phil Wang
c59843d7b8 add a version of simple vit using flash attention 2023-03-18 09:41:39 -07:00
lucidrains
9a8e509b27 separate a simple vit from mp3, so that simple vit can be used after being pretrained 2023-03-07 19:31:10 -08:00
Phil Wang
258dd8c7c6 release mp3, contributed by @Vishu26 2023-03-07 14:29:45 -08:00
Srikumar Sastry
4218556acd Add Masked Position Prediction (#260)
* Create mp3.py

* Implementation: Position Prediction as an Effective Pretraining Strategy

* Added description for Masked Position Prediction

* MP3 image added
2023-03-07 14:28:40 -08:00
Phil Wang
f621c2b041 typo 2023-03-04 20:30:02 -08:00
Phil Wang
5699ed7d13 double down on dual patch norm, fix MAE and Simmim to be compatible with dual patchnorm 2023-02-10 10:39:50 -08:00
Phil Wang
46dcaf23d8 seeing a signal with dual patchnorm in another repository, fully incorporate 2023-02-06 09:45:12 -08:00
Phil Wang
bdaf2d1491 adopt dual patchnorm paper for as many vit as applicable, release 1.0.0 2023-02-03 08:11:29 -08:00
Phil Wang
500e23105a need simple vit with patch dropout for another project 2022-12-05 10:47:36 -08:00
Phil Wang
89e1996c8b add vit with patch dropout, fully embrace structured dropout as multiple papers are now corroborating each other 2022-12-02 11:28:11 -08:00
29 changed files with 637 additions and 18 deletions

View File

@@ -27,6 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install wheel
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |

View File

@@ -27,6 +27,7 @@
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Masked Position Prediction](#masked-position-prediction)
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
@@ -303,7 +304,7 @@ cct = CCT(
pooling_padding = 1,
num_layers = 14,
num_heads = 6,
mlp_radio = 3.,
mlp_ratio = 3.,
num_classes = 1000,
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)
@@ -844,6 +845,44 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
## Masked Position Prediction
<img src="./images/mp3.png" width="400px"></img>
New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.
```python
import torch
from vit_pytorch.mp3 import ViT, MP3
v = ViT(
num_classes = 1000,
image_size = 256,
patch_size = 8,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
)
mp3 = MP3(
vit = v,
masking_ratio = 0.75
)
images = torch.randn(8, 3, 256, 256)
loss = mp3(images)
loss.backward()
# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn
# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')
```
## Adaptive Token Sampling
<img src="./images/ats.png" width="400px"></img>
@@ -1043,7 +1082,7 @@ cct = CCT(
pooling_padding = 1,
num_layers = 14,
num_heads = 6,
mlp_radio = 3.,
mlp_ratio = 3.,
num_classes = 1000,
positional_embedding = 'learnable'
)
@@ -1873,6 +1912,28 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@article{Liu2022PatchDropoutEV,
title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.07220}
}
```
```bibtex
@misc{https://doi.org/10.48550/arxiv.2302.01327,
doi = {10.48550/ARXIV.2302.01327},
url = {https://arxiv.org/abs/2302.01327},
author = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
title = {Dual PatchNorm},
publisher = {arXiv},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
@@ -1885,12 +1946,11 @@ Coming from computer vision and new to transformers? Here are some resources tha
```
```bibtex
@article{Liu2022PatchDropoutEV,
title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
journal = {ArXiv},
year = {2022},
volume = {abs/2208.07220}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```

BIN
images/mp3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 518 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.40.0',
version = '1.2.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
@@ -16,7 +16,7 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.6.0',
'einops>=0.6.1',
'torch>=1.10',
'torchvision'
],

View File

@@ -1,3 +1,10 @@
import torch
from packaging import version
if version.parse(torch.__version__) >= version.parse('2.0.0'):
from einops._torch_specific import allow_ops_in_compiled_graph
allow_ops_in_compiled_graph()
from vit_pytorch.vit import ViT
from vit_pytorch.simple_vit import SimpleViT

View File

@@ -230,7 +230,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -150,7 +150,9 @@ class CaiT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))

View File

@@ -186,7 +186,9 @@ class ImageEmbedder(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -105,7 +105,9 @@ class DeepViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -17,7 +17,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -118,7 +118,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -126,7 +126,9 @@ class LocalViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -24,8 +24,11 @@ class MAE(nn.Module):
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
self.to_patch = encoder.to_patch_embedding[0]
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
# decoder parameters
self.decoder_dim = decoder_dim

186
vit_pytorch/mp3.py Normal file
View File

@@ -0,0 +1,186 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# positional embedding
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# (cross)attention
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = self.norm(context) if exists(context) else x
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, num_classes, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.dim = dim
self.num_patches = num_patches
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
# Masked Position Prediction Pre-Training
class MP3(nn.Module):
def __init__(self, vit: ViT, masking_ratio):
super().__init__()
self.vit = vit
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
dim = vit.dim
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, vit.num_patches)
)
def forward(self, img):
device = img.device
tokens = self.vit.to_patch_embedding(img)
tokens = rearrange(tokens, 'b ... d -> b (...) d')
batch, num_patches, *_ = tokens.shape
# Masking
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
batch_range = torch.arange(batch, device = device)[:, None]
tokens_unmasked = tokens[batch_range, unmasked_indices]
attended_tokens = self.vit.transformer(tokens, tokens_unmasked)
logits = rearrange(self.mlp_head(attended_tokens), 'b n d -> (b n) d')
# Define labels
labels = repeat(torch.arange(num_patches, device = device), 'n -> (b n)', b = batch)
loss = F.cross_entropy(logits, labels)
return loss

View File

@@ -96,6 +96,9 @@ class MPP(nn.Module):
self.loss = MPPLoss(patch_size, channels, output_channel_bits,
max_pixel_val, mean, std)
# extract patching function
self.patch_to_emb = nn.Sequential(transformer.to_patch_embedding[1:])
# output transformation
self.to_bits = nn.Linear(dim, 2**(output_channel_bits * channels))
@@ -151,7 +154,7 @@ class MPP(nn.Module):
masked_input[bool_mask_replace] = self.mask_token
# linear embedding of patches
masked_input = transformer.to_patch_embedding[-1](masked_input)
masked_input = self.patch_to_emb(masked_input)
# add cls token to input sequence
b, n, _ = masked_input.shape

View File

@@ -144,7 +144,9 @@ class NesT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
LayerNorm(patch_dim),
nn.Conv2d(patch_dim, layer_dims[0], 1),
LayerNorm(layer_dims[0])
)
block_repeats = cast_tuple(block_repeats, num_hierarchies)

View File

@@ -18,8 +18,11 @@ class SimMIM(nn.Module):
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
self.to_patch = encoder.to_patch_embedding[0]
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
# simple linear head

View File

@@ -0,0 +1,176 @@
from collections import namedtuple
from packaging import version
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
# constants
Config = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# main class
class Attend(nn.Module):
def __init__(self, use_flash = False):
super().__init__()
self.use_flash = use_flash
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = Config(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
self.cuda_config = Config(True, False, False)
else:
self.cuda_config = Config(False, True, True)
def flash_attn(self, q, k, v):
config = self.cuda_config if q.is_cuda else self.cpu_config
# flash attention - https://arxiv.org/abs/2205.14135
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(q, k, v)
return out
def forward(self, q, k, v):
n, device, scale = q.shape[-2], q.device, q.shape[-1] ** -0.5
if self.use_flash:
return self.flash_attn(q, k, v)
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
# attention
attn = sim.softmax(dim=-1)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
return out
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, use_flash = True):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = Attend(use_flash = use_flash)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
out = self.attend(q, k, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_flash):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, use_flash = use_flash),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, use_flash = True):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_flash)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

View File

@@ -91,7 +91,9 @@ class SimpleViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -85,7 +85,9 @@ class SimpleViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (n p) -> b n (p c)', p = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -103,7 +103,9 @@ class SimpleViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

View File

@@ -0,0 +1,143 @@
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
omega = 1. / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
return pe.type(dtype)
# patch dropout
class PatchDropout(nn.Module):
def __init__(self, prob):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
def forward(self, x):
if not self.training or self.prob == 0.:
return x
b, n, _, device = *x.shape, x.device
batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, '... -> ... 1')
num_patches_keep = max(1, int(n * (1 - self.prob)))
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices
return x[batch_indices, patch_indices_keep]
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.patch_dropout = PatchDropout(patch_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
x = self.to_patch_embedding(img)
pe = posemb_sincos_2d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.patch_dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

View File

@@ -71,7 +71,12 @@ class PatchEmbedding(nn.Module):
self.dim = dim
self.dim_out = dim_out
self.patch_size = patch_size
self.proj = nn.Conv2d(patch_size ** 2 * dim, dim_out, 1)
self.proj = nn.Sequential(
LayerNorm(patch_size ** 2 * dim),
nn.Conv2d(patch_size ** 2 * dim, dim_out, 1),
LayerNorm(dim_out)
)
def forward(self, fmap):
p = self.patch_size

View File

@@ -93,7 +93,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -84,7 +84,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (n p) -> b n (p c)', p = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -95,7 +95,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -25,7 +25,7 @@ class PatchDropout(nn.Module):
batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, '... -> ... 1')
num_patches_keep = max(1, int(n * self.prob))
num_patches_keep = max(1, int(n * (1 - self.prob)))
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices
return x[batch_indices, patch_indices_keep]

View File

@@ -121,7 +121,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

View File

@@ -120,7 +120,9 @@ class ViT(nn.Module):
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
@@ -144,7 +146,7 @@ class ViT(nn.Module):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
x = x + self.pos_embedding
x = x + self.pos_embedding[:, :f, :n]
if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)