mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-30 08:02:29 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a954089eb7 |
2
setup.py
2
setup.py
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
|
||||
setup(
|
||||
name = 'vit-pytorch',
|
||||
packages = find_packages(exclude=['examples']),
|
||||
version = '0.18.1',
|
||||
version = '0.18.2',
|
||||
license='MIT',
|
||||
description = 'Vision Transformer (ViT) - Pytorch',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -95,10 +95,12 @@ class Attention(nn.Module):
|
||||
qkv = (q, self.to_k(x), self.to_v(x))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
dots = self.apply_pos_bias(dots)
|
||||
|
||||
dots *= self.scale
|
||||
|
||||
attn = self.attend(dots)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
Reference in New Issue
Block a user