From 141239ca86afc6e1fe6f4e50b60d173e21ca38ec Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 31 Oct 2024 06:48:24 -0700 Subject: [PATCH] fix value residual --- setup.py | 2 +- vit_pytorch/simple_vit_with_value_residual.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 31ce3ac..c511aee 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.8.6', + version = '1.8.7', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/simple_vit_with_value_residual.py b/vit_pytorch/simple_vit_with_value_residual.py index 392ed96..b87d1f5 100644 --- a/vit_pytorch/simple_vit_with_value_residual.py +++ b/vit_pytorch/simple_vit_with_value_residual.py @@ -57,7 +57,7 @@ class Attention(Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) if exists(value_residual): - v = v + value_residual + v = 0.5 * (v + value_residual) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale