From 6c8dfc185ea41f4d2388e4d33bbb76f900ff8a0a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 13 Nov 2020 12:25:21 -0800 Subject: [PATCH] remove float(-inf) as masking value --- setup.py | 2 +- vit_pytorch/vit_pytorch.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index e3b22c9..ed4c7d1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.2.6', + version = '0.2.7', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/vit_pytorch.py b/vit_pytorch/vit_pytorch.py index e30c57f..1520fe3 100644 --- a/vit_pytorch/vit_pytorch.py +++ b/vit_pytorch/vit_pytorch.py @@ -51,12 +51,13 @@ class Attention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + mask_value = -torch.finfo(dots.dtype).max if mask is not None: mask = F.pad(mask.flatten(1), (1, 0), value = True) assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' mask = mask[:, None, :] * mask[:, :, None] - dots.masked_fill_(~mask, float('-inf')) + dots.masked_fill_(~mask, mask_value) del mask attn = dots.softmax(dim=-1)