use einops repeat

This commit is contained in:
Phil Wang
2020-10-28 18:13:57 -07:00
parent c1043ab00c
commit dc5b89c942
2 changed files with 4 additions and 4 deletions

View File

@@ -1,5 +1,5 @@
import torch
from einops import rearrange
from einops import rearrange, repeat
from torch import nn
class ViT(nn.Module):
@@ -32,7 +32,7 @@ class ViT(nn.Module):
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(b, -1, -1)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)

View File

@@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, repeat
from torch import nn
MIN_NUM_PATCHES = 16
@@ -115,7 +115,7 @@ class ViT(nn.Module):
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = self.cls_token.expand(b, -1, -1)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)