mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aeed3381c1 | ||
|
|
3f754956fb | ||
|
|
918869571c |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.20.2',
|
||||
version = '0.20.5',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -148,6 +148,6 @@ class DistillWrapper(nn.Module):
|
||||
|
||||
else:
|
||||
teacher_labels = teacher_logits.argmax(dim = -1)
|
||||
distill_loss = F.cross_entropy(student_logits, teacher_labels)
|
||||
distill_loss = F.cross_entropy(distill_logits, teacher_labels)
|
||||
|
||||
return loss * (1 - alpha) + distill_loss * alpha
|
||||
|
||||
@@ -29,7 +29,7 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim * mult, 1),
|
||||
nn.GELU(),
|
||||
nn.Hardswish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(dim * mult, dim, 1),
|
||||
nn.Dropout(dropout)
|
||||
|
||||
@@ -35,13 +35,14 @@ class T2TViT(nn.Module):
|
||||
for i, (kernel_size, stride) in enumerate(t2t_layers):
|
||||
layer_dim *= kernel_size ** 2
|
||||
is_first = i == 0
|
||||
is_last = i == (len(t2t_layers) - 1)
|
||||
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)
|
||||
|
||||
layers.extend([
|
||||
RearrangeImage() if not is_first else nn.Identity(),
|
||||
nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
|
||||
Rearrange('b c n -> b n c'),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
|
||||
Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout) if not is_last else nn.Identity(),
|
||||
])
|
||||
|
||||
layers.append(nn.Linear(layer_dim, dim))
|
||||
|
||||
Reference in New Issue
Block a user