diff --git a/pyproject.toml b/pyproject.toml index d606fe2..ceca7fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.15.2" +version = "1.15.3" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vit_with_decorr.py b/vit_pytorch/vit_with_decorr.py index 6982835..587b282 100644 --- a/vit_pytorch/vit_with_decorr.py +++ b/vit_pytorch/vit_with_decorr.py @@ -2,7 +2,7 @@ # but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss import torch -from torch import nn, stack +from torch import nn, stack, tensor import torch.nn.functional as F from torch.nn import Module, ModuleList @@ -25,13 +25,17 @@ def pair(t): class DecorrelationLoss(Module): def __init__( self, - sample_frac = 1. + sample_frac = 1., + soft_validate_num_sampled = False ): super().__init__() assert 0. <= sample_frac <= 1. self.need_sample = sample_frac < 1. self.sample_frac = sample_frac + self.soft_validate_num_sampled = soft_validate_num_sampled + self.register_buffer('zero', tensor(0.), persistent = False) + def forward( self, tokens @@ -40,7 +44,11 @@ class DecorrelationLoss(Module): if self.need_sample: num_sampled = int(seq_len * self.sample_frac) - assert num_sampled >= 2. + + assert self.soft_validate_num_sampled or num_sampled >= 2. + + if num_sampled <= 1: + return self.zero tokens, packed_shape = pack([tokens], '* n d e') @@ -220,3 +228,7 @@ if __name__ == '__main__': decorr_loss(hiddens) decorr_loss(hiddens[0]) + + decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True) + out = decorr_loss(hiddens) + assert out.item() == 0 \ No newline at end of file