Compare commits

..

6 Commits

Author SHA1 Message Date
Phil Wang
64aae4680b correct need for post-attention dropout 2022-03-30 10:05:19 -07:00
Phil Wang
6d7298d8ad link to tensorflow2 translation by @taki0112 2022-03-28 09:05:34 -07:00
Phil Wang
9cd56ff29b CCT allow for rectangular images 2022-03-26 14:02:49 -07:00
Phil Wang
2aae406ce8 add proposed parallel vit from facebook ai for exploration purposes 2022-03-23 10:42:35 -07:00
Phil Wang
c2b2db2a54 fix window size of none for scalable vit for rectangular images 2022-03-22 17:37:59 -07:00
Phil Wang
719048d1bd some better defaults for scalable vit 2022-03-22 17:19:58 -07:00
23 changed files with 302 additions and 54 deletions

104
README.md
View File

@@ -27,6 +27,7 @@
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
- [Parallel ViT](#parallel-vit)
- [Dino](#dino)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
@@ -44,6 +45,8 @@ For a Pytorch implementation with pretrained models, please see Ross Wightman's
The official Jax repository is <a href="https://github.com/google-research/vision_transformer">here</a>.
A tensorflow2 translation also exists <a href="https://github.com/taki0112/vit-tensorflow">here</a>, created by research scientist <a href="https://github.com/taki0112">Junho Kim</a>! 🙏
## Install
```bash
@@ -240,6 +243,7 @@ preds = v(img) # (1, 1000)
```
## CCT
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
@@ -251,22 +255,25 @@ You can use this with two methods
import torch
from vit_pytorch.cct import CCT
model = CCT(
img_size=224,
embedding_dim=384,
n_conv_layers=2,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_layers=14,
num_heads=6,
mlp_radio=3.,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
cct = CCT(
img_size = (224, 448),
embedding_dim = 384,
n_conv_layers = 2,
kernel_size = 7,
stride = 2,
padding = 3,
pooling_kernel_size = 3,
pooling_stride = 2,
pooling_padding = 1,
num_layers = 14,
num_heads = 6,
mlp_radio = 3.,
num_classes = 1000,
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)
img = torch.randn(1, 3, 224, 448)
pred = cct(img) # (1, 1000)
```
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
@@ -277,23 +284,23 @@ and the embedding dimension.
import torch
from vit_pytorch.cct import cct_14
model = cct_14(
img_size=224,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
cct = cct_14(
img_size = 224,
n_conv_layers = 1,
kernel_size = 7,
stride = 2,
padding = 3,
pooling_kernel_size = 3,
pooling_stride = 2,
pooling_padding = 1,
num_classes = 1000,
positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)
```
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
Repository</a> includes links to pretrained model checkpoints.
## Cross ViT
<img src="./images/cross_vit.png" width="400px"></img>
@@ -866,6 +873,37 @@ img = torch.randn(4, 3, 256, 256)
tokens = spt(img) # (4, 256, 1024)
```
## Parallel ViT
<img src="./images/parallel-vit.png" width="350px"></img>
This <a href="https://arxiv.org/abs/2203.09795">paper</a> propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.
You can try this variant as follows
```python
import torch
from vit_pytorch.parallel_vit import ViT
v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
num_parallel_branches = 2, # in paper, they claimed 2 was optimal
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(4, 3, 256, 256)
preds = v(img) # (4, 1000)
```
## Dino
<img src="./images/dino.png" width="350px"></img>
@@ -1396,6 +1434,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
```bibtex
@inproceedings{Touvron2022ThreeTE,
title = {Three things everyone should know about Vision Transformers},
author = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou},
year = {2022}
}
```
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},

BIN
images/parallel-vit.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.28.0',
version = '0.30.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -139,6 +139,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.output_num_tokens = output_num_tokens
@@ -163,6 +165,7 @@ class Attention(nn.Module):
dots = dots.masked_fill(~dots_mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
sampled_token_ids = None

View File

@@ -76,6 +76,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
@@ -96,7 +97,10 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax
attn = self.attend(dots)
attn = self.dropout(attn)
attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax
out = einsum('b h i j, b h j d -> b h i d', attn, v)

View File

@@ -2,7 +2,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# Pre-defined CCT Models
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# CCT Models
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
@@ -55,8 +61,8 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
padding=padding,
*args, **kwargs)
# modules
# Modules
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
@@ -308,6 +314,7 @@ class CCT(nn.Module):
pooling_padding=1,
*args, **kwargs):
super(CCT, self).__init__()
img_height, img_width = pair(img_size)
self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
n_output_channels=embedding_dim,
@@ -324,8 +331,8 @@ class CCT(nn.Module):
self.classifier = TransformerClassifier(
sequence_length=self.tokenizer.sequence_length(n_channels=n_input_channels,
height=img_size,
width=img_size),
height=img_height,
width=img_width),
embedding_dim=embedding_dim,
seq_pool=True,
dropout_rate=0.,
@@ -336,4 +343,3 @@ class CCT(nn.Module):
def forward(self, x):
x = self.tokenizer(x)
return self.classifier(x)

View File

@@ -48,6 +48,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
@@ -69,6 +71,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -95,6 +95,9 @@ class Attention(nn.Module):
self.window_size = window_size
self.norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
@@ -151,6 +154,7 @@ class Attention(nn.Module):
# attend
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# merge heads

View File

@@ -76,6 +76,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
@@ -94,6 +95,7 @@ class Attention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)

View File

@@ -42,6 +42,8 @@ class Attention(nn.Module):
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.dropout = nn.Dropout(dropout)
self.reattn_weights = nn.Parameter(torch.randn(heads, heads))
self.reattn_norm = nn.Sequential(
@@ -64,6 +66,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
# re-attention

View File

@@ -52,6 +52,7 @@ class Attention(nn.Module):
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
@@ -100,6 +101,7 @@ class Attention(nn.Module):
dots = self.apply_pos_bias(dots)
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)

View File

@@ -78,6 +78,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -93,6 +94,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -54,6 +54,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
@@ -67,7 +69,10 @@ class Attention(nn.Module):
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)

View File

@@ -55,6 +55,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
@@ -71,6 +72,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)

140
vit_pytorch/parallel_vit.py Normal file
View File

@@ -0,0 +1,140 @@
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
return sum([fn(x) for fn in self.fns])
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
]))
def forward(self, x):
for attns, ffs in self.layers:
x = attns(x) + x
x = ffs(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)

View File

@@ -48,6 +48,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -63,6 +64,7 @@ class Attention(nn.Module):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = self.attend(dots)
attn= self.dropout(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -61,8 +61,13 @@ class Attention(nn.Module):
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, rel_pos_bias = None):
h = self.heads
@@ -86,6 +91,7 @@ class Attention(nn.Module):
sim = sim + rel_pos_bias
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# merge heads

View File

@@ -104,6 +104,7 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.use_ds_conv = use_ds_conv
@@ -148,6 +149,7 @@ class Attention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)

View File

@@ -81,8 +81,8 @@ class ScalableSelfAttention(nn.Module):
self,
dim,
heads = 8,
dim_key = 64,
dim_value = 64,
dim_key = 32,
dim_value = 32,
dropout = 0.,
reduction_factor = 1
):
@@ -90,6 +90,7 @@ class ScalableSelfAttention(nn.Module):
self.heads = heads
self.scale = dim_key ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
@@ -116,6 +117,7 @@ class ScalableSelfAttention(nn.Module):
# attention
attn = self.attend(dots)
attn = self.dropout(attn)
# aggregate values
@@ -132,8 +134,8 @@ class InteractiveWindowedSelfAttention(nn.Module):
dim,
window_size,
heads = 8,
dim_key = 64,
dim_value = 64,
dim_key = 32,
dim_value = 32,
dropout = 0.
):
super().__init__()
@@ -141,6 +143,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
self.scale = dim_key ** -0.5
self.window_size = window_size
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)
@@ -156,8 +159,8 @@ class InteractiveWindowedSelfAttention(nn.Module):
def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size
wsz = default(wsz, height) # take height as window size if not given
assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})'
wsz_h, wsz_w = default(wsz, height), default(wsz, width)
assert (height % wsz_h) == 0 and (width % wsz_w) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz_h}, {wsz_w})'
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
@@ -167,7 +170,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
# divide into window (and split out heads) for efficient self attention
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz_h, w2 = wsz_w), (q, k, v))
# similarity
@@ -176,6 +179,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
# attention
attn = self.attend(dots)
attn = self.dropout(attn)
# aggregate values
@@ -183,7 +187,7 @@ class InteractiveWindowedSelfAttention(nn.Module):
# reshape the windows back to full feature map (and merge heads)
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)
out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz_h, y = width // wsz_w, w1 = wsz_h, w2 = wsz_w)
# add LIM output
@@ -199,12 +203,12 @@ class Transformer(nn.Module):
heads = 8,
ff_expansion_factor = 4,
dropout = 0.,
ssa_dim_key = 64,
ssa_dim_value = 64,
ssa_dim_key = 32,
ssa_dim_value = 32,
ssa_reduction_factor = 1,
iwsa_dim_key = 64,
iwsa_dim_value = 64,
iwsa_window_size = 64,
iwsa_dim_key = 32,
iwsa_dim_value = 32,
iwsa_window_size = None,
norm_output = True
):
super().__init__()
@@ -244,12 +248,12 @@ class ScalableViT(nn.Module):
depth,
heads,
reduction_factor,
window_size = None,
iwsa_dim_key = 32,
iwsa_dim_value = 32,
ssa_dim_key = 32,
ssa_dim_value = 32,
ff_expansion_factor = 4,
iwsa_dim_key = 64,
iwsa_dim_value = 64,
window_size = 64,
ssa_dim_key = 64,
ssa_dim_value = 64,
channels = 3,
dropout = 0.
):

View File

@@ -130,6 +130,8 @@ class GlobalAttention(nn.Module):
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
@@ -145,6 +147,7 @@ class GlobalAttention(nn.Module):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = dots.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)

View File

@@ -42,6 +42,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -56,6 +58,7 @@ class Attention(nn.Module):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -42,6 +42,8 @@ class LSA(nn.Module):
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -60,6 +62,7 @@ class LSA(nn.Module):
dots = dots.masked_fill(mask, mask_value)
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

View File

@@ -63,6 +63,8 @@ class Attention(nn.Module):
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
@@ -77,6 +79,7 @@ class Attention(nn.Module):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')