mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2026-03-12 19:00:16 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65a8c17a5 | ||
|
|
f7c164d910 | ||
|
|
c7b74e0bc3 | ||
|
|
5b5d98a3a7 |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.2.2',
|
||||
version = '0.2.4',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -3,6 +3,8 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
@@ -85,6 +87,7 @@ class ViT(nn.Module):
|
||||
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
patch_dim = channels * patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective. try decreasing your patch size'
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
@@ -102,8 +105,7 @@ class ViT(nn.Module):
|
||||
nn.Linear(dim, mlp_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(mlp_dim, num_classes),
|
||||
nn.Dropout(dropout)
|
||||
nn.Linear(mlp_dim, num_classes)
|
||||
)
|
||||
|
||||
def forward(self, img, mask = None):
|
||||
|
||||
Reference in New Issue
Block a user