mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-10 01:28:40 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f1a6a05e9 | ||
|
|
9a95e7904e | ||
|
|
b4853d39c2 | ||
|
|
29fbf0aff4 | ||
|
|
4b8f5bc900 | ||
|
|
f86e052c05 | ||
|
|
2fa2b62def | ||
|
|
9f87d1c43b | ||
|
|
2c6dd7010a | ||
|
|
6460119f65 | ||
|
|
4e62e5f05e | ||
|
|
b3e90a2652 |
136
README.md
136
README.md
@@ -6,6 +6,7 @@
|
||||
- [Install](#install)
|
||||
- [Usage](#usage)
|
||||
- [Parameters](#parameters)
|
||||
- [Simple ViT](#simple-vit)
|
||||
- [Distillation](#distillation)
|
||||
- [Deep ViT](#deep-vit)
|
||||
- [CaiT](#cait)
|
||||
@@ -29,6 +30,7 @@
|
||||
- [Adaptive Token Sampling](#adaptive-token-sampling)
|
||||
- [Patch Merger](#patch-merger)
|
||||
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
|
||||
- [3D Vit](#3d-vit)
|
||||
- [Parallel ViT](#parallel-vit)
|
||||
- [Learnable Memory ViT](#learnable-memory-vit)
|
||||
- [Dino](#dino)
|
||||
@@ -51,6 +53,8 @@ The official Jax repository is <a href="https://github.com/google-research/visio
|
||||
|
||||
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
|
||||
|
||||
<a href="https://github.com/conceptofmind/vit-flax">Flax translation</a> by <a href="https://github.com/conceptofmind">Enrico Shippole</a>!
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -106,6 +110,33 @@ Embedding dropout rate.
|
||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||
|
||||
|
||||
## Simple ViT
|
||||
|
||||
<a href="https://arxiv.org/abs/2205.01580">An update</a> from some of the same authors of the original paper proposes simplifications to `ViT` that allows it to train faster and better.
|
||||
|
||||
Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head
|
||||
|
||||
You can use it by importing the `SimpleViT` as shown below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch import SimpleViT
|
||||
|
||||
v = SimpleViT(
|
||||
image_size = 256,
|
||||
patch_size = 32,
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 16,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
|
||||
preds = v(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## Distillation
|
||||
|
||||
<img src="./images/distill.png" width="300px"></img>
|
||||
@@ -937,6 +968,60 @@ img = torch.randn(4, 3, 256, 256)
|
||||
tokens = spt(img) # (4, 256, 1024)
|
||||
```
|
||||
|
||||
## 3D ViT
|
||||
|
||||
By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.
|
||||
|
||||
You will need to pass in two additional hyperparameters: (1) the number of frames `frames` and (2) patch size along the frame dimension `frame_patch_size`
|
||||
|
||||
For starters, 3D ViT
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.vit_3d import ViT
|
||||
|
||||
v = ViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
3D Simple ViT
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.simple_vit_3d import SimpleViT
|
||||
|
||||
v = SimpleViT(
|
||||
image_size = 128, # image size
|
||||
frames = 16, # number of frames
|
||||
image_patch_size = 16, # image patch size
|
||||
frame_patch_size = 2, # frame patch size
|
||||
num_classes = 1000,
|
||||
dim = 1024,
|
||||
depth = 6,
|
||||
heads = 8,
|
||||
mlp_dim = 2048
|
||||
)
|
||||
|
||||
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
|
||||
|
||||
preds = v(video) # (4, 1000)
|
||||
```
|
||||
|
||||
## Parallel ViT
|
||||
|
||||
<img src="./images/parallel-vit.png" width="350px"></img>
|
||||
@@ -1227,6 +1312,47 @@ logits, embeddings = v(img)
|
||||
embeddings # (1, 65, 1024) - (batch x patches x model dim)
|
||||
```
|
||||
|
||||
Or say for `CrossViT`, which has a multi-scale encoder that outputs two sets of embeddings for 'large' and 'small' scales
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.cross_vit import CrossViT
|
||||
|
||||
v = CrossViT(
|
||||
image_size = 256,
|
||||
num_classes = 1000,
|
||||
depth = 4,
|
||||
sm_dim = 192,
|
||||
sm_patch_size = 16,
|
||||
sm_enc_depth = 2,
|
||||
sm_enc_heads = 8,
|
||||
sm_enc_mlp_dim = 2048,
|
||||
lg_dim = 384,
|
||||
lg_patch_size = 64,
|
||||
lg_enc_depth = 3,
|
||||
lg_enc_heads = 8,
|
||||
lg_enc_mlp_dim = 2048,
|
||||
cross_attn_depth = 2,
|
||||
cross_attn_heads = 8,
|
||||
dropout = 0.1,
|
||||
emb_dropout = 0.1
|
||||
)
|
||||
|
||||
# wrap the CrossViT
|
||||
|
||||
from vit_pytorch.extractor import Extractor
|
||||
v = Extractor(v, layer_name = 'multi_scale_encoder') # take embedding coming from the output of multi-scale-encoder
|
||||
|
||||
# forward pass now returns predictions and the attention maps
|
||||
|
||||
img = torch.randn(1, 3, 256, 256)
|
||||
logits, embeddings = v(img)
|
||||
|
||||
# there is one extra token due to the CLS token
|
||||
|
||||
embeddings # ((1, 257, 192), (1, 17, 384)) - (batch x patches x dimension) <- large and small scales respectively
|
||||
```
|
||||
|
||||
## Research Ideas
|
||||
|
||||
### Efficient Attention
|
||||
@@ -1669,6 +1795,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{Beyer2022BetterPlainViT
|
||||
title = {Better plain ViT baselines for ImageNet-1k},
|
||||
author = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
|
||||
publisher = {arXiv},
|
||||
year = {2022}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{vaswani2017attention,
|
||||
title = {Attention Is All You Need},
|
||||
|
||||
3
setup.py
3
setup.py
@@ -3,9 +3,10 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.34.1',
|
||||
version = '0.36.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
long_description_content_type = 'text/markdown',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
url = 'https://github.com/lucidrains/vit-pytorch',
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from vit_pytorch.vit import ViT
|
||||
from vit_pytorch.simple_vit import SimpleViT
|
||||
|
||||
from vit_pytorch.mae import MAE
|
||||
from vit_pytorch.dino import Dino
|
||||
|
||||
@@ -4,14 +4,27 @@ from torch import nn
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def identity(t):
|
||||
return t
|
||||
|
||||
def clone_and_detach(t):
|
||||
return t.clone().detach()
|
||||
|
||||
def apply_tuple_or_single(fn, val):
|
||||
if isinstance(val, tuple):
|
||||
return tuple(map(fn, val))
|
||||
return fn(val)
|
||||
|
||||
class Extractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vit,
|
||||
device = None,
|
||||
layer = None,
|
||||
layer_name = 'transformer',
|
||||
layer_save_input = False,
|
||||
return_embeddings_only = False
|
||||
return_embeddings_only = False,
|
||||
detach = True
|
||||
):
|
||||
super().__init__()
|
||||
self.vit = vit
|
||||
@@ -23,17 +36,24 @@ class Extractor(nn.Module):
|
||||
self.ejected = False
|
||||
self.device = device
|
||||
|
||||
self.layer = layer
|
||||
self.layer_name = layer_name
|
||||
self.layer_save_input = layer_save_input # whether to save input or output of layer
|
||||
self.return_embeddings_only = return_embeddings_only
|
||||
|
||||
self.detach_fn = clone_and_detach if detach else identity
|
||||
|
||||
def _hook(self, _, inputs, output):
|
||||
tensor_to_save = inputs if self.layer_save_input else output
|
||||
self.latents = tensor_to_save.clone().detach()
|
||||
layer_output = inputs if self.layer_save_input else output
|
||||
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
|
||||
|
||||
def _register_hook(self):
|
||||
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
|
||||
layer = getattr(self.vit, self.layer_name)
|
||||
if not exists(self.layer):
|
||||
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
|
||||
layer = getattr(self.vit, self.layer_name)
|
||||
else:
|
||||
layer = self.layer
|
||||
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hooks.append(handle)
|
||||
self.hook_registered = True
|
||||
@@ -62,7 +82,7 @@ class Extractor(nn.Module):
|
||||
pred = self.vit(img)
|
||||
|
||||
target_device = self.device if exists(self.device) else img.device
|
||||
latents = self.latents.to(target_device)
|
||||
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)
|
||||
|
||||
if return_embeddings_only or self.return_embeddings_only:
|
||||
return latents
|
||||
|
||||
@@ -28,7 +28,7 @@ class MAE(nn.Module):
|
||||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]
|
||||
|
||||
# decoder parameters
|
||||
|
||||
self.decoder_dim = decoder_dim
|
||||
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
|
||||
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
|
||||
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
|
||||
@@ -73,7 +73,7 @@ class MAE(nn.Module):
|
||||
|
||||
# reapply decoder position embedding to unmasked tokens
|
||||
|
||||
decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
|
||||
|
||||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
|
||||
|
||||
@@ -81,13 +81,15 @@ class MAE(nn.Module):
|
||||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
|
||||
|
||||
# concat the masked tokens to the decoder tokens and attend with decoder
|
||||
|
||||
decoder_tokens = torch.cat((mask_tokens, decoder_tokens), dim = 1)
|
||||
|
||||
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device)
|
||||
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
|
||||
decoder_tokens[batch_range, masked_indices] = mask_tokens
|
||||
decoded_tokens = self.decoder(decoder_tokens)
|
||||
|
||||
# splice out the mask tokens and project to pixel values
|
||||
|
||||
mask_tokens = decoded_tokens[:, :num_masked]
|
||||
mask_tokens = decoded_tokens[batch_range, masked_indices]
|
||||
pred_pixel_values = self.to_pixels(mask_tokens)
|
||||
|
||||
# calculate reconstruction loss
|
||||
|
||||
@@ -100,12 +100,14 @@ def MBConv(
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
|
||||
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
|
||||
nn.Conv2d(dim_out, dim_out, 1),
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
|
||||
116
vit_pytorch/simple_vit.py
Normal file
116
vit_pytorch/simple_vit.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
|
||||
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
|
||||
omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_2d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
128
vit_pytorch/simple_vit_3d.py
Normal file
128
vit_pytorch/simple_vit_3d.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
def posemb_sincos_3d(patches, temperature = 10000, dtype = torch.float32):
|
||||
_, f, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
|
||||
|
||||
z, y, x = torch.meshgrid(
|
||||
torch.arange(f, device = device),
|
||||
torch.arange(h, device = device),
|
||||
torch.arange(w, device = device),
|
||||
indexing = 'ij')
|
||||
|
||||
fourier_dim = dim // 6
|
||||
|
||||
omega = torch.arange(fourier_dim, device = device) / (fourier_dim - 1)
|
||||
omega = 1. / (temperature ** omega)
|
||||
|
||||
z = z.flatten()[:, None] * omega[None, :]
|
||||
y = y.flatten()[:, None] * omega[None, :]
|
||||
x = x.flatten()[:, None] * omega[None, :]
|
||||
|
||||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos(), z.sin(), z.cos()), dim = 1)
|
||||
|
||||
pe = F.pad(pe, (0, dim - (fourier_dim * 6))) # pad if feature dimension not cleanly divisible by 6
|
||||
return pe.type(dtype)
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head),
|
||||
FeedForward(dim, mlp_dim)
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class SimpleViT(nn.Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by the frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
|
||||
|
||||
self.to_latent = nn.Identity()
|
||||
self.linear_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
*_, h, w, dtype = *img.shape, img.dtype
|
||||
|
||||
x = self.to_patch_embedding(img)
|
||||
pe = posemb_sincos_3d(x)
|
||||
x = rearrange(x, 'b ... d -> b (...) d') + pe
|
||||
|
||||
x = self.transformer(x)
|
||||
x = x.mean(dim = 1)
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.linear_head(x)
|
||||
@@ -114,7 +114,7 @@ class ViT(nn.Module):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
129
vit_pytorch/vit_3d.py
Normal file
129
vit_pytorch/vit_3d.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
# classes
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
def forward(self, x, **kwargs):
|
||||
return self.fn(self.norm(x), **kwargs)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
project_out = not (heads == 1 and dim_head == dim)
|
||||
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout)
|
||||
) if project_out else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
|
||||
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
|
||||
]))
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
return x
|
||||
|
||||
class ViT(nn.Module):
|
||||
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
patch_height, patch_width = pair(image_patch_size)
|
||||
|
||||
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
|
||||
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.Linear(patch_dim, dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
|
||||
self.pool = pool
|
||||
self.to_latent = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
x = self.to_latent(x)
|
||||
return self.mlp_head(x)
|
||||
Reference in New Issue
Block a user