From dc5b89c9426f26c58ba34522ca0bac12dd58fa61 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 28 Oct 2020 18:13:57 -0700 Subject: [PATCH] use einops repeat --- vit_pytorch/efficient.py | 4 ++-- vit_pytorch/vit_pytorch.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vit_pytorch/efficient.py b/vit_pytorch/efficient.py index 2f5d798..70c34df 100644 --- a/vit_pytorch/efficient.py +++ b/vit_pytorch/efficient.py @@ -1,5 +1,5 @@ import torch -from einops import rearrange +from einops import rearrange, repeat from torch import nn class ViT(nn.Module): @@ -32,7 +32,7 @@ class ViT(nn.Module): x = self.patch_to_embedding(x) b, n, _ = x.shape - cls_tokens = self.cls_token.expand(b, -1, -1) + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.transformer(x) diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index 910c8b1..e30c57f 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat from torch import nn MIN_NUM_PATCHES = 16 @@ -115,7 +115,7 @@ class ViT(nn.Module): x = self.patch_to_embedding(x) b, n, _ = x.shape - cls_tokens = self.cls_token.expand(b, -1, -1) + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, :(n + 1)] x = self.dropout(x)