mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
Fixed torch1.12 api changes
remove old train.py and partial_fc.py
This commit is contained in:
@@ -17,7 +17,7 @@ To avail the latest features of PyTorch, we have upgraded to version 1.12.0.
|
||||
|
||||
## How to Training
|
||||
|
||||
To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
|
||||
To train a model, execute the `train_v2.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
|
||||
|
||||
### 1. To run on one GPU:
|
||||
|
||||
@@ -32,7 +32,7 @@ It is not recommended to use a single GPU for training, as this may result in lo
|
||||
### 2. To run on a machine with 8 GPUs:
|
||||
|
||||
```shell
|
||||
torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
|
||||
torchrun --nproc_per_node=8 train_v2.py configs/ms1mv3_r50
|
||||
```
|
||||
|
||||
### 3. To run on 2 machines with 8 GPUs each:
|
||||
@@ -40,13 +40,13 @@ torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
|
||||
Node 0:
|
||||
|
||||
```shell
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train_v2.py configs/wf42m_pfc02_16gpus_r100
|
||||
```
|
||||
|
||||
Node 1:
|
||||
|
||||
```shell
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
|
||||
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train_v2.py configs/wf42m_pfc02_16gpus_r100
|
||||
```
|
||||
|
||||
### 4. Run ViT-B on a machine with 24k batchsize:
|
||||
|
||||
@@ -80,6 +80,13 @@ def get_model(name, **kwargs):
|
||||
return VisionTransformer(
|
||||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
|
||||
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
|
||||
|
||||
elif name == "vit_h": # For WebFace42M
|
||||
num_features = kwargs.get("num_features", 512)
|
||||
from .vit import VisionTransformer
|
||||
return VisionTransformer(
|
||||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=1024, depth=48,
|
||||
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0, using_checkpoint=True)
|
||||
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@@ -38,6 +38,7 @@ config.frequent = 10
|
||||
|
||||
# For Large Sacle Dataset, such as WebFace42M
|
||||
config.dali = False
|
||||
config.dali_aug = False
|
||||
|
||||
# Gradient ACC
|
||||
config.gradient_acc = 1
|
||||
|
||||
28
recognition/arcface_torch/configs/wf42m_pfc02_vit_h.py
Normal file
28
recognition/arcface_torch/configs/wf42m_pfc02_vit_h.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# make training faster
|
||||
# our RAM is 256G
|
||||
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
||||
|
||||
config = edict()
|
||||
config.margin_list = (1.0, 0.0, 0.4)
|
||||
config.network = "vit_h"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 0.5
|
||||
config.fp16 = True
|
||||
config.weight_decay = 0.1
|
||||
config.batch_size = 768
|
||||
config.optimizer = "adamw"
|
||||
config.lr = 0.001
|
||||
config.verbose = 2000
|
||||
config.dali = True
|
||||
|
||||
config.rec = "/train_tmp/WebFace42M"
|
||||
config.num_classes = 2059906
|
||||
config.num_image = 42474557
|
||||
config.num_epoch = 16
|
||||
config.warmup_epoch = config.num_epoch // 8
|
||||
config.val_targets = []
|
||||
config.dali_aug = True
|
||||
@@ -21,6 +21,7 @@ def get_dataloader(
|
||||
local_rank,
|
||||
batch_size,
|
||||
dali = False,
|
||||
dali_aug = False,
|
||||
seed = 2048,
|
||||
num_workers = 2,
|
||||
) -> Iterable:
|
||||
@@ -51,7 +52,7 @@ def get_dataloader(
|
||||
if dali:
|
||||
return dali_data_iter(
|
||||
batch_size=batch_size, rec_file=rec, idx_file=idx,
|
||||
num_threads=2, local_rank=local_rank)
|
||||
num_threads=2, local_rank=local_rank, dali_aug=dali_aug)
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
train_sampler = DistributedSampler(
|
||||
@@ -194,7 +195,9 @@ def dali_data_iter(
|
||||
initial_fill=32768, random_shuffle=True,
|
||||
prefetch_queue_depth=1, local_rank=0, name="reader",
|
||||
mean=(127.5, 127.5, 127.5),
|
||||
std=(127.5, 127.5, 127.5)):
|
||||
std=(127.5, 127.5, 127.5),
|
||||
dali_aug=False
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
----------
|
||||
@@ -209,6 +212,34 @@ def dali_data_iter(
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
|
||||
|
||||
def dali_random_resize(img, resize_size, image_size=112):
|
||||
img = fn.resize(img, resize_x=resize_size, resize_y=resize_size)
|
||||
img = fn.resize(img, size=(image_size, image_size))
|
||||
return img
|
||||
def dali_random_gaussian_blur(img, window_size):
|
||||
img = fn.gaussian_blur(img, window_size=window_size * 2 + 1)
|
||||
return img
|
||||
def dali_random_gray(img, prob_gray):
|
||||
saturate = fn.random.coin_flip(probability=1 - prob_gray)
|
||||
saturate = fn.cast(saturate, dtype=types.FLOAT)
|
||||
img = fn.hsv(img, saturation=saturate)
|
||||
return img
|
||||
def dali_random_hsv(img, hue, saturation):
|
||||
img = fn.hsv(img, hue=hue, saturation=saturation)
|
||||
return img
|
||||
def multiplexing(condition, true_case, false_case):
|
||||
neg_condition = condition ^ True
|
||||
return condition * true_case + neg_condition * false_case
|
||||
|
||||
condition_resize = fn.random.coin_flip(probability=0.1)
|
||||
size_resize = fn.random.uniform(range=(int(112 * 0.5), int(112 * 0.8)), dtype=types.FLOAT)
|
||||
condition_blur = fn.random.coin_flip(probability=0.2)
|
||||
window_size_blur = fn.random.uniform(range=(1, 2), dtype=types.INT32)
|
||||
condition_flip = fn.random.coin_flip(probability=0.5)
|
||||
condition_hsv = fn.random.coin_flip(probability=0.2)
|
||||
hsv_hue = fn.random.uniform(range=(0., 20.), dtype=types.FLOAT)
|
||||
hsv_saturation = fn.random.uniform(range=(1., 1.2), dtype=types.FLOAT)
|
||||
|
||||
pipe = Pipeline(
|
||||
batch_size=batch_size, num_threads=num_threads,
|
||||
device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, )
|
||||
@@ -219,6 +250,13 @@ def dali_data_iter(
|
||||
num_shards=world_size, shard_id=rank,
|
||||
random_shuffle=random_shuffle, pad_last_batch=False, name=name)
|
||||
images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
|
||||
if dali_aug:
|
||||
images = fn.cast(images, dtype=types.UINT8)
|
||||
images = multiplexing(condition_resize, dali_random_resize(images, size_resize, image_size=112), images)
|
||||
images = multiplexing(condition_blur, dali_random_gaussian_blur(images, window_size_blur), images)
|
||||
images = multiplexing(condition_hsv, dali_random_hsv(images, hsv_hue, hsv_saturation), images)
|
||||
images = dali_random_gray(images, 0.1)
|
||||
|
||||
images = fn.crop_mirror_normalize(
|
||||
images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip)
|
||||
pipe.set_outputs(images, labels)
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
ip_list=("ip1" "ip2" "ip3" "ip4")
|
||||
|
||||
config=wf42m_pfc03_32gpu_r100
|
||||
ip_list=("ip1" "ip2" "ip3" "ip4" "ip5")
|
||||
config=wf42m_pfc02_vit_h.py
|
||||
|
||||
for((node_rank=0;node_rank<${#ip_list[*]};node_rank++));
|
||||
do
|
||||
ssh ubuntu@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \
|
||||
ssh root@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
torchrun \
|
||||
--nproc_per_node=8 \
|
||||
--nnodes=${#ip_list[*]} \
|
||||
--node_rank=$node_rank \
|
||||
--master_addr=${ip_list[0]} \
|
||||
--master_port=22345 train.py configs/$config" &
|
||||
--master_port=22345 train_v2.py configs/$config" &
|
||||
done
|
||||
|
||||
@@ -1,30 +1,86 @@
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim import SGD
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
class PolynomialLRWarmup(_LRScheduler):
|
||||
def __init__(self, optimizer, warmup_iters, total_iters=5, power=1.0, last_epoch=-1, verbose=False):
|
||||
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
|
||||
self.total_iters = total_iters
|
||||
self.power = power
|
||||
self.warmup_iters = warmup_iters
|
||||
|
||||
class PolyScheduler(_LRScheduler):
|
||||
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
|
||||
self.base_lr = base_lr
|
||||
self.warmup_lr_init = 0.0001
|
||||
self.max_steps: int = max_steps
|
||||
self.warmup_steps: int = warmup_steps
|
||||
self.power = 2
|
||||
super(PolyScheduler, self).__init__(optimizer, -1, False)
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
def get_warmup_lr(self):
|
||||
alpha = float(self.last_epoch) / float(self.warmup_steps)
|
||||
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == -1:
|
||||
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
return self.get_warmup_lr()
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn("To get the last learning rate computed by the scheduler, "
|
||||
"please use `get_last_lr()`.", UserWarning)
|
||||
|
||||
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
if self.last_epoch <= self.warmup_iters:
|
||||
return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
|
||||
else:
|
||||
l = self.last_epoch
|
||||
w = self.warmup_iters
|
||||
t = self.total_iters
|
||||
decay_factor = ((1.0 - (l - w) / (t - w)) / (1.0 - (l - 1 - w) / (t - w))) ** self.power
|
||||
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
|
||||
|
||||
def _get_closed_form_lr(self):
|
||||
|
||||
if self.last_epoch <= self.warmup_iters:
|
||||
return [
|
||||
base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
|
||||
else:
|
||||
alpha = pow(
|
||||
1
|
||||
- float(self.last_epoch - self.warmup_steps)
|
||||
/ float(self.max_steps - self.warmup_steps),
|
||||
self.power,
|
||||
)
|
||||
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
|
||||
return [
|
||||
(
|
||||
base_lr * (1.0 - (min(self.total_iters, self.last_epoch) - self.warmup_iters) / (self.total_iters - self.warmup_iters)) ** self.power
|
||||
)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(32, 32)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
test_module = TestModule()
|
||||
test_module_pfc = TestModule()
|
||||
lr_pfc_weight = 1 / 3
|
||||
base_lr = 10
|
||||
total_steps = 1000
|
||||
|
||||
sgd = SGD([
|
||||
{"params": test_module.parameters(), "lr": base_lr},
|
||||
{"params": test_module_pfc.parameters(), "lr": base_lr * lr_pfc_weight}
|
||||
], base_lr)
|
||||
|
||||
scheduler = PolynomialLRWarmup(sgd, total_steps//10, total_steps, power=2)
|
||||
|
||||
x = []
|
||||
y = []
|
||||
y_pfc = []
|
||||
for i in range(total_steps):
|
||||
scheduler.step()
|
||||
lr = scheduler.get_last_lr()[0]
|
||||
lr_pfc = scheduler.get_last_lr()[1]
|
||||
x.append(i)
|
||||
y.append(lr)
|
||||
y_pfc.append(lr_pfc)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
fontsize=15
|
||||
plt.figure(figsize=(6, 6))
|
||||
plt.plot(x, y, linestyle='-', linewidth=2, )
|
||||
plt.plot(x, y_pfc, linestyle='-', linewidth=2, )
|
||||
plt.xlabel('Iterations') # x_label
|
||||
plt.ylabel("Lr") # y_label
|
||||
plt.savefig("tmp.png", dpi=600, bbox_inches='tight')
|
||||
|
||||
@@ -1,531 +0,0 @@
|
||||
import collections
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import distributed
|
||||
from torch.nn.functional import linear, normalize
|
||||
|
||||
|
||||
class PartialFC(torch.nn.Module):
|
||||
"""
|
||||
https://arxiv.org/abs/2203.15565
|
||||
A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
|
||||
|
||||
When sample rate less than 1, in each iteration, positive class centers and a random subset of
|
||||
negative class centers are selected to compute the margin-based softmax loss, all class
|
||||
centers are still maintained throughout the whole training process, but only a subset is
|
||||
selected and updated in each iteration.
|
||||
|
||||
.. note::
|
||||
When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
|
||||
|
||||
Example:
|
||||
--------
|
||||
>>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
|
||||
>>> for img, labels in data_loader:
|
||||
>>> embeddings = net(img)
|
||||
>>> loss = module_pfc(embeddings, labels, optimizer)
|
||||
>>> loss.backward()
|
||||
>>> optimizer.step()
|
||||
"""
|
||||
_version = 1
|
||||
def __init__(
|
||||
self,
|
||||
margin_loss: Callable,
|
||||
embedding_size: int,
|
||||
num_classes: int,
|
||||
sample_rate: float = 1.0,
|
||||
fp16: bool = False,
|
||||
):
|
||||
"""
|
||||
Paramenters:
|
||||
-----------
|
||||
embedding_size: int
|
||||
The dimension of embedding, required
|
||||
num_classes: int
|
||||
Total number of classes, required
|
||||
sample_rate: float
|
||||
The rate of negative centers participating in the calculation, default is 1.0.
|
||||
"""
|
||||
super(PartialFC, self).__init__()
|
||||
assert (
|
||||
distributed.is_initialized()
|
||||
), "must initialize distributed before create this"
|
||||
self.rank = distributed.get_rank()
|
||||
self.world_size = distributed.get_world_size()
|
||||
|
||||
self.dist_cross_entropy = DistCrossEntropy()
|
||||
self.embedding_size = embedding_size
|
||||
self.sample_rate: float = sample_rate
|
||||
self.fp16 = fp16
|
||||
self.num_local: int = num_classes // self.world_size + int(
|
||||
self.rank < num_classes % self.world_size
|
||||
)
|
||||
self.class_start: int = num_classes // self.world_size * self.rank + min(
|
||||
self.rank, num_classes % self.world_size
|
||||
)
|
||||
self.num_sample: int = int(self.sample_rate * self.num_local)
|
||||
self.last_batch_size: int = 0
|
||||
self.weight: torch.Tensor
|
||||
self.weight_mom: torch.Tensor
|
||||
self.weight_activated: torch.nn.Parameter
|
||||
self.weight_activated_mom: torch.Tensor
|
||||
self.is_updated: bool = True
|
||||
self.init_weight_update: bool = True
|
||||
|
||||
if self.sample_rate < 1:
|
||||
self.register_buffer("weight",
|
||||
tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
|
||||
self.register_buffer("weight_mom",
|
||||
tensor=torch.zeros_like(self.weight))
|
||||
self.register_parameter("weight_activated",
|
||||
param=torch.nn.Parameter(torch.empty(0, 0)))
|
||||
self.register_buffer("weight_activated_mom",
|
||||
tensor=torch.empty(0, 0))
|
||||
self.register_buffer("weight_index",
|
||||
tensor=torch.empty(0, 0))
|
||||
else:
|
||||
self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
|
||||
|
||||
# margin_loss
|
||||
if isinstance(margin_loss, Callable):
|
||||
self.margin_softmax = margin_loss
|
||||
else:
|
||||
raise
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
labels: torch.Tensor,
|
||||
index_positive: torch.Tensor,
|
||||
optimizer: torch.optim.Optimizer):
|
||||
"""
|
||||
This functions will change the value of labels
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
labels: torch.Tensor
|
||||
pass
|
||||
index_positive: torch.Tensor
|
||||
pass
|
||||
optimizer: torch.optim.Optimizer
|
||||
pass
|
||||
"""
|
||||
positive = torch.unique(labels[index_positive], sorted=True).cuda()
|
||||
if self.num_sample - positive.size(0) >= 0:
|
||||
perm = torch.rand(size=[self.num_local]).cuda()
|
||||
perm[positive] = 2.0
|
||||
index = torch.topk(perm, k=self.num_sample)[1].cuda()
|
||||
index = index.sort()[0].cuda()
|
||||
else:
|
||||
index = positive
|
||||
self.weight_index = index
|
||||
|
||||
labels[index_positive] = torch.searchsorted(index, labels[index_positive])
|
||||
|
||||
self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
|
||||
self.weight_activated_mom = self.weight_mom[self.weight_index]
|
||||
|
||||
if isinstance(optimizer, torch.optim.SGD):
|
||||
# TODO the params of partial fc must be last in the params list
|
||||
optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
|
||||
optimizer.param_groups[-1]["params"][0] = self.weight_activated
|
||||
optimizer.state[self.weight_activated][
|
||||
"momentum_buffer"
|
||||
] = self.weight_activated_mom
|
||||
else:
|
||||
raise
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
""" partial weight to global
|
||||
"""
|
||||
if self.init_weight_update:
|
||||
self.init_weight_update = False
|
||||
return
|
||||
|
||||
if self.sample_rate < 1:
|
||||
self.weight[self.weight_index] = self.weight_activated
|
||||
self.weight_mom[self.weight_index] = self.weight_activated_mom
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
local_embeddings: torch.Tensor,
|
||||
local_labels: torch.Tensor,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
----------
|
||||
local_embeddings: torch.Tensor
|
||||
feature embeddings on each GPU(Rank).
|
||||
local_labels: torch.Tensor
|
||||
labels on each GPU(Rank).
|
||||
|
||||
Returns:
|
||||
-------
|
||||
loss: torch.Tensor
|
||||
pass
|
||||
"""
|
||||
local_labels.squeeze_()
|
||||
local_labels = local_labels.long()
|
||||
self.update()
|
||||
|
||||
batch_size = local_embeddings.size(0)
|
||||
if self.last_batch_size == 0:
|
||||
self.last_batch_size = batch_size
|
||||
assert self.last_batch_size == batch_size, (
|
||||
"last batch size do not equal current batch size: {} vs {}".format(
|
||||
self.last_batch_size, batch_size))
|
||||
|
||||
_gather_embeddings = [
|
||||
torch.zeros((batch_size, self.embedding_size)).cuda()
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
_gather_labels = [
|
||||
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
|
||||
]
|
||||
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
|
||||
distributed.all_gather(_gather_labels, local_labels)
|
||||
|
||||
embeddings = torch.cat(_list_embeddings)
|
||||
labels = torch.cat(_gather_labels)
|
||||
|
||||
labels = labels.view(-1, 1)
|
||||
index_positive = (self.class_start <= labels) & (
|
||||
labels < self.class_start + self.num_local
|
||||
)
|
||||
labels[~index_positive] = -1
|
||||
labels[index_positive] -= self.class_start
|
||||
|
||||
if self.sample_rate < 1:
|
||||
self.sample(labels, index_positive, optimizer)
|
||||
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
norm_embeddings = normalize(embeddings)
|
||||
norm_weight_activated = normalize(self.weight_activated)
|
||||
logits = linear(norm_embeddings, norm_weight_activated)
|
||||
if self.fp16:
|
||||
logits = logits.float()
|
||||
logits = logits.clamp(-1, 1)
|
||||
|
||||
logits = self.margin_softmax(logits, labels)
|
||||
loss = self.dist_cross_entropy(logits, labels)
|
||||
return loss
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
if destination is None:
|
||||
destination = collections.OrderedDict()
|
||||
destination._metadata = collections.OrderedDict()
|
||||
|
||||
for name, module in self._modules.items():
|
||||
if module is not None:
|
||||
module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
|
||||
if self.sample_rate < 1:
|
||||
destination["weight"] = self.weight.detach()
|
||||
else:
|
||||
destination["weight"] = self.weight_activated.data.detach()
|
||||
return destination
|
||||
|
||||
def load_state_dict(self, state_dict, strict: bool = True):
|
||||
if self.sample_rate < 1:
|
||||
self.weight = state_dict["weight"].to(self.weight.device)
|
||||
self.weight_mom.zero_()
|
||||
self.weight_activated.data.zero_()
|
||||
self.weight_activated_mom.zero_()
|
||||
self.weight_index.zero_()
|
||||
else:
|
||||
self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
|
||||
|
||||
|
||||
class PartialFCAdamW(torch.nn.Module):
|
||||
def __init__(self,
|
||||
margin_loss: Callable,
|
||||
embedding_size: int,
|
||||
num_classes: int,
|
||||
sample_rate: float = 1.0,
|
||||
fp16: bool = False,):
|
||||
"""
|
||||
Paramenters:
|
||||
-----------
|
||||
embedding_size: int
|
||||
The dimension of embedding, required
|
||||
num_classes: int
|
||||
Total number of classes, required
|
||||
sample_rate: float
|
||||
The rate of negative centers participating in the calculation, default is 1.0.
|
||||
"""
|
||||
super(PartialFCAdamW, self).__init__()
|
||||
assert (
|
||||
distributed.is_initialized()
|
||||
), "must initialize distributed before create this"
|
||||
self.rank = distributed.get_rank()
|
||||
self.world_size = distributed.get_world_size()
|
||||
|
||||
self.dist_cross_entropy = DistCrossEntropy()
|
||||
self.embedding_size = embedding_size
|
||||
self.sample_rate: float = sample_rate
|
||||
self.fp16 = fp16
|
||||
self.num_local: int = num_classes // self.world_size + int(
|
||||
self.rank < num_classes % self.world_size
|
||||
)
|
||||
self.class_start: int = num_classes // self.world_size * self.rank + min(
|
||||
self.rank, num_classes % self.world_size
|
||||
)
|
||||
self.num_sample: int = int(self.sample_rate * self.num_local)
|
||||
self.last_batch_size: int = 0
|
||||
self.weight: torch.Tensor
|
||||
self.weight_exp_avg: torch.Tensor
|
||||
self.weight_exp_avg_sq: torch.Tensor
|
||||
self.weight_activated: torch.nn.Parameter
|
||||
self.weight_activated_exp_avg: torch.Tensor
|
||||
self.weight_activated_exp_avg_sq: torch.Tensor
|
||||
|
||||
self.is_updated: bool = True
|
||||
self.init_weight_update: bool = True
|
||||
|
||||
if self.sample_rate < 1:
|
||||
self.register_buffer("weight",
|
||||
tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
|
||||
self.register_buffer("weight_exp_avg",
|
||||
tensor=torch.zeros_like(self.weight))
|
||||
self.register_buffer("weight_exp_avg_sq",
|
||||
tensor=torch.zeros_like(self.weight))
|
||||
self.register_parameter("weight_activated",
|
||||
param=torch.nn.Parameter(torch.empty(0, 0)))
|
||||
self.register_buffer("weight_activated_exp_avg",
|
||||
tensor=torch.empty(0, 0))
|
||||
self.register_buffer("weight_activated_exp_avg_sq",
|
||||
tensor=torch.empty(0, 0))
|
||||
else:
|
||||
self.weight_activated = torch.nn.Parameter(
|
||||
torch.normal(0, 0.01, (self.num_local, embedding_size))
|
||||
)
|
||||
self.step = 0
|
||||
|
||||
if isinstance(margin_loss, Callable):
|
||||
self.margin_softmax = margin_loss
|
||||
else:
|
||||
raise
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, labels, index_positive, optimizer):
|
||||
self.step += 1
|
||||
positive = torch.unique(labels[index_positive], sorted=True).cuda()
|
||||
if self.num_sample - positive.size(0) >= 0:
|
||||
perm = torch.rand(size=[self.num_local]).cuda()
|
||||
perm[positive] = 2.0
|
||||
index = torch.topk(perm, k=self.num_sample)[1].cuda()
|
||||
index = index.sort()[0].cuda()
|
||||
else:
|
||||
index = positive
|
||||
self.weight_index = index
|
||||
labels[index_positive] = torch.searchsorted(index, labels[index_positive])
|
||||
self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
|
||||
self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index]
|
||||
self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[self.weight_index]
|
||||
|
||||
if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
|
||||
# TODO the params of partial fc must be last in the params list
|
||||
optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
|
||||
optimizer.param_groups[-1]["params"][0] = self.weight_activated
|
||||
optimizer.state[self.weight_activated]["exp_avg"] = self.weight_activated_exp_avg
|
||||
optimizer.state[self.weight_activated]["exp_avg_sq"] = self.weight_activated_exp_avg_sq
|
||||
optimizer.state[self.weight_activated]["step"] = self.step
|
||||
else:
|
||||
raise
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self):
|
||||
""" partial weight to global
|
||||
"""
|
||||
if self.init_weight_update:
|
||||
self.init_weight_update = False
|
||||
return
|
||||
|
||||
if self.sample_rate < 1:
|
||||
self.weight[self.weight_index] = self.weight_activated
|
||||
self.weight_exp_avg[self.weight_index] = self.weight_activated_exp_avg
|
||||
self.weight_exp_avg_sq[self.weight_index] = self.weight_activated_exp_avg_sq
|
||||
|
||||
def forward(
|
||||
self,
|
||||
local_embeddings: torch.Tensor,
|
||||
local_labels: torch.Tensor,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
----------
|
||||
local_embeddings: torch.Tensor
|
||||
feature embeddings on each GPU(Rank).
|
||||
local_labels: torch.Tensor
|
||||
labels on each GPU(Rank).
|
||||
|
||||
Returns:
|
||||
-------
|
||||
loss: torch.Tensor
|
||||
pass
|
||||
"""
|
||||
local_labels.squeeze_()
|
||||
local_labels = local_labels.long()
|
||||
self.update()
|
||||
|
||||
batch_size = local_embeddings.size(0)
|
||||
if self.last_batch_size == 0:
|
||||
self.last_batch_size = batch_size
|
||||
assert self.last_batch_size == batch_size, (
|
||||
"last batch size do not equal current batch size: {} vs {}".format(
|
||||
self.last_batch_size, batch_size))
|
||||
|
||||
_gather_embeddings = [
|
||||
torch.zeros((batch_size, self.embedding_size)).cuda()
|
||||
for _ in range(self.world_size)
|
||||
]
|
||||
_gather_labels = [
|
||||
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
|
||||
]
|
||||
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
|
||||
distributed.all_gather(_gather_labels, local_labels)
|
||||
|
||||
embeddings = torch.cat(_list_embeddings)
|
||||
labels = torch.cat(_gather_labels)
|
||||
|
||||
labels = labels.view(-1, 1)
|
||||
index_positive = (self.class_start <= labels) & (
|
||||
labels < self.class_start + self.num_local
|
||||
)
|
||||
labels[~index_positive] = -1
|
||||
labels[index_positive] -= self.class_start
|
||||
|
||||
if self.sample_rate < 1:
|
||||
self.sample(labels, index_positive, optimizer)
|
||||
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
norm_embeddings = normalize(embeddings)
|
||||
norm_weight_activated = normalize(self.weight_activated)
|
||||
logits = linear(norm_embeddings, norm_weight_activated)
|
||||
if self.fp16:
|
||||
logits = logits.float()
|
||||
logits = logits.clamp(-1, 1)
|
||||
|
||||
logits = self.margin_softmax(logits, labels)
|
||||
loss = self.dist_cross_entropy(logits, labels)
|
||||
return loss
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
if destination is None:
|
||||
destination = collections.OrderedDict()
|
||||
destination._metadata = collections.OrderedDict()
|
||||
|
||||
for name, module in self._modules.items():
|
||||
if module is not None:
|
||||
module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
|
||||
if self.sample_rate < 1:
|
||||
destination["weight"] = self.weight.detach()
|
||||
else:
|
||||
destination["weight"] = self.weight_activated.data.detach()
|
||||
return destination
|
||||
|
||||
def load_state_dict(self, state_dict, strict: bool = True):
|
||||
if self.sample_rate < 1:
|
||||
self.weight = state_dict["weight"].to(self.weight.device)
|
||||
self.weight_exp_avg.zero_()
|
||||
self.weight_exp_avg_sq.zero_()
|
||||
self.weight_activated.data.zero_()
|
||||
self.weight_activated_exp_avg.zero_()
|
||||
self.weight_activated_exp_avg_sq.zero_()
|
||||
else:
|
||||
self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
|
||||
|
||||
|
||||
class DistCrossEntropyFunc(torch.autograd.Function):
|
||||
"""
|
||||
CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
|
||||
Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
|
||||
""" """
|
||||
batch_size = logits.size(0)
|
||||
# for numerical stability
|
||||
max_logits, _ = torch.max(logits, dim=1, keepdim=True)
|
||||
# local to global
|
||||
distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
|
||||
logits.sub_(max_logits)
|
||||
logits.exp_()
|
||||
sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
|
||||
# local to global
|
||||
distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
|
||||
logits.div_(sum_logits_exp)
|
||||
index = torch.where(label != -1)[0]
|
||||
# loss
|
||||
loss = torch.zeros(batch_size, 1, device=logits.device)
|
||||
loss[index] = logits[index].gather(1, label[index])
|
||||
distributed.all_reduce(loss, distributed.ReduceOp.SUM)
|
||||
ctx.save_for_backward(index, logits, label)
|
||||
return loss.clamp_min_(1e-30).log_().mean() * (-1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, loss_gradient):
|
||||
"""
|
||||
Args:
|
||||
loss_grad (torch.Tensor): gradient backward by last layer
|
||||
Returns:
|
||||
gradients for each input in forward function
|
||||
`None` gradients for one-hot label
|
||||
"""
|
||||
(
|
||||
index,
|
||||
logits,
|
||||
label,
|
||||
) = ctx.saved_tensors
|
||||
batch_size = logits.size(0)
|
||||
one_hot = torch.zeros(
|
||||
size=[index.size(0), logits.size(1)], device=logits.device
|
||||
)
|
||||
one_hot.scatter_(1, label[index], 1)
|
||||
logits[index] -= one_hot
|
||||
logits.div_(batch_size)
|
||||
return logits * loss_gradient.item(), None
|
||||
|
||||
|
||||
class DistCrossEntropy(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(DistCrossEntropy, self).__init__()
|
||||
|
||||
def forward(self, logit_part, label_part):
|
||||
return DistCrossEntropyFunc.apply(logit_part, label_part)
|
||||
|
||||
|
||||
class AllGatherFunc(torch.autograd.Function):
|
||||
"""AllGather op with gradient backward"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, *gather_list):
|
||||
gather_list = list(gather_list)
|
||||
distributed.all_gather(gather_list, tensor)
|
||||
return tuple(gather_list)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
grad_list = list(grads)
|
||||
rank = distributed.get_rank()
|
||||
grad_out = grad_list[rank]
|
||||
|
||||
dist_ops = [
|
||||
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
|
||||
if i == rank
|
||||
else distributed.reduce(
|
||||
grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
|
||||
)
|
||||
for i in range(distributed.get_world_size())
|
||||
]
|
||||
for _op in dist_ops:
|
||||
_op.wait()
|
||||
|
||||
grad_out *= len(grad_list) # cooperate with distributed loss function
|
||||
return (grad_out, *[None for _ in range(len(grad_list))])
|
||||
|
||||
|
||||
AllGather = AllGatherFunc.apply
|
||||
@@ -1,260 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from backbones import get_model
|
||||
from dataset import get_dataloader
|
||||
from losses import CombinedMarginLoss
|
||||
from lr_scheduler import PolyScheduler
|
||||
from partial_fc import PartialFC, PartialFCAdamW
|
||||
from torch import distributed
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils.utils_callbacks import CallBackLogging, CallBackVerification
|
||||
from utils.utils_config import get_config
|
||||
from utils.utils_distributed_sampler import setup_seed
|
||||
from utils.utils_logging import AverageMeter, init_logging
|
||||
|
||||
assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \
|
||||
we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
|
||||
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
distributed.init_process_group("nccl")
|
||||
except KeyError:
|
||||
rank = 0
|
||||
local_rank = 0
|
||||
world_size = 1
|
||||
distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://127.0.0.1:12584",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# get config
|
||||
cfg = get_config(args.config)
|
||||
# global control random seed
|
||||
setup_seed(seed=cfg.seed, cuda_deterministic=False)
|
||||
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
os.makedirs(cfg.output, exist_ok=True)
|
||||
init_logging(rank, cfg.output)
|
||||
|
||||
summary_writer = (
|
||||
SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
|
||||
if rank == 0
|
||||
else None
|
||||
)
|
||||
|
||||
wandb_logger = None
|
||||
if cfg.using_wandb:
|
||||
import wandb
|
||||
# Sign in to wandb
|
||||
try:
|
||||
wandb.login(key=cfg.wandb_key)
|
||||
except Exception as e:
|
||||
print("WandB Key must be provided in config file (base.py).")
|
||||
print(f"Config Error: {e}")
|
||||
# Initialize wandb
|
||||
run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
|
||||
run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
|
||||
try:
|
||||
wandb_logger = wandb.init(
|
||||
entity = cfg.wandb_entity,
|
||||
project = cfg.wandb_project,
|
||||
sync_tensorboard = True,
|
||||
resume=cfg.wandb_resume,
|
||||
name = run_name,
|
||||
notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None
|
||||
if wandb_logger:
|
||||
wandb_logger.config.update(cfg)
|
||||
except Exception as e:
|
||||
print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
|
||||
print(f"Config Error: {e}")
|
||||
|
||||
train_loader = get_dataloader(
|
||||
cfg.rec,
|
||||
local_rank,
|
||||
cfg.batch_size,
|
||||
cfg.dali,
|
||||
cfg.seed,
|
||||
cfg.num_workers
|
||||
)
|
||||
|
||||
backbone = get_model(
|
||||
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
|
||||
|
||||
backbone = torch.nn.parallel.DistributedDataParallel(
|
||||
module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
|
||||
find_unused_parameters=True)
|
||||
|
||||
backbone.train()
|
||||
# FIXME using gradient checkpoint if there are some unused parameters will cause error
|
||||
backbone._set_static_graph()
|
||||
|
||||
margin_loss = CombinedMarginLoss(
|
||||
64,
|
||||
cfg.margin_list[0],
|
||||
cfg.margin_list[1],
|
||||
cfg.margin_list[2],
|
||||
cfg.interclass_filtering_threshold
|
||||
)
|
||||
|
||||
if cfg.optimizer == "sgd":
|
||||
module_partial_fc = PartialFC(
|
||||
margin_loss, cfg.embedding_size, cfg.num_classes,
|
||||
cfg.sample_rate, cfg.fp16)
|
||||
module_partial_fc.train().cuda()
|
||||
# TODO the params of partial fc must be last in the params list
|
||||
opt = torch.optim.SGD(
|
||||
params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
|
||||
lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)
|
||||
|
||||
elif cfg.optimizer == "adamw":
|
||||
module_partial_fc = PartialFCAdamW(
|
||||
margin_loss, cfg.embedding_size, cfg.num_classes,
|
||||
cfg.sample_rate, cfg.fp16)
|
||||
module_partial_fc.train().cuda()
|
||||
opt = torch.optim.AdamW(
|
||||
params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
|
||||
lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
else:
|
||||
raise
|
||||
|
||||
cfg.total_batch_size = cfg.batch_size * world_size
|
||||
cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
|
||||
cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
|
||||
|
||||
lr_scheduler = PolyScheduler(
|
||||
optimizer=opt,
|
||||
base_lr=cfg.lr,
|
||||
max_steps=cfg.total_step,
|
||||
warmup_steps=cfg.warmup_step,
|
||||
last_epoch=-1
|
||||
)
|
||||
|
||||
start_epoch = 0
|
||||
global_step = 0
|
||||
if cfg.resume:
|
||||
dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
|
||||
start_epoch = dict_checkpoint["epoch"]
|
||||
global_step = dict_checkpoint["global_step"]
|
||||
backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
|
||||
module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
|
||||
opt.load_state_dict(dict_checkpoint["state_optimizer"])
|
||||
lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
|
||||
del dict_checkpoint
|
||||
|
||||
for key, value in cfg.items():
|
||||
num_space = 25 - len(key)
|
||||
logging.info(": " + key + " " * num_space + str(value))
|
||||
|
||||
callback_verification = CallBackVerification(
|
||||
val_targets=cfg.val_targets, rec_prefix=cfg.rec,
|
||||
summary_writer=summary_writer, wandb_logger = wandb_logger
|
||||
)
|
||||
callback_logging = CallBackLogging(
|
||||
frequent=cfg.frequent,
|
||||
total_step=cfg.total_step,
|
||||
batch_size=cfg.batch_size,
|
||||
start_step = global_step,
|
||||
writer=summary_writer
|
||||
)
|
||||
|
||||
loss_am = AverageMeter()
|
||||
amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
|
||||
|
||||
for epoch in range(start_epoch, cfg.num_epoch):
|
||||
|
||||
if isinstance(train_loader, DataLoader):
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
for _, (img, local_labels) in enumerate(train_loader):
|
||||
global_step += 1
|
||||
local_embeddings = backbone(img)
|
||||
loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
|
||||
|
||||
if cfg.fp16:
|
||||
amp.scale(loss).backward()
|
||||
amp.unscale_(opt)
|
||||
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
|
||||
amp.step(opt)
|
||||
amp.update()
|
||||
else:
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
|
||||
opt.step()
|
||||
|
||||
opt.zero_grad()
|
||||
lr_scheduler.step()
|
||||
|
||||
with torch.no_grad():
|
||||
if wandb_logger:
|
||||
wandb_logger.log({
|
||||
'Loss/Step Loss': loss.item(),
|
||||
'Loss/Train Loss': loss_am.avg,
|
||||
'Process/Step': global_step,
|
||||
'Process/Epoch': epoch
|
||||
})
|
||||
|
||||
loss_am.update(loss.item(), 1)
|
||||
callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
|
||||
|
||||
if global_step % cfg.verbose == 0 and global_step > 0:
|
||||
callback_verification(global_step, backbone)
|
||||
|
||||
if cfg.save_all_states:
|
||||
checkpoint = {
|
||||
"epoch": epoch + 1,
|
||||
"global_step": global_step,
|
||||
"state_dict_backbone": backbone.module.state_dict(),
|
||||
"state_dict_softmax_fc": module_partial_fc.state_dict(),
|
||||
"state_optimizer": opt.state_dict(),
|
||||
"state_lr_scheduler": lr_scheduler.state_dict()
|
||||
}
|
||||
torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
|
||||
|
||||
if rank == 0:
|
||||
path_module = os.path.join(cfg.output, "model.pt")
|
||||
torch.save(backbone.module.state_dict(), path_module)
|
||||
|
||||
if wandb_logger and cfg.save_artifacts:
|
||||
artifact_name = f"{run_name}_E{epoch}"
|
||||
model = wandb.Artifact(artifact_name, type='model')
|
||||
model.add_file(path_module)
|
||||
wandb_logger.log_artifact(model)
|
||||
|
||||
if cfg.dali:
|
||||
train_loader.reset()
|
||||
|
||||
if rank == 0:
|
||||
path_module = os.path.join(cfg.output, "model.pt")
|
||||
torch.save(backbone.module.state_dict(), path_module)
|
||||
|
||||
from torch2onnx import convert_onnx
|
||||
convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))
|
||||
|
||||
if wandb_logger and cfg.save_artifacts:
|
||||
artifact_name = f"{run_name}_Final"
|
||||
model = wandb.Artifact(artifact_name, type='model')
|
||||
model.add_file(path_module)
|
||||
wandb_logger.log_artifact(model)
|
||||
|
||||
distributed.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.backends.cudnn.benchmark = True
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Distributed Arcface Training in Pytorch")
|
||||
parser.add_argument("config", type=str, help="py config file")
|
||||
main(parser.parse_args())
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from backbones import get_model
|
||||
from dataset import get_dataloader
|
||||
from losses import CombinedMarginLoss
|
||||
from lr_scheduler import PolyScheduler
|
||||
from lr_scheduler import PolynomialLRWarmup
|
||||
from partial_fc_v2 import PartialFC_V2
|
||||
from torch import distributed
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -17,6 +17,7 @@ from utils.utils_callbacks import CallBackLogging, CallBackVerification
|
||||
from utils.utils_config import get_config
|
||||
from utils.utils_distributed_sampler import setup_seed
|
||||
from utils.utils_logging import AverageMeter, init_logging
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook
|
||||
|
||||
assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \
|
||||
we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
|
||||
@@ -81,12 +82,12 @@ def main(args):
|
||||
except Exception as e:
|
||||
print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
|
||||
print(f"Config Error: {e}")
|
||||
|
||||
train_loader = get_dataloader(
|
||||
cfg.rec,
|
||||
local_rank,
|
||||
cfg.batch_size,
|
||||
cfg.dali,
|
||||
cfg.dali_aug,
|
||||
cfg.seed,
|
||||
cfg.num_workers
|
||||
)
|
||||
@@ -97,6 +98,7 @@ def main(args):
|
||||
backbone = torch.nn.parallel.DistributedDataParallel(
|
||||
module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
|
||||
find_unused_parameters=True)
|
||||
backbone.register_comm_hook(None, fp16_compress_hook)
|
||||
|
||||
backbone.train()
|
||||
# FIXME using gradient checkpoint if there are some unused parameters will cause error
|
||||
@@ -113,7 +115,7 @@ def main(args):
|
||||
if cfg.optimizer == "sgd":
|
||||
module_partial_fc = PartialFC_V2(
|
||||
margin_loss, cfg.embedding_size, cfg.num_classes,
|
||||
cfg.sample_rate, cfg.fp16)
|
||||
cfg.sample_rate, False)
|
||||
module_partial_fc.train().cuda()
|
||||
# TODO the params of partial fc must be last in the params list
|
||||
opt = torch.optim.SGD(
|
||||
@@ -123,7 +125,7 @@ def main(args):
|
||||
elif cfg.optimizer == "adamw":
|
||||
module_partial_fc = PartialFC_V2(
|
||||
margin_loss, cfg.embedding_size, cfg.num_classes,
|
||||
cfg.sample_rate, cfg.fp16)
|
||||
cfg.sample_rate, False)
|
||||
module_partial_fc.train().cuda()
|
||||
opt = torch.optim.AdamW(
|
||||
params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
|
||||
@@ -135,13 +137,10 @@ def main(args):
|
||||
cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
|
||||
cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
|
||||
|
||||
lr_scheduler = PolyScheduler(
|
||||
lr_scheduler = PolynomialLRWarmup(
|
||||
optimizer=opt,
|
||||
base_lr=cfg.lr,
|
||||
max_steps=cfg.total_step,
|
||||
warmup_steps=cfg.warmup_step,
|
||||
last_epoch=-1
|
||||
)
|
||||
warmup_iters=cfg.warmup_step,
|
||||
total_iters=cfg.total_step)
|
||||
|
||||
start_epoch = 0
|
||||
global_step = 0
|
||||
|
||||
Reference in New Issue
Block a user