Compare commits

...

3 Commits

Author SHA1 Message Date
Phil Wang
25b384297d return None from extractor if no attention layers 2022-01-28 17:49:58 -08:00
Phil Wang
64a07f50e6 epsilon should be inside square root 2022-01-24 17:24:41 -08:00
Phil Wang
126d204ff2 fix block repeats in readme example for Nest 2022-01-22 21:32:53 -08:00
7 changed files with 11 additions and 11 deletions

View File

@@ -542,7 +542,7 @@ nest = NesT(
dim = 96,
heads = 3,
num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
block_repeats = (2, 2, 8), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.26.4',
version = '0.26.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -62,9 +62,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(

View File

@@ -30,9 +30,9 @@ class LayerNorm(nn.Module): # layernorm, but done in the channel dimension #1
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):

View File

@@ -20,9 +20,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):

View File

@@ -55,5 +55,5 @@ class Recorder(nn.Module):
target_device = self.device if self.device is not None else img.device
recordings = tuple(map(lambda t: t.to(target_device), self.recordings))
attns = torch.stack(recordings, dim = 1)
attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
return pred, attns

View File

@@ -38,9 +38,9 @@ class LayerNorm(nn.Module):
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):