mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-01-06 13:02:30 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5888f05300 | ||
|
|
d518e89573 | ||
|
|
dd6462d19b | ||
|
|
a1ee1daa1a | ||
|
|
3cff5e547a | ||
|
|
fdaf7f92b9 | ||
|
|
0ebd4edab9 | ||
|
|
aa49c2783a |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.15.6"
|
||||
version = "1.16.4"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -9,7 +9,6 @@ from torch import nn, Tensor
|
||||
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
@@ -28,6 +27,12 @@ def pair(t):
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def posemb_grid(ph, pw, device):
|
||||
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
|
||||
w_idx = torch.arange(pw, device=device).repeat(ph)
|
||||
return torch.stack([h_idx, w_idx], dim=-1)
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
@@ -117,8 +122,7 @@ class Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(heads, dim_head)
|
||||
self.k_norm = RMSNorm(heads, dim_head)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_p = dropout
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
@@ -145,19 +149,22 @@ class Attention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2))
|
||||
# combine masks if both exist
|
||||
if exists(mask) or exists(attn_mask):
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
if exists(mask) and exists(attn_mask):
|
||||
attn_mask = mask & attn_mask
|
||||
elif exists(mask):
|
||||
attn_mask = mask
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = attn_mask,
|
||||
dropout_p = self.dropout_p if self.training else 0.,
|
||||
scale = 1. # RMSNorm already includes sqrt(dim) scaling
|
||||
)
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
|
||||
for images in batched_images:
|
||||
num_images.append(len(images))
|
||||
|
||||
sequences = []
|
||||
positions = []
|
||||
image_ids = torch.empty((0,), device = device, dtype = torch.long)
|
||||
|
||||
for image_id, image in enumerate(images):
|
||||
assert image.ndim ==3 and image.shape[0] == c
|
||||
# compute patch dimensions for all images
|
||||
patch_dims = []
|
||||
for image in images:
|
||||
assert image.ndim == 3 and image.shape[0] == c
|
||||
image_dims = image.shape[-2:]
|
||||
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
|
||||
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
|
||||
|
||||
ph, pw = map(lambda dim: dim // p, image_dims)
|
||||
# extract patches for all images
|
||||
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
|
||||
|
||||
pos = torch.stack(torch.meshgrid((
|
||||
arange(ph),
|
||||
arange(pw)
|
||||
), indexing = 'ij'), dim = -1)
|
||||
# compute positions - uses lru_cache to avoid redundant computation across forward passes
|
||||
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
|
||||
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
# handle token dropout
|
||||
if has_token_dropout:
|
||||
for i, (seq, pos) in enumerate(zip(sequences, positions)):
|
||||
image_dims = images[i].shape[-2:]
|
||||
token_dropout = self.calc_token_dropout(*image_dims)
|
||||
seq_len = seq.shape[0]
|
||||
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
|
||||
sequences[i] = seq[keep_indices]
|
||||
positions[i] = pos[keep_indices]
|
||||
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
# build image_ids efficiently using repeat_interleave
|
||||
patch_counts = [seq.shape[0] for seq in sequences]
|
||||
image_ids = torch.repeat_interleave(
|
||||
arange(len(images)),
|
||||
torch.tensor(patch_counts, device=device)
|
||||
)
|
||||
|
||||
batched_image_ids.append(image_ids)
|
||||
batched_sequences.append(torch.cat(sequences, dim = 0))
|
||||
batched_positions.append(torch.cat(positions, dim = 0))
|
||||
batched_sequences.append(torch.cat(sequences, dim=0))
|
||||
batched_positions.append(torch.cat(positions, dim=0))
|
||||
|
||||
# derive key padding mask
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class NaViT(Module):
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
|
||||
@@ -146,7 +146,7 @@ class SimpleViT(Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -421,7 +421,8 @@ class VAAT(Module):
|
||||
dim_head,
|
||||
dim_action,
|
||||
mlp_dim,
|
||||
num_views = None,
|
||||
num_image_views = None,
|
||||
num_audio_views = None,
|
||||
num_tasks = None,
|
||||
dim_extra_token = None,
|
||||
num_register_tokens = 4,
|
||||
@@ -462,6 +463,8 @@ class VAAT(Module):
|
||||
|
||||
ast_dim = ast.dim
|
||||
|
||||
self.ast_accept_spec = ast.accept_spec
|
||||
|
||||
assert ast.depth == depth or exists(ast_layer_indices), f'if the VAAT depth is not equal to the AST depth, you must pass in the indices from the AST to be layered to the VAAT in order from bottom to top'
|
||||
|
||||
ast_layer_indices = default(ast_layer_indices, tuple(range(depth)))
|
||||
@@ -480,7 +483,9 @@ class VAAT(Module):
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
self.view_emb = nn.Parameter(torch.randn(num_views, vit_dim) * 1e-2) if exists(num_views) and num_views > 1 else None
|
||||
self.image_view_emb = nn.Parameter(torch.randn(num_image_views, vit_dim) * 1e-2) if exists(num_image_views) and num_image_views > 1 else None
|
||||
|
||||
self.audio_view_emb = nn.Parameter(torch.randn(num_audio_views, ast_dim) * 1e-2) if exists(num_audio_views) and num_audio_views > 1 else None
|
||||
|
||||
# handle maybe task conditioning
|
||||
|
||||
@@ -523,12 +528,12 @@ class VAAT(Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
audio_or_spec, # (b t) | (b f t) - batch, audio len | batch, spec freq, time
|
||||
video_or_image, # (b v? c t? h w) - batch, views [wrist + third person or more], channels, maybe time, height, width
|
||||
audio_or_spec, # (b v? t) | (b v?f t) - batch, audio len | batch, spec freq, time
|
||||
*,
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
extra = None, # (b d) - batch, dim extra
|
||||
tasks = None, # (b)
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
actions = None, # (b k d) - batch, action chunk length, action dimension
|
||||
return_hiddens = False,
|
||||
freeze_vit = False,
|
||||
freeze_ast = False
|
||||
@@ -551,11 +556,26 @@ class VAAT(Module):
|
||||
|
||||
assert video_or_image.shape[3] == self.time_seq_len
|
||||
|
||||
# audio shapes - adding view if impliciy to be 1
|
||||
|
||||
if audio_or_spec.ndim == 2 and not self.ast_accept_spec:
|
||||
audio_or_spec = rearrange(audio_or_spec, 'b t -> b 1 t')
|
||||
|
||||
elif audio_or_spec.ndim == 3 and self.ast_accept_spec:
|
||||
audio_or_spec = rearrange(audio_or_spec, 'b f t -> b 1 f t')
|
||||
|
||||
# to images
|
||||
|
||||
images = rearrange(video_or_image, 'b v c t h w -> b v t c h w')
|
||||
|
||||
images, packed_shape = pack([images], '* c h w')
|
||||
images, image_packed_shape = pack([images], '* c h w')
|
||||
|
||||
# to audio
|
||||
|
||||
if self.ast_accept_spec:
|
||||
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* f t')
|
||||
else:
|
||||
audio_or_spec, audio_packed_shape = pack([audio_or_spec], '* t')
|
||||
|
||||
# get representation trajectory from vit
|
||||
|
||||
@@ -570,9 +590,9 @@ class VAAT(Module):
|
||||
|
||||
hiddens = hiddens[self.vit_layer_indices]
|
||||
|
||||
# pack temporarily for embedding
|
||||
# unpack temporarily for embedding
|
||||
|
||||
hiddens, = unpack(hiddens, packed_shape, 'l * n d') # l for layers
|
||||
hiddens, = unpack(hiddens, image_packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe add time embeddings
|
||||
|
||||
@@ -582,11 +602,11 @@ class VAAT(Module):
|
||||
|
||||
# maybe view embeddings
|
||||
|
||||
if exists(self.view_emb):
|
||||
assert self.view_emb.shape[0] == hiddens.shape[2]
|
||||
if exists(self.image_view_emb):
|
||||
assert self.image_view_emb.shape[0] == hiddens.shape[2]
|
||||
|
||||
view_emb = rearrange(self.view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + view_emb
|
||||
image_view_emb = rearrange(self.image_view_emb, 'v d -> v 1 1 d')
|
||||
hiddens = hiddens + image_view_emb
|
||||
|
||||
# get representation trajectory from ast
|
||||
|
||||
@@ -601,6 +621,18 @@ class VAAT(Module):
|
||||
|
||||
audio_hiddens = audio_hiddens[self.ast_layer_indices]
|
||||
|
||||
# unpack audio temporarily for embedding
|
||||
|
||||
audio_hiddens, = unpack(audio_hiddens, audio_packed_shape, 'l * n d') # l for layers
|
||||
|
||||
# maybe audio view embeddings
|
||||
|
||||
if exists(self.audio_view_emb):
|
||||
assert self.audio_view_emb.shape[0] == audio_hiddens.shape[2]
|
||||
|
||||
audio_view_emb = rearrange(self.audio_view_emb, 'v d -> v 1 1 d')
|
||||
audio_hiddens = audio_hiddens + audio_view_emb
|
||||
|
||||
# maybe tasks
|
||||
|
||||
if exists(tasks):
|
||||
@@ -612,7 +644,7 @@ class VAAT(Module):
|
||||
|
||||
image_context = rearrange(hiddens, 'l b v t n d -> l b (v t n) d')
|
||||
|
||||
audio_context = audio_hiddens # eventually handle views (stereo and beyond)
|
||||
audio_context = rearrange(audio_hiddens, 'l b v n d -> l b (v n) d')
|
||||
|
||||
# get main action tokens and maybe append extra
|
||||
|
||||
@@ -703,7 +735,7 @@ if __name__ == '__main__':
|
||||
mlp_dim = 384 * 4
|
||||
)
|
||||
|
||||
vat = VAAT(
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
@@ -714,7 +746,8 @@ if __name__ == '__main__':
|
||||
dim_action = 20,
|
||||
action_chunk_len = 7,
|
||||
time_seq_len = 4,
|
||||
num_views = 2,
|
||||
num_image_views = 2,
|
||||
num_audio_views = 2,
|
||||
num_tasks = 4,
|
||||
add_self_attn = True,
|
||||
dim_extra_token = 33, # extra token with some variable dimension
|
||||
@@ -727,18 +760,18 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
images = torch.randn(2, 2, 3, 4, 256, 256) # (2 views with 4 frames)
|
||||
audio = torch.randn(2, 14_100 * 5)
|
||||
audio = torch.randn(2, 2, 14_100 * 5)
|
||||
|
||||
tasks = torch.randint(0, 4, (2,))
|
||||
extra = torch.randn(2, 33) # extra internal state
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -11,7 +12,7 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
@@ -62,13 +63,14 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -90,7 +92,9 @@ class ViT(nn.Module):
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_cls_tokens = 1 if pool == 'cls' else 0
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
@@ -99,8 +103,9 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
@@ -111,12 +116,15 @@ class ViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -89,7 +89,7 @@ class ViT(nn.Module):
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -141,7 +141,7 @@ class ViT(nn.Module):
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
|
||||
Reference in New Issue
Block a user