complete and release crossformer

This commit is contained in:
Phil Wang
2021-11-22 17:10:53 -08:00
parent 768e47441e
commit 36e32b70fb
3 changed files with 50 additions and 3 deletions

View File

@@ -503,6 +503,23 @@ This <a href="https://arxiv.org/abs/2108.00154">paper</a> beats PVT and Swin usi
They also have cross-scale embedding layer, which they shown to be a generic layer that can improve all vision transformers. Dynamic relative positional bias was also formulated to allow the net to generalize to images of greater resolution.
```python
import torch
from vit_pytorch.crossformer import CrossFormer
model = CrossFormer(
num_classes = 1000, # number of output classes
dim = (64, 128, 256, 512), # dimension at each stage
depth = (2, 2, 8, 2), # depth of transformer at each stage
global_window_size = (8, 4, 2, 1), # global window sizes at each stage
local_window_size = 7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
)
img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```
## NesT
<img src="./images/nest.png" width="400px"></img>

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.23.2',
version = '0.24.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -44,6 +44,23 @@ class CrossEmbedLayer(nn.Module):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
# dynamic positional bias
def DynamicPositionBias(dim):
return nn.Sequential(
nn.Linear(2, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, 1),
Rearrange('... () -> ...')
)
# transformer classes
class LayerNorm(nn.Module):
@@ -86,12 +103,14 @@ class Attention(nn.Module):
self.attn_type = attn_type
self.window_size = window_size
self.dpb = DynamicPositionBias(dim // 4)
self.norm = LayerNorm(dim)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)
def forward(self, x):
*_, height, width, heads, wsz = *x.shape, self.heads, self.window_size
*_, height, width, heads, wsz, device = *x.shape, self.heads, self.window_size, x.device
# prenorm
@@ -115,6 +134,17 @@ class Attention(nn.Module):
sim = einsum('b h i d, b h j d -> b h i j', q, k)
# add dynamic positional bias
i_pos = torch.arange(wsz, device = device)
j_pos = torch.arange(wsz, device = device)
grid = torch.stack(torch.meshgrid(i_pos, j_pos))
grid = rearrange(grid, 'c i j -> (i j) c')
rel_ij = grid[:, None] - grid[None, :]
rel_pos_bias = self.dpb(rel_ij.float())
sim = sim + rel_pos_bias
# attend
attn = sim.softmax(dim = -1)
@@ -212,7 +242,7 @@ class CrossFormer(nn.Module):
for (dim_in, dim_out), layers, global_wsz, local_wsz, cel_kernel_sizes, cel_stride in zip(dim_in_and_out, depth, global_window_size, local_window_size, cross_embed_kernel_sizes, cross_embed_strides):
self.layers.append(nn.ModuleList([
CrossEmbedLayer(dim_in, dim_out, cel_kernel_sizes, stride = cel_stride),
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers)
Transformer(dim_out, local_window_size = local_wsz, global_window_size = global_wsz, depth = layers, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))
# final logits