mirror of
https://github.com/lucidrains/vit-pytorch.git
synced 2025-12-29 23:52:27 +00:00
108 lines
1.9 KiB
Python
108 lines
1.9 KiB
Python
# /// script
|
|
# dependencies = [
|
|
# "accelerate",
|
|
# "vit-pytorch",
|
|
# "wandb"
|
|
# ]
|
|
# ///
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
|
|
import torchvision.transforms as T
|
|
from torchvision.datasets import CIFAR100
|
|
|
|
# constants
|
|
|
|
BATCH_SIZE = 32
|
|
LEARNING_RATE = 3e-4
|
|
EPOCHS = 10
|
|
DECORR_LOSS_WEIGHT = 1e-1
|
|
|
|
TRACK_EXPERIMENT_ONLINE = False
|
|
|
|
# helpers
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
# data
|
|
|
|
transform = T.Compose([
|
|
T.ToTensor(),
|
|
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
|
|
dataset = CIFAR100(
|
|
root = 'data',
|
|
download = True,
|
|
train = True,
|
|
transform = transform
|
|
)
|
|
|
|
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
|
|
|
|
# model
|
|
|
|
from vit_pytorch.vit_with_decorr import ViT
|
|
|
|
vit = ViT(
|
|
dim = 128,
|
|
num_classes = 100,
|
|
image_size = 32,
|
|
patch_size = 4,
|
|
depth = 6,
|
|
heads = 8,
|
|
dim_head = 64,
|
|
mlp_dim = 128 * 4,
|
|
decorr_sample_frac = 1. # use all tokens
|
|
)
|
|
|
|
# optim
|
|
|
|
from torch.optim import Adam
|
|
|
|
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
|
|
|
|
# prepare
|
|
|
|
from accelerate import Accelerator
|
|
|
|
accelerator = Accelerator()
|
|
|
|
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
|
|
|
|
# experiment
|
|
|
|
import wandb
|
|
|
|
wandb.init(
|
|
project = 'vit-decorr',
|
|
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
|
|
)
|
|
|
|
wandb.run.name = 'baseline'
|
|
|
|
# loop
|
|
|
|
for _ in range(EPOCHS):
|
|
for images, labels in dataloader:
|
|
|
|
logits, decorr_aux_loss = vit(images)
|
|
loss = F.cross_entropy(logits, labels)
|
|
|
|
|
|
total_loss = (
|
|
loss +
|
|
decorr_aux_loss * DECORR_LOSS_WEIGHT
|
|
)
|
|
|
|
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
|
|
|
|
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
|
|
|
|
accelerator.backward(total_loss)
|
|
optim.step()
|
|
optim.zero_grad()
|