mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Minor changes
This commit is contained in:
119
README.md
119
README.md
@@ -62,70 +62,6 @@ Dropout rate.
|
|||||||
Embedding dropout rate.
|
Embedding dropout rate.
|
||||||
- `pool`: string, either `cls` token pooling or `mean` pooling
|
- `pool`: string, either `cls` token pooling or `mean` pooling
|
||||||
|
|
||||||
## CCT
|
|
||||||
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
|
||||||
|
|
||||||
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
|
|
||||||
by using convolutions instead of patching and performing sequence pooling. This
|
|
||||||
allows for CCT to have high accuracy and a low number of parameters.
|
|
||||||
|
|
||||||
You can use this with two methods
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from vit_pytorch.cct import CCT
|
|
||||||
|
|
||||||
model = CCT(
|
|
||||||
img_size=224,
|
|
||||||
embedding_dim=768,
|
|
||||||
n_input_channels=3,
|
|
||||||
n_conv_layers=1,
|
|
||||||
kernel_size=7,
|
|
||||||
stride=2,
|
|
||||||
padding=3,
|
|
||||||
pooling_kernel_size=3,
|
|
||||||
pooling_stride=2,
|
|
||||||
pooling_padding=1,
|
|
||||||
num_layers=12,
|
|
||||||
num_heads=12,
|
|
||||||
mlp_radio=4.,
|
|
||||||
num_classes=1000,
|
|
||||||
dropout_rate=0.1,
|
|
||||||
attention_dropout=0.1,
|
|
||||||
stochastic_depth_rate=0.1,
|
|
||||||
positional_embedding='sine', # ['sine', 'learnable', 'none']
|
|
||||||
sequence_length=None,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
|
|
||||||
which pre-define the number of layers, number of attention heads, the mlp ratio,
|
|
||||||
and the embedding dimension.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from vit_pytorch.cct import cct_2
|
|
||||||
|
|
||||||
model = cct_2(
|
|
||||||
img_size=224,
|
|
||||||
n_input_channels=3,
|
|
||||||
n_conv_layers=1,
|
|
||||||
kernel_size=7,
|
|
||||||
stride=2,
|
|
||||||
padding=3,
|
|
||||||
pooling_kernel_size=3,
|
|
||||||
pooling_stride=2,
|
|
||||||
pooling_padding=1,
|
|
||||||
num_classes=1000,
|
|
||||||
dropout_rate=0.1,
|
|
||||||
attention_dropout=0.1,
|
|
||||||
stochastic_depth_rate=0.1,
|
|
||||||
positional_embedding='sine', # ['sine', 'learnable', 'none']
|
|
||||||
sequence_length=None,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
|
|
||||||
Repository</a>
|
|
||||||
|
|
||||||
|
|
||||||
## Distillation
|
## Distillation
|
||||||
|
|
||||||
@@ -267,6 +203,61 @@ img = torch.randn(1, 3, 224, 224)
|
|||||||
preds = v(img) # (1, 1000)
|
preds = v(img) # (1, 1000)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## CCT
|
||||||
|
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
|
||||||
|
|
||||||
|
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
|
||||||
|
by using convolutions instead of patching and performing sequence pooling. This
|
||||||
|
allows for CCT to have high accuracy and a low number of parameters.
|
||||||
|
|
||||||
|
You can use this with two methods
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from vit_pytorch.cct import CCT
|
||||||
|
|
||||||
|
model = CCT(
|
||||||
|
img_size=224,
|
||||||
|
embedding_dim=384,
|
||||||
|
n_conv_layers=2,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
pooling_kernel_size=3,
|
||||||
|
pooling_stride=2,
|
||||||
|
pooling_padding=1,
|
||||||
|
num_layers=14,
|
||||||
|
num_heads=6,
|
||||||
|
mlp_radio=3.,
|
||||||
|
num_classes=1000,
|
||||||
|
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
|
||||||
|
which pre-define the number of layers, number of attention heads, the mlp ratio,
|
||||||
|
and the embedding dimension.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from vit_pytorch.cct import cct_14
|
||||||
|
|
||||||
|
model = cct_14(
|
||||||
|
img_size=224,
|
||||||
|
n_conv_layers=1,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
padding=3,
|
||||||
|
pooling_kernel_size=3,
|
||||||
|
pooling_stride=2,
|
||||||
|
pooling_padding=1,
|
||||||
|
num_classes=1000,
|
||||||
|
positional_embedding='learnable', # ['sine', 'learnable', 'none']
|
||||||
|
)
|
||||||
|
```
|
||||||
|
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
|
||||||
|
Repository</a> includes links to pretrained model checkpoints.
|
||||||
|
|
||||||
|
|
||||||
## Cross ViT
|
## Cross ViT
|
||||||
|
|
||||||
<img src="./images/cross_vit.png" width="400px"></img>
|
<img src="./images/cross_vit.png" width="400px"></img>
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# Pre-defined CCT Models
|
# Pre-defined CCT Models
|
||||||
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_16']
|
__all__ = ['cct_2', 'cct_4', 'cct_6', 'cct_7', 'cct_8', 'cct_14', 'cct_16']
|
||||||
|
|
||||||
|
|
||||||
def cct_2(*args, **kwargs):
|
def cct_2(*args, **kwargs):
|
||||||
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
return _cct(num_layers=2, num_heads=2, mlp_ratio=1, embedding_dim=128,
|
||||||
@@ -39,6 +40,7 @@ def cct_16(*args, **kwargs):
|
|||||||
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
return _cct(num_layers=16, num_heads=6, mlp_ratio=3, embedding_dim=384,
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
|
||||||
kernel_size=3, stride=None, padding=None,
|
kernel_size=3, stride=None, padding=None,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
@@ -81,6 +83,7 @@ class Attention(nn.Module):
|
|||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Inspired by torch.nn.TransformerEncoderLayer and
|
Inspired by torch.nn.TransformerEncoderLayer and
|
||||||
@@ -143,6 +146,7 @@ class DropPath(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(nn.Module):
|
class Tokenizer(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
kernel_size, stride, padding,
|
kernel_size, stride, padding,
|
||||||
|
|||||||
Reference in New Issue
Block a user