diff --git a/vit_pytorch/vit_1d.py b/vit_pytorch/vit_1d.py index c67e135..7213422 100644 --- a/vit_pytorch/vit_1d.py +++ b/vit_pytorch/vit_1d.py @@ -10,7 +10,7 @@ class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout = 0.): super().__init__() self.net = nn.Sequential( - nn.Layernorm(dim), + nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout),