diff --git a/pyproject.toml b/pyproject.toml index 58336e8..db230fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vit-pytorch" -version = "1.12.4" +version = "1.12.5" description = "Vision Transformer (ViT) - Pytorch" readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" } diff --git a/vit_pytorch/vit_nd_rotary.py b/vit_pytorch/vit_nd_rotary.py index afe3158..84dde6d 100644 --- a/vit_pytorch/vit_nd_rotary.py +++ b/vit_pytorch/vit_nd_rotary.py @@ -126,8 +126,9 @@ class Attention(Module): self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) - + self.to_qk = nn.Linear(dim, inner_dim * 2, bias = False) + self.to_v = nn.Linear(dim, inner_dim, bias = False) + self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) @@ -135,7 +136,8 @@ class Attention(Module): def forward(self, x, pos = None): x = self.norm(x) - qkv = self.to_qkv(x).chunk(3, dim = -1) + qkv = (*self.to_qk(x).chunk(2, dim = -1), self.to_v(x)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) # Apply rotary embeddings if available @@ -245,6 +247,23 @@ class ViTND(Module): self.to_latent = nn.Identity() self.mlp_head = nn.Linear(dim, num_classes) + def muon_parameters(self): + params = [] + + for m in self.modules(): + if isinstance(m, Attention): + params.extend([ + m.to_v.weight, + m.to_out[0].weight + ]) + elif isinstance(m, FeedForward): + params.extend([ + m.net[1].weight, + m.net[-2].weight + ]) + + return params + def forward( self, x,