From 3cff5e547a188d817ff67ad88cdd9216a314dc64 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 2 Dec 2025 05:21:52 -0800 Subject: [PATCH] address https://github.com/lucidrains/vit-pytorch/issues/352 --- pyproject.toml | 2 +- vit_pytorch/na_vit_nested_tensor_3d.py | 2 +- vit_pytorch/simple_flash_attn_vit_3d.py | 2 +- vit_pytorch/simple_vit_3d.py | 2 +- vit_pytorch/vit_3d.py | 2 +- vit_pytorch/vivit.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eff48eb..23d429b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.16.1" +version = "1.16.2" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py index 0faf922..0531ac4 100644 --- a/vit_pytorch/na_vit_nested_tensor_3d.py +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -176,7 +176,7 @@ class NaViT(Module): self.channels = channels self.patch_size = patch_size - self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size) + self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c pf p1 p2)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size) self.to_patch_embedding = nn.Sequential( nn.LayerNorm(patch_dim), diff --git a/vit_pytorch/simple_flash_attn_vit_3d.py b/vit_pytorch/simple_flash_attn_vit_3d.py index 8381c4a..8b84d84 100644 --- a/vit_pytorch/simple_flash_attn_vit_3d.py +++ b/vit_pytorch/simple_flash_attn_vit_3d.py @@ -146,7 +146,7 @@ class SimpleViT(Module): patch_dim = channels * patch_height * patch_width * frame_patch_size self.to_patch_embedding = nn.Sequential( - Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), + Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), diff --git a/vit_pytorch/simple_vit_3d.py b/vit_pytorch/simple_vit_3d.py index 8a1460f..38a90a8 100644 --- a/vit_pytorch/simple_vit_3d.py +++ b/vit_pytorch/simple_vit_3d.py @@ -103,7 +103,7 @@ class SimpleViT(nn.Module): patch_dim = channels * patch_height * patch_width * frame_patch_size self.to_patch_embedding = nn.Sequential( - Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), + Rearrange('b c (f pf) (h p1) (w p2) -> b f h w (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), diff --git a/vit_pytorch/vit_3d.py b/vit_pytorch/vit_3d.py index a2058fb..c2c35c9 100644 --- a/vit_pytorch/vit_3d.py +++ b/vit_pytorch/vit_3d.py @@ -89,7 +89,7 @@ class ViT(nn.Module): assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' self.to_patch_embedding = nn.Sequential( - Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), + Rearrange('b c (f pf) (h p1) (w p2) -> b (f h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), diff --git a/vit_pytorch/vivit.py b/vit_pytorch/vivit.py index 43ad0d3..76a04de 100644 --- a/vit_pytorch/vivit.py +++ b/vit_pytorch/vivit.py @@ -141,7 +141,7 @@ class ViT(nn.Module): self.global_average_pool = pool == 'mean' self.to_patch_embedding = nn.Sequential( - Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), + Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (pf p1 p2 c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim)