for n-dimensional vit, have a method for fetching muon friendly parameters

This commit is contained in:
lucidrains
2025-10-13 12:07:48 -07:00
parent 3becf087bb
commit 39fd9ac8be
2 changed files with 23 additions and 4 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "vit-pytorch"
version = "1.12.4"
version = "1.12.5"
description = "Vision Transformer (ViT) - Pytorch"
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }

View File

@@ -126,8 +126,9 @@ class Attention(Module):
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
@@ -135,7 +136,8 @@ class Attention(Module):
def forward(self, x, pos = None):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# Apply rotary embeddings if available
@@ -245,6 +247,23 @@ class ViTND(Module):
self.to_latent = nn.Identity()
self.mlp_head = nn.Linear(dim, num_classes)
def muon_parameters(self):
params = []
for m in self.modules():
if isinstance(m, Attention):
params.extend([
m.to_v.weight,
m.to_out[0].weight
])
elif isinstance(m, FeedForward):
params.extend([
m.net[1].weight,
m.net[-2].weight
])
return params
def forward(
self,
x,