Compare commits

...

1 Commits

Author SHA1 Message Date
Phil Wang
b6096b63a2 offer 1d versions, in light of https://arxiv.org/abs/2211.14730 2022-12-01 10:28:11 -08:00
3 changed files with 260 additions and 2 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.38.1',
version = '0.39.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
@@ -16,7 +16,7 @@ setup(
'image recognition'
],
install_requires=[
'einops>=0.4.1',
'einops>=0.6.0',
'torch>=1.10',
'torchvision'
],

View File

@@ -0,0 +1,125 @@
import torch
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
# helpers
def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
_, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype
n = torch.arange(n, device = device)
assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
omega = 1. / (temperature ** omega)
n = n.flatten()[:, None] * omega[None, :]
pe = torch.cat((n.sin(), n.cos()), dim = 1)
return pe.type(dtype)
# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class SimpleViT(nn.Module):
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
super().__init__()
assert seq_len % patch_size == 0
num_patches = seq_len // patch_size
patch_dim = channels * patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (n p) -> b n (p c)', p = patch_size),
nn.Linear(patch_dim, dim),
)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, series):
*_, n, dtype = *series.shape, series.dtype
x = self.to_patch_embedding(series)
pe = posemb_sincos_1d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe
x = self.transformer(x)
x = x.mean(dim = 1)
x = self.to_latent(x)
return self.linear_head(x)
if __name__ == '__main__':
v = SimpleViT(
seq_len = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
time_series = torch.randn(4, 3, 256)
logits = v(time_series) # (4, 1000)

133
vit_pytorch/vit_1d.py Normal file
View File

@@ -0,0 +1,133 @@
import torch
from torch import nn
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert (seq_len % patch_size) == 0
num_patches = seq_len // patch_size
patch_dim = channels * patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (n p) -> b n (p c)', p = patch_size),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, series):
x = self.to_patch_embedding(series)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
x, ps = pack([cls_tokens, x], 'b * d')
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x)
cls_tokens, _ = unpack(x, ps, 'b * d')
return self.mlp_head(cls_tokens)
if __name__ == '__main__':
v = ViT(
seq_len = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
time_series = torch.randn(4, 3, 256)
logits = v(time_series) # (4, 1000)