mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
use einops repeat
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
class ViT(nn.Module):
|
||||
@@ -32,7 +32,7 @@ class ViT(nn.Module):
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(b, -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.transformer(x)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
@@ -115,7 +115,7 @@ class ViT(nn.Module):
|
||||
x = self.patch_to_embedding(x)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(b, -1, -1)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x = self.dropout(x)
|
||||
|
||||
Reference in New Issue
Block a user