diff --git a/README.md b/README.md index ab6d8b0..e49e76f 100644 --- a/README.md +++ b/README.md @@ -435,6 +435,33 @@ img = torch.randn(1, 3, 224, 224) pred = model(img) # (1, 1000) ``` +## RegionViT + + + + + +This paper proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens. + +You can use it as follows + +```python +import torch +from vit_pytorch.regionvit import RegionViT + +model = RegionViT( + dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage + depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage + window_size = 7, # window size, which should be either 7 or 14 + num_classes = 1000, # number of output lcasses + tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models + use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance +) + +x = torch.randn(1, 3, 224, 224) +logits = model(x) # (1, 1000) +``` + ## NesT @@ -892,6 +919,17 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{chen2021regionvit, + title = {RegionViT: Regional-to-Local Attention for Vision Transformers}, + author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan}, + year = {2021}, + eprint = {2106.02689}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + ```bibtex @misc{caron2021emerging, title = {Emerging Properties in Self-Supervised Vision Transformers}, diff --git a/images/regionvit.png b/images/regionvit.png new file mode 100644 index 0000000..c843ca3 Binary files /dev/null and b/images/regionvit.png differ diff --git a/images/regionvit2.png b/images/regionvit2.png new file mode 100644 index 0000000..0468bcf Binary files /dev/null and b/images/regionvit2.png differ diff --git a/setup.py b/setup.py index 6edaf35..c979310 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.20.8', + version = '0.21.0', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/regionvit.py b/vit_pytorch/regionvit.py new file mode 100644 index 0000000..55557bf --- /dev/null +++ b/vit_pytorch/regionvit.py @@ -0,0 +1,268 @@ +import torch +from torch import nn, einsum +from einops import rearrange +from einops.layers.torch import Rearrange, Reduce +import torch.nn.functional as F + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def cast_tuple(val, length = 1): + return val if isinstance(val, tuple) else ((val,) * length) + +def divisible_by(val, d): + return (val % d) == 0 + +# helper classes + +class Downsample(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1) + + def forward(self, x): + return self.conv(x) + +class PEG(nn.Module): + def __init__(self, dim, kernel_size = 3): + super().__init__() + self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1) + + def forward(self, x): + return self.proj(x) + x + +# transformer classes + +def FeedForward(dim, mult = 4, dropout = 0.): + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * mult, 1), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim, 1) + ) + +class Attention(nn.Module): + def __init__( + self, + dim, + heads = 4, + dim_head = 32, + dropout = 0. + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Linear(inner_dim, dim) + + def forward(self, x, rel_pos_bias = None): + h = self.heads + + # prenorm + + x = self.norm(x) + + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + q = q * self.scale + + sim = einsum('b h i d, b h j d -> b h i j', q, k) + + # add relative positional bias for local tokens + + if exists(rel_pos_bias): + sim = sim + rel_pos_bias + + attn = sim.softmax(dim = -1) + + # merge heads + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class R2LTransformer(nn.Module): + def __init__( + self, + dim, + *, + window_size, + depth = 4, + heads = 4, + dim_head = 32, + attn_dropout = 0., + ff_dropout = 0., + ): + super().__init__() + self.layers = nn.ModuleList([]) + + self.window_size = window_size + rel_positions = 2 * window_size - 1 + self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads) + + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout), + FeedForward(dim, dropout = ff_dropout) + ])) + + def forward(self, local_tokens, region_tokens): + device = local_tokens.device + lh, lw = local_tokens.shape[-2:] + rh, rw = region_tokens.shape[-2:] + window_size_h, window_size_w = lh // rh, lw // rw + + local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c') + region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c') + + # calculate local relative positional bias + + h_range = torch.arange(window_size_h, device = device) + w_range = torch.arange(window_size_w, device = device) + + grid_x, grid_y = torch.meshgrid(h_range, w_range) + grid = torch.stack((grid_x, grid_y)) + grid = rearrange(grid, 'c h w -> c (h w)') + grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1) + bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0) + rel_pos_bias = self.local_rel_pos_bias(bias_indices) + rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j') + rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0) + + # go through r2l transformer layers + + for attn, ff in self.layers: + region_tokens = attn(region_tokens) + region_tokens + + # concat region tokens to local tokens + + local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh) + local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w) + region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d') + + # do self attention on local tokens, along with its regional token + + region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1) + region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens + + # split back local and regional tokens + + region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:] + local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h) + region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw) + + # feedforwards + + local_tokens = ff(local_tokens) + local_tokens + region_tokens = ff(region_tokens) + region_tokens + + local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw) + region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw) + return local_tokens, region_tokens + +# classes + +class RegionViT(nn.Module): + def __init__( + self, + *, + dim = (64, 128, 256, 512), + depth = (2, 2, 8, 2), + window_size = 7, + num_classes = 1000, + tokenize_local_3_conv = False, + local_patch_size = 4, + use_peg = False, + attn_dropout = 0., + ff_dropout = 0., + channels = 3, + ): + super().__init__() + dim = cast_tuple(dim, 4) + depth = cast_tuple(depth, 4) + assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4' + assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4' + + self.local_patch_size = local_patch_size + + region_patch_size = local_patch_size * window_size + self.region_patch_size = local_patch_size * window_size + + init_dim, *_, last_dim = dim + + # local and region encoders + + if tokenize_local_3_conv: + self.local_encoder = nn.Sequential( + nn.Conv2d(3, init_dim, 3, 2, 1), + nn.LayerNorm(init_dim), + nn.GELU(), + nn.Conv2d(init_dim, init_dim, 3, 2, 1), + nn.LayerNorm(init_dim), + nn.GELU(), + nn.Conv2d(init_dim, init_dim, 3, 1, 1) + ) + else: + self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3) + + self.region_encoder = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size), + nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1) + ) + + # layers + + current_dim = init_dim + self.layers = nn.ModuleList([]) + + for ind, dim, num_layers in zip(range(4), dim, depth): + not_first = ind != 0 + need_downsample = not_first + need_peg = not_first and use_peg + + self.layers.append(nn.ModuleList([ + Downsample(current_dim, dim) if need_downsample else nn.Identity(), + PEG(dim) if need_peg else nn.Identity(), + R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout) + ])) + + current_dim = dim + + # final logits + + self.to_logits = nn.Sequential( + Reduce('b c h w -> b c', 'mean'), + nn.LayerNorm(last_dim), + nn.Linear(last_dim, num_classes) + ) + + def forward( + self, + x, + return_local_tokens = False + ): + *_, h, w = x.shape + assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size' + assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size' + + local_tokens = self.local_encoder(x) + region_tokens = self.region_encoder(x) + + for down, peg, transformer in self.layers: + local_tokens, region_tokens = down(local_tokens), down(region_tokens) + local_tokens = peg(local_tokens) + local_tokens, region_tokens = transformer(local_tokens, region_tokens) + + return self.to_logits(region_tokens)