From 0ffed746045bd936acec44206247da8fee67f736 Mon Sep 17 00:00:00 2001 From: AnXiang Date: Mon, 16 Nov 2020 17:20:33 +0800 Subject: [PATCH] Update training log. --- .../partial_fc/pytorch/IJB/IJBC_img2array.py | 32 ++++++++++ recognition/partial_fc/pytorch/config.py | 5 +- recognition/partial_fc/pytorch/partial_fc.py | 63 ++++++++++++++----- 3 files changed, 84 insertions(+), 16 deletions(-) create mode 100644 recognition/partial_fc/pytorch/IJB/IJBC_img2array.py diff --git a/recognition/partial_fc/pytorch/IJB/IJBC_img2array.py b/recognition/partial_fc/pytorch/IJB/IJBC_img2array.py new file mode 100644 index 0000000..c282cc6 --- /dev/null +++ b/recognition/partial_fc/pytorch/IJB/IJBC_img2array.py @@ -0,0 +1,32 @@ +import torch + +def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): + batch_size = args.batch_size + data_shape = (3, 112, 112) + + files = files_list + print('files:', len(files)) + rare_size = len(files) % batch_size + faceness_scores = [] + batch = 0 + img_feats = np.empty((len(files), 1024), dtype=np.float32) + + batch_data = np.empty((2 * batch_size, 3, 112, 112)) + embedding = Embedding(model_path, epoch, data_shape, batch_size, gpu_id) + for img_index, each_line in enumerate(files[:len(files) - rare_size]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + # print(2*(img_index-batch*batch_size), 2*(img_index-batch*batch_size)+1) + batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] + batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] + if (img_index + 1) % batch_size == 0: + print('batch', batch) + img_feats[batch * batch_size:batch * batch_size + + batch_size][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) \ No newline at end of file diff --git a/recognition/partial_fc/pytorch/config.py b/recognition/partial_fc/pytorch/config.py index f9ff3e8..4df4ba1 100644 --- a/recognition/partial_fc/pytorch/config.py +++ b/recognition/partial_fc/pytorch/config.py @@ -9,14 +9,17 @@ config.momentum = 0.9 config.weight_decay = 5e-4 config.batch_size = 64 config.lr = 0.1 +config.output = "tmp_models" if config.dataset == "emore": config.rec = "/train_tmp/faces_emore" config.num_classes = 85742 config.num_epoch = 16 + def lr_step_func(epoch): - return ((epoch + 1) / (4 + 1))**2 if epoch < -1 else 0.1**len( + return ((epoch + 1) / (4 + 1)) ** 2 if epoch < -1 else 0.1 ** len( [m for m in [8, 14] if m - 1 <= epoch]) + config.lr_func = lr_step_func diff --git a/recognition/partial_fc/pytorch/partial_fc.py b/recognition/partial_fc/pytorch/partial_fc.py index 308ae94..acc4ba7 100644 --- a/recognition/partial_fc/pytorch/partial_fc.py +++ b/recognition/partial_fc/pytorch/partial_fc.py @@ -41,6 +41,25 @@ class MarginSoftmax(nn.Module): return ret +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + # ....... def main(local_rank): dist.init_process_group(backend='nccl', init_method='env://') @@ -85,10 +104,10 @@ def main(local_rank): }, { 'params': dist_sample_classifer.parameters() }], - lr=0.1, - momentum=0.9, - weight_decay=cfg.weight_decay, - rescale=cfg.world_size) + lr=0.1, + momentum=0.9, + weight_decay=cfg.weight_decay, + rescale=cfg.world_size) # Lr scheduler scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, @@ -99,11 +118,17 @@ def main(local_rank): if local_rank == 0: writer = SummaryWriter(log_dir='logs/shows') + # + total_step = int(len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch) + if dist.get_rank() == 0: + print("Total Step is: %d" % total_step) + + losses = AverageMeter() global_step = 0 + train_start = time.time() for epoch in range(start_epoch, n_epochs): train_sampler.set_epoch(epoch) for step, (img, label) in enumerate(train_loader): - start = time.time() total_label, norm_weight = dist_sample_classifer.prepare( label, optimizer) features = F.normalize(backbone(img)) @@ -166,22 +191,30 @@ def main(local_rank): # Update classifer dist_sample_classifer.update() optimizer.zero_grad() - - if cfg.local_rank == 0: + losses.update(loss_v, 1) + if cfg.local_rank == 0 and step % 50 == 0: + time_now = (time.time() - train_start) / 3600 + time_total = time_now / ((global_step + 1) / total_step) + time_for_end = time_total - time_now + writer.add_scalar('time_for_end', time_for_end, global_step) writer.add_scalar('loss', loss_v, global_step) - print( - "Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d" - % ((cfg.batch_size / (time.time() - start) * cfg.world_size), - loss_v, epoch, global_step)) + print("Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % + ( + (cfg.batch_size * global_step / (time.time() - train_start) * cfg.world_size), + losses.avg, + epoch, + global_step, + time_for_end + )) + losses.reset() global_step += 1 scheduler.step() if dist.get_rank() == 0: import os - if not os.path.exists('models'): - os.makedirs('models') - torch.save(backbone.module.state_dict(), - "models/" + str(epoch) + 'backbone.pth') + if not os.path.exists(cfg.output): + os.makedirs(cfg.output) + torch.save(backbone.module.state_dict(), os.path.join(cfg.output, str(epoch) + 'backbone.pth')) dist.destroy_process_group()