an option to return zero for decorr aux loss if insufficient samples

This commit is contained in:
lucidrains
2025-11-09 10:08:06 -08:00
parent 5cf8384c56
commit 4386742cd1
2 changed files with 16 additions and 4 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.15.2"
version = "1.15.3"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -2,7 +2,7 @@
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
import torch
from torch import nn, stack
from torch import nn, stack, tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
@@ -25,13 +25,17 @@ def pair(t):
class DecorrelationLoss(Module):
def __init__(
self,
sample_frac = 1.
sample_frac = 1.,
soft_validate_num_sampled = False
):
super().__init__()
assert 0. <= sample_frac <= 1.
self.need_sample = sample_frac < 1.
self.sample_frac = sample_frac
self.soft_validate_num_sampled = soft_validate_num_sampled
self.register_buffer('zero', tensor(0.), persistent = False)
def forward(
self,
tokens
@@ -40,7 +44,11 @@ class DecorrelationLoss(Module):
if self.need_sample:
num_sampled = int(seq_len * self.sample_frac)
assert num_sampled >= 2.
assert self.soft_validate_num_sampled or num_sampled >= 2.
if num_sampled <= 1:
return self.zero
tokens, packed_shape = pack([tokens], '* n d e')
@@ -220,3 +228,7 @@ if __name__ == '__main__':
decorr_loss(hiddens)
decorr_loss(hiddens[0])
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
out = decorr_loss(hiddens)
assert out.item() == 0