diff --git a/README.md b/README.md
index e14aad2..8d6971d 100644
--- a/README.md
+++ b/README.md
@@ -32,6 +32,7 @@
- [Parallel ViT](#parallel-vit)
- [Learnable Memory ViT](#learnable-memory-vit)
- [Dino](#dino)
+- [EsViT](#esvit)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
* [Efficient Attention](#efficient-attention)
@@ -1076,6 +1077,80 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```
+## EsViT
+
+
+
+`EsViT` is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput.
+
+Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.
+
+```python
+import torch
+from vit_pytorch.cvt import CvT
+from vit_pytorch.es_vit import EsViTTrainer
+
+cvt = CvT(
+ num_classes = 1000,
+ s1_emb_dim = 64,
+ s1_emb_kernel = 7,
+ s1_emb_stride = 4,
+ s1_proj_kernel = 3,
+ s1_kv_proj_stride = 2,
+ s1_heads = 1,
+ s1_depth = 1,
+ s1_mlp_mult = 4,
+ s2_emb_dim = 192,
+ s2_emb_kernel = 3,
+ s2_emb_stride = 2,
+ s2_proj_kernel = 3,
+ s2_kv_proj_stride = 2,
+ s2_heads = 3,
+ s2_depth = 2,
+ s2_mlp_mult = 4,
+ s3_emb_dim = 384,
+ s3_emb_kernel = 3,
+ s3_emb_stride = 2,
+ s3_proj_kernel = 3,
+ s3_kv_proj_stride = 2,
+ s3_heads = 4,
+ s3_depth = 10,
+ s3_mlp_mult = 4,
+ dropout = 0.
+)
+
+learner = EsViTTrainer(
+ cvt,
+ image_size = 256,
+ hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding
+ projection_hidden_size = 256, # projector network hidden dimension
+ projection_layers = 4, # number of layers in projection network
+ num_classes_K = 65336, # output logits dimensions (referenced as K in paper)
+ student_temp = 0.9, # student temperature
+ teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
+ local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper
+ global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper
+ moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
+ center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
+)
+
+opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)
+
+def sample_unlabelled_images():
+ return torch.randn(8, 3, 256, 256)
+
+for _ in range(1000):
+ images = sample_unlabelled_images()
+ loss = learner(images)
+ opt.zero_grad()
+ loss.backward()
+ opt.step()
+ learner.update_moving_average() # update moving average of teacher encoder and teacher centers
+
+# save your improved network
+torch.save(cvt.state_dict(), './pretrained-net.pt')
+```
+
## Accessing Attention
If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below
@@ -1584,6 +1659,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```
+```bibtex
+@article{Li2021EfficientSV,
+ title = {Efficient Self-supervised Vision Transformers for Representation Learning},
+ author = {Chunyuan Li and Jianwei Yang and Pengchuan Zhang and Mei Gao and Bin Xiao and Xiyang Dai and Lu Yuan and Jianfeng Gao},
+ journal = {ArXiv},
+ year = {2021},
+ volume = {abs/2106.09785}
+}
+```
+
```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
diff --git a/images/esvit.png b/images/esvit.png
new file mode 100644
index 0000000..5acdaf4
Binary files /dev/null and b/images/esvit.png differ
diff --git a/setup.py b/setup.py
index e4474b4..6df8283 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.33.2',
+ version = '0.34.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py
index add9ce9..2750284 100644
--- a/vit_pytorch/cvt.py
+++ b/vit_pytorch/cvt.py
@@ -164,12 +164,14 @@ class CvT(nn.Module):
dim = config['emb_dim']
- self.layers = nn.Sequential(
- *layers,
+ self.layers = nn.Sequential(*layers)
+
+ self.to_logits = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Rearrange('... () () -> ...'),
nn.Linear(dim, num_classes)
)
def forward(self, x):
- return self.layers(x)
+ latents = self.layers(x)
+ return self.to_logits(latents)
diff --git a/vit_pytorch/es_vit.py b/vit_pytorch/es_vit.py
new file mode 100644
index 0000000..4db6bcb
--- /dev/null
+++ b/vit_pytorch/es_vit.py
@@ -0,0 +1,367 @@
+import copy
+import random
+from functools import wraps, partial
+
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from torchvision import transforms as T
+
+from einops import rearrange, reduce, repeat
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+def default(val, default):
+ return val if exists(val) else default
+
+def singleton(cache_key):
+ def inner_fn(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ instance = getattr(self, cache_key)
+ if instance is not None:
+ return instance
+
+ instance = fn(self, *args, **kwargs)
+ setattr(self, cache_key, instance)
+ return instance
+ return wrapper
+ return inner_fn
+
+def get_module_device(module):
+ return next(module.parameters()).device
+
+def set_requires_grad(model, val):
+ for p in model.parameters():
+ p.requires_grad = val
+
+# tensor related helpers
+
+def log(t, eps = 1e-20):
+ return torch.log(t + eps)
+
+# loss function # (algorithm 1 in the paper)
+
+def view_loss_fn(
+ teacher_logits,
+ student_logits,
+ teacher_temp,
+ student_temp,
+ centers,
+ eps = 1e-20
+):
+ teacher_logits = teacher_logits.detach()
+ student_probs = (student_logits / student_temp).softmax(dim = -1)
+ teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
+ return - (teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
+
+def region_loss_fn(
+ teacher_logits,
+ student_logits,
+ teacher_latent,
+ student_latent,
+ teacher_temp,
+ student_temp,
+ centers,
+ eps = 1e-20
+):
+ teacher_logits = teacher_logits.detach()
+ student_probs = (student_logits / student_temp).softmax(dim = -1)
+ teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1)
+
+ sim_matrix = einsum('b i d, b j d -> b i j', student_latent, teacher_latent)
+ sim_indices = sim_matrix.max(dim = -1).indices
+ sim_indices = repeat(sim_indices, 'b n -> b n k', k = teacher_probs.shape[-1])
+ max_sim_teacher_probs = teacher_probs.gather(1, sim_indices)
+
+ return - (max_sim_teacher_probs * log(student_probs, eps)).sum(dim = -1).mean()
+
+# augmentation utils
+
+class RandomApply(nn.Module):
+ def __init__(self, fn, p):
+ super().__init__()
+ self.fn = fn
+ self.p = p
+
+ def forward(self, x):
+ if random.random() > self.p:
+ return x
+ return self.fn(x)
+
+# exponential moving average
+
+class EMA():
+ def __init__(self, beta):
+ super().__init__()
+ self.beta = beta
+
+ def update_average(self, old, new):
+ if old is None:
+ return new
+ return old * self.beta + (1 - self.beta) * new
+
+def update_moving_average(ema_updater, ma_model, current_model):
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
+ old_weight, up_weight = ma_params.data, current_params.data
+ ma_params.data = ema_updater.update_average(old_weight, up_weight)
+
+# MLP class for projector and predictor
+
+class L2Norm(nn.Module):
+ def forward(self, x, eps = 1e-6):
+ return F.normalize(x, dim = 1, eps = eps)
+
+class MLP(nn.Module):
+ def __init__(self, dim, dim_out, num_layers, hidden_size = 256):
+ super().__init__()
+
+ layers = []
+ dims = (dim, *((hidden_size,) * (num_layers - 1)))
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 1)
+
+ layers.extend([
+ nn.Linear(layer_dim_in, layer_dim_out),
+ nn.GELU() if not is_last else nn.Identity()
+ ])
+
+ self.net = nn.Sequential(
+ *layers,
+ L2Norm(),
+ nn.Linear(hidden_size, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+# a wrapper class for the base neural network
+# will manage the interception of the hidden layer output
+# and pipe it into the projecter and predictor nets
+
+class NetWrapper(nn.Module):
+ def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2):
+ super().__init__()
+ self.net = net
+ self.layer = layer
+
+ self.view_projector = None
+ self.region_projector = None
+ self.projection_hidden_size = projection_hidden_size
+ self.projection_num_layers = projection_num_layers
+ self.output_dim = output_dim
+
+ self.hidden = {}
+ self.hook_registered = False
+
+ def _find_layer(self):
+ if type(self.layer) == str:
+ modules = dict([*self.net.named_modules()])
+ return modules.get(self.layer, None)
+ elif type(self.layer) == int:
+ children = [*self.net.children()]
+ return children[self.layer]
+ return None
+
+ def _hook(self, _, input, output):
+ device = input[0].device
+ self.hidden[device] = output
+
+ def _register_hook(self):
+ layer = self._find_layer()
+ assert layer is not None, f'hidden layer ({self.layer}) not found'
+ handle = layer.register_forward_hook(self._hook)
+ self.hook_registered = True
+
+ @singleton('view_projector')
+ def _get_view_projector(self, hidden):
+ dim = hidden.shape[1]
+ projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
+ return projector.to(hidden)
+
+ @singleton('region_projector')
+ def _get_region_projector(self, hidden):
+ dim = hidden.shape[1]
+ projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size)
+ return projector.to(hidden)
+
+ def get_embedding(self, x):
+ if self.layer == -1:
+ return self.net(x)
+
+ if not self.hook_registered:
+ self._register_hook()
+
+ self.hidden.clear()
+ _ = self.net(x)
+ hidden = self.hidden[x.device]
+ self.hidden.clear()
+
+ assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
+ return hidden
+
+ def forward(self, x, return_projection = True):
+ region_latents = self.get_embedding(x)
+ global_latent = reduce(region_latents, 'b c h w -> b c', 'mean')
+
+ if not return_projection:
+ return global_latent, region_latents
+
+ view_projector = self._get_view_projector(global_latent)
+ region_projector = self._get_region_projector(region_latents)
+
+ region_latents = rearrange(region_latents, 'b c h w -> b (h w) c')
+
+ return view_projector(global_latent), region_projector(region_latents), region_latents
+
+# main class
+
+class EsViTTrainer(nn.Module):
+ def __init__(
+ self,
+ net,
+ image_size,
+ hidden_layer = -2,
+ projection_hidden_size = 256,
+ num_classes_K = 65336,
+ projection_layers = 4,
+ student_temp = 0.9,
+ teacher_temp = 0.04,
+ local_upper_crop_scale = 0.4,
+ global_lower_crop_scale = 0.5,
+ moving_average_decay = 0.9,
+ center_moving_average_decay = 0.9,
+ augment_fn = None,
+ augment_fn2 = None
+ ):
+ super().__init__()
+ self.net = net
+
+ # default BYOL augmentation
+
+ DEFAULT_AUG = torch.nn.Sequential(
+ RandomApply(
+ T.ColorJitter(0.8, 0.8, 0.8, 0.2),
+ p = 0.3
+ ),
+ T.RandomGrayscale(p=0.2),
+ T.RandomHorizontalFlip(),
+ RandomApply(
+ T.GaussianBlur((3, 3), (1.0, 2.0)),
+ p = 0.2
+ ),
+ T.Normalize(
+ mean=torch.tensor([0.485, 0.456, 0.406]),
+ std=torch.tensor([0.229, 0.224, 0.225])),
+ )
+
+ self.augment1 = default(augment_fn, DEFAULT_AUG)
+ self.augment2 = default(augment_fn2, DEFAULT_AUG)
+
+ # local and global crops
+
+ self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale))
+ self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.))
+
+ self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer)
+
+ self.teacher_encoder = None
+ self.teacher_ema_updater = EMA(moving_average_decay)
+
+ self.register_buffer('teacher_view_centers', torch.zeros(1, num_classes_K))
+ self.register_buffer('last_teacher_view_centers', torch.zeros(1, num_classes_K))
+
+ self.register_buffer('teacher_region_centers', torch.zeros(1, num_classes_K))
+ self.register_buffer('last_teacher_region_centers', torch.zeros(1, num_classes_K))
+
+ self.teacher_centering_ema_updater = EMA(center_moving_average_decay)
+
+ self.student_temp = student_temp
+ self.teacher_temp = teacher_temp
+
+ # get device of network and make wrapper same device
+ device = get_module_device(net)
+ self.to(device)
+
+ # send a mock image tensor to instantiate singleton parameters
+ self.forward(torch.randn(2, 3, image_size, image_size, device=device))
+
+ @singleton('teacher_encoder')
+ def _get_teacher_encoder(self):
+ teacher_encoder = copy.deepcopy(self.student_encoder)
+ set_requires_grad(teacher_encoder, False)
+ return teacher_encoder
+
+ def reset_moving_average(self):
+ del self.teacher_encoder
+ self.teacher_encoder = None
+
+ def update_moving_average(self):
+ assert self.teacher_encoder is not None, 'target encoder has not been created yet'
+ update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
+
+ new_teacher_view_centers = self.teacher_centering_ema_updater.update_average(self.teacher_view_centers, self.last_teacher_view_centers)
+ self.teacher_view_centers.copy_(new_teacher_view_centers)
+
+ new_teacher_region_centers = self.teacher_centering_ema_updater.update_average(self.teacher_region_centers, self.last_teacher_region_centers)
+ self.teacher_region_centers.copy_(new_teacher_region_centers)
+
+ def forward(
+ self,
+ x,
+ return_embedding = False,
+ return_projection = True,
+ student_temp = None,
+ teacher_temp = None
+ ):
+ if return_embedding:
+ return self.student_encoder(x, return_projection = return_projection)
+
+ image_one, image_two = self.augment1(x), self.augment2(x)
+
+ local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two)
+ global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two)
+
+ student_view_proj_one, student_region_proj_one, student_latent_one = self.student_encoder(local_image_one)
+ student_view_proj_two, student_region_proj_two, student_latent_two = self.student_encoder(local_image_two)
+
+ with torch.no_grad():
+ teacher_encoder = self._get_teacher_encoder()
+ teacher_view_proj_one, teacher_region_proj_one, teacher_latent_one = teacher_encoder(global_image_one)
+ teacher_view_proj_two, teacher_region_proj_two, teacher_latent_two = teacher_encoder(global_image_two)
+
+ view_loss_fn_ = partial(
+ view_loss_fn,
+ student_temp = default(student_temp, self.student_temp),
+ teacher_temp = default(teacher_temp, self.teacher_temp),
+ centers = self.teacher_view_centers
+ )
+
+ region_loss_fn_ = partial(
+ region_loss_fn,
+ student_temp = default(student_temp, self.student_temp),
+ teacher_temp = default(teacher_temp, self.teacher_temp),
+ centers = self.teacher_region_centers
+ )
+
+ # calculate view-level loss
+
+ teacher_view_logits_avg = torch.cat((teacher_view_proj_one, teacher_view_proj_two)).mean(dim = 0)
+ self.last_teacher_view_centers.copy_(teacher_view_logits_avg)
+
+ teacher_region_logits_avg = torch.cat((teacher_region_proj_one, teacher_region_proj_two)).mean(dim = (0, 1))
+ self.last_teacher_region_centers.copy_(teacher_region_logits_avg)
+
+ view_loss = (view_loss_fn_(teacher_view_proj_one, student_view_proj_two) \
+ + view_loss_fn_(teacher_view_proj_two, student_view_proj_one)) / 2
+
+ # calculate region-level loss
+
+ region_loss = (region_loss_fn_(teacher_region_proj_one, student_region_proj_two, teacher_latent_one, student_latent_two) \
+ + region_loss_fn_(teacher_region_proj_two, student_region_proj_one, teacher_latent_two, student_latent_one)) / 2
+
+ return (view_loss + region_loss) / 2