Compare commits

..

1 Commits
1.4.4 ... 1.2.6

Author SHA1 Message Date
Phil Wang
57a3862b7b release NaViT 2023-07-24 13:54:31 -07:00
30 changed files with 292 additions and 425 deletions

View File

@@ -7,7 +7,7 @@
- [Usage](#usage)
- [Parameters](#parameters)
- [Simple ViT](#simple-vit)
- [NaViT](#navit)
- [NaViT](#na-vit)
- [Distillation](#distillation)
- [Deep ViT](#deep-vit)
- [CaiT](#cait)
@@ -142,7 +142,7 @@ preds = v(img) # (1, 1000)
## NaViT
<img src="./images/navit.png" width="450px"></img>
<img src="./images/na_vit.png" width="450px"></img>
<a href="https://arxiv.org/abs/2307.06304">This paper</a> proposes to leverage the flexibility of attention and masking for variable lengthed sequences to train images of multiple resolution, packed into a single batch. They demonstrate much faster training and improved accuracies, with the only cost being extra complexity in the architecture and dataloading. They use factorized 2d positional encodings, token dropping, as well as query-key normalization.
@@ -161,8 +161,7 @@ v = NaViT(
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1,
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
emb_dropout = 0.1
)
# 5 images of different resolutions - List[List[Tensor]]
@@ -179,24 +178,6 @@ preds = v(images) # (5, 1000) - 5, because 5 images of different resolution abov
```
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
```python
images = [
torch.randn(3, 256, 256),
torch.randn(3, 128, 128),
torch.randn(3, 128, 256),
torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]
preds = v(
images,
group_images = True,
group_max_seq_len = 64
) # (5, 1000)
```
## Distillation
<img src="./images/distill.png" width="300px"></img>

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.4.4',
version = '1.2.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',

View File

@@ -110,11 +110,18 @@ class AdaptiveTokenSampling(nn.Module):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -131,7 +138,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -148,7 +154,6 @@ class Attention(nn.Module):
def forward(self, x, *, mask):
num_tokens = x.shape[1]
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -184,8 +189,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
self.layers.append(nn.ModuleList([
Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):

View File

@@ -44,11 +44,18 @@ class LayerScale(nn.Module):
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -65,7 +72,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -83,7 +89,6 @@ class Attention(nn.Module):
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
@@ -110,8 +115,8 @@ class Transformer(nn.Module):
for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)

View File

@@ -13,13 +13,22 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
# pre-layernorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feedforward
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -38,7 +47,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -52,7 +60,6 @@ class Attention(nn.Module):
def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = default(context, x)
if kv_include_self:
@@ -79,8 +86,8 @@ class Transformer(nn.Module):
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
@@ -114,8 +121,8 @@ class CrossTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ProjectInOut(lg_dim, sm_dim, ttention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
]))
def forward(self, sm_tokens, lg_tokens):

View File

@@ -34,11 +34,19 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -67,7 +75,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -82,8 +89,6 @@ class Attention(nn.Module):
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
@@ -102,8 +107,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -5,11 +5,25 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -26,7 +40,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.dropout = nn.Dropout(dropout)
@@ -46,8 +59,6 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -75,13 +86,13 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
x = attn(x)
x = ff(x)
return x
class DeepViT(nn.Module):

View File

@@ -26,6 +26,16 @@ class ExcludeCLS(nn.Module):
x = self.fn(x, **kwargs)
return torch.cat((cls_token, x), dim = 1)
# prenorm
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# feed forward related classes
class DepthWiseConv2d(nn.Module):
@@ -42,7 +52,6 @@ class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Conv2d(dim, hidden_dim, 1),
nn.Hardswish(),
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
@@ -68,7 +77,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -80,8 +88,6 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -100,8 +106,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -19,20 +19,20 @@ def cast_tuple(val, length = 1):
# helper classes
class Residual(nn.Module):
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(x) + x
return self.fn(self.norm(x)) + x
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -132,7 +132,6 @@ class Attention(nn.Module):
self.heads = dim // dim_head
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.attend = nn.Sequential(
@@ -161,8 +160,6 @@ class Attention(nn.Module):
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
x = self.norm(x)
# flatten
x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
@@ -262,13 +259,13 @@ class MaxViT(nn.Module):
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)

View File

@@ -22,11 +22,20 @@ def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
@@ -44,7 +53,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
@@ -56,10 +64,9 @@ class Attention(nn.Module):
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
q, k, v = map(lambda t: rearrange(
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
@@ -81,8 +88,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads, dim_head, dropout),
FeedForward(dim, mlp_dim, dropout)
PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
]))
def forward(self, x):
@@ -160,9 +167,11 @@ class MobileViTBlock(nn.Module):
# Global representations
_, _, h, w = x.shape
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
ph=self.ph, pw=self.pw)
x = self.transformer(x)
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
# Fusion
x = self.conv3(x)

View File

@@ -1,5 +1,5 @@
from functools import partial
from typing import List, Union
from typing import List
import torch
import torch.nn.functional as F
@@ -17,58 +17,12 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
def always(val):
return lambda *args: val
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def divisible_by(numer, denom):
return (numer % denom) == 0
# auto grouping images
def group_images_by_max_seq_len(
images: List[Tensor],
patch_size: int,
calc_token_dropout = None,
max_seq_len = 2048
) -> List[List[Tensor]]:
calc_token_dropout = default(calc_token_dropout, always(0.))
groups = []
group = []
seq_len = 0
if isinstance(calc_token_dropout, (float, int)):
calc_token_dropout = always(calc_token_dropout)
for image in images:
assert isinstance(image, Tensor)
image_dims = image.shape[-2:]
ph, pw = map(lambda t: t // patch_size, image_dims)
image_seq_len = (ph * pw)
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
if (seq_len + image_seq_len) > max_seq_len:
groups.append(group)
group = []
seq_len = 0
group.append(image)
seq_len += image_seq_len
if len(group) > 0:
groups.append(group)
return groups
# normalization
# they use layernorm without bias, something that pytorch does not offer
@@ -184,26 +138,10 @@ class Transformer(nn.Module):
return self.norm(x)
class NaViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width
self.calc_token_dropout = None
if callable(token_dropout_prob):
self.calc_token_dropout = token_dropout_prob
elif isinstance(token_dropout_prob, (float, int)):
assert 0. < token_dropout_prob < 1.
token_dropout_prob = float(token_dropout_prob)
self.calc_token_dropout = lambda height, width: token_dropout_prob
# calculate patching related stuff
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
@@ -245,25 +183,13 @@ class NaViT(nn.Module):
def forward(
self,
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
group_images = False,
group_max_seq_len = 2048
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
):
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
p, c, device = self.patch_size, self.channels, self.device
arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
# auto pack if specified
if group_images:
batched_images = group_images_by_max_seq_len(
batched_images,
patch_size = self.patch_size,
calc_token_dropout = self.calc_token_dropout,
max_seq_len = group_max_seq_len
)
# process images into variable lengthed sequences with attention mask
num_images = []
@@ -293,16 +219,6 @@ class NaViT(nn.Module):
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
seq_len = seq.shape[-2]
if has_token_dropout:
token_dropout = self.calc_token_dropout(*image_dims)
num_keep = max(1, int(seq_len * (1 - token_dropout)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
seq = seq[keep_indices]
pos = pos[keep_indices]
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)

View File

@@ -24,11 +24,19 @@ class LayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -46,7 +54,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
@@ -59,8 +66,6 @@ class Attention(nn.Module):
def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)
@@ -88,8 +93,8 @@ class Transformer(nn.Module):
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
*_, h, w = x.shape

View File

@@ -19,11 +19,18 @@ class Parallel(nn.Module):
def forward(self, x):
return sum([fn(x) for fn in self.fns])
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -42,7 +49,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -54,7 +60,6 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -72,8 +77,8 @@ class Transformer(nn.Module):
super().__init__()
self.layers = nn.ModuleList([])
attn_block = lambda: Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)
ff_block = lambda: FeedForward(dim, mlp_dim, dropout = dropout)
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
for _ in range(depth):
self.layers.append(nn.ModuleList([

View File

@@ -17,11 +17,18 @@ def conv_output_size(image_size, kernel_size, stride, padding = 0):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -40,7 +47,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
@@ -52,8 +58,6 @@ class Attention(nn.Module):
def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
@@ -72,8 +76,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -55,6 +55,14 @@ class DepthWiseConv2d(nn.Module):
# helper classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SpatialConv(nn.Module):
def __init__(self, dim_in, dim_out, kernel, bias = False):
super().__init__()
@@ -78,7 +86,6 @@ class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
GEGLU() if use_glu else nn.GELU(),
nn.Dropout(dropout),
@@ -96,7 +103,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -115,9 +121,6 @@ class Attention(nn.Module):
b, n, _, h = *x.shape, self.heads
to_q_kwargs = {'fmap_dims': fmap_dims} if self.use_ds_conv else {}
x = self.norm(x)
q = self.to_q(x, **to_q_kwargs)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))
@@ -159,8 +162,8 @@ class Transformer(nn.Module):
self.pos_emb = AxialRotaryEmbedding(dim_head, max_freq = image_size)
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv),
FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
]))
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])

View File

@@ -33,6 +33,15 @@ class ChanLayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
@@ -56,7 +65,6 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = dim * expansion_factor
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -84,7 +92,6 @@ class ScalableSelfAttention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)
@@ -97,8 +104,6 @@ class ScalableSelfAttention(nn.Module):
def forward(self, x):
height, width, heads = *x.shape[-2:], self.heads
x = self.norm(x)
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
# split out heads
@@ -140,7 +145,6 @@ class InteractiveWindowedSelfAttention(nn.Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
@@ -155,8 +159,6 @@ class InteractiveWindowedSelfAttention(nn.Module):
def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
x = self.norm(x)
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
@@ -215,11 +217,11 @@ class Transformer(nn.Module):
is_first = ind == 0
self.layers.append(nn.ModuleList([
ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout),
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PEG(dim) if is_first else None,
FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout),
InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout)
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

View File

@@ -25,6 +25,15 @@ class ChanLayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class OverlappingPatchEmbed(nn.Module):
def __init__(self, dim_in, dim_out, stride = 2):
super().__init__()
@@ -50,7 +59,6 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -77,8 +85,6 @@ class DSSA(nn.Module):
self.window_size = window_size
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.attend = nn.Sequential(
nn.Softmax(dim = -1),
nn.Dropout(dropout)
@@ -132,8 +138,6 @@ class DSSA(nn.Module):
assert (height % wsz) == 0 and (width % wsz) == 0, f'height {height} and width {width} must be divisible by window size {wsz}'
num_windows = (height // wsz) * (width // wsz)
x = self.norm(x)
# fold in windows for "depthwise" attention - not sure why it is named depthwise when it is just "windowed" attention
x = rearrange(x, 'b c (h w1) (w w2) -> (b h w) c (w1 w2)', w1 = wsz, w2 = wsz)
@@ -221,8 +225,8 @@ class Transformer(nn.Module):
for ind in range(depth):
self.layers.append(nn.ModuleList([
DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout),
PreNorm(dim, DSSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = dropout)),
]))
self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

View File

@@ -64,7 +64,6 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -75,7 +74,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
@@ -102,10 +101,12 @@ class SimpleViT(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.pool = "mean"
def forward(self, img):
device = img.device

View File

@@ -62,7 +62,6 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -73,7 +72,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
return x
class SimpleViT(nn.Module):
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
@@ -94,7 +93,10 @@ class SimpleViT(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, series):
*_, n, dtype = *series.shape, series.dtype

View File

@@ -77,7 +77,6 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -88,7 +87,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
@@ -112,7 +111,10 @@ class SimpleViT(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype

View File

@@ -87,7 +87,6 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
@@ -98,7 +97,7 @@ class Transformer(nn.Module):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
return x
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
@@ -123,7 +122,10 @@ class SimpleViT(nn.Module):
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Linear(dim, num_classes)
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype

View File

@@ -1,141 +0,0 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper
# in latest tweet, seem to claim more stable training at higher learning rates
# unsure if this has taken off within Brain, or it has some hidden drawback
class RMSNorm(nn.Module):
def __init__(self, heads, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, 1, dim) / self.scale)
def forward(self, x):
normed = F.normalize(x, dim = -1)
return normed * self.scale * self.gamma
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
q = self.q_norm(q)
k = self.k_norm(k)
dots = torch.matmul(q, k.transpose(-1, -2))
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = posemb_sincos_2d(
h = image_height // patch_height,
w = image_width // patch_width,
dim = dim,
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.pool = "mean"
self.to_latent = nn.Identity()
self.linear_head = nn.LayerNorm(dim)
def forward(self, img):
device = img.device
x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)

View File

@@ -42,11 +42,20 @@ class LayerNorm(nn.Module):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
@@ -90,7 +99,6 @@ class LocalAttention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias = False)
@@ -100,8 +108,6 @@ class LocalAttention(nn.Module):
)
def forward(self, fmap):
fmap = self.norm(fmap)
shape, p = fmap.shape, self.patch_size
b, n, x, y, h = *shape, self.heads
x, y = map(lambda t: t // p, (x, y))
@@ -126,8 +132,6 @@ class GlobalAttention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = LayerNorm(dim)
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
@@ -139,8 +143,6 @@ class GlobalAttention(nn.Module):
)
def forward(self, x):
x = self.norm(x)
shape = x.shape
b, n, _, y, h = *shape, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
@@ -162,10 +164,10 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size)) if has_local else nn.Identity(),
Residual(FeedForward(dim, mlp_mult, dropout = dropout)) if has_local else nn.Identity(),
Residual(GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k)),
Residual(FeedForward(dim, mlp_mult, dropout = dropout))
Residual(PreNorm(dim, LocalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, patch_size = local_patch_size))) if has_local else nn.Identity(),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))) if has_local else nn.Identity(),
Residual(PreNorm(dim, GlobalAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout, k = global_k))),
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
]))
def forward(self, x):
for local_attn, ff1, global_attn, ff2 in self.layers:

View File

@@ -11,18 +11,24 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
@@ -35,8 +41,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -48,8 +52,6 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -65,20 +67,17 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
@@ -108,7 +107,10 @@ class ViT(nn.Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)

View File

@@ -6,11 +6,18 @@ from einops.layers.torch import Rearrange
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Layernorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -29,7 +36,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -41,7 +47,6 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -60,8 +65,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -11,11 +11,18 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -34,7 +41,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -46,7 +52,6 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -65,8 +70,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -13,11 +13,18 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -34,7 +41,6 @@ class LSA(nn.Module):
self.heads = heads
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -46,7 +52,6 @@ class LSA(nn.Module):
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -69,8 +74,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -30,11 +30,18 @@ class PatchDropout(nn.Module):
return x[batch_indices, patch_indices_keep]
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -53,7 +60,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -65,7 +71,6 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -84,8 +89,8 @@ class Transformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:

View File

@@ -32,11 +32,18 @@ class PatchMerger(nn.Module):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -55,7 +62,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -67,7 +73,6 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -83,7 +88,6 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
@@ -91,8 +95,8 @@ class Transformer(nn.Module):
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for index, (attn, ff) in enumerate(self.layers):
@@ -102,7 +106,7 @@ class Transformer(nn.Module):
if index == self.patch_merge_layer_index:
x = self.patch_merger(x)
return self.norm(x)
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
@@ -129,6 +133,7 @@ class ViT(nn.Module):
self.mlp_head = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

View File

@@ -14,11 +14,18 @@ def pair(t):
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
@@ -37,7 +44,6 @@ class Attention(nn.Module):
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
@@ -49,7 +55,6 @@ class Attention(nn.Module):
) if project_out else nn.Identity()
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
@@ -65,18 +70,17 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
return x
class ViT(nn.Module):
def __init__(
@@ -133,7 +137,10 @@ class ViT(nn.Module):
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, video):
x = self.to_patch_embedding(video)