From 90be7233a3f55c29692a72da6ee4dcb5aab267d4 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 11 May 2024 08:04:14 -0700 Subject: [PATCH] rotary needs to be done with full precision to be safe --- setup.py | 2 +- vit_pytorch/rvt.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a704c7d..5ad1dde 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open('README.md') as f: setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.6.8', + version = '1.6.9', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 1ad51dc..1d7559c 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -3,12 +3,14 @@ from math import sqrt, pi, log import torch from torch import nn, einsum import torch.nn.functional as F +from torch.cuda.amp import autocast from einops import rearrange, repeat from einops.layers.torch import Rearrange # rotary embeddings +@autocast(enabled = False) def rotate_every_two(x): x = rearrange(x, '... (d j) -> ... d j', j = 2) x1, x2 = x.unbind(dim = -1) @@ -22,6 +24,7 @@ class AxialRotaryEmbedding(nn.Module): scales = torch.linspace(1., max_freq / 2, self.dim // 4) self.register_buffer('scales', scales) + @autocast(enabled = False) def forward(self, x): device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))