Compare commits

...

1 Commits
1.8.1 ... 0.9.2

Author SHA1 Message Date
Phil Wang
2d413a392c remove masking, as it complicates with little benefit 2021-03-23 12:16:32 -07:00
3 changed files with 9 additions and 9 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.9.1',
version = '0.9.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.vit import ViT
from vit_pytorch.t2t import T2TViT
from vit_pytorch.efficient import ViT as EfficientViT
@@ -15,7 +15,7 @@ def exists(val):
# classes
class DistillMixin:
def forward(self, img, distill_token = None, mask = None):
def forward(self, img, distill_token = None):
distilling = exists(distill_token)
x = self.to_patch_embedding(img)
b, n, _ = x.shape
@@ -28,7 +28,7 @@ class DistillMixin:
distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)
x = self._attend(x, mask)
x = self._attend(x)
if distilling:
x, distill_tokens = x[:, :-1], x[:, -1]
@@ -56,9 +56,9 @@ class DistillableViT(DistillMixin, ViT):
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
def _attend(self, x):
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.transformer(x)
return x
class DistillableT2TViT(DistillMixin, T2TViT):
@@ -74,7 +74,7 @@ class DistillableT2TViT(DistillMixin, T2TViT):
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
def _attend(self, x):
x = self.dropout(x)
x = self.transformer(x)
return x
@@ -92,7 +92,7 @@ class DistillableEfficientViT(DistillMixin, EfficientViT):
v.load_state_dict(self.state_dict())
return v
def _attend(self, x, mask):
def _attend(self, x):
return self.transformer(x)
# knowledge distillation wrapper

View File

@@ -2,7 +2,7 @@ import math
import torch
from torch import nn
from vit_pytorch.vit_pytorch import Transformer
from vit_pytorch.vit import Transformer
from einops import rearrange, repeat
from einops.layers.torch import Rearrange