mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
an option to return zero for decorr aux loss if insufficient samples
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "vit-pytorch"
|
name = "vit-pytorch"
|
||||||
version = "1.15.2"
|
version = "1.15.3"
|
||||||
description = "Vision Transformer (ViT) - Pytorch"
|
description = "Vision Transformer (ViT) - Pytorch"
|
||||||
readme = { file = "README.md", content-type = "text/markdown" }
|
readme = { file = "README.md", content-type = "text/markdown" }
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
|
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, stack
|
from torch import nn, stack, tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
|
|
||||||
@@ -25,13 +25,17 @@ def pair(t):
|
|||||||
class DecorrelationLoss(Module):
|
class DecorrelationLoss(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sample_frac = 1.
|
sample_frac = 1.,
|
||||||
|
soft_validate_num_sampled = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert 0. <= sample_frac <= 1.
|
assert 0. <= sample_frac <= 1.
|
||||||
self.need_sample = sample_frac < 1.
|
self.need_sample = sample_frac < 1.
|
||||||
self.sample_frac = sample_frac
|
self.sample_frac = sample_frac
|
||||||
|
|
||||||
|
self.soft_validate_num_sampled = soft_validate_num_sampled
|
||||||
|
self.register_buffer('zero', tensor(0.), persistent = False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
tokens
|
tokens
|
||||||
@@ -40,7 +44,11 @@ class DecorrelationLoss(Module):
|
|||||||
|
|
||||||
if self.need_sample:
|
if self.need_sample:
|
||||||
num_sampled = int(seq_len * self.sample_frac)
|
num_sampled = int(seq_len * self.sample_frac)
|
||||||
assert num_sampled >= 2.
|
|
||||||
|
assert self.soft_validate_num_sampled or num_sampled >= 2.
|
||||||
|
|
||||||
|
if num_sampled <= 1:
|
||||||
|
return self.zero
|
||||||
|
|
||||||
tokens, packed_shape = pack([tokens], '* n d e')
|
tokens, packed_shape = pack([tokens], '* n d e')
|
||||||
|
|
||||||
@@ -220,3 +228,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
decorr_loss(hiddens)
|
decorr_loss(hiddens)
|
||||||
decorr_loss(hiddens[0])
|
decorr_loss(hiddens[0])
|
||||||
|
|
||||||
|
decorr_loss = DecorrelationLoss(0.0001, soft_validate_num_sampled = True)
|
||||||
|
out = decorr_loss(hiddens)
|
||||||
|
assert out.item() == 0
|
||||||
Reference in New Issue
Block a user