diff --git a/README.md b/README.md
index ca500d2..bdb2bbf 100644
--- a/README.md
+++ b/README.md
@@ -382,7 +382,7 @@ pred = model(img) # (1, 1000)
-This paper decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.
+This paper decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.
You can use it with the following code (ex. NesT-T)
@@ -395,7 +395,7 @@ nest = NesT(
patch_size = 4,
dim = 96,
heads = 3,
- num_heirarchies = 3, # number of heirarchies
+ num_hierarchies = 3, # number of hierarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)
diff --git a/setup.py b/setup.py
index 4cec5e1..6e0b1ed 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.19.2',
+ version = '0.19.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
diff --git a/vit_pytorch/nest.py b/vit_pytorch/nest.py
index a98895f..5e9492a 100644
--- a/vit_pytorch/nest.py
+++ b/vit_pytorch/nest.py
@@ -114,7 +114,7 @@ class NesT(nn.Module):
num_classes,
dim,
heads,
- num_heirarchies,
+ num_hierarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
@@ -126,11 +126,11 @@ class NesT(nn.Module):
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
- blocks = 2 ** (num_heirarchies - 1)
+ blocks = 2 ** (num_hierarchies - 1)
seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
- heirarchies = list(reversed(range(num_heirarchies)))
- mults = [2 ** i for i in heirarchies]
+ hierarchies = list(reversed(range(num_hierarchies)))
+ mults = [2 ** i for i in hierarchies]
layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))
@@ -143,11 +143,11 @@ class NesT(nn.Module):
nn.Conv2d(patch_dim, layer_dims[0], 1),
)
- block_repeats = cast_tuple(block_repeats, num_heirarchies)
+ block_repeats = cast_tuple(block_repeats, num_hierarchies)
self.layers = nn.ModuleList([])
- for level, heads, (dim_in, dim_out), block_repeat in zip(heirarchies, layer_heads, dim_pairs, block_repeats):
+ for level, heads, (dim_in, dim_out), block_repeat in zip(hierarchies, layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat
@@ -166,9 +166,9 @@ class NesT(nn.Module):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape
- num_heirarchies = len(self.layers)
+ num_hierarchies = len(self.layers)
- for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers):
+ for level, (transformer, aggregate) in zip(reversed(range(num_hierarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)