mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
rotary needs to be done with full precision to be safe
This commit is contained in:
2
setup.py
2
setup.py
@@ -6,7 +6,7 @@ with open('README.md') as f:
|
|||||||
setup(
|
setup(
|
||||||
name = 'vit-pytorch',
|
name = 'vit-pytorch',
|
||||||
packages = find_packages(exclude=['examples']),
|
packages = find_packages(exclude=['examples']),
|
||||||
version = '1.6.8',
|
version = '1.6.9',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'Vision Transformer (ViT) - Pytorch',
|
description = 'Vision Transformer (ViT) - Pytorch',
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ from math import sqrt, pi, log
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
# rotary embeddings
|
# rotary embeddings
|
||||||
|
|
||||||
|
@autocast(enabled = False)
|
||||||
def rotate_every_two(x):
|
def rotate_every_two(x):
|
||||||
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
||||||
x1, x2 = x.unbind(dim = -1)
|
x1, x2 = x.unbind(dim = -1)
|
||||||
@@ -22,6 +24,7 @@ class AxialRotaryEmbedding(nn.Module):
|
|||||||
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
|
||||||
self.register_buffer('scales', scales)
|
self.register_buffer('scales', scales)
|
||||||
|
|
||||||
|
@autocast(enabled = False)
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user