updated gradient accumulation

This commit is contained in:
XiangAn
2022-07-13 15:12:41 +08:00
parent e78eee59ca
commit 4f227ba5b4
6 changed files with 510 additions and 3 deletions

View File

@@ -1,7 +1,9 @@
# Distributed Arcface Training in Pytorch
This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
identity on a single server.
identity on a single server.
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-c)](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient)
## Requirements
@@ -38,8 +40,12 @@ Node 1:
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/webface42m_r100_lr01_pfc02_bs4k_16gpus
```
config.num_classes = 85742
config.num_image = 5822653
### 3. Run ViT-B on a machine with 24k batchsize:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12345 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
```
## Download Datasets or Prepare Datasets
- [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images)
@@ -83,6 +89,7 @@ globalised multi-racial testset contains 242,143 identities and 1,624,305 images
| WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) |
| WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) |
| WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) |
| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) |
#### 2. Training on Multi-Host GPU

View File

@@ -39,6 +39,8 @@ config.frequent = 10
# For Large Sacle Dataset, such as WebFace42M
config.dali = False
# Gradient ACC
config.gradient_acc = 1
# setup seed
config.seed = 2048

View File

@@ -0,0 +1,28 @@
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.margin_list = (1.0, 0.0, 0.4)
config.network = "vit_b_dp005_mask_005"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.3
config.fp16 = True
config.weight_decay = 0.1
config.batch_size = 256
config.gradient_acc = 12 # total batchsize is 256 * 12
config.optimizer = "adamw"
config.lr = 0.001
config.verbose = 2000
config.dali = False
config.rec = "/train_tmp/WebFace42M"
config.num_classes = 2059906
config.num_image = 42474557
config.num_epoch = 40
config.warmup_epoch = config.num_epoch // 10
config.val_targets = []

View File

@@ -32,6 +32,7 @@ def get_dataloader(
# Synthetic
if root_dir == "synthetic":
train_set = SyntheticDataset()
dali = False
# Mxnet RecordIO
elif os.path.exists(rec) and os.path.exists(idx):

View File

@@ -0,0 +1,260 @@
import math
from typing import Callable
import torch
from torch import distributed
from torch.nn.functional import linear, normalize
class PartialFC_V2(torch.nn.Module):
"""
https://arxiv.org/abs/2203.15565
A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
When sample rate less than 1, in each iteration, positive class centers and a random subset of
negative class centers are selected to compute the margin-based softmax loss, all class
centers are still maintained throughout the whole training process, but only a subset is
selected and updated in each iteration.
.. note::
When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
Example:
--------
>>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
>>> for img, labels in data_loader:
>>> embeddings = net(img)
>>> loss = module_pfc(embeddings, labels)
>>> loss.backward()
>>> optimizer.step()
"""
_version = 2
def __init__(
self,
margin_loss: Callable,
embedding_size: int,
num_classes: int,
sample_rate: float = 1.0,
fp16: bool = False,
):
"""
Paramenters:
-----------
embedding_size: int
The dimension of embedding, required
num_classes: int
Total number of classes, required
sample_rate: float
The rate of negative centers participating in the calculation, default is 1.0.
"""
super(PartialFC_V2, self).__init__()
assert (
distributed.is_initialized()
), "must initialize distributed before create this"
self.rank = distributed.get_rank()
self.world_size = distributed.get_world_size()
self.dist_cross_entropy = DistCrossEntropy()
self.embedding_size = embedding_size
self.sample_rate: float = sample_rate
self.fp16 = fp16
self.num_local: int = num_classes // self.world_size + int(
self.rank < num_classes % self.world_size
)
self.class_start: int = num_classes // self.world_size * self.rank + min(
self.rank, num_classes % self.world_size
)
self.num_sample: int = int(self.sample_rate * self.num_local)
self.last_batch_size: int = 0
self.is_updated: bool = True
self.init_weight_update: bool = True
self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
# margin_loss
if isinstance(margin_loss, Callable):
self.margin_softmax = margin_loss
else:
raise
def sample(self, labels, index_positive):
"""
This functions will change the value of labels
Parameters:
-----------
labels: torch.Tensor
pass
index_positive: torch.Tensor
pass
optimizer: torch.optim.Optimizer
pass
"""
with torch.no_grad():
positive = torch.unique(labels[index_positive], sorted=True).cuda()
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local]).cuda()
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1].cuda()
index = index.sort()[0].cuda()
else:
index = positive
self.weight_index = index
labels[index_positive] = torch.searchsorted(index, labels[index_positive])
return self.weight[self.weight_index]
def forward(
self,
local_embeddings: torch.Tensor,
local_labels: torch.Tensor,
):
"""
Parameters:
----------
local_embeddings: torch.Tensor
feature embeddings on each GPU(Rank).
local_labels: torch.Tensor
labels on each GPU(Rank).
Returns:
-------
loss: torch.Tensor
pass
"""
local_labels.squeeze_()
local_labels = local_labels.long()
batch_size = local_embeddings.size(0)
if self.last_batch_size == 0:
self.last_batch_size = batch_size
assert self.last_batch_size == batch_size, (
f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")
_gather_embeddings = [
torch.zeros((batch_size, self.embedding_size)).cuda()
for _ in range(self.world_size)
]
_gather_labels = [
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
]
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
distributed.all_gather(_gather_labels, local_labels)
embeddings = torch.cat(_list_embeddings)
labels = torch.cat(_gather_labels)
labels = labels.view(-1, 1)
index_positive = (self.class_start <= labels) & (
labels < self.class_start + self.num_local
)
labels[~index_positive] = -1
labels[index_positive] -= self.class_start
if self.sample_rate < 1:
weight = self.sample(labels, index_positive)
else:
weight = self.weight
with torch.cuda.amp.autocast(self.fp16):
norm_embeddings = normalize(embeddings)
norm_weight_activated = normalize(weight)
logits = linear(norm_embeddings, norm_weight_activated)
if self.fp16:
logits = logits.float()
logits = logits.clamp(-1, 1)
logits = self.margin_softmax(logits, labels)
loss = self.dist_cross_entropy(logits, labels)
return loss
class DistCrossEntropyFunc(torch.autograd.Function):
"""
CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
"""
@staticmethod
def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
""" """
batch_size = logits.size(0)
# for numerical stability
max_logits, _ = torch.max(logits, dim=1, keepdim=True)
# local to global
distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
logits.sub_(max_logits)
logits.exp_()
sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
# local to global
distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
logits.div_(sum_logits_exp)
index = torch.where(label != -1)[0]
# loss
loss = torch.zeros(batch_size, 1, device=logits.device)
loss[index] = logits[index].gather(1, label[index])
distributed.all_reduce(loss, distributed.ReduceOp.SUM)
ctx.save_for_backward(index, logits, label)
return loss.clamp_min_(1e-30).log_().mean() * (-1)
@staticmethod
def backward(ctx, loss_gradient):
"""
Args:
loss_grad (torch.Tensor): gradient backward by last layer
Returns:
gradients for each input in forward function
`None` gradients for one-hot label
"""
(
index,
logits,
label,
) = ctx.saved_tensors
batch_size = logits.size(0)
one_hot = torch.zeros(
size=[index.size(0), logits.size(1)], device=logits.device
)
one_hot.scatter_(1, label[index], 1)
logits[index] -= one_hot
logits.div_(batch_size)
return logits * loss_gradient.item(), None
class DistCrossEntropy(torch.nn.Module):
def __init__(self):
super(DistCrossEntropy, self).__init__()
def forward(self, logit_part, label_part):
return DistCrossEntropyFunc.apply(logit_part, label_part)
class AllGatherFunc(torch.autograd.Function):
"""AllGather op with gradient backward"""
@staticmethod
def forward(ctx, tensor, *gather_list):
gather_list = list(gather_list)
distributed.all_gather(gather_list, tensor)
return tuple(gather_list)
@staticmethod
def backward(ctx, *grads):
grad_list = list(grads)
rank = distributed.get_rank()
grad_out = grad_list[rank]
dist_ops = [
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
if i == rank
else distributed.reduce(
grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
)
for i in range(distributed.get_world_size())
]
for _op in dist_ops:
_op.wait()
grad_out *= len(grad_list) # cooperate with distributed loss function
return (grad_out, *[None for _ in range(len(grad_list))])
AllGather = AllGatherFunc.apply

View File

@@ -0,0 +1,209 @@
import argparse
import logging
import os
import numpy as np
import torch
from torch import distributed
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from backbones import get_model
from dataset import get_dataloader
from losses import CombinedMarginLoss
from lr_scheduler import PolyScheduler
from partial_fc_v2 import PartialFC_V2
from utils.utils_callbacks import CallBackLogging, CallBackVerification
from utils.utils_config import get_config
from utils.utils_logging import AverageMeter, init_logging
from utils.utils_distributed_sampler import setup_seed
assert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \
we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future."
try:
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
distributed.init_process_group("nccl")
except KeyError:
world_size = 1
rank = 0
distributed.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:12584",
rank=rank,
world_size=world_size,
)
def main(args):
# get config
cfg = get_config(args.config)
# global control random seed
setup_seed(seed=cfg.seed, cuda_deterministic=False)
torch.cuda.set_device(args.local_rank)
os.makedirs(cfg.output, exist_ok=True)
init_logging(rank, cfg.output)
summary_writer = (
SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
if rank == 0
else None
)
train_loader = get_dataloader(
cfg.rec,
args.local_rank,
cfg.batch_size,
cfg.dali,
cfg.seed,
cfg.num_workers
)
backbone = get_model(
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16,
find_unused_parameters=True)
backbone.train()
# FIXME using gradient checkpoint if there are some unused parameters will cause error
backbone._set_static_graph()
margin_loss = CombinedMarginLoss(
64,
cfg.margin_list[0],
cfg.margin_list[1],
cfg.margin_list[2],
cfg.interclass_filtering_threshold
)
if cfg.optimizer == "sgd":
module_partial_fc = PartialFC_V2(
margin_loss, cfg.embedding_size, cfg.num_classes,
cfg.sample_rate, cfg.fp16)
module_partial_fc.train().cuda()
# TODO the params of partial fc must be last in the params list
opt = torch.optim.SGD(
params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)
elif cfg.optimizer == "adamw":
module_partial_fc = PartialFC_V2(
margin_loss, cfg.embedding_size, cfg.num_classes,
cfg.sample_rate, cfg.fp16)
module_partial_fc.train().cuda()
opt = torch.optim.AdamW(
params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
lr=cfg.lr, weight_decay=cfg.weight_decay)
else:
raise
cfg.total_batch_size = cfg.batch_size * world_size
cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
lr_scheduler = PolyScheduler(
optimizer=opt,
base_lr=cfg.lr,
max_steps=cfg.total_step,
warmup_steps=cfg.warmup_step,
last_epoch=-1
)
start_epoch = 0
global_step = 0
if cfg.resume:
dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
start_epoch = dict_checkpoint["epoch"]
global_step = dict_checkpoint["global_step"]
backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
opt.load_state_dict(dict_checkpoint["state_optimizer"])
lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
del dict_checkpoint
for key, value in cfg.items():
num_space = 25 - len(key)
logging.info(": " + key + " " * num_space + str(value))
callback_verification = CallBackVerification(
val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer
)
callback_logging = CallBackLogging(
frequent=cfg.frequent,
total_step=cfg.total_step,
batch_size=cfg.batch_size,
start_step = global_step,
writer=summary_writer
)
loss_am = AverageMeter()
amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
for epoch in range(start_epoch, cfg.num_epoch):
if isinstance(train_loader, DataLoader):
train_loader.sampler.set_epoch(epoch)
for _, (img, local_labels) in enumerate(train_loader):
global_step += 1
local_embeddings = backbone(img)
loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels)
if cfg.fp16:
amp.scale(loss).backward()
if global_step % cfg.gradient_acc == 0:
amp.unscale_(opt)
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
amp.step(opt)
amp.update()
opt.zero_grad()
else:
loss.backward()
if global_step % cfg.gradient_acc == 0:
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
opt.step()
opt.zero_grad()
lr_scheduler.step()
with torch.no_grad():
loss_am.update(loss.item(), 1)
callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
if global_step % cfg.verbose == 0 and global_step > 0:
callback_verification(global_step, backbone)
if cfg.save_all_states:
checkpoint = {
"epoch": epoch + 1,
"global_step": global_step,
"state_dict_backbone": backbone.module.state_dict(),
"state_dict_softmax_fc": module_partial_fc.state_dict(),
"state_optimizer": opt.state_dict(),
"state_lr_scheduler": lr_scheduler.state_dict()
}
torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
if rank == 0:
path_module = os.path.join(cfg.output, "model.pt")
torch.save(backbone.module.state_dict(), path_module)
if cfg.dali:
train_loader.reset()
if rank == 0:
path_module = os.path.join(cfg.output, "model.pt")
torch.save(backbone.module.state_dict(), path_module)
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(
description="Distributed Arcface Training in Pytorch")
parser.add_argument("config", type=str, help="py config file")
parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
main(parser.parse_args())