mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-05-14 12:18:06 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6032a54b48 | ||
|
|
06a1f42924 | ||
|
|
6ae6a3ab64 |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.17.7"
|
||||
version = "1.17.8"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -189,7 +189,7 @@ class ViT(Module):
|
||||
|
||||
x = self.transformer(x)
|
||||
|
||||
if self.mlp_head is None:
|
||||
if not exists(self.mlp_head):
|
||||
return x
|
||||
|
||||
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
|
||||
|
||||
@@ -89,7 +89,6 @@ class Attention(Module):
|
||||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)
|
||||
|
||||
attn = self.attend(dots)
|
||||
@@ -109,7 +108,7 @@ class Transformer(Module):
|
||||
self.layers = ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
|
||||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_flash_attn = use_flash_attn),
|
||||
FeedForward(dim, mlp_dim, dropout = dropout)
|
||||
]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user