mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad80b6c51e |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.16.0"
|
||||
version = "1.16.1"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
@@ -11,7 +12,7 @@ def pair(t):
|
||||
|
||||
# classes
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
class FeedForward(Module):
|
||||
def __init__(self, dim, hidden_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
@@ -26,7 +27,7 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
class Attention(Module):
|
||||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
@@ -62,13 +63,14 @@ class Attention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Transformer(Module):
|
||||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
self.layers.append(ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
@@ -80,7 +82,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViT(nn.Module):
|
||||
class ViT(Module):
|
||||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
|
||||
super().__init__()
|
||||
image_height, image_width = pair(image_size)
|
||||
@@ -101,8 +103,9 @@ class ViT(nn.Module):
|
||||
nn.LayerNorm(dim),
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + num_cls_tokens, dim))
|
||||
self.num_cls_tokens = num_cls_tokens
|
||||
self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
|
||||
self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))
|
||||
|
||||
self.dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
@@ -114,12 +117,15 @@ class ViT(nn.Module):
|
||||
self.mlp_head = nn.Linear(dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
batch = img.shape[0]
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 ... d -> b ... d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
|
||||
seq = x.shape[1]
|
||||
|
||||
x = x + self.pos_embedding[:seq]
|
||||
x = self.dropout(x)
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
Reference in New Issue
Block a user