Compare commits

...

1 Commits

Author SHA1 Message Date
Phil Wang
a254a0258a fix typo 2021-06-01 07:33:00 -07:00
3 changed files with 11 additions and 11 deletions

View File

@@ -382,7 +382,7 @@ pred = model(img) # (1, 1000)
<img src="./images/nest.png" width="400px"></img>
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
You can use it with the following code (ex. NesT-T)
@@ -395,7 +395,7 @@ nest = NesT(
patch_size = 4,
dim = 96,
heads = 3,
num_heirarchies = 3, # number of heirarchies
num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.19.2',
version = '0.19.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -114,7 +114,7 @@ class NesT(nn.Module):
num_classes,
dim,
heads,
num_heirarchies,
num_hierarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
@@ -126,11 +126,11 @@ class NesT(nn.Module):
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_heirarchies - 1)
blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
heirarchies = list(reversed(range(num_heirarchies)))
mults = [2 ** i for i in heirarchies]
hierarchies = list(reversed(range(num_hierarchies)))
mults = [2 ** i for i in hierarchies]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
@@ -143,11 +143,11 @@ class NesT(nn.Module):
nn.Conv2d(patch_dim, layer_dims[0], 1),
)
block_repeats = cast_tuple(block_repeats, num_heirarchies)
block_repeats = cast_tuple(block_repeats, num_hierarchies)
self.layers = nn.ModuleList([])
for level, heads, (dim_in, dim_out), block_repeat in zip(heirarchies, layer_heads, dim_pairs, block_repeats):
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat
@@ -166,9 +166,9 @@ class NesT(nn.Module):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape
num_heirarchies = len(self.layers)
num_hierarchies = len(self.layers)
for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers):
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)