rotary needs to be done with full precision to be safe

This commit is contained in:
lucidrains
2024-05-11 08:04:14 -07:00
parent bca88e9039
commit 90be7233a3
2 changed files with 4 additions and 1 deletions

View File

@@ -6,7 +6,7 @@ with open('README.md') as f:
setup( setup(
name = 'vit-pytorch', name = 'vit-pytorch',
packages = find_packages(exclude=['examples']), packages = find_packages(exclude=['examples']),
version = '1.6.8', version = '1.6.9',
license='MIT', license='MIT',
description = 'Vision Transformer (ViT) - Pytorch', description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description, long_description=long_description,

View File

@@ -3,12 +3,14 @@ from math import sqrt, pi, log
import torch import torch
from torch import nn, einsum from torch import nn, einsum
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.amp import autocast
from einops import rearrange, repeat from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
# rotary embeddings # rotary embeddings
@autocast(enabled = False)
def rotate_every_two(x): def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2) x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1) x1, x2 = x.unbind(dim = -1)
@@ -22,6 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
scales = torch.linspace(1., max_freq / 2, self.dim // 4) scales = torch.linspace(1., max_freq / 2, self.dim // 4)
self.register_buffer('scales', scales) self.register_buffer('scales', scales)
@autocast(enabled = False)
def forward(self, x): def forward(self, x):
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2])) device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))