mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
complete and release crossformer
This commit is contained in:
17
README.md
17
README.md
@@ -503,6 +503,23 @@ This <a href="https://arxiv.org/abs/2108.00154">paper</a> beats PVT and Swin usi
|
||||
|
||||
They also have cross-scale embedding layer, which they shown to be a generic layer that can improve all vision transformers. Dynamic relative positional bias was also formulated to allow the net to generalize to images of greater resolution.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from vit_pytorch.crossformer import CrossFormer
|
||||
|
||||
model = CrossFormer(
|
||||
num_classes = 1000, # number of output classes
|
||||
dim = (64, 128, 256, 512), # dimension at each stage
|
||||
depth = (2, 2, 8, 2), # depth of transformer at each stage
|
||||
global_window_size = (8, 4, 2, 1), # global window sizes at each stage
|
||||
local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
|
||||
)
|
||||
|
||||
img = torch.randn(1, 3, 224, 224)
|
||||
|
||||
pred = model(img) # (1, 1000)
|
||||
```
|
||||
|
||||
## NesT
|
||||
|
||||
<img src="./images/nest.png" width="400px"></img>
|
||||
|
||||
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.23.2',
|
||||
version = '0.24.0',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -44,6 +44,23 @@ class CrossEmbedLayer(nn.Module):
|
||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||
return torch.cat(fmaps, dim = 1)
|
||||
|
||||
# dynamic positional bias
|
||||
|
||||
def DynamicPositionBias(dim):
|
||||
return nn.Sequential(
|
||||
nn.Linear(2, dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim, dim),
|
||||
nn.LayerNorm(dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim, 1),
|
||||
Rearrange('... () -> ...')
|
||||
)
|
||||
|
||||
# transformer classes
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
@@ -86,12 +103,14 @@ class Attention(nn.Module):
|
||||
self.attn_type = attn_type
|
||||
self.window_size = window_size
|
||||
|
||||
self.dpb = DynamicPositionBias(dim // 4)
|
||||
|
||||
self.norm = LayerNorm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
*_, height, width, heads, wsz = *x.shape, self.heads, self.window_size
|
||||
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
|
||||
|
||||
# prenorm
|
||||
|
||||
@@ -115,6 +134,17 @@ class Attention(nn.Module):
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
# add dynamic positional bias
|
||||
|
||||
i_pos = torch.arange(wsz, device = device)
|
||||
j_pos = torch.arange(wsz, device = device)
|
||||
grid = torch.stack(torch.meshgrid(i_pos, j_pos))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_ij = grid[:, None] - grid[None, :]
|
||||
rel_pos_bias = self.dpb(rel_ij.float())
|
||||
|
||||
sim = sim + rel_pos_bias
|
||||
|
||||
# attend
|
||||
|
||||
attn = sim.softmax(dim = -1)
|
||||
@@ -212,7 +242,7 @@ class CrossFormer(nn.Module):
|
||||
for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
|
||||
self.layers.append(nn.ModuleList([
|
||||
CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
|
||||
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers)
|
||||
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
# final logits
|
||||
|
||||
Reference in New Issue
Block a user