diff --git a/README.md b/README.md index 232de0b..673a8c0 100644 --- a/README.md +++ b/README.md @@ -422,6 +422,56 @@ for _ in range(100): torch.save(model.state_dict(), './pretrained-net.pt') ``` +## Dino + +You can train `ViT` with the recent SOTA self-supervised learning technique, Dino, with the following code. + +```python +import torch +from vit_pytorch import ViT, Dino + +model = ViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 8, + mlp_dim = 2048 +) + +learner = Dino( + model, + image_size = 256, + hidden_layer = 'to_latent', # hidden layer name or index, from which to extract the embedding + projection_hidden_size = 256, # projector network hidden dimension + projection_layers = 4, # number of layers in projection network + num_classes_K = 65336, # output logits dimensions (referenced as K in paper) + student_temp = 0.9, # student temperature + teacher_temp = 0.04, # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs + local_upper_crop_scale = 0.4, # upper bound for local crop - 0.4 was recommended in the paper + global_lower_crop_scale = 0.5, # lower bound for global crop - 0.5 was recommended in the paper + moving_average_decay = 0.9, # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok + center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok +) + +opt = torch.optim.Adam(learner.parameters(), lr = 3e-4) + +def sample_unlabelled_images(): + return torch.randn(20, 3, 256, 256) + +for _ in range(100): + images = sample_unlabelled_images() + loss = learner(images) + opt.zero_grad() + loss.backward() + opt.step() + learner.update_moving_average() # update moving average of teacher encoder and teacher centers + +# save your improved network +torch.save(model.state_dict(), './pretrained-net.pt') +``` + ## Accessing Attention If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below @@ -465,56 +515,6 @@ v = v.eject() # wrapper is discarded and original ViT instance is returned ## Research Ideas -### Self Supervised Training - -You can train this with a near SOTA self-supervised learning technique, BYOL, with the following code. - -(1) -```bash -$ pip install byol-pytorch -``` - -(2) -```python -import torch -from vit_pytorch import ViT -from byol_pytorch import BYOL - -model = ViT( - image_size = 256, - patch_size = 32, - num_classes = 1000, - dim = 1024, - depth = 6, - heads = 8, - mlp_dim = 2048 -) - -learner = BYOL( - model, - image_size = 256, - hidden_layer = 'to_latent' -) - -opt = torch.optim.Adam(learner.parameters(), lr=3e-4) - -def sample_unlabelled_images(): - return torch.randn(20, 3, 256, 256) - -for _ in range(100): - images = sample_unlabelled_images() - loss = learner(images) - opt.zero_grad() - loss.backward() - opt.step() - learner.update_moving_average() # update moving average of target encoder - -# save your improved network -torch.save(model.state_dict(), './pretrained-net.pt') -``` - -A pytorch-lightning script is ready for you to use at the repository link above. - ### Efficient Attention There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer. @@ -781,6 +781,17 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@misc{caron2021emerging, + title = {Emerging Properties in Self-Supervised Vision Transformers}, + author = {Mathilde Caron and Hugo Touvron and Ishan Misra and Hervé Jégou and Julien Mairal and Piotr Bojanowski and Armand Joulin}, + year = {2021}, + eprint = {2104.14294}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + ```bibtex @misc{vaswani2017attention, title = {Attention Is All You Need}, diff --git a/setup.py b/setup.py index 9922c4b..c60f3cf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.17.3', + version = '0.18.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', @@ -15,8 +15,9 @@ setup( 'image recognition' ], install_requires=[ + 'einops>=0.3', 'torch>=1.6', - 'einops>=0.3' + 'torchvision' ], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/vit_pytorch/__init__.py b/vit_pytorch/__init__.py index 1edbfeb..7928faf 100644 --- a/vit_pytorch/__init__.py +++ b/vit_pytorch/__init__.py @@ -1 +1,2 @@ from vit_pytorch.vit import ViT +from vit_pytorch.dino import Dino diff --git a/vit_pytorch/dino.py b/vit_pytorch/dino.py new file mode 100644 index 0000000..42fdf7e --- /dev/null +++ b/vit_pytorch/dino.py @@ -0,0 +1,303 @@ +import copy +import random +from functools import wraps, partial + +import torch +from torch import nn +import torch.nn.functional as F + +from torchvision import transforms as T + +# helper functions + +def exists(val): + return val is not None + +def default(val, default): + return val if exists(val) else default + +def singleton(cache_key): + def inner_fn(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + instance = getattr(self, cache_key) + if instance is not None: + return instance + + instance = fn(self, *args, **kwargs) + setattr(self, cache_key, instance) + return instance + return wrapper + return inner_fn + +def get_module_device(module): + return next(module.parameters()).device + +def set_requires_grad(model, val): + for p in model.parameters(): + p.requires_grad = val + +# loss function # (algorithm 1 in the paper) + +def loss_fn( + teacher_logits, + student_logits, + teacher_temp, + student_temp, + centers, + eps = 1e-20 +): + teacher_logits = teacher_logits.detach() + student_probs = (student_logits / student_temp).softmax(dim = -1) + teacher_probs = ((teacher_logits - centers) / teacher_temp).softmax(dim = -1) + return - (teacher_probs * torch.log(student_probs + eps)).sum(dim = -1).mean() + +# augmentation utils + +class RandomApply(nn.Module): + def __init__(self, fn, p): + super().__init__() + self.fn = fn + self.p = p + + def forward(self, x): + if random.random() > self.p: + return x + return self.fn(x) + +# exponential moving average + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + +def update_moving_average(ema_updater, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = ema_updater.update_average(old_weight, up_weight) + +# MLP class for projector and predictor + +class L2Norm(nn.Module): + def forward(self, x, eps = 1e-6): + norm = x.norm(dim = 1, keepdim = True).clamp(min = eps) + return x / norm + +class MLP(nn.Module): + def __init__(self, dim, dim_out, num_layers, hidden_size = 256): + super().__init__() + + layers = [] + dims = (dim, *((hidden_size,) * (num_layers - 1))) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 1) + + layers.extend([ + nn.Linear(layer_dim_in, layer_dim_out), + nn.GELU() if not is_last else nn.Identity() + ]) + + self.net = nn.Sequential( + *layers, + L2Norm(), + nn.Linear(hidden_size, dim_out) + ) + + def forward(self, x): + return self.net(x) + +# a wrapper class for the base neural network +# will manage the interception of the hidden layer output +# and pipe it into the projecter and predictor nets + +class NetWrapper(nn.Module): + def __init__(self, net, output_dim, projection_hidden_size, projection_num_layers, layer = -2): + super().__init__() + self.net = net + self.layer = layer + + self.projector = None + self.projection_hidden_size = projection_hidden_size + self.projection_num_layers = projection_num_layers + self.output_dim = output_dim + + self.hidden = {} + self.hook_registered = False + + def _find_layer(self): + if type(self.layer) == str: + modules = dict([*self.net.named_modules()]) + return modules.get(self.layer, None) + elif type(self.layer) == int: + children = [*self.net.children()] + return children[self.layer] + return None + + def _hook(self, _, input, output): + device = input[0].device + self.hidden[device] = output.flatten(1) + + def _register_hook(self): + layer = self._find_layer() + assert layer is not None, f'hidden layer ({self.layer}) not found' + handle = layer.register_forward_hook(self._hook) + self.hook_registered = True + + @singleton('projector') + def _get_projector(self, hidden): + _, dim = hidden.shape + projector = MLP(dim, self.output_dim, self.projection_num_layers, self.projection_hidden_size) + return projector.to(hidden) + + def get_embedding(self, x): + if self.layer == -1: + return self.net(x) + + if not self.hook_registered: + self._register_hook() + + self.hidden.clear() + _ = self.net(x) + hidden = self.hidden[x.device] + self.hidden.clear() + + assert hidden is not None, f'hidden layer {self.layer} never emitted an output' + return hidden + + def forward(self, x, return_projection = True): + embed = self.get_embedding(x) + if not return_projection: + return embed + + projector = self._get_projector(embed) + return projector(embed), embed + +# main class + +class Dino(nn.Module): + def __init__( + self, + net, + image_size, + hidden_layer = -2, + projection_hidden_size = 256, + num_classes_K = 65336, + projection_layers = 4, + student_temp = 0.9, + teacher_temp = 0.04, + local_upper_crop_scale = 0.4, + global_lower_crop_scale = 0.5, + moving_average_decay = 0.9, + center_moving_average_decay = 0.9, + augment_fn = None, + augment_fn2 = None + ): + super().__init__() + self.net = net + + # default BYOL augmentation + + DEFAULT_AUG = torch.nn.Sequential( + RandomApply( + T.ColorJitter(0.8, 0.8, 0.8, 0.2), + p = 0.3 + ), + T.RandomGrayscale(p=0.2), + T.RandomHorizontalFlip(), + RandomApply( + T.GaussianBlur((3, 3), (1.0, 2.0)), + p = 0.2 + ), + T.Normalize( + mean=torch.tensor([0.485, 0.456, 0.406]), + std=torch.tensor([0.229, 0.224, 0.225])), + ) + + self.augment1 = default(augment_fn, DEFAULT_AUG) + self.augment2 = default(augment_fn2, DEFAULT_AUG) + + # local and global crops + + self.local_crop = T.RandomResizedCrop((image_size, image_size), scale = (0.05, local_upper_crop_scale)) + self.global_crop = T.RandomResizedCrop((image_size, image_size), scale = (global_lower_crop_scale, 1.)) + + self.student_encoder = NetWrapper(net, num_classes_K, projection_hidden_size, projection_layers, layer = hidden_layer) + + self.teacher_encoder = None + self.teacher_ema_updater = EMA(moving_average_decay) + + self.register_buffer('teacher_centers', torch.zeros(1, num_classes_K)) + self.register_buffer('last_teacher_centers', torch.zeros(1, num_classes_K)) + + self.teacher_centering_ema_updater = EMA(center_moving_average_decay) + + self.student_temp = student_temp + self.teacher_temp = teacher_temp + + # get device of network and make wrapper same device + device = get_module_device(net) + self.to(device) + + # send a mock image tensor to instantiate singleton parameters + self.forward(torch.randn(2, 3, image_size, image_size, device=device)) + + @singleton('teacher_encoder') + def _get_teacher_encoder(self): + teacher_encoder = copy.deepcopy(self.student_encoder) + set_requires_grad(teacher_encoder, False) + return teacher_encoder + + def reset_moving_average(self): + del self.teacher_encoder + self.teacher_encoder = None + + def update_moving_average(self): + assert self.teacher_encoder is not None, 'target encoder has not been created yet' + update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder) + + new_teacher_centers = self.teacher_centering_ema_updater.update_average(self.teacher_centers, self.last_teacher_centers) + self.teacher_centers.copy_(new_teacher_centers) + + def forward( + self, + x, + return_embedding = False, + return_projection = True, + student_temp = None, + teacher_temp = None + ): + if return_embedding: + return self.student_encoder(x, return_projection = return_projection) + + image_one, image_two = self.augment1(x), self.augment2(x) + + local_image_one, local_image_two = self.local_crop(image_one), self.local_crop(image_two) + global_image_one, global_image_two = self.global_crop(image_one), self.global_crop(image_two) + + student_proj_one, _ = self.student_encoder(local_image_one) + student_proj_two, _ = self.student_encoder(local_image_two) + + with torch.no_grad(): + teacher_encoder = self._get_teacher_encoder() + teacher_proj_one, _ = teacher_encoder(global_image_one) + teacher_proj_two, _ = teacher_encoder(global_image_two) + + loss_fn_ = partial( + loss_fn, + student_temp = default(student_temp, self.student_temp), + teacher_temp = default(teacher_temp, self.teacher_temp), + centers = self.teacher_centers + ) + + teacher_logits_avg = torch.cat((teacher_proj_one, teacher_proj_two)).mean(dim = 0) + self.last_teacher_centers.copy_(teacher_logits_avg) + + loss = (loss_fn_(teacher_proj_one, student_proj_two) + loss_fn_(teacher_proj_two, student_proj_one)) / 2 + return loss