remove float(-inf) as masking value

This commit is contained in:
Phil Wang
2020-11-13 12:25:21 -08:00
parent 4f84ad7a64
commit 6c8dfc185e
2 changed files with 3 additions and 2 deletions

View File

@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.2.6',
version = '0.2.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',

View File

@@ -51,12 +51,13 @@ class Attention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
mask_value = -torch.finfo(dots.dtype).max
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
dots.masked_fill_(~mask, mask_value)
del mask
attn = dots.softmax(dim=-1)