mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-14 12:17:55 +00:00
Update training log.
This commit is contained in:
@@ -41,6 +41,25 @@ class MarginSoftmax(nn.Module):
|
||||
return ret
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# .......
|
||||
def main(local_rank):
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
@@ -85,10 +104,10 @@ def main(local_rank):
|
||||
}, {
|
||||
'params': dist_sample_classifer.parameters()
|
||||
}],
|
||||
lr=0.1,
|
||||
momentum=0.9,
|
||||
weight_decay=cfg.weight_decay,
|
||||
rescale=cfg.world_size)
|
||||
lr=0.1,
|
||||
momentum=0.9,
|
||||
weight_decay=cfg.weight_decay,
|
||||
rescale=cfg.world_size)
|
||||
|
||||
# Lr scheduler
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
|
||||
@@ -99,11 +118,17 @@ def main(local_rank):
|
||||
if local_rank == 0:
|
||||
writer = SummaryWriter(log_dir='logs/shows')
|
||||
|
||||
#
|
||||
total_step = int(len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch)
|
||||
if dist.get_rank() == 0:
|
||||
print("Total Step is: %d" % total_step)
|
||||
|
||||
losses = AverageMeter()
|
||||
global_step = 0
|
||||
train_start = time.time()
|
||||
for epoch in range(start_epoch, n_epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
for step, (img, label) in enumerate(train_loader):
|
||||
start = time.time()
|
||||
total_label, norm_weight = dist_sample_classifer.prepare(
|
||||
label, optimizer)
|
||||
features = F.normalize(backbone(img))
|
||||
@@ -166,22 +191,30 @@ def main(local_rank):
|
||||
# Update classifer
|
||||
dist_sample_classifer.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
losses.update(loss_v, 1)
|
||||
if cfg.local_rank == 0 and step % 50 == 0:
|
||||
time_now = (time.time() - train_start) / 3600
|
||||
time_total = time_now / ((global_step + 1) / total_step)
|
||||
time_for_end = time_total - time_now
|
||||
writer.add_scalar('time_for_end', time_for_end, global_step)
|
||||
writer.add_scalar('loss', loss_v, global_step)
|
||||
print(
|
||||
"Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d"
|
||||
% ((cfg.batch_size / (time.time() - start) * cfg.world_size),
|
||||
loss_v, epoch, global_step))
|
||||
print("Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" %
|
||||
(
|
||||
(cfg.batch_size * global_step / (time.time() - train_start) * cfg.world_size),
|
||||
losses.avg,
|
||||
epoch,
|
||||
global_step,
|
||||
time_for_end
|
||||
))
|
||||
losses.reset()
|
||||
|
||||
global_step += 1
|
||||
scheduler.step()
|
||||
if dist.get_rank() == 0:
|
||||
import os
|
||||
if not os.path.exists('models'):
|
||||
os.makedirs('models')
|
||||
torch.save(backbone.module.state_dict(),
|
||||
"models/" + str(epoch) + 'backbone.pth')
|
||||
if not os.path.exists(cfg.output):
|
||||
os.makedirs(cfg.output)
|
||||
torch.save(backbone.module.state_dict(), os.path.join(cfg.output, str(epoch) + 'backbone.pth'))
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user