mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
fix typo
This commit is contained in:
@@ -382,7 +382,7 @@ pred = model(img) # (1, 1000)
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.
|
||||
This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
|
||||
|
||||
You can use it with the following code (ex. NesT-T)
|
||||
|
||||
@@ -395,7 +395,7 @@ nest = NesT(
|
||||
patch_size = 4,
|
||||
dim = 96,
|
||||
heads = 3,
|
||||
num_heirarchies = 3, # number of heirarchies
|
||||
num_hierarchies = 3, # number of hierarchies
|
||||
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
|
||||
num_classes = 1000
|
||||
)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.19.2',
|
||||
version = '0.19.3',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -114,7 +114,7 @@ class NesT(nn.Module):
|
||||
num_classes,
|
||||
dim,
|
||||
heads,
|
||||
num_heirarchies,
|
||||
num_hierarchies,
|
||||
block_repeats,
|
||||
mlp_mult = 4,
|
||||
channels = 3,
|
||||
@@ -126,11 +126,11 @@ class NesT(nn.Module):
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
fmap_size = image_size // patch_size
|
||||
blocks = 2 ** (num_heirarchies - 1)
|
||||
blocks = 2 ** (num_hierarchies - 1)
|
||||
|
||||
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
|
||||
heirarchies = list(reversed(range(num_heirarchies)))
|
||||
mults = [2 ** i for i in heirarchies]
|
||||
hierarchies = list(reversed(range(num_hierarchies)))
|
||||
mults = [2 ** i for i in hierarchies]
|
||||
|
||||
layer_heads = list(map(lambda t: t * heads, mults))
|
||||
layer_dims = list(map(lambda t: t * dim, mults))
|
||||
@@ -143,11 +143,11 @@ class NesT(nn.Module):
|
||||
nn.Conv2d(patch_dim, layer_dims[0], 1),
|
||||
)
|
||||
|
||||
block_repeats = cast_tuple(block_repeats, num_heirarchies)
|
||||
block_repeats = cast_tuple(block_repeats, num_hierarchies)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
for level, heads, (dim_in, dim_out), block_repeat in zip(heirarchies, layer_heads, dim_pairs, block_repeats):
|
||||
for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
|
||||
is_last = level == 0
|
||||
depth = block_repeat
|
||||
|
||||
@@ -166,9 +166,9 @@ class NesT(nn.Module):
|
||||
x = self.to_patch_embedding(img)
|
||||
b, c, h, w = x.shape
|
||||
|
||||
num_heirarchies = len(self.layers)
|
||||
num_hierarchies = len(self.layers)
|
||||
|
||||
for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers):
|
||||
for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
|
||||
block_size = 2 ** level
|
||||
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
|
||||
x = transformer(x)
|
||||
|
||||
Reference in New Issue
Block a user