mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
077d8c188f | ||
|
|
5888f05300 | ||
|
|
d518e89573 | ||
|
|
dd6462d19b | ||
|
|
a1ee1daa1a | ||
|
|
3cff5e547a | ||
|
|
fdaf7f92b9 | ||
|
|
0ebd4edab9 |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.15.7"
|
||||
version = "1.16.5"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -25,12 +25,12 @@ class DistillMixin:
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x += self.pos_embedding[:(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
||||
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x)
|
||||
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -9,7 +9,6 @@ from torch import nn, Tensor
|
||||
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
# helpers
|
||||
|
||||
@@ -28,6 +27,12 @@ def pair(t):
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def posemb_grid(ph, pw, device):
|
||||
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
|
||||
w_idx = torch.arange(pw, device=device).repeat(ph)
|
||||
return torch.stack([h_idx, w_idx], dim=-1)
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
@@ -117,8 +122,7 @@ class Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(heads, dim_head)
|
||||
self.k_norm = RMSNorm(heads, dim_head)
|
||||
|
||||
self.attend = nn.Softmax(dim = -1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout_p = dropout
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
||||
@@ -145,19 +149,22 @@ class Attention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
dots = torch.matmul(q, k.transpose(-1, -2))
|
||||
# combine masks if both exist
|
||||
if exists(mask) or exists(attn_mask):
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
if exists(mask) and exists(attn_mask):
|
||||
attn_mask = mask & attn_mask
|
||||
elif exists(mask):
|
||||
attn_mask = mask
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = attn_mask,
|
||||
dropout_p = self.dropout_p if self.training else 0.,
|
||||
scale = 1. # RMSNorm already includes sqrt(dim) scaling
|
||||
)
|
||||
|
||||
if exists(attn_mask):
|
||||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = torch.matmul(attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
@@ -281,42 +288,41 @@ class NaViT(nn.Module):
|
||||
for images in batched_images:
|
||||
num_images.append(len(images))
|
||||
|
||||
sequences = []
|
||||
positions = []
|
||||
image_ids = torch.empty((0,), device = device, dtype = torch.long)
|
||||
|
||||
for image_id, image in enumerate(images):
|
||||
assert image.ndim ==3 and image.shape[0] == c
|
||||
# compute patch dimensions for all images
|
||||
patch_dims = []
|
||||
for image in images:
|
||||
assert image.ndim == 3 and image.shape[0] == c
|
||||
image_dims = image.shape[-2:]
|
||||
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'
|
||||
patch_dims.append((image_dims[0] // p, image_dims[1] // p))
|
||||
|
||||
ph, pw = map(lambda dim: dim // p, image_dims)
|
||||
# extract patches for all images
|
||||
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
|
||||
|
||||
pos = torch.stack(torch.meshgrid((
|
||||
arange(ph),
|
||||
arange(pw)
|
||||
), indexing = 'ij'), dim = -1)
|
||||
# compute positions - uses lru_cache to avoid redundant computation across forward passes
|
||||
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
|
||||
|
||||
pos = rearrange(pos, 'h w c -> (h w) c')
|
||||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)
|
||||
|
||||
seq_len = seq.shape[-2]
|
||||
|
||||
if has_token_dropout:
|
||||
# handle token dropout
|
||||
if has_token_dropout:
|
||||
for i, (seq, pos) in enumerate(zip(sequences, positions)):
|
||||
image_dims = images[i].shape[-2:]
|
||||
token_dropout = self.calc_token_dropout(*image_dims)
|
||||
seq_len = seq.shape[0]
|
||||
num_keep = max(1, int(seq_len * (1 - token_dropout)))
|
||||
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
|
||||
keep_indices = torch.randn((seq_len,), device=device).topk(num_keep, dim=-1).indices
|
||||
sequences[i] = seq[keep_indices]
|
||||
positions[i] = pos[keep_indices]
|
||||
|
||||
seq = seq[keep_indices]
|
||||
pos = pos[keep_indices]
|
||||
|
||||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
|
||||
sequences.append(seq)
|
||||
positions.append(pos)
|
||||
# build image_ids efficiently using repeat_interleave
|
||||
patch_counts = [seq.shape[0] for seq in sequences]
|
||||
image_ids = torch.repeat_interleave(
|
||||
arange(len(images)),
|
||||
torch.tensor(patch_counts, device=device)
|
||||
)
|
||||
|
||||
batched_image_ids.append(image_ids)
|
||||
batched_sequences.append(torch.cat(sequences, dim = 0))
|
||||
batched_positions.append(torch.cat(positions, dim = 0))
|
||||
batched_sequences.append(torch.cat(sequences, dim=0))
|
||||
batched_positions.append(torch.cat(positions, dim=0))
|
||||
|
||||
# derive key padding mask
|
||||
|
||||
|
||||
@@ -176,7 +176,7 @@ class NaViT(Module):
|
||||
|
||||
self.channels = channels
|
||||
self.patch_size = patch_size
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
nn.LayerNorm(patch_dim),
|
||||
|
||||
@@ -146,7 +146,7 @@ class SimpleViT(Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -103,7 +103,7 @@ class SimpleViT(nn.Module):
|
||||
patch_dim = channels * patch_height * patch_width * frame_patch_size
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -735,7 +735,7 @@ if __name__ == '__main__':
|
||||
mlp_dim = 384 * 4
|
||||
)
|
||||
|
||||
vat = VAAT(
|
||||
vaat = VAAT(
|
||||
vit,
|
||||
ast,
|
||||
dim = 512,
|
||||
@@ -767,11 +767,11 @@ if __name__ == '__main__':
|
||||
|
||||
actions = torch.randn(2, 7, 20) # actions for learning
|
||||
|
||||
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
|
||||
loss.backward()
|
||||
|
||||
# after much training
|
||||
|
||||
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
|
||||
|
||||
assert pred_actions.shape == (2, 7, 20)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -11,7 +12,7 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
@@ -62,13 +63,14 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -90,7 +92,9 @@ class ViT(nn.Module):
|
||||
|
||||
num_patches = (image_height // patch_height) * (image_width // patch_width)
|
||||
patch_dim = channels * patch_height * patch_width
|
||||
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
num_cls_tokens = 1 if pool == 'cls' else 0
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
|
||||
@@ -99,8 +103,9 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
|
||||
@@ -111,12 +116,15 @@ class ViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -89,7 +89,7 @@ class ViT(nn.Module):
|
||||
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
|
||||
@@ -141,7 +141,7 @@ class ViT(nn.Module):
|
||||
self.global_average_pool = pool == 'mean'
|
||||
|
||||
self.to_patch_embedding = nn.Sequential(
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
|
||||
nn.LayerNorm(patch_dim),
|
||||
nn.Linear(patch_dim, dim),
|
||||
nn.LayerNorm(dim)
|
||||
|
||||
Reference in New Issue
Block a user