diff --git a/README.md b/README.md index e14aad2..8d6971d 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ - [Parallel ViT](#parallel-vit) - [Learnable Memory ViT](#learnable-memory-vit) - [Dino](#dino) +- [EsViT](#esvit) - [Accessing Attention](#accessing-attention) - [Research Ideas](#research-ideas) * [Efficient Attention](#efficient-attention) @@ -1076,6 +1077,80 @@ for _ in range(100): torch.save(model.state_dict(), './pretrained-net.pt') ``` +## EsViT + + + +`EsViT` is a variant of Dino (from above) re-engineered to support efficient `ViT`s with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it `outperforms its supervised counterpart on 17 out of 18 datasets` at 3 times higher throughput. + +Even though it is named as though it were a new `ViT` variant, it actually is just a strategy for training any multistage `ViT` (in the paper, they focused on Swin). The example below will show how to use it with `CvT`. You'll need to set the `hidden_layer` to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits. + +```python +import torch +from vit_pytorch.cvt import CvT +from vit_pytorch.es_vit import EsViTTrainer + +cvt = CvT( + num_classes = 1000, + s1_emb_dim = 64, + s1_emb_kernel = 7, + s1_emb_stride = 4, + s1_proj_kernel = 3, + s1_kv_proj_stride = 2, + s1_heads = 1, + s1_depth = 1, + s1_mlp_mult = 4, + s2_emb_dim = 192, + s2_emb_kernel = 3, + s2_emb_stride = 2, + s2_proj_kernel = 3, + s2_kv_proj_stride = 2, + s2_heads = 3, + s2_depth = 2, + s2_mlp_mult = 4, + s3_emb_dim = 384, + s3_emb_kernel = 3, + s3_emb_stride = 2, + s3_proj_kernel = 3, + s3_kv_proj_stride = 2, + s3_heads = 4, + s3_depth = 10, + s3_mlp_mult = 4, + dropout = 0. +) + +learner = EsViTTrainer( + cvt, + image_size = 256, + hidden_layer = 'layers', # hidden layer name or index, from which to extract the embedding + projection_hidden_size = 256, # projector network hidden dimension + projection_layers = 4, # number of layers in projection network + num_classes_K = 65336, # output logits dimensions (referenced as K in paper) + student_temp = 0.9, # student temperature + teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs + local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper + global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper + moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok + center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok +) + +opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4) + +def sample_unlabelled_images(): + return torch.randn(8, 3, 256, 256) + +for _ in range(1000): + images = sample_unlabelled_images() + loss = learner(images) + opt.zero_grad() + loss.backward() + opt.step() + learner.update_moving_average() # update moving average of teacher encoder and teacher centers + +# save your improved network +torch.save(cvt.state_dict(), './pretrained-net.pt') +``` + ## Accessing Attention If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below @@ -1584,6 +1659,16 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@article{Li2021EfficientSV, + title = {Efficient Self-supervised Vision Transformers for Representation Learning}, + author = {Chunyuan Li and Jianwei Yang and Pengchuan Zhang and Mei Gao and Bin Xiao and Xiyang Dai and Lu Yuan and Jianfeng Gao}, + journal = {ArXiv}, + year = {2021}, + volume = {abs/2106.09785} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/images/esvit.png b/images/esvit.png new file mode 100644 index 0000000..5acdaf4 Binary files /dev/null and b/images/esvit.png differ diff --git a/setup.py b/setup.py index e4474b4..6df8283 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.33.2', + version = '0.34.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/cvt.py b/vit_pytorch/cvt.py index add9ce9..2750284 100644 --- a/vit_pytorch/cvt.py +++ b/vit_pytorch/cvt.py @@ -164,12 +164,14 @@ class CvT(nn.Module): dim = config['emb_dim'] - self.layers = nn.Sequential( - *layers, + self.layers = nn.Sequential(*layers) + + self.to_logits = nn.Sequential( nn.AdaptiveAvgPool2d(1), Rearrange('... () () -> ...'), nn.Linear(dim, num_classes) ) def forward(self, x): - return self.layers(x) + latents = self.layers(x) + return self.to_logits(latents) diff --git a/vit_pytorch/es_vit.py b/vit_pytorch/es_vit.py new file mode 100644 index 0000000..4db6bcb --- /dev/null +++ b/vit_pytorch/es_vit.py @@ -0,0 +1,367 @@ +import copy +import random +from functools import wraps, partial + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from torchvision import transforms as T + +from einops import rearrange, reduce, repeat + +# helper functions + +def exists(val): + return val is not None + +def default(val, default): + return val if exists(val) else default + +def singleton(cache_key): + def inner_fn(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + instance = getattr(self, cache_key) + if instance is not None: + return instance + + instance = fn(self, *args, **kwargs) + setattr(self, cache_key, instance) + return instance + return wrapper + return inner_fn + +def get_module_device(module): + return next(module.parameters()).device + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + +# tensor related helpers + +def log(t, eps = 1e-20): + return torch.log(t + eps) + +# loss function # (algorithm 1 in the paper) + +def view_loss_fn( + teacher_logits, + student_logits, + teacher_temp, + student_temp, + centers, + eps = 1e-20 +): + teacher_logits = teacher_logits.detach() + student_probs = (student_logits / student_temp).softmax(dim = -1) + teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1) + return - (teacher_probs * log(student_probs, eps)).sum(dim = -1).mean() + +def region_loss_fn( + teacher_logits, + student_logits, + teacher_latent, + student_latent, + teacher_temp, + student_temp, + centers, + eps = 1e-20 +): + teacher_logits = teacher_logits.detach() + student_probs = (student_logits / student_temp).softmax(dim = -1) + teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1) + + sim_matrix = einsum('b i d, b j d -> b i j', student_latent, teacher_latent) + sim_indices = sim_matrix.max(dim = -1).indices + sim_indices = repeat(sim_indices, 'b n -> b n k', k = teacher_probs.shape[-1]) + max_sim_teacher_probs = teacher_probs.gather(1, sim_indices) + + return - (max_sim_teacher_probs * log(student_probs, eps)).sum(dim = -1).mean() + +# augmentation utils + +class RandomApply(nn.Module): + def __init__(self, fn, p): + super().__init__() + self.fn = fn + self.p = p + + def forward(self, x): + if random.random() > self.p: + return x + return self.fn(x) + +# exponential moving average + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + +def update_moving_average(ema_updater, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = ema_updater.update_average(old_weight, up_weight) + +# MLP class for projector and predictor + +class L2Norm(nn.Module): + def forward(self, x, eps = 1e-6): + return F.normalize(x, dim = 1, eps = eps) + +class MLP(nn.Module): + def __init__(self, dim, dim_out, num_layers, hidden_size = 256): + super().__init__() + + layers = [] + dims = (dim, *((hidden_size,) * (num_layers - 1))) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 1) + + layers.extend([ + nn.Linear(layer_dim_in, layer_dim_out), + nn.GELU() if not is_last else nn.Identity() + ]) + + self.net = nn.Sequential( + *layers, + L2Norm(), + nn.Linear(hidden_size, dim_out) + ) + + def forward(self, x): + return self.net(x) + +# a wrapper class for the base neural network +# will manage the interception of the hidden layer output +# and pipe it into the projecter and predictor nets + +class NetWrapper(nn.Module): + def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2): + super().__init__() + self.net = net + self.layer = layer + + self.view_projector = None + self.region_projector = None + self.projection_hidden_size = projection_hidden_size + self.projection_num_layers = projection_num_layers + self.output_dim = output_dim + + self.hidden = {} + self.hook_registered = False + + def _find_layer(self): + if type(self.layer) == str: + modules = dict([*self.net.named_modules()]) + return modules.get(self.layer, None) + elif type(self.layer) == int: + children = [*self.net.children()] + return children[self.layer] + return None + + def _hook(self, _, input, output): + device = input[0].device + self.hidden[device] = output + + def _register_hook(self): + layer = self._find_layer() + assert layer is not None, f'hidden layer ({self.layer}) not found' + handle = layer.register_forward_hook(self._hook) + self.hook_registered = True + + @singleton('view_projector') + def _get_view_projector(self, hidden): + dim = hidden.shape[1] + projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size) + return projector.to(hidden) + + @singleton('region_projector') + def _get_region_projector(self, hidden): + dim = hidden.shape[1] + projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size) + return projector.to(hidden) + + def get_embedding(self, x): + if self.layer == -1: + return self.net(x) + + if not self.hook_registered: + self._register_hook() + + self.hidden.clear() + _ = self.net(x) + hidden = self.hidden[x.device] + self.hidden.clear() + + assert hidden is not None, f'hidden layer {self.layer} never emitted an output' + return hidden + + def forward(self, x, return_projection = True): + region_latents = self.get_embedding(x) + global_latent = reduce(region_latents, 'b c h w -> b c', 'mean') + + if not return_projection: + return global_latent, region_latents + + view_projector = self._get_view_projector(global_latent) + region_projector = self._get_region_projector(region_latents) + + region_latents = rearrange(region_latents, 'b c h w -> b (h w) c') + + return view_projector(global_latent), region_projector(region_latents), region_latents + +# main class + +class EsViTTrainer(nn.Module): + def __init__( + self, + net, + image_size, + hidden_layer = -2, + projection_hidden_size = 256, + num_classes_K = 65336, + projection_layers = 4, + student_temp = 0.9, + teacher_temp = 0.04, + local_upper_crop_scale = 0.4, + global_lower_crop_scale = 0.5, + moving_average_decay = 0.9, + center_moving_average_decay = 0.9, + augment_fn = None, + augment_fn2 = None + ): + super().__init__() + self.net = net + + # default BYOL augmentation + + DEFAULT_AUG = torch.nn.Sequential( + RandomApply( + T.ColorJitter(0.8, 0.8, 0.8, 0.2), + p = 0.3 + ), + T.RandomGrayscale(p=0.2), + T.RandomHorizontalFlip(), + RandomApply( + T.GaussianBlur((3, 3), (1.0, 2.0)), + p = 0.2 + ), + T.Normalize( + mean=torch.tensor([0.485, 0.456, 0.406]), + std=torch.tensor([0.229, 0.224, 0.225])), + ) + + self.augment1 = default(augment_fn, DEFAULT_AUG) + self.augment2 = default(augment_fn2, DEFAULT_AUG) + + # local and global crops + + self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale)) + self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.)) + + self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer) + + self.teacher_encoder = None + self.teacher_ema_updater = EMA(moving_average_decay) + + self.register_buffer('teacher_view_centers', torch.zeros(1, num_classes_K)) + self.register_buffer('last_teacher_view_centers', torch.zeros(1, num_classes_K)) + + self.register_buffer('teacher_region_centers', torch.zeros(1, num_classes_K)) + self.register_buffer('last_teacher_region_centers', torch.zeros(1, num_classes_K)) + + self.teacher_centering_ema_updater = EMA(center_moving_average_decay) + + self.student_temp = student_temp + self.teacher_temp = teacher_temp + + # get device of network and make wrapper same device + device = get_module_device(net) + self.to(device) + + # send a mock image tensor to instantiate singleton parameters + self.forward(torch.randn(2, 3, image_size, image_size, device=device)) + + @singleton('teacher_encoder') + def _get_teacher_encoder(self): + teacher_encoder = copy.deepcopy(self.student_encoder) + set_requires_grad(teacher_encoder, False) + return teacher_encoder + + def reset_moving_average(self): + del self.teacher_encoder + self.teacher_encoder = None + + def update_moving_average(self): + assert self.teacher_encoder is not None, 'target encoder has not been created yet' + update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder) + + new_teacher_view_centers = self.teacher_centering_ema_updater.update_average(self.teacher_view_centers, self.last_teacher_view_centers) + self.teacher_view_centers.copy_(new_teacher_view_centers) + + new_teacher_region_centers = self.teacher_centering_ema_updater.update_average(self.teacher_region_centers, self.last_teacher_region_centers) + self.teacher_region_centers.copy_(new_teacher_region_centers) + + def forward( + self, + x, + return_embedding = False, + return_projection = True, + student_temp = None, + teacher_temp = None + ): + if return_embedding: + return self.student_encoder(x, return_projection = return_projection) + + image_one, image_two = self.augment1(x), self.augment2(x) + + local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two) + global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two) + + student_view_proj_one, student_region_proj_one, student_latent_one = self.student_encoder(local_image_one) + student_view_proj_two, student_region_proj_two, student_latent_two = self.student_encoder(local_image_two) + + with torch.no_grad(): + teacher_encoder = self._get_teacher_encoder() + teacher_view_proj_one, teacher_region_proj_one, teacher_latent_one = teacher_encoder(global_image_one) + teacher_view_proj_two, teacher_region_proj_two, teacher_latent_two = teacher_encoder(global_image_two) + + view_loss_fn_ = partial( + view_loss_fn, + student_temp = default(student_temp, self.student_temp), + teacher_temp = default(teacher_temp, self.teacher_temp), + centers = self.teacher_view_centers + ) + + region_loss_fn_ = partial( + region_loss_fn, + student_temp = default(student_temp, self.student_temp), + teacher_temp = default(teacher_temp, self.teacher_temp), + centers = self.teacher_region_centers + ) + + # calculate view-level loss + + teacher_view_logits_avg = torch.cat((teacher_view_proj_one, teacher_view_proj_two)).mean(dim = 0) + self.last_teacher_view_centers.copy_(teacher_view_logits_avg) + + teacher_region_logits_avg = torch.cat((teacher_region_proj_one, teacher_region_proj_two)).mean(dim = (0, 1)) + self.last_teacher_region_centers.copy_(teacher_region_logits_avg) + + view_loss = (view_loss_fn_(teacher_view_proj_one, student_view_proj_two) \ + + view_loss_fn_(teacher_view_proj_two, student_view_proj_one)) / 2 + + # calculate region-level loss + + region_loss = (region_loss_fn_(teacher_region_proj_one, student_region_proj_two, teacher_latent_one, student_latent_two) \ + + region_loss_fn_(teacher_region_proj_two, student_region_proj_one, teacher_latent_two, student_latent_one)) / 2 + + return (view_loss + region_loss) / 2