Compare commits

...

2 Commits

Author SHA1 Message Date
lucidrains
ad80b6c51e fix positional embed for mean pool case and cleanup 2025-11-27 16:56:36 -08:00
lucidrains
0ebd4edab9 address https://github.com/lucidrains/vit-pytorch/issues/351 2025-11-27 06:07:43 -08:00
3 changed files with 25 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.15.7"
version = "1.16.1"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -735,7 +735,7 @@ if __name__ == '__main__':
mlp_dim = 384 * 4
)
vat = VAAT(
vaat = VAAT(
vit,
ast,
dim = 512,
@@ -767,11 +767,11 @@ if __name__ == '__main__':
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn import Module, ModuleList
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
@@ -11,7 +12,7 @@ def pair(t):
# classes
class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
@@ -62,13 +63,14 @@ class Attention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
self.layers = ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
return self.norm(x)
class ViT(nn.Module):
class ViT(Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
@@ -90,7 +92,9 @@ class ViT(nn.Module):
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
@@ -99,8 +103,10 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.num_cls_tokens = num_cls_tokens
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -111,12 +117,15 @@ class ViT(nn.Module):
self.mlp_head = nn.Linear(dim, num_classes)
def forward(self, img):
batch = img.shape[0]
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
x = torch.cat((cls_tokens, x), dim = 1)
seq = x.shape[1]
x = x + self.pos_embedding[:seq]
x = self.dropout(x)
x = self.transformer(x)