Update training log.

This commit is contained in:
AnXiang
2020-11-16 17:20:33 +08:00
parent 863a7ea9ea
commit 0ffed74604
3 changed files with 84 additions and 16 deletions

View File

@@ -41,6 +41,25 @@ class MarginSoftmax(nn.Module):
return ret
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# .......
def main(local_rank):
dist.init_process_group(backend='nccl', init_method='env://')
@@ -85,10 +104,10 @@ def main(local_rank):
}, {
'params': dist_sample_classifer.parameters()
}],
lr=0.1,
momentum=0.9,
weight_decay=cfg.weight_decay,
rescale=cfg.world_size)
lr=0.1,
momentum=0.9,
weight_decay=cfg.weight_decay,
rescale=cfg.world_size)
# Lr scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
@@ -99,11 +118,17 @@ def main(local_rank):
if local_rank == 0:
writer = SummaryWriter(log_dir='logs/shows')
#
total_step = int(len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch)
if dist.get_rank() == 0:
print("Total Step is: %d" % total_step)
losses = AverageMeter()
global_step = 0
train_start = time.time()
for epoch in range(start_epoch, n_epochs):
train_sampler.set_epoch(epoch)
for step, (img, label) in enumerate(train_loader):
start = time.time()
total_label, norm_weight = dist_sample_classifer.prepare(
label, optimizer)
features = F.normalize(backbone(img))
@@ -166,22 +191,30 @@ def main(local_rank):
# Update classifer
dist_sample_classifer.update()
optimizer.zero_grad()
if cfg.local_rank == 0:
losses.update(loss_v, 1)
if cfg.local_rank == 0 and step % 50 == 0:
time_now = (time.time() - train_start) / 3600
time_total = time_now / ((global_step + 1) / total_step)
time_for_end = time_total - time_now
writer.add_scalar('time_for_end', time_for_end, global_step)
writer.add_scalar('loss', loss_v, global_step)
print(
"Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d"
% ((cfg.batch_size / (time.time() - start) * cfg.world_size),
loss_v, epoch, global_step))
print("Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" %
(
(cfg.batch_size * global_step / (time.time() - train_start) * cfg.world_size),
losses.avg,
epoch,
global_step,
time_for_end
))
losses.reset()
global_step += 1
scheduler.step()
if dist.get_rank() == 0:
import os
if not os.path.exists('models'):
os.makedirs('models')
torch.save(backbone.module.state_dict(),
"models/" + str(epoch) + 'backbone.pth')
if not os.path.exists(cfg.output):
os.makedirs(cfg.output)
torch.save(backbone.module.state_dict(), os.path.join(cfg.output, str(epoch) + 'backbone.pth'))
dist.destroy_process_group()