mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
077d8c188f | ||
|
|
5888f05300 | ||
|
|
d518e89573 |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.16.3"
|
||||
version = "1.16.5"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -25,12 +25,12 @@ class DistillMixin:
|
||||
x = self.to_patch_embedding(img)
|
||||
b, n, _ = x.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
|
||||
cls_tokens = repeat(self.cls_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((cls_tokens, x), dim = 1)
|
||||
x += self.pos_embedding[:, :(n + 1)]
|
||||
x += self.pos_embedding[:(n + 1)]
|
||||
|
||||
if distilling:
|
||||
distill_tokens = repeat(distill_token, '1 n d -> b n d', b = b)
|
||||
distill_tokens = repeat(distill_token, 'n d -> b n d', b = b)
|
||||
x = torch.cat((x, distill_tokens), dim = 1)
|
||||
|
||||
x = self._attend(x)
|
||||
@@ -125,7 +125,7 @@ class DistillWrapper(Module):
|
||||
self.alpha = alpha
|
||||
self.hard = hard
|
||||
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))
|
||||
self.distillation_token = nn.Parameter(torch.randn(1, dim))
|
||||
|
||||
self.distill_mlp = nn.Sequential(
|
||||
nn.LayerNorm(dim) if mlp_layernorm else nn.Identity(),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
@@ -27,6 +27,12 @@ def pair(t):
|
||||
def divisible_by(numer, denom):
|
||||
return (numer % denom) == 0
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def posemb_grid(ph, pw, device):
|
||||
h_idx = torch.arange(ph, device=device).repeat_interleave(pw)
|
||||
w_idx = torch.arange(pw, device=device).repeat(ph)
|
||||
return torch.stack([h_idx, w_idx], dim=-1)
|
||||
|
||||
# auto grouping images
|
||||
|
||||
def group_images_by_max_seq_len(
|
||||
@@ -293,12 +299,8 @@ class NaViT(nn.Module):
|
||||
# extract patches for all images
|
||||
sequences = [rearrange(img, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1=p, p2=p) for img in images]
|
||||
|
||||
# compute positions using repeat_interleave (faster than meshgrid per image)
|
||||
positions = []
|
||||
for ph, pw in patch_dims:
|
||||
h_idx = arange(ph).repeat_interleave(pw)
|
||||
w_idx = arange(pw).repeat(ph)
|
||||
positions.append(torch.stack([h_idx, w_idx], dim=-1))
|
||||
# compute positions - uses lru_cache to avoid redundant computation across forward passes
|
||||
positions = [posemb_grid(ph, pw, device) for ph, pw in patch_dims]
|
||||
|
||||
# handle token dropout
|
||||
if has_token_dropout:
|
||||
|
||||
Reference in New Issue
Block a user