mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-29 23:52:27 +00:00
an option to return zero for decorr aux loss if insufficient samples
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vit-pytorch"
|
||||
version = "1.15.2"
|
||||
version = "1.15.3"
|
||||
description = "Vision Transformer (ViT) - Pytorch"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
license = { file = "LICENSE" }
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
|
||||
|
||||
import torch
|
||||
from torch import nn, stack
|
||||
from torch import nn, stack, tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
@@ -25,13 +25,17 @@ def pair(t):
|
||||
class DecorrelationLoss(Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_frac = 1.
|
||||
sample_frac = 1.,
|
||||
soft_validate_num_sampled = False
|
||||
):
|
||||
super().__init__()
|
||||
assert 0. <= sample_frac <= 1.
|
||||
self.need_sample = sample_frac < 1.
|
||||
self.sample_frac = sample_frac
|
||||
|
||||
self.soft_validate_num_sampled = soft_validate_num_sampled
|
||||
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens
|
||||
@@ -40,7 +44,11 @@ class DecorrelationLoss(Module):
|
||||
|
||||
if self.need_sample:
|
||||
num_sampled = int(seq_len * self.sample_frac)
|
||||
assert num_sampled >= 2.
|
||||
|
||||
assert self.soft_validate_num_sampled or num_sampled >= 2.
|
||||
|
||||
if num_sampled <= 1:
|
||||
return self.zero
|
||||
|
||||
tokens, packed_shape = pack([tokens], '* n d e')
|
||||
|
||||
@@ -220,3 +228,7 @@ if __name__ == '__main__':
|
||||
|
||||
decorr_loss(hiddens)
|
||||
decorr_loss(hiddens[0])
|
||||
|
||||
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
|
||||
out = decorr_loss(hiddens)
|
||||
assert out.item() == 0
|
||||
Reference in New Issue
Block a user