lucidrains
2025-11-27 06:07:43 -08:00
parent aa49c2783a
commit 0ebd4edab9
3 changed files with 10 additions and 7 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.15.7"
version = "1.16.0"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -735,7 +735,7 @@ if __name__ == '__main__':
mlp_dim = 384 * 4
)
vat = VAAT(
vaat = VAAT(
vit,
ast,
dim = 512,
@@ -767,11 +767,11 @@ if __name__ == '__main__':
actions = torch.randn(2, 7, 20) # actions for learning
loss = vat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss = vaat(images, audio, actions = actions, tasks = tasks, extra = extra, freeze_vit = True)
loss.backward()
# after much training
pred_actions, hiddens = vat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
pred_actions, hiddens = vaat(images, audio, tasks = tasks, extra = extra, return_hiddens = True)
assert pred_actions.shape == (2, 7, 20)

View File

@@ -90,7 +90,9 @@ class ViT(nn.Module):
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_cls_tokens = 1 if pool == 'cls' else 0
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
@@ -99,8 +101,9 @@ class ViT(nn.Module):
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
@@ -114,7 +117,7 @@ class ViT(nn.Module):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)