mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
Add FP16 Support.
Merge PartialFC pytorch into arcface_torch. Pytorch1.6+ is all you need
This commit is contained in:
@@ -49,11 +49,32 @@ More details see [eval.md](docs/eval.md) in docs.
|
||||
| MS1MV3-Arcface | r18 | 92.08 | 94.68 |97.65 |97.63 |99.73|
|
||||
| MS1MV3-Arcface | r34 | | | | | |
|
||||
| MS1MV3-Arcface | r50 | 94.79 | 96.43 |98.28 |98.89 |99.85|
|
||||
| MS1MV3-Arcface | r50-amp | 94.72 | 96.41 |98.30 |99.06 |99.85|
|
||||
| MS1MV3-Arcface | r100 | 95.22 | 96.87 |98.45 |99.19 |99.85|
|
||||
|
||||
### Glint360k
|
||||
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) |agedb30|cfp_fp|lfw |
|
||||
| :---: | :--- | :--- | :--- |:--- |:--- |:--- |
|
||||
| Glint360k-Cosface | r100 | - | - |- |- |- |
|
||||
| Glint360k-Cosface | r100 | 96.19 | 97.39 |98.52 |99.26 |99.83|
|
||||
|
||||
More details see [eval.md](docs/modelzoo.md) in docs.
|
||||
More details see [eval.md](docs/modelzoo.md) in docs.
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
```
|
||||
@inproceedings{deng2019arcface,
|
||||
title={Arcface: Additive angular margin loss for deep face recognition},
|
||||
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
|
||||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={4690--4699},
|
||||
year={2019}
|
||||
}
|
||||
@inproceedings{an2020partical_fc,
|
||||
title={Partial FC: Training 10 Million Identities on a Single Machine},
|
||||
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
|
||||
Zhang, Debing and Fu Ying},
|
||||
booktitle={Arxiv 2010.05222},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
@@ -35,50 +35,36 @@ class IBasicBlock(nn.Module):
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(
|
||||
inplanes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
||||
self.conv1 = conv3x3(inplanes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(
|
||||
planes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
||||
self.prelu = nn.PReLU(planes)
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn3 = nn.BatchNorm2d(
|
||||
planes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.bn1(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.prelu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class IResNet(nn.Module):
|
||||
fc_scale = 7 * 7
|
||||
|
||||
def __init__(self,
|
||||
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None):
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
||||
super(IResNet, self).__init__()
|
||||
self.fp16 = fp16
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
@@ -109,8 +95,7 @@ class IResNet(nn.Module):
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
||||
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
||||
self.fc = nn.Linear(512 * block.expansion * self.fc_scale,
|
||||
num_features)
|
||||
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
||||
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
||||
nn.init.constant_(self.features.weight, 1.0)
|
||||
self.features.weight.requires_grad = False
|
||||
@@ -154,21 +139,19 @@ class IResNet(nn.Module):
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.bn2(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.bn2(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x.float() if self.fp16 else x)
|
||||
x = self.features(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 64
|
||||
config.lr = 0.1 # batch size is 512
|
||||
config.output = "ms1mv3_r50_arcface"
|
||||
config.output = "ms1mv3_arcface_r50"
|
||||
|
||||
if config.dataset == "emore":
|
||||
config.rec = "/train_tmp/faces_emore"
|
||||
|
||||
@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
class PartialFC(Module):
|
||||
"""
|
||||
Author: {Yang Xiao, Xiang An, XuHan Zhu} in DeepGlint,
|
||||
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
|
||||
Partial FC: Training 10 Million Identities on a Single Machine
|
||||
See the original paper:
|
||||
https://arxiv.org/abs/2010.05222
|
||||
|
||||
@@ -16,6 +16,7 @@ from dataset import MXFaceDataset, DataLoaderX
|
||||
from partial_fc import PartialFC
|
||||
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
|
||||
from utils.utils_logging import AverageMeter, init_logging
|
||||
from utils.utils_amp import MaxClipGradScaler
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
@@ -42,7 +43,7 @@ def main(args):
|
||||
sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=True)
|
||||
|
||||
dropout = 0.4 if cfg.dataset is "webface" else 0
|
||||
backbone = eval("backbones.{}".format(args.network))(False, dropout=dropout).to(local_rank)
|
||||
backbone = eval("backbones.{}".format(args.network))(False, dropout=dropout, fp16=cfg.fp16).to(local_rank)
|
||||
|
||||
if args.resume:
|
||||
try:
|
||||
@@ -81,8 +82,7 @@ def main(args):
|
||||
|
||||
start_epoch = 0
|
||||
total_step = int(len(trainset) / cfg.batch_size / world_size * cfg.num_epoch)
|
||||
if rank is 0:
|
||||
logging.info("Total Step is: %d" % total_step)
|
||||
if rank is 0: logging.info("Total Step is: %d" % total_step)
|
||||
|
||||
callback_verification = CallBackVerification(2000, rank, cfg.val_targets, cfg.rec)
|
||||
callback_logging = CallBackLogging(50, rank, total_step, cfg.batch_size, world_size, None)
|
||||
@@ -90,21 +90,31 @@ def main(args):
|
||||
|
||||
loss = AverageMeter()
|
||||
global_step = 0
|
||||
grad_scaler = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
|
||||
for epoch in range(start_epoch, cfg.num_epoch):
|
||||
train_sampler.set_epoch(epoch)
|
||||
for step, (img, label) in enumerate(train_loader):
|
||||
global_step += 1
|
||||
features = F.normalize(backbone(img))
|
||||
x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
|
||||
features.backward(x_grad)
|
||||
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
|
||||
opt_backbone.step()
|
||||
|
||||
if cfg.fp16:
|
||||
features.backward(grad_scaler.scale(x_grad))
|
||||
grad_scaler.unscale_(opt_backbone)
|
||||
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
|
||||
grad_scaler.step(opt_backbone)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
features.backward(x_grad)
|
||||
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
|
||||
opt_backbone.step()
|
||||
|
||||
opt_pfc.step()
|
||||
module_partial_fc.update()
|
||||
opt_backbone.zero_grad()
|
||||
opt_pfc.zero_grad()
|
||||
loss.update(loss_v, 1)
|
||||
callback_logging(global_step, loss, epoch)
|
||||
callback_logging(global_step, loss, epoch, cfg.fp16, grad_scaler)
|
||||
callback_verification(global_step, backbone)
|
||||
callback_checkpoint(global_step, backbone, module_partial_fc)
|
||||
scheduler_backbone.step()
|
||||
|
||||
81
recognition/arcface_torch/utils/utils_amp.py
Normal file
81
recognition/arcface_torch/utils/utils_amp.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch._six import container_abcs
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
|
||||
class _MultiDeviceReplicator(object):
|
||||
"""
|
||||
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||
"""
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert master_tensor.is_cuda
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
def get(self, device) -> torch.Tensor:
|
||||
retval = self._per_device_tensors.get(device, None)
|
||||
if retval is None:
|
||||
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||
self._per_device_tensors[device] = retval
|
||||
return retval
|
||||
|
||||
|
||||
class MaxClipGradScaler(GradScaler):
|
||||
def __init__(self, init_scale, max_scale: float, growth_interval=100):
|
||||
GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
|
||||
self.max_scale = max_scale
|
||||
|
||||
def scale_clip(self):
|
||||
if self.get_scale() == self.max_scale:
|
||||
self.set_growth_factor(1)
|
||||
elif self.get_scale() < self.max_scale:
|
||||
self.set_growth_factor(2)
|
||||
elif self.get_scale() > self.max_scale:
|
||||
self._scale.fill_(self.max_scale)
|
||||
self.set_growth_factor(1)
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
|
||||
Arguments:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
self.scale_clip()
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert outputs.is_cuda
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||
|
||||
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert val.is_cuda
|
||||
if len(stash) == 0:
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(val.device)
|
||||
assert self._scale is not None
|
||||
stash.append(_MultiDeviceReplicator(self._scale))
|
||||
return val * stash[0].get(val.device)
|
||||
elif isinstance(val, container_abcs.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, list) or isinstance(val, tuple):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
return apply_scale(outputs)
|
||||
@@ -62,14 +62,13 @@ class CallBackLogging(object):
|
||||
self.init = False
|
||||
self.tic = 0
|
||||
|
||||
def __call__(self, global_step, loss: AverageMeter, epoch: int):
|
||||
def __call__(self, global_step, loss: AverageMeter, epoch: int, fp16: bool, grad_scaler: torch.cuda.amp.GradScaler):
|
||||
if self.rank is 0 and global_step > 0 and global_step % self.frequent == 0:
|
||||
if self.init:
|
||||
try:
|
||||
speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
|
||||
speed_total = speed * self.world_size
|
||||
except ZeroDivisionError:
|
||||
speed = float('inf')
|
||||
speed_total = float('inf')
|
||||
|
||||
time_now = (time.time() - self.time_start) / 3600
|
||||
@@ -78,10 +77,15 @@ class CallBackLogging(object):
|
||||
if self.writer is not None:
|
||||
self.writer.add_scalar('time_for_end', time_for_end, global_step)
|
||||
self.writer.add_scalar('loss', loss.avg, global_step)
|
||||
|
||||
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % (
|
||||
speed_total, loss.avg, epoch, global_step, time_for_end
|
||||
)
|
||||
if fp16:
|
||||
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d "\
|
||||
"Fp16 Grad Scale: %2.f Required: %1.f hours" % (
|
||||
speed_total, loss.avg, epoch, global_step, grad_scaler.get_scale(), time_for_end
|
||||
)
|
||||
else:
|
||||
msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % (
|
||||
speed_total, loss.avg, epoch, global_step, time_for_end
|
||||
)
|
||||
logging.info(msg)
|
||||
loss.reset()
|
||||
self.tic = time.time()
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import torch
|
||||
|
||||
def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
|
||||
batch_size = args.batch_size
|
||||
data_shape = (3, 112, 112)
|
||||
|
||||
files = files_list
|
||||
print('files:', len(files))
|
||||
rare_size = len(files) % batch_size
|
||||
faceness_scores = []
|
||||
batch = 0
|
||||
img_feats = np.empty((len(files), 1024), dtype=np.float32)
|
||||
|
||||
batch_data = np.empty((2 * batch_size, 3, 112, 112))
|
||||
embedding = Embedding(model_path, epoch, data_shape, batch_size, gpu_id)
|
||||
for img_index, each_line in enumerate(files[:len(files) - rare_size]):
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
input_blob = embedding.get(img, lmk)
|
||||
# print(2*(img_index-batch*batch_size), 2*(img_index-batch*batch_size)+1)
|
||||
batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
|
||||
batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
|
||||
if (img_index + 1) % batch_size == 0:
|
||||
print('batch', batch)
|
||||
img_feats[batch * batch_size:batch * batch_size +
|
||||
batch_size][:] = embedding.forward_db(batch_data)
|
||||
batch += 1
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
@@ -1,455 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
# import cPickle
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import timeit
|
||||
import sklearn
|
||||
import argparse
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
from sklearn import preprocessing
|
||||
import cv2
|
||||
import sys
|
||||
import glob
|
||||
|
||||
sys.path.append('./recognition')
|
||||
from embedding import Embedding
|
||||
# from embedding_test import Embedding
|
||||
from menpo.visualize import print_progress
|
||||
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
from prettytable import PrettyTable
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import horovod as hvd
|
||||
import mxnet as mx
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
parser = argparse.ArgumentParser(description='do ijb test')
|
||||
# general
|
||||
parser.add_argument('--model-prefix', default='', help='path to load model.')
|
||||
parser.add_argument('--model-epoch', default=1, type=int, help='')
|
||||
parser.add_argument('--image-path', default='', type=str, help='')
|
||||
parser.add_argument('--result-dir', default='.', type=str, help='')
|
||||
parser.add_argument('--gpu', default=7, type=int, help='gpu id')
|
||||
parser.add_argument('--batch-size', default=32, type=int, help='')
|
||||
parser.add_argument('--job', default='insightface', type=str, help='job name')
|
||||
parser.add_argument('--target',
|
||||
default='IJBC',
|
||||
type=str,
|
||||
help='target, set to IJBC or IJBB')
|
||||
args = parser.parse_args()
|
||||
|
||||
target = args.target
|
||||
model_path = args.model_prefix
|
||||
image_path = args.image_path
|
||||
result_dir = args.result_dir
|
||||
gpu_id = args.gpu
|
||||
epoch = args.model_epoch
|
||||
use_norm_score = True # if Ture, TestMode(N1)
|
||||
use_detector_score = True # if Ture, TestMode(D1)
|
||||
use_flip_test = True # if Ture, TestMode(F1)
|
||||
job = args.job
|
||||
|
||||
# initialize Horovod
|
||||
# hvd.init()
|
||||
# rank_size = hvd.size()
|
||||
|
||||
|
||||
# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
|
||||
def divideIntoNstrand(listTemp, n):
|
||||
twoList = [[] for i in range(n)]
|
||||
for i, e in enumerate(listTemp):
|
||||
twoList[i % n].append(e)
|
||||
return twoList
|
||||
|
||||
|
||||
def read_template_media_list(path):
|
||||
# ijb_meta = np.loadtxt(path, dtype=str)
|
||||
ijb_meta = pd.read_csv(path, sep=' ', header=None).values
|
||||
templates = ijb_meta[:, 1].astype(np.int)
|
||||
medias = ijb_meta[:, 2].astype(np.int)
|
||||
return templates, medias
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
# pairs = np.loadtxt(path, dtype=str)
|
||||
pairs = pd.read_csv(path, sep=' ', header=None).values
|
||||
# print(pairs.shape)
|
||||
# print(pairs[:, 0].astype(np.int))
|
||||
t1 = pairs[:, 0].astype(np.int)
|
||||
t2 = pairs[:, 1].astype(np.int)
|
||||
label = pairs[:, 2].astype(np.int)
|
||||
return t1, t2, label
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_image_feature(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
#
|
||||
# def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
|
||||
# batch_size = 125
|
||||
# data_shape = (3, 112, 112)
|
||||
# embedding = Embedding(model_path, epoch, data_shape, batch_size, gpu_id)
|
||||
# files = files_list
|
||||
# print('files:', len(files))
|
||||
#
|
||||
# faceness_scores = []
|
||||
# img_feats = []
|
||||
# batch = 0
|
||||
# img_feats = np.empty((len(files), 1024), dtype=np.float32)
|
||||
# batch_data = mx.ndarray.empty((2 * batch_size, 3, 112, 112))
|
||||
# for img_index, each_line in enumerate(files):
|
||||
# # if img_index % 500 == 0:
|
||||
# # print('processing', img_index)
|
||||
# name_lmk_score = each_line.strip().split(' ')
|
||||
# img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
# img = cv2.imread(img_name)
|
||||
# lmk = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32)
|
||||
# lmk = lmk.reshape((5, 2))
|
||||
# input_blob = embedding.get(img, lmk)
|
||||
# # print(img_index*2 - batch*batch_size, img_index*2 + 1 - batch*batch_size)
|
||||
# batch_data[2*(img_index-batch*batch_size)][:] = mx.nd.array(input_blob)[0]
|
||||
# batch_data[2*(img_index-batch*batch_size)+1][:] = mx.nd.array(input_blob)[1]
|
||||
#
|
||||
# if (img_index+1) % batch_size == 0:
|
||||
# print('batch', batch)
|
||||
# # img_feats.append(embedding.forward_db(batch_data))
|
||||
# img_feats[batch*batch_size:batch*batch_size+batch_size][:] = embedding.forward_db(batch_data)
|
||||
# batch += 1
|
||||
# batch_data = mx.ndarray.empty((2 * batch_size, 3, 112, 112))
|
||||
# faceness_scores.append(name_lmk_score[-1])
|
||||
# # img_feats = np.array(img_feats).astype(np.float32)
|
||||
# faceness_scores = np.array(faceness_scores).astype(np.float32)
|
||||
# # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
|
||||
# # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
|
||||
# return img_feats, faceness_scores
|
||||
def get_image_feature(img_path, img_list_path, model_path, epoch, gpu_id):
|
||||
img_list = open(img_list_path)
|
||||
embedding = Embedding(model_path, epoch, gpu_id)
|
||||
files = img_list.readlines()
|
||||
print('files:', len(files))
|
||||
faceness_scores = []
|
||||
img_feats = []
|
||||
for img_index, each_line in enumerate(files):
|
||||
if img_index % 500 == 0:
|
||||
print('processing', img_index)
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
img_feats.append(embedding.get(img, lmk))
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
img_feats = np.array(img_feats).astype(np.float32)
|
||||
faceness_scores = np.array(faceness_scores).astype(np.float32)
|
||||
|
||||
#img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
|
||||
#faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
|
||||
return img_feats, faceness_scores
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def image2template_feature(img_feats=None, templates=None, medias=None):
|
||||
# ==========================================================
|
||||
# 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
|
||||
# 2. compute media feature.
|
||||
# 3. compute template feature.
|
||||
# ==========================================================
|
||||
unique_templates = np.unique(templates)
|
||||
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
|
||||
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
|
||||
(ind_t, ) = np.where(templates == uqt)
|
||||
face_norm_feats = img_feats[ind_t]
|
||||
face_medias = medias[ind_t]
|
||||
unique_medias, unique_media_counts = np.unique(face_medias,
|
||||
return_counts=True)
|
||||
media_norm_feats = []
|
||||
for u, ct in zip(unique_medias, unique_media_counts):
|
||||
(ind_m, ) = np.where(face_medias == u)
|
||||
if ct == 1:
|
||||
media_norm_feats += [face_norm_feats[ind_m]]
|
||||
else: # image features from the same video will be aggregated into one feature
|
||||
media_norm_feats += [
|
||||
np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
|
||||
]
|
||||
media_norm_feats = np.array(media_norm_feats)
|
||||
# media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
|
||||
template_feats[count_template] = np.sum(media_norm_feats, axis=0)
|
||||
if count_template % 2000 == 0:
|
||||
print('Finish Calculating {} template features.'.format(
|
||||
count_template))
|
||||
# template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
|
||||
template_norm_feats = sklearn.preprocessing.normalize(template_feats)
|
||||
# print(template_norm_feats.shape)
|
||||
return template_norm_feats, unique_templates
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def verification(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
# ==========================================================
|
||||
# Compute set-to-set Similarity Score.
|
||||
# ==========================================================
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
|
||||
score = np.zeros((len(p1), )) # save cosine distance between pairs
|
||||
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
# In[ ]:
|
||||
def verification2(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
score = np.zeros((len(p1), )) # save cosine distance between pairs
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
def read_score(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
# # Step1: Load Meta Data
|
||||
|
||||
# In[ ]:
|
||||
|
||||
assert target == 'IJBC' or target == 'IJBB'
|
||||
|
||||
# =============================================================
|
||||
# load image and template relationships for template feature embedding
|
||||
# tid --> template id, mid --> media id
|
||||
# format:
|
||||
# image_name tid mid
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
templates, medias = read_template_media_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_face_tid_mid.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load template pairs for template-to-template verification
|
||||
# tid : template id, label : 1/0
|
||||
# format:
|
||||
# tid_1 tid_2 label
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
p1, p2, label = read_template_pair_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_template_pair_label.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 2: Get Image Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load image features
|
||||
# format:
|
||||
# img_feats: [image_num x feats_dim] (227630, 512)
|
||||
# =============================================================
|
||||
# start = timeit.default_timer()
|
||||
# img_path = '%s/loose_crop' % image_path
|
||||
# img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
|
||||
# img_list = open(img_list_path)
|
||||
# files = img_list.readlines()
|
||||
# # files_list = divideIntoNstrand(files, rank_size)
|
||||
# files_list = files
|
||||
|
||||
# img_feats
|
||||
# for i in range(rank_size):
|
||||
# img_feats, faceness_scores = get_image_feature(img_path, files_list, model_path, epoch, gpu_id)
|
||||
|
||||
start = timeit.default_timer()
|
||||
img_path = '%s/loose_crop' % image_path
|
||||
img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
|
||||
img_feats, faceness_scores = get_image_feature(img_path, img_list_path,
|
||||
model_path, epoch, gpu_id)
|
||||
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
|
||||
img_feats.shape[1]))
|
||||
|
||||
# # Step3: Get Template Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute template features from image features.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
# ==========================================================
|
||||
# Norm feature before aggregation into template feature?
|
||||
# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
|
||||
# ==========================================================
|
||||
# 1. FaceScore (Feature Norm)
|
||||
# 2. FaceScore (Detector)
|
||||
|
||||
if use_flip_test:
|
||||
# concat --- F1
|
||||
# img_input_feats = img_feats
|
||||
# add --- F2
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] //
|
||||
2] + img_feats[:, img_feats.shape[1] // 2:]
|
||||
else:
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
|
||||
|
||||
if use_norm_score:
|
||||
img_input_feats = img_input_feats
|
||||
else:
|
||||
# normalise features to remove norm information
|
||||
img_input_feats = img_input_feats / np.sqrt(
|
||||
np.sum(img_input_feats**2, -1, keepdims=True))
|
||||
|
||||
if use_detector_score:
|
||||
print(img_input_feats.shape, faceness_scores.shape)
|
||||
# img_input_feats = img_input_feats * np.matlib.repmat(faceness_scores[:,np.newaxis], 1, img_input_feats.shape[1])
|
||||
img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
|
||||
else:
|
||||
img_input_feats = img_input_feats
|
||||
|
||||
template_norm_feats, unique_templates = image2template_feature(
|
||||
img_input_feats, templates, medias)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 4: Get Template Similarity Scores
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute verification scores between template pairs.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
score = verification(template_norm_feats, unique_templates, p1, p2)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
|
||||
save_path = result_dir + '/%s_result' % target
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
score_save_file = os.path.join(save_path, "%s.npy" % job)
|
||||
np.save(score_save_file, score)
|
||||
|
||||
# # Step 5: Get ROC Curves and TPR@FPR Table
|
||||
|
||||
# In[ ]:
|
||||
|
||||
files = [score_save_file]
|
||||
methods = []
|
||||
scores = []
|
||||
for file in files:
|
||||
methods.append(Path(file).stem)
|
||||
scores.append(np.load(file))
|
||||
|
||||
methods = np.array(methods)
|
||||
scores = dict(zip(methods, scores))
|
||||
colours = dict(
|
||||
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
||||
# x_labels = [1/(10**x) for x in np.linspace(6, 0, 6)]
|
||||
x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
|
||||
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
||||
fig = plt.figure()
|
||||
for method in methods:
|
||||
fpr, tpr, _ = roc_curve(label, scores[method])
|
||||
roc_auc = auc(fpr, tpr)
|
||||
fpr = np.flipud(fpr)
|
||||
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
||||
plt.plot(fpr,
|
||||
tpr,
|
||||
color=colours[method],
|
||||
lw=1,
|
||||
label=('[%s (AUC = %0.4f %%)]' %
|
||||
(method.split('-')[-1], roc_auc * 100)))
|
||||
tpr_fpr_row = []
|
||||
tpr_fpr_row.append("%s-%s" % (method, target))
|
||||
for fpr_iter in np.arange(len(x_labels)):
|
||||
_, min_index = min(
|
||||
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
||||
# tpr_fpr_row.append('%.4f' % tpr[min_index])
|
||||
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
||||
tpr_fpr_table.add_row(tpr_fpr_row)
|
||||
plt.xlim([10**-6, 0.1])
|
||||
plt.ylim([0.3, 1.0])
|
||||
plt.grid(linestyle='--', linewidth=1)
|
||||
plt.xticks(x_labels)
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.xscale('log')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('ROC on IJB')
|
||||
plt.legend(loc="lower right")
|
||||
# plt.show()
|
||||
fig.savefig(os.path.join(save_path, '%s.pdf' % job))
|
||||
print(tpr_fpr_table)
|
||||
@@ -1,440 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
# import cPickle
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import timeit
|
||||
import sklearn
|
||||
import argparse
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
from sklearn import preprocessing
|
||||
import cv2
|
||||
import sys
|
||||
import glob
|
||||
|
||||
sys.path.append('./recognition')
|
||||
from embedding_test import Embedding
|
||||
from menpo.visualize import print_progress
|
||||
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
from prettytable import PrettyTable
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
# import horovod as hvd
|
||||
import mxnet as mx
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
parser = argparse.ArgumentParser(description='do ijb test')
|
||||
# general
|
||||
parser.add_argument('--model-prefix', default='', help='path to load model.')
|
||||
parser.add_argument('--model-epoch', default=1, type=int, help='')
|
||||
parser.add_argument('--image-path', default='', type=str, help='')
|
||||
parser.add_argument('--result-dir', default='.', type=str, help='')
|
||||
parser.add_argument('--gpu', default=7, type=int, help='gpu id')
|
||||
parser.add_argument('--batch-size', default=128, type=int, help='')
|
||||
parser.add_argument('--job', default='insightface', type=str, help='job name')
|
||||
parser.add_argument('--target',
|
||||
default='IJBC',
|
||||
type=str,
|
||||
help='target, set to IJBC or IJBB')
|
||||
args = parser.parse_args()
|
||||
|
||||
target = args.target
|
||||
model_path = args.model_prefix
|
||||
image_path = args.image_path
|
||||
result_dir = args.result_dir
|
||||
gpu_id = args.gpu
|
||||
epoch = args.model_epoch
|
||||
use_norm_score = True # if Ture, TestMode(N1)
|
||||
use_detector_score = True # if Ture, TestMode(D1)
|
||||
use_flip_test = True # if Ture, TestMode(F1)
|
||||
job = args.job
|
||||
batch_size = args.batch_size
|
||||
|
||||
# initialize Horovod
|
||||
# hvd.init()
|
||||
# rank_size = hvd.size()
|
||||
|
||||
|
||||
# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
|
||||
def divideIntoNstrand(listTemp, n):
|
||||
twoList = [[] for i in range(n)]
|
||||
for i, e in enumerate(listTemp):
|
||||
twoList[i % n].append(e)
|
||||
return twoList
|
||||
|
||||
|
||||
def read_template_media_list(path):
|
||||
# ijb_meta = np.loadtxt(path, dtype=str)
|
||||
ijb_meta = pd.read_csv(path, sep=' ', header=None).values
|
||||
templates = ijb_meta[:, 1].astype(np.int)
|
||||
medias = ijb_meta[:, 2].astype(np.int)
|
||||
return templates, medias
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
# pairs = np.loadtxt(path, dtype=str)
|
||||
pairs = pd.read_csv(path, sep=' ', header=None).values
|
||||
# print(pairs.shape)
|
||||
# print(pairs[:, 0].astype(np.int))
|
||||
t1 = pairs[:, 0].astype(np.int)
|
||||
t2 = pairs[:, 1].astype(np.int)
|
||||
label = pairs[:, 2].astype(np.int)
|
||||
return t1, t2, label
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_image_feature(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
|
||||
batch_size = args.batch_size
|
||||
data_shape = (3, 112, 112)
|
||||
|
||||
files = files_list
|
||||
print('files:', len(files))
|
||||
rare_size = len(files) % batch_size
|
||||
faceness_scores = []
|
||||
batch = 0
|
||||
img_feats = np.empty((len(files), 1024), dtype=np.float32)
|
||||
|
||||
batch_data = np.empty((2 * batch_size, 3, 112, 112))
|
||||
embedding = Embedding(model_path, epoch, data_shape, batch_size, gpu_id)
|
||||
for img_index, each_line in enumerate(files[:len(files) - rare_size]):
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
input_blob = embedding.get(img, lmk)
|
||||
# print(2*(img_index-batch*batch_size), 2*(img_index-batch*batch_size)+1)
|
||||
batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
|
||||
batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
|
||||
if (img_index + 1) % batch_size == 0:
|
||||
print('batch', batch)
|
||||
img_feats[batch * batch_size:batch * batch_size +
|
||||
batch_size][:] = embedding.forward_db(batch_data)
|
||||
batch += 1
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
# img_feats = np.array(img_feats).astype(np.float32)
|
||||
batch_data = np.empty((2 * rare_size, 3, 112, 112))
|
||||
embedding = Embedding(model_path, epoch, data_shape, rare_size, gpu_id)
|
||||
for img_index, each_line in enumerate(files[len(files) - rare_size:]):
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
input_blob = embedding.get(img, lmk)
|
||||
batch_data[2 * img_index][:] = input_blob[0]
|
||||
batch_data[2 * img_index + 1][:] = input_blob[1]
|
||||
if (img_index + 1) % rare_size == 0:
|
||||
print('batch', batch)
|
||||
img_feats[len(files) -
|
||||
rare_size:][:] = embedding.forward_db(batch_data)
|
||||
batch += 1
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
faceness_scores = np.array(faceness_scores).astype(np.float32)
|
||||
# img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
|
||||
# faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
|
||||
return img_feats, faceness_scores
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def image2template_feature(img_feats=None, templates=None, medias=None):
|
||||
# ==========================================================
|
||||
# 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
|
||||
# 2. compute media feature.
|
||||
# 3. compute template feature.
|
||||
# ==========================================================
|
||||
unique_templates = np.unique(templates)
|
||||
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
|
||||
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
|
||||
(ind_t, ) = np.where(templates == uqt)
|
||||
face_norm_feats = img_feats[ind_t]
|
||||
face_medias = medias[ind_t]
|
||||
unique_medias, unique_media_counts = np.unique(face_medias,
|
||||
return_counts=True)
|
||||
media_norm_feats = []
|
||||
for u, ct in zip(unique_medias, unique_media_counts):
|
||||
(ind_m, ) = np.where(face_medias == u)
|
||||
if ct == 1:
|
||||
media_norm_feats += [face_norm_feats[ind_m]]
|
||||
else: # image features from the same video will be aggregated into one feature
|
||||
media_norm_feats += [
|
||||
np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
|
||||
]
|
||||
media_norm_feats = np.array(media_norm_feats)
|
||||
# media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
|
||||
template_feats[count_template] = np.sum(media_norm_feats, axis=0)
|
||||
if count_template % 2000 == 0:
|
||||
print('Finish Calculating {} template features.'.format(
|
||||
count_template))
|
||||
# template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
|
||||
template_norm_feats = sklearn.preprocessing.normalize(template_feats)
|
||||
# print(template_norm_feats.shape)
|
||||
return template_norm_feats, unique_templates
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def verification(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
# ==========================================================
|
||||
# Compute set-to-set Similarity Score.
|
||||
# ==========================================================
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
|
||||
score = np.zeros((len(p1), )) # save cosine distance between pairs
|
||||
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
# In[ ]:
|
||||
def verification2(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
score = np.zeros((len(p1), )) # save cosine distance between pairs
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
def read_score(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
# # Step1: Load Meta Data
|
||||
|
||||
# In[ ]:
|
||||
|
||||
assert target == 'IJBC' or target == 'IJBB'
|
||||
|
||||
# =============================================================
|
||||
# load image and template relationships for template feature embedding
|
||||
# tid --> template id, mid --> media id
|
||||
# format:
|
||||
# image_name tid mid
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
templates, medias = read_template_media_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_face_tid_mid.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load template pairs for template-to-template verification
|
||||
# tid : template id, label : 1/0
|
||||
# format:
|
||||
# tid_1 tid_2 label
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
p1, p2, label = read_template_pair_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_template_pair_label.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 2: Get Image Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load image features
|
||||
# format:
|
||||
# img_feats: [image_num x feats_dim] (227630, 512)
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
img_path = '%s/loose_crop' % image_path
|
||||
img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
|
||||
img_list = open(img_list_path)
|
||||
files = img_list.readlines()
|
||||
# files_list = divideIntoNstrand(files, rank_size)
|
||||
files_list = files
|
||||
|
||||
# img_feats
|
||||
# for i in range(rank_size):
|
||||
img_feats, faceness_scores = get_image_feature(img_path, files_list,
|
||||
model_path, epoch, gpu_id)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
|
||||
img_feats.shape[1]))
|
||||
|
||||
# # Step3: Get Template Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute template features from image features.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
# ==========================================================
|
||||
# Norm feature before aggregation into template feature?
|
||||
# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
|
||||
# ==========================================================
|
||||
# 1. FaceScore (Feature Norm)
|
||||
# 2. FaceScore (Detector)
|
||||
|
||||
if use_flip_test:
|
||||
# concat --- F1
|
||||
# img_input_feats = img_feats
|
||||
# add --- F2
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] //
|
||||
2] + img_feats[:, img_feats.shape[1] // 2:]
|
||||
else:
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
|
||||
|
||||
if use_norm_score:
|
||||
img_input_feats = img_input_feats
|
||||
else:
|
||||
# normalise features to remove norm information
|
||||
img_input_feats = img_input_feats / np.sqrt(
|
||||
np.sum(img_input_feats**2, -1, keepdims=True))
|
||||
|
||||
if use_detector_score:
|
||||
print(img_input_feats.shape, faceness_scores.shape)
|
||||
# img_input_feats = img_input_feats * np.matlib.repmat(faceness_scores[:,np.newaxis], 1, img_input_feats.shape[1])
|
||||
img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
|
||||
else:
|
||||
img_input_feats = img_input_feats
|
||||
|
||||
template_norm_feats, unique_templates = image2template_feature(
|
||||
img_input_feats, templates, medias)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 4: Get Template Similarity Scores
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute verification scores between template pairs.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
score = verification(template_norm_feats, unique_templates, p1, p2)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
|
||||
save_path = result_dir + '/%s_result' % target
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
score_save_file = os.path.join(save_path, "%s.npy" % job)
|
||||
np.save(score_save_file, score)
|
||||
|
||||
# # Step 5: Get ROC Curves and TPR@FPR Table
|
||||
|
||||
# In[ ]:
|
||||
|
||||
files = [score_save_file]
|
||||
methods = []
|
||||
scores = []
|
||||
for file in files:
|
||||
methods.append(Path(file).stem)
|
||||
scores.append(np.load(file))
|
||||
|
||||
methods = np.array(methods)
|
||||
scores = dict(zip(methods, scores))
|
||||
colours = dict(
|
||||
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
||||
# x_labels = [1/(10**x) for x in np.linspace(6, 0, 6)]
|
||||
x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
|
||||
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
||||
fig = plt.figure()
|
||||
for method in methods:
|
||||
fpr, tpr, _ = roc_curve(label, scores[method])
|
||||
roc_auc = auc(fpr, tpr)
|
||||
fpr = np.flipud(fpr)
|
||||
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
||||
plt.plot(fpr,
|
||||
tpr,
|
||||
color=colours[method],
|
||||
lw=1,
|
||||
label=('[%s (AUC = %0.4f %%)]' %
|
||||
(method.split('-')[-1], roc_auc * 100)))
|
||||
tpr_fpr_row = []
|
||||
tpr_fpr_row.append("%s-%s" % (method, target))
|
||||
for fpr_iter in np.arange(len(x_labels)):
|
||||
_, min_index = min(
|
||||
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
||||
# tpr_fpr_row.append('%.4f' % tpr[min_index])
|
||||
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
||||
tpr_fpr_table.add_row(tpr_fpr_row)
|
||||
plt.xlim([10**-6, 0.1])
|
||||
plt.ylim([0.3, 1.0])
|
||||
plt.grid(linestyle='--', linewidth=1)
|
||||
plt.xticks(x_labels)
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.xscale('log')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('ROC on IJB')
|
||||
plt.legend(loc="lower right")
|
||||
# plt.show()
|
||||
fig.savefig(os.path.join(save_path, '%s.pdf' % job))
|
||||
print(tpr_fpr_table)
|
||||
@@ -1,103 +0,0 @@
|
||||
import os
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from prettytable import PrettyTable
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from menpo.visualize import print_progress
|
||||
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
|
||||
target = 'IJBC'
|
||||
job = 'IJBC'
|
||||
title = 'IJB-C'
|
||||
|
||||
root = '/train/trainset/1'
|
||||
score_save_file_1 = '{}/glint-face/IJB/result/celeb360kfinal0.1/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
score_save_file_2 = '{}/glint-face/IJB/result/celeb360kfinal1.0/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
score_save_file_3 = '{}/glint-face/IJB/result/emore0.4/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
score_save_file_4 = '{}/glint-face/IJB/result/emore0.8/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
score_save_file_5 = '{}/glint-face/IJB/result/emore1.0/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
score_save_file_6 = '{}/glint-face/IJB/result/retina1.0/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
|
||||
save_path = '{}/glint-face/IJB'.format(root)
|
||||
image_path = '{}/face/IJB_release/{}'.format(root, target)
|
||||
methods = [
|
||||
'celeb360k_final0.1, S=0.1', 'celeb360k_final1.0, S=1.0', 'MS1MV2, S=0.4',
|
||||
'MS1MV2, S=0.8', 'MS1MV2, S=1.0', 'RETINA, S=1.0'
|
||||
]
|
||||
files = [
|
||||
score_save_file_1, score_save_file_2, score_save_file_3, score_save_file_4,
|
||||
score_save_file_5, score_save_file_6
|
||||
]
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
pairs = pd.read_csv(path, sep=' ', header=None).values
|
||||
# print(pairs.shape)
|
||||
# print(pairs[:, 0].astype(np.int))
|
||||
t1 = pairs[:, 0].astype(np.int)
|
||||
t2 = pairs[:, 1].astype(np.int)
|
||||
label = pairs[:, 2].astype(np.int)
|
||||
return t1, t2, label
|
||||
|
||||
|
||||
p1, p2, label = read_template_pair_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_template_pair_label.txt' % target.lower()))
|
||||
|
||||
scores = []
|
||||
for file in files:
|
||||
scores.append(np.load(file))
|
||||
|
||||
methods = np.array(methods)
|
||||
scores = dict(zip(methods, scores))
|
||||
colours = dict(
|
||||
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
||||
# x_labels = [1/(10**x) for x in np.linspace(6, 0, 6)]
|
||||
x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
|
||||
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
||||
fig = plt.figure()
|
||||
for method in methods:
|
||||
fpr, tpr, _ = roc_curve(label, scores[method])
|
||||
roc_auc = auc(fpr, tpr)
|
||||
fpr = np.flipud(fpr)
|
||||
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
||||
plt.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
color=colours[method],
|
||||
lw=1,
|
||||
# label=('[%s (AUC = %0.4f %%)]' % (method.split('-')[-1], roc_auc * 100))
|
||||
label=method)
|
||||
tpr_fpr_row = []
|
||||
tpr_fpr_row.append("%s-%s" % (method, target))
|
||||
for fpr_iter in np.arange(len(x_labels)):
|
||||
_, min_index = min(
|
||||
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
||||
# tpr_fpr_row.append('%.4f' % tpr[min_index])
|
||||
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
||||
tpr_fpr_table.add_row(tpr_fpr_row)
|
||||
plt.xlim([10**-6, 0.1])
|
||||
plt.ylim([0.30, 1.0])
|
||||
plt.grid(linestyle='--', linewidth=1)
|
||||
plt.xticks(x_labels)
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.xscale('log')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('ROC on {}'.format(title))
|
||||
plt.legend(loc="lower right")
|
||||
# plt.show()
|
||||
fig.savefig(os.path.join(save_path, '%s.pdf' % job))
|
||||
print(tpr_fpr_table)
|
||||
@@ -1,404 +0,0 @@
|
||||
# coding: utf-8
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
# import cPickle
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import timeit
|
||||
import sklearn
|
||||
import argparse
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
from sklearn import preprocessing
|
||||
import cv2
|
||||
import sys
|
||||
import glob
|
||||
|
||||
sys.path.append('./recognition')
|
||||
from embedding import Embedding
|
||||
from menpo.visualize import print_progress
|
||||
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
from prettytable import PrettyTable
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import horovod as hvd
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
parser = argparse.ArgumentParser(description='do ijb test')
|
||||
# general
|
||||
parser.add_argument('--model-prefix', default='', help='path to load model.')
|
||||
parser.add_argument('--model-epoch', default=1, type=int, help='')
|
||||
parser.add_argument('--image-path', default='', type=str, help='')
|
||||
parser.add_argument('--gpu', default=7, type=int, help='gpu id')
|
||||
parser.add_argument('--batch-size', default=32, type=int, help='')
|
||||
parser.add_argument('--job', default='insightface', type=str, help='job name')
|
||||
parser.add_argument('--target',
|
||||
default='IJBC',
|
||||
type=str,
|
||||
help='target, set to IJBC or IJBB')
|
||||
args = parser.parse_args()
|
||||
|
||||
target = args.target
|
||||
model_path = args.model_prefix
|
||||
image_path = args.image_path
|
||||
gpu_id = args.gpu
|
||||
epoch = args.model_epoch
|
||||
use_norm_score = True # if Ture, TestMode(N1)
|
||||
use_detector_score = True # if Ture, TestMode(D1)
|
||||
use_flip_test = True # if Ture, TestMode(F1)
|
||||
job = args.job
|
||||
|
||||
# initialize Horovod
|
||||
hvd.init()
|
||||
rank_size = hvd.size()
|
||||
|
||||
|
||||
# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
|
||||
def divideIntoNstrand(listTemp, n):
|
||||
twoList = [[] for i in range(n)]
|
||||
for i, e in enumerate(listTemp):
|
||||
twoList[i % n].append(e)
|
||||
return twoList
|
||||
|
||||
|
||||
def read_template_media_list(path):
|
||||
# ijb_meta = np.loadtxt(path, dtype=str)
|
||||
ijb_meta = pd.read_csv(path, sep=' ', header=None).values
|
||||
templates = ijb_meta[:, 1].astype(np.int)
|
||||
medias = ijb_meta[:, 2].astype(np.int)
|
||||
return templates, medias
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
# pairs = np.loadtxt(path, dtype=str)
|
||||
pairs = pd.read_csv(path, sep=' ', header=None).values
|
||||
# print(pairs.shape)
|
||||
# print(pairs[:, 0].astype(np.int))
|
||||
t1 = pairs[:, 0].astype(np.int)
|
||||
t2 = pairs[:, 1].astype(np.int)
|
||||
label = pairs[:, 2].astype(np.int)
|
||||
return t1, t2, label
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def read_image_feature(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
|
||||
embedding = Embedding(model_path, epoch, gpu_id)
|
||||
files = files_list[hvd.rank()]
|
||||
print('files:', len(files))
|
||||
faceness_scores = []
|
||||
img_feats = []
|
||||
for img_index, each_line in enumerate(files):
|
||||
if img_index % 500 == 0:
|
||||
print('processing', img_index)
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
img_feats.append(embedding.get(img, lmk))
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
img_feats = np.array(img_feats).astype(np.float32)
|
||||
faceness_scores = np.array(faceness_scores).astype(np.float32)
|
||||
|
||||
# img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
|
||||
# faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
|
||||
return img_feats, faceness_scores
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def image2template_feature(img_feats=None, templates=None, medias=None):
|
||||
# ==========================================================
|
||||
# 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
|
||||
# 2. compute media feature.
|
||||
# 3. compute template feature.
|
||||
# ==========================================================
|
||||
unique_templates = np.unique(templates)
|
||||
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
|
||||
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
(ind_t, ) = np.where(templates == uqt)
|
||||
face_norm_feats = img_feats[ind_t]
|
||||
face_medias = medias[ind_t]
|
||||
unique_medias, unique_media_counts = np.unique(face_medias,
|
||||
return_counts=True)
|
||||
media_norm_feats = []
|
||||
for u, ct in zip(unique_medias, unique_media_counts):
|
||||
(ind_m, ) = np.where(face_medias == u)
|
||||
if ct == 1:
|
||||
media_norm_feats += [face_norm_feats[ind_m]]
|
||||
else: # image features from the same video will be aggregated into one feature
|
||||
media_norm_feats += [
|
||||
np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
|
||||
]
|
||||
media_norm_feats = np.array(media_norm_feats)
|
||||
# media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
|
||||
template_feats[count_template] = np.sum(media_norm_feats, axis=0)
|
||||
if count_template % 2000 == 0:
|
||||
print('Finish Calculating {} template features.'.format(
|
||||
count_template))
|
||||
# template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
|
||||
template_norm_feats = sklearn.preprocessing.normalize(template_feats)
|
||||
# print(template_norm_feats.shape)
|
||||
return template_norm_feats, unique_templates
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
def verification(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
# ==========================================================
|
||||
# Compute set-to-set Similarity Score.
|
||||
# ==========================================================
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
|
||||
score = np.zeros((len(p1), )) # save cosine distance between pairs
|
||||
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
# In[ ]:
|
||||
def verification2(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
score = np.zeros((len(p1), )) # save cosine distance between pairs
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
def read_score(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = pickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
# # Step1: Load Meta Data
|
||||
|
||||
# In[ ]:
|
||||
|
||||
assert target == 'IJBC' or target == 'IJBB'
|
||||
|
||||
# =============================================================
|
||||
# load image and template relationships for template feature embedding
|
||||
# tid --> template id, mid --> media id
|
||||
# format:
|
||||
# image_name tid mid
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
templates, medias = read_template_media_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_face_tid_mid.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load template pairs for template-to-template verification
|
||||
# tid : template id, label : 1/0
|
||||
# format:
|
||||
# tid_1 tid_2 label
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
p1, p2, label = read_template_pair_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_template_pair_label.txt' % target.lower()))
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 2: Get Image Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# load image features
|
||||
# format:
|
||||
# img_feats: [image_num x feats_dim] (227630, 512)
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
img_path = '%s/loose_crop' % image_path
|
||||
img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
|
||||
img_list = open(img_list_path)
|
||||
files = img_list.readlines()
|
||||
files_list = divideIntoNstrand(files, rank_size)
|
||||
|
||||
# img_feats
|
||||
# for i in range(rank_size):
|
||||
img_feats, faceness_scores = get_image_feature(img_path, files_list,
|
||||
model_path, epoch, gpu_id)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
|
||||
img_feats.shape[1]))
|
||||
|
||||
# # Step3: Get Template Features
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute template features from image features.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
# ==========================================================
|
||||
# Norm feature before aggregation into template feature?
|
||||
# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
|
||||
# ==========================================================
|
||||
# 1. FaceScore (Feature Norm)
|
||||
# 2. FaceScore (Detector)
|
||||
|
||||
if use_flip_test:
|
||||
# concat --- F1
|
||||
# img_input_feats = img_feats
|
||||
# add --- F2
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] //
|
||||
2] + img_feats[:, img_feats.shape[1] // 2:]
|
||||
else:
|
||||
img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
|
||||
|
||||
if use_norm_score:
|
||||
img_input_feats = img_input_feats
|
||||
else:
|
||||
# normalise features to remove norm information
|
||||
img_input_feats = img_input_feats / np.sqrt(
|
||||
np.sum(img_input_feats**2, -1, keepdims=True))
|
||||
|
||||
if use_detector_score:
|
||||
print(img_input_feats.shape, faceness_scores.shape)
|
||||
# img_input_feats = img_input_feats * np.matlib.repmat(faceness_scores[:,np.newaxis], 1, img_input_feats.shape[1])
|
||||
img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
|
||||
else:
|
||||
img_input_feats = img_input_feats
|
||||
|
||||
template_norm_feats, unique_templates = image2template_feature(
|
||||
img_input_feats, templates, medias)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# # Step 4: Get Template Similarity Scores
|
||||
|
||||
# In[ ]:
|
||||
|
||||
# =============================================================
|
||||
# compute verification scores between template pairs.
|
||||
# =============================================================
|
||||
start = timeit.default_timer()
|
||||
score = verification(template_norm_feats, unique_templates, p1, p2)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
|
||||
# In[ ]:
|
||||
|
||||
save_path = './%s_result' % target
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
score_save_file = os.path.join(save_path, "%s.npy" % job)
|
||||
np.save(score_save_file, score)
|
||||
|
||||
# # Step 5: Get ROC Curves and TPR@FPR Table
|
||||
|
||||
# In[ ]:
|
||||
|
||||
files = [score_save_file]
|
||||
methods = []
|
||||
scores = []
|
||||
for file in files:
|
||||
methods.append(Path(file).stem)
|
||||
scores.append(np.load(file))
|
||||
|
||||
methods = np.array(methods)
|
||||
scores = dict(zip(methods, scores))
|
||||
colours = dict(
|
||||
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
||||
# x_labels = [1/(10**x) for x in np.linspace(6, 0, 6)]
|
||||
x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
|
||||
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
||||
fig = plt.figure()
|
||||
for method in methods:
|
||||
fpr, tpr, _ = roc_curve(label, scores[method])
|
||||
roc_auc = auc(fpr, tpr)
|
||||
fpr = np.flipud(fpr)
|
||||
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
||||
plt.plot(fpr,
|
||||
tpr,
|
||||
color=colours[method],
|
||||
lw=1,
|
||||
label=('[%s (AUC = %0.4f %%)]' %
|
||||
(method.split('-')[-1], roc_auc * 100)))
|
||||
tpr_fpr_row = []
|
||||
tpr_fpr_row.append("%s-%s" % (method, target))
|
||||
for fpr_iter in np.arange(len(x_labels)):
|
||||
_, min_index = min(
|
||||
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
||||
# tpr_fpr_row.append('%.4f' % tpr[min_index])
|
||||
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
||||
tpr_fpr_table.add_row(tpr_fpr_row)
|
||||
plt.xlim([10**-6, 0.1])
|
||||
plt.ylim([0.3, 1.0])
|
||||
plt.grid(linestyle='--', linewidth=1)
|
||||
plt.xticks(x_labels)
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.xscale('log')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('ROC on IJB')
|
||||
plt.legend(loc="lower right")
|
||||
# plt.show()
|
||||
fig.savefig(os.path.join(save_path, '%s.pdf' % job))
|
||||
print(tpr_fpr_table)
|
||||
@@ -1,366 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
import os
|
||||
import numpy as np
|
||||
import timeit
|
||||
import sklearn
|
||||
import cv2
|
||||
import sys
|
||||
import argparse
|
||||
import glob
|
||||
import numpy.matlib
|
||||
import heapq
|
||||
import math
|
||||
from datetime import datetime as dt
|
||||
|
||||
from sklearn import preprocessing
|
||||
sys.path.append('./recognition')
|
||||
from embedding import Embedding
|
||||
from menpo.visualize import print_progress
|
||||
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
|
||||
|
||||
def read_template_subject_id_list(path):
|
||||
ijb_meta = np.loadtxt(path, dtype=str, skiprows=1, delimiter=',')
|
||||
templates = ijb_meta[:, 0].astype(np.int)
|
||||
subject_ids = ijb_meta[:, 1].astype(np.int)
|
||||
return templates, subject_ids
|
||||
|
||||
|
||||
def read_template_media_list(path):
|
||||
ijb_meta = np.loadtxt(path, dtype=str)
|
||||
templates = ijb_meta[:, 1].astype(np.int)
|
||||
medias = ijb_meta[:, 2].astype(np.int)
|
||||
return templates, medias
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
pairs = np.loadtxt(path, dtype=str)
|
||||
t1 = pairs[:, 0].astype(np.int)
|
||||
t2 = pairs[:, 1].astype(np.int)
|
||||
label = pairs[:, 2].astype(np.int)
|
||||
return t1, t2, label
|
||||
|
||||
|
||||
#def get_image_feature(feature_path, faceness_path):
|
||||
# img_feats = np.loadtxt(feature_path)
|
||||
# faceness_scores = np.loadtxt(faceness_path)
|
||||
# return img_feats, faceness_scores
|
||||
def get_image_feature(img_path, img_list_path, model_path, epoch, gpu_id):
|
||||
img_list = open(img_list_path)
|
||||
embedding = Embedding(model_path, epoch, gpu_id)
|
||||
files = img_list.readlines()
|
||||
print('files:', len(files))
|
||||
faceness_scores = []
|
||||
img_feats = []
|
||||
for img_index, each_line in enumerate(files):
|
||||
if img_index % 500 == 0:
|
||||
print('processing', img_index)
|
||||
name_lmk_score = each_line.strip().split(' ')
|
||||
img_name = os.path.join(img_path, name_lmk_score[0])
|
||||
img = cv2.imread(img_name)
|
||||
lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
|
||||
dtype=np.float32)
|
||||
lmk = lmk.reshape((5, 2))
|
||||
img_feats.append(embedding.get(img, lmk))
|
||||
faceness_scores.append(name_lmk_score[-1])
|
||||
img_feats = np.array(img_feats).astype(np.float32)
|
||||
faceness_scores = np.array(faceness_scores).astype(np.float32)
|
||||
|
||||
#img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
|
||||
#faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
|
||||
return img_feats, faceness_scores
|
||||
|
||||
|
||||
def image2template_feature(img_feats=None,
|
||||
templates=None,
|
||||
medias=None,
|
||||
choose_templates=None,
|
||||
choose_ids=None):
|
||||
# ==========================================================
|
||||
# 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
|
||||
# 2. compute media feature.
|
||||
# 3. compute template feature.
|
||||
# ==========================================================
|
||||
unique_templates, indices = np.unique(choose_templates, return_index=True)
|
||||
unique_subjectids = choose_ids[indices]
|
||||
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
|
||||
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
(ind_t, ) = np.where(templates == uqt)
|
||||
face_norm_feats = img_feats[ind_t]
|
||||
face_medias = medias[ind_t]
|
||||
unique_medias, unique_media_counts = np.unique(face_medias,
|
||||
return_counts=True)
|
||||
media_norm_feats = []
|
||||
for u, ct in zip(unique_medias, unique_media_counts):
|
||||
(ind_m, ) = np.where(face_medias == u)
|
||||
if ct == 1:
|
||||
media_norm_feats += [face_norm_feats[ind_m]]
|
||||
else: # image features from the same video will be aggregated into one feature
|
||||
media_norm_feats += [
|
||||
np.mean(face_norm_feats[ind_m], 0, keepdims=True)
|
||||
]
|
||||
media_norm_feats = np.array(media_norm_feats)
|
||||
# media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
|
||||
template_feats[count_template] = np.sum(media_norm_feats, 0)
|
||||
if count_template % 2000 == 0:
|
||||
print('Finish Calculating {} template features.'.format(
|
||||
count_template))
|
||||
template_norm_feats = template_feats / np.sqrt(
|
||||
np.sum(template_feats**2, -1, keepdims=True))
|
||||
return template_norm_feats, unique_templates, unique_subjectids
|
||||
|
||||
|
||||
def verification(template_norm_feats=None,
|
||||
unique_templates=None,
|
||||
p1=None,
|
||||
p2=None):
|
||||
# ==========================================================
|
||||
# Compute set-to-set Similarity Score.
|
||||
# ==========================================================
|
||||
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
|
||||
for count_template, uqt in enumerate(unique_templates):
|
||||
template2id[uqt] = count_template
|
||||
|
||||
score = np.zeros((len(p1), )) # save cosine distance between pairs
|
||||
|
||||
total_pairs = np.array(range(len(p1)))
|
||||
batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
|
||||
sublists = [
|
||||
total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
|
||||
]
|
||||
total_sublists = len(sublists)
|
||||
for c, s in enumerate(sublists):
|
||||
feat1 = template_norm_feats[template2id[p1[s]]]
|
||||
feat2 = template_norm_feats[template2id[p2[s]]]
|
||||
similarity_score = np.sum(feat1 * feat2, -1)
|
||||
score[s] = similarity_score.flatten()
|
||||
if c % 10 == 0:
|
||||
print('Finish {}/{} pairs.'.format(c, total_sublists))
|
||||
return score
|
||||
|
||||
|
||||
def read_score(path):
|
||||
with open(path, 'rb') as fid:
|
||||
img_feats = cPickle.load(fid)
|
||||
return img_feats
|
||||
|
||||
|
||||
def evaluation(query_feats, gallery_feats, mask):
|
||||
Fars = [0.01, 0.1]
|
||||
print(query_feats.shape)
|
||||
print(gallery_feats.shape)
|
||||
|
||||
query_num = query_feats.shape[0]
|
||||
gallery_num = gallery_feats.shape[0]
|
||||
|
||||
similarity = np.dot(query_feats, gallery_feats.T)
|
||||
print('similarity shape', similarity.shape)
|
||||
top_inds = np.argsort(-similarity)
|
||||
print(top_inds.shape)
|
||||
|
||||
# calculate top1
|
||||
correct_num = 0
|
||||
for i in range(query_num):
|
||||
j = top_inds[i, 0]
|
||||
if j == mask[i]:
|
||||
correct_num += 1
|
||||
print("top1 = {}".format(correct_num / query_num))
|
||||
# calculate top5
|
||||
correct_num = 0
|
||||
for i in range(query_num):
|
||||
j = top_inds[i, 0:5]
|
||||
if mask[i] in j:
|
||||
correct_num += 1
|
||||
print("top5 = {}".format(correct_num / query_num))
|
||||
# calculate 10
|
||||
correct_num = 0
|
||||
for i in range(query_num):
|
||||
j = top_inds[i, 0:10]
|
||||
if mask[i] in j:
|
||||
correct_num += 1
|
||||
print("top10 = {}".format(correct_num / query_num))
|
||||
|
||||
neg_pair_num = query_num * gallery_num - query_num
|
||||
print(neg_pair_num)
|
||||
required_topk = [math.ceil(query_num * x) for x in Fars]
|
||||
top_sims = similarity
|
||||
# calculate fars and tprs
|
||||
pos_sims = []
|
||||
for i in range(query_num):
|
||||
gt = mask[i]
|
||||
pos_sims.append(top_sims[i, gt])
|
||||
top_sims[i, gt] = -2.0
|
||||
|
||||
pos_sims = np.array(pos_sims)
|
||||
print(pos_sims.shape)
|
||||
neg_sims = top_sims[np.where(top_sims > -2.0)]
|
||||
print("neg_sims num = {}".format(len(neg_sims)))
|
||||
neg_sims = heapq.nlargest(max(required_topk), neg_sims) # heap sort
|
||||
print("after sorting , neg_sims num = {}".format(len(neg_sims)))
|
||||
for far, pos in zip(Fars, required_topk):
|
||||
th = neg_sims[pos - 1]
|
||||
recall = np.sum(pos_sims > th) / query_num
|
||||
print("far = {:.10f} pr = {:.10f} th = {:.10f}".format(
|
||||
far, recall, th))
|
||||
|
||||
|
||||
def gen_mask(query_ids, reg_ids):
|
||||
mask = []
|
||||
for query_id in query_ids:
|
||||
pos = [i for i, x in enumerate(reg_ids) if query_id == x]
|
||||
if len(pos) != 1:
|
||||
raise RuntimeError(
|
||||
"RegIdsError with id = {}, duplicate = {} ".format(
|
||||
query_id, len(pos)))
|
||||
mask.append(pos[0])
|
||||
return mask
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='do ijb 1n test')
|
||||
# general
|
||||
parser.add_argument('--model-prefix',
|
||||
default='',
|
||||
help='path to load model.')
|
||||
parser.add_argument('--model-epoch', default=1, type=int, help='')
|
||||
parser.add_argument('--gpu', default=7, type=int, help='gpu id')
|
||||
parser.add_argument('--batch-size', default=32, type=int, help='')
|
||||
parser.add_argument('--job',
|
||||
default='insightface',
|
||||
type=str,
|
||||
help='job name')
|
||||
parser.add_argument('--target',
|
||||
default='IJBC',
|
||||
type=str,
|
||||
help='target, set to IJBC or IJBB')
|
||||
args = parser.parse_args()
|
||||
target = args.target
|
||||
model_path = args.model_prefix
|
||||
gpu_id = args.gpu
|
||||
epoch = args.model_epoch
|
||||
meta_dir = "%s/meta" % args.target #meta root dir
|
||||
if target == 'IJBC':
|
||||
gallery_s1_record = "%s_1N_gallery_G1.csv" % (args.target.lower())
|
||||
gallery_s2_record = "%s_1N_gallery_G2.csv" % (args.target.lower())
|
||||
else:
|
||||
gallery_s1_record = "%s_1N_gallery_S1.csv" % (args.target.lower())
|
||||
gallery_s2_record = "%s_1N_gallery_S2.csv" % (args.target.lower())
|
||||
gallery_s1_templates, gallery_s1_subject_ids = read_template_subject_id_list(
|
||||
os.path.join(meta_dir, gallery_s1_record))
|
||||
print(gallery_s1_templates.shape, gallery_s1_subject_ids.shape)
|
||||
|
||||
gallery_s2_templates, gallery_s2_subject_ids = read_template_subject_id_list(
|
||||
os.path.join(meta_dir, gallery_s2_record))
|
||||
print(gallery_s2_templates.shape, gallery_s2_templates.shape)
|
||||
|
||||
gallery_templates = np.concatenate(
|
||||
[gallery_s1_templates, gallery_s2_templates])
|
||||
gallery_subject_ids = np.concatenate(
|
||||
[gallery_s1_subject_ids, gallery_s2_subject_ids])
|
||||
print(gallery_templates.shape, gallery_subject_ids.shape)
|
||||
|
||||
media_record = "%s_face_tid_mid.txt" % args.target.lower()
|
||||
total_templates, total_medias = read_template_media_list(
|
||||
os.path.join(meta_dir, media_record))
|
||||
print("total_templates", total_templates.shape, total_medias.shape)
|
||||
#load image features
|
||||
start = timeit.default_timer()
|
||||
feature_path = '' #feature path
|
||||
face_path = '' #face path
|
||||
img_path = './%s/loose_crop' % target
|
||||
img_list_path = './%s/meta/%s_name_5pts_score.txt' % (target,
|
||||
target.lower())
|
||||
#img_feats, faceness_scores = get_image_feature(feature_path, face_path)
|
||||
img_feats, faceness_scores = get_image_feature(img_path, img_list_path,
|
||||
model_path, epoch, gpu_id)
|
||||
print('img_feats', img_feats.shape)
|
||||
print('faceness_scores', faceness_scores.shape)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
|
||||
img_feats.shape[1]))
|
||||
|
||||
# compute template features from image features.
|
||||
start = timeit.default_timer()
|
||||
# ==========================================================
|
||||
# Norm feature before aggregation into template feature?
|
||||
# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
|
||||
# ==========================================================
|
||||
use_norm_score = True # if True, TestMode(N1)
|
||||
use_detector_score = True # if True, TestMode(D1)
|
||||
use_flip_test = True # if True, TestMode(F1)
|
||||
|
||||
if use_flip_test:
|
||||
# concat --- F1
|
||||
#img_input_feats = img_feats
|
||||
# add --- F2
|
||||
img_input_feats = img_feats[:, 0:int(
|
||||
img_feats.shape[1] / 2)] + img_feats[:,
|
||||
int(img_feats.shape[1] / 2):]
|
||||
else:
|
||||
img_input_feats = img_feats[:, 0:int(img_feats.shape[1] / 2)]
|
||||
|
||||
if use_norm_score:
|
||||
img_input_feats = img_input_feats
|
||||
else:
|
||||
# normalise features to remove norm information
|
||||
img_input_feats = img_input_feats / np.sqrt(
|
||||
np.sum(img_input_feats**2, -1, keepdims=True))
|
||||
|
||||
if use_detector_score:
|
||||
img_input_feats = img_input_feats * np.matlib.repmat(
|
||||
faceness_scores[:, np.newaxis], 1, img_input_feats.shape[1])
|
||||
else:
|
||||
img_input_feats = img_input_feats
|
||||
print("input features shape", img_input_feats.shape)
|
||||
|
||||
#load gallery feature
|
||||
gallery_templates_feature, gallery_unique_templates, gallery_unique_subject_ids = image2template_feature(
|
||||
img_input_feats, total_templates, total_medias, gallery_templates,
|
||||
gallery_subject_ids)
|
||||
stop = timeit.default_timer()
|
||||
print('Time: %.2f s. ' % (stop - start))
|
||||
print("gallery_templates_feature", gallery_templates_feature.shape)
|
||||
print("gallery_unique_subject_ids", gallery_unique_subject_ids.shape)
|
||||
#np.savetxt("gallery_templates_feature.txt", gallery_templates_feature)
|
||||
#np.savetxt("gallery_unique_subject_ids.txt", gallery_unique_subject_ids)
|
||||
|
||||
#load prope feature
|
||||
probe_mixed_record = "%s_1N_probe_mixed.csv" % target.lower()
|
||||
probe_mixed_templates, probe_mixed_subject_ids = read_template_subject_id_list(
|
||||
os.path.join(meta_dir, probe_mixed_record))
|
||||
print(probe_mixed_templates.shape, probe_mixed_subject_ids.shape)
|
||||
probe_mixed_templates_feature, probe_mixed_unique_templates, probe_mixed_unique_subject_ids = image2template_feature(
|
||||
img_input_feats, total_templates, total_medias, probe_mixed_templates,
|
||||
probe_mixed_subject_ids)
|
||||
print("probe_mixed_templates_feature", probe_mixed_templates_feature.shape)
|
||||
print("probe_mixed_unique_subject_ids",
|
||||
probe_mixed_unique_subject_ids.shape)
|
||||
#np.savetxt("probe_mixed_templates_feature.txt", probe_mixed_templates_feature)
|
||||
#np.savetxt("probe_mixed_unique_subject_ids.txt", probe_mixed_unique_subject_ids)
|
||||
|
||||
#root_dir = "" #feature root dir
|
||||
#gallery_id_path = "" #id filepath
|
||||
#gallery_feats_path = "" #feature filelpath
|
||||
#print("{}: start loading gallery feat {}".format(dt.now(), gallery_id_path))
|
||||
#gallery_ids, gallery_feats = load_feat_file(root_dir, gallery_id_path, gallery_feats_path)
|
||||
#print("{}: end loading gallery feat".format(dt.now()))
|
||||
#
|
||||
#probe_id_path = "probe_mixed_unique_subject_ids.txt" #probe id filepath
|
||||
#probe_feats_path = "probe_mixed_templates_feature.txt" #probe feats filepath
|
||||
#print("{}: start loading probe feat {}".format(dt.now(), probe_id_path))
|
||||
#probe_ids, probe_feats = load_feat_file(root_dir, probe_id_path, probe_feats_path)
|
||||
#print("{}: end loading probe feat".format(dt.now()))
|
||||
|
||||
gallery_ids = gallery_unique_subject_ids
|
||||
gallery_feats = gallery_templates_feature
|
||||
probe_ids = probe_mixed_unique_subject_ids
|
||||
probe_feats = probe_mixed_templates_feature
|
||||
|
||||
mask = gen_mask(probe_ids, gallery_ids)
|
||||
|
||||
print("{}: start evaluation".format(dt.now()))
|
||||
evaluation(probe_feats, gallery_feats, mask)
|
||||
print("{}: end evaluation".format(dt.now()))
|
||||
@@ -1,36 +0,0 @@
|
||||
To reproduce the figures and tables in the notebook, please download everything (model, code, data and meta info) from here:
|
||||
[Dropbox](https://www.dropbox.com/s/33a6haw7v79e5qe/IJB_release.tar?dl=0)
|
||||
or
|
||||
[Baidu Cloud](https://pan.baidu.com/s/1oer0p4_mcOrs4cfdeWfbFg)
|
||||
|
||||
Updated Meta data (1:1 and 1:N):
|
||||
[IJB-B Dropbox](https://www.dropbox.com/s/5n2ehrsucmu7vsd/IJBB_meta.tar?dl=0)
|
||||
and
|
||||
[IJB-C Dropbox](https://www.dropbox.com/s/pgju50f2gcgqkc2/IJBC_meta.tar?dl=0)
|
||||
|
||||
Please apply for the IJB-B and IJB-C by yourself and strictly follow their distribution licenses.
|
||||
|
||||
## Aknowledgement
|
||||
|
||||
Great thanks for Weidi Xie's instruction [2,3,4,5] to evaluate ArcFace [1] on IJB-B[6] and IJB-C[7] (1:1 protocol).
|
||||
|
||||
Great thanks for Yuge Huang's code [8] to evaluate ArcFace [1] on IJB-B[6] and IJB-C[7] (1:N protocol).
|
||||
|
||||
## Reference
|
||||
|
||||
[1] Jiankang Deng, Jia Guo, Niannan Xue, Stefanos Zafeiriou. Arcface: Additive angular margin loss for deep face recognition[J]. arXiv:1801.07698, 2018.
|
||||
|
||||
[2] https://github.com/ox-vgg/vgg_face2.
|
||||
|
||||
[3] Qiong Cao, Li Shen, Weidi Xie, Omkar M Parkhi, Andrew Zisserman. VGGFace2: A dataset for recognising faces across pose and age. FG, 2018.
|
||||
|
||||
[4] Weidi Xie, Andrew Zisserman. Multicolumn Networks for Face Recognition. BMVC 2018.
|
||||
|
||||
[5] Weidi Xie, Li Shen, Andrew Zisserman. Comparator Networks. ECCV, 2018.
|
||||
|
||||
[6] Whitelam, Cameron, Emma Taborsky, Austin Blanton, Brianna Maze, Jocelyn C. Adams, Tim Miller, Nathan D. Kalka et al. IARPA Janus Benchmark-B Face Dataset. CVPR Workshops, 2017.
|
||||
|
||||
[7] Maze, Brianna, Jocelyn Adams, James A. Duncan, Nathan Kalka, Tim Miller, Charles Otto, Anil K. Jain et al. IARPA Janus Benchmark–C: Face Dataset and Protocol. ICB, 2018.
|
||||
|
||||
[8] Yuge Huang, Pengcheng Shen, Ying Tai, Shaoxin Li, Xiaoming Liu, Jilin Li, Feiyue Huang, Rongrong Ji. Distribution Distillation Loss: Generic Approach for Improving Face Recognition from Hard Samples. arXiv:2002.03662.
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
|
||||
METHOD='webface'
|
||||
root='/train/trainset/1'
|
||||
IJB='IJBB'
|
||||
GPU_ID=1
|
||||
echo ${root}/glint-face/IJB/result/${METHOD}
|
||||
|
||||
cd IJB
|
||||
/usr/bin/python3 -u IJB_11_Batch.py --model-prefix /root/xy/work_dir/xyface/models/32backbone.pth \
|
||||
--image-path ${root}/face/IJB_release/${IJB} \
|
||||
--result-dir ${root}/glint-face/IJB/result/${METHOD} \
|
||||
--model-epoch 0 --gpu ${GPU_ID} \
|
||||
--target ${IJB} --job cosface \
|
||||
--batch-size 2096
|
||||
cd ..
|
||||
@@ -1,130 +0,0 @@
|
||||
import os
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from prettytable import PrettyTable
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from menpo.visualize import print_progress
|
||||
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
|
||||
version = 'C'
|
||||
|
||||
target = 'IJB' + version
|
||||
job = 'IJB' + version
|
||||
title = 'IJB-' + version
|
||||
|
||||
root = '/train/trainset/1'
|
||||
|
||||
#retina = '{}/glint-face/IJB/result/retina1.0/{}_result/cosface.npy'.format(root, target)
|
||||
#glint_retinaface_fp16 = '{}/glint-face/IJB/result/glint_retinaface_fp16/{}_result/cosface.npy'.format(root, target)
|
||||
retina_fp16_10_percents = '{}/glint-face/IJB/result/glint_retinaface_fp16_0.1/{}_result/arcface.npy'.format(
|
||||
root, target)
|
||||
retina_fp32_10_percents = '{}/glint-face/IJB/result/retina_0.1_fp32/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
retina_fp16 = '{}/glint-face/IJB/result/glint_retinaface_fp16/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
celeb360k_final = '{}/glint-face/IJB/result/celeb360kfinal1.0/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
celeb360k_final_10_percents = '{}/glint-face/IJB/result/celeb360kfinal0.1/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
retina_4GPU = '{}/glint-face/IJB/result/anxiang_ms1m_retina/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
retina_4GPU_scale2 = '{}/glint-face/IJB/result/anxiang_retina_largelr/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
emore_percents_10 = '{}/glint-face/IJB/result/emore0.1/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
emore_percents_40 = '{}/glint-face/IJB/result/emore0.4/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
emore_percents_80 = '{}/glint-face/IJB/result/emore0.8/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
#emore_percents_10 = '{}/glint-face/IJB/result/emore0.1/{}_result/cosface.npy'.format(root, target)
|
||||
#emore_percents_10 = '{}/glint-face/IJB/result/emore_cosface_0.1_margin0.45/{}_result/cosface.npy'.format(root, target)
|
||||
emore = '{}/glint-face/IJB/result/emore1.0/{}_result/cosface.npy'.format(
|
||||
root, target)
|
||||
|
||||
#celeb360k_0_1 = '{}/glint-face/IJB/result/celeb360k_0.1/{}_result/cosface.npy'.format(root, target)
|
||||
#celeb360k_1_0_1 = '{}/glint-face/IJB/result/celeb360k/{}_result/cosface.npy'.format(root, target)
|
||||
|
||||
save_path = '{}/glint-face/IJB'.format(root)
|
||||
image_path = '{}/face/IJB_release/{}'.format(root, target)
|
||||
methods = [
|
||||
'retina_fp16', 'retina_fp16_0.1', 'retina_fp32_0.1', 'celeb360k_final',
|
||||
'celeb360k_final_10_percents'
|
||||
]
|
||||
methods = ['retina_4GPU', 'retina_4GPU_scale2']
|
||||
methods = ['emore', 'emore_percents_10']
|
||||
methods = [
|
||||
'emore', 'emore_percents_10', 'emore_percents_40', 'emore_percents_80'
|
||||
]
|
||||
files = [
|
||||
retina_fp16, retina_fp16_10_percents, retina_fp32_10_percents,
|
||||
celeb360k_final, celeb360k_final_10_percents
|
||||
]
|
||||
#files = [retina_4GPU, retina_4GPU_scale2]
|
||||
#files = [emore, emore_percents_10]
|
||||
files = [emore, emore_percents_10, emore_percents_40, emore_percents_80]
|
||||
|
||||
|
||||
def read_template_pair_list(path):
|
||||
pairs = pd.read_csv(path, sep=' ', header=None).values
|
||||
# print(pairs.shape)
|
||||
# print(pairs[:, 0].astype(np.int))
|
||||
t1 = pairs[:, 0].astype(np.int)
|
||||
t2 = pairs[:, 1].astype(np.int)
|
||||
label = pairs[:, 2].astype(np.int)
|
||||
return t1, t2, label
|
||||
|
||||
|
||||
p1, p2, label = read_template_pair_list(
|
||||
os.path.join('%s/meta' % image_path,
|
||||
'%s_template_pair_label.txt' % target.lower()))
|
||||
|
||||
scores = []
|
||||
for file in files:
|
||||
scores.append(np.load(file))
|
||||
|
||||
methods = np.array(methods)
|
||||
scores = dict(zip(methods, scores))
|
||||
colours = dict(
|
||||
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
|
||||
# x_labels = [1/(10**x) for x in np.linspace(6, 0, 6)]
|
||||
x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1]
|
||||
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
|
||||
fig = plt.figure()
|
||||
for method in methods:
|
||||
fpr, tpr, _ = roc_curve(label, scores[method])
|
||||
roc_auc = auc(fpr, tpr)
|
||||
fpr = np.flipud(fpr)
|
||||
tpr = np.flipud(tpr) # select largest tpr at same fpr
|
||||
plt.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
color=colours[method],
|
||||
lw=1,
|
||||
# label=('[%s (AUC = %0.4f %%)]' % (method.split('-')[-1], roc_auc * 100))
|
||||
label=method)
|
||||
tpr_fpr_row = []
|
||||
tpr_fpr_row.append("%s-%s" % (method, target))
|
||||
for fpr_iter in np.arange(len(x_labels)):
|
||||
_, min_index = min(
|
||||
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
|
||||
# tpr_fpr_row.append('%.4f' % tpr[min_index])
|
||||
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
|
||||
tpr_fpr_table.add_row(tpr_fpr_row)
|
||||
plt.xlim([10**-6, 0.1])
|
||||
plt.ylim([0.30, 1.0])
|
||||
plt.grid(linestyle='--', linewidth=1)
|
||||
plt.xticks(x_labels)
|
||||
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
|
||||
plt.xscale('log')
|
||||
plt.xlabel('False Positive Rate')
|
||||
plt.ylabel('True Positive Rate')
|
||||
plt.title('ROC on {}'.format(title))
|
||||
plt.legend(loc="lower right")
|
||||
# plt.show()
|
||||
fig.savefig(os.path.join(save_path, '%s.pdf' % job))
|
||||
print(tpr_fpr_table)
|
||||
@@ -1,19 +0,0 @@
|
||||
To reproduce the figures and tables in the notebook, please download everything (model, code, data and meta info) from here:
|
||||
[Dropbox] https://www.dropbox.com/s/33a6haw7v79e5qe/IJB_release.tar?dl=0
|
||||
or
|
||||
[Baidu Cloud] https://pan.baidu.com/s/1oer0p4_mcOrs4cfdeWfbFg
|
||||
|
||||
Please apply for the IJB-B and IJB-C by yourself and strictly follow their distribution licenses.
|
||||
|
||||
Aknowledgement
|
||||
Great thanks for Weidi Xie's instruction [2,3,4,5] to evaluate ArcFace [1] on IJB-B[6] and IJB-C[7].
|
||||
|
||||
[1] Jiankang Deng, Jia Guo, Niannan Xue, Stefanos Zafeiriou. Arcface: Additive angular margin loss for deep face recognition[J]. arXiv:1801.07698, 2018.
|
||||
[2] https://github.com/ox-vgg/vgg_face2.
|
||||
[3] Qiong Cao, Li Shen, Weidi Xie, Omkar M Parkhi, Andrew Zisserman. VGGFace2: A dataset for recognising faces across pose and age. FG, 2018.
|
||||
[4] Weidi Xie, Andrew Zisserman. Multicolumn Networks for Face Recognition. BMVC 2018.
|
||||
[5] Weidi Xie, Li Shen, Andrew Zisserman. Comparator Networks. ECCV, 2018.
|
||||
[6] Whitelam, Cameron, Emma Taborsky, Austin Blanton, Brianna Maze, Jocelyn C. Adams, Tim Miller, Nathan D. Kalka et al. IARPA Janus Benchmark-B Face Dataset. CVPR Workshops, 2017.
|
||||
[7] Maze, Brianna, Jocelyn Adams, James A. Duncan, Nathan Kalka, Tim Miller, Charles Otto, Anil K. Jain et al. IARPA Janus Benchmark–C: Face Dataset and Protocol. ICB, 2018.
|
||||
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import sys
|
||||
import mxnet as mx
|
||||
import datetime
|
||||
from skimage import transform as trans
|
||||
import sklearn
|
||||
from sklearn import preprocessing
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
sys.path.append('/root/xy/work_dir/xyface/')
|
||||
from backbones import iresnet50, iresnet100
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, prefix, epoch, data_shape, batch_size=1, ctx_id=0):
|
||||
print('loading', prefix, epoch)
|
||||
image_size = (112, 112)
|
||||
self.image_size = image_size
|
||||
weight = torch.load(prefix)
|
||||
resnet = iresnet50().cuda()
|
||||
resnet.load_state_dict(weight)
|
||||
model = torch.nn.DataParallel(resnet)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
src = np.array(
|
||||
[[30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366],
|
||||
[33.5493, 92.3655], [62.7299, 92.2041]],
|
||||
dtype=np.float32)
|
||||
src[:, 0] += 8.0
|
||||
self.src = src
|
||||
self.batch_size = batch_size
|
||||
self.data_shape = data_shape
|
||||
|
||||
def get(self, rimg, landmark):
|
||||
|
||||
assert landmark.shape[0] == 68 or landmark.shape[0] == 5
|
||||
assert landmark.shape[1] == 2
|
||||
if landmark.shape[0] == 68:
|
||||
landmark5 = np.zeros((5, 2), dtype=np.float32)
|
||||
landmark5[0] = (landmark[36] + landmark[39]) / 2
|
||||
landmark5[1] = (landmark[42] + landmark[45]) / 2
|
||||
landmark5[2] = landmark[30]
|
||||
landmark5[3] = landmark[48]
|
||||
landmark5[4] = landmark[54]
|
||||
else:
|
||||
landmark5 = landmark
|
||||
tform = trans.SimilarityTransform()
|
||||
tform.estimate(landmark5, self.src)
|
||||
M = tform.params[0:2, :]
|
||||
img = cv2.warpAffine(rimg,
|
||||
M, (self.image_size[1], self.image_size[0]),
|
||||
borderValue=0.0)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_flip = np.fliplr(img)
|
||||
img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
|
||||
img_flip = np.transpose(img_flip, (2, 0, 1))
|
||||
input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]),
|
||||
dtype=np.uint8)
|
||||
input_blob[0] = img
|
||||
input_blob[1] = img_flip
|
||||
return input_blob
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_db(self, batch_data):
|
||||
imgs = torch.Tensor(batch_data).cuda()
|
||||
imgs.div_(255).sub_(0.5).div_(0.5)
|
||||
feat = self.model(imgs)
|
||||
feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
|
||||
return feat.cpu().numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weight = torch.load('/root/xy/work_dir/xyface/backbone0.pth')
|
||||
resnet = resnet50().cuda()
|
||||
resnet.load_state_dict(weight)
|
||||
res = torch.nn.DataParallel(resnet, [0, 1, 2, 3, 4, 5, 6, 7])
|
||||
tin = torch.Tensor(1023, 3, 112, 112).cuda()
|
||||
out = res(tin)
|
||||
print(out.size())
|
||||
@@ -1,64 +1,3 @@
|
||||
# Parital FC
|
||||
# partial fc
|
||||
|
||||
## Results
|
||||
We employ ResNet100 as the backbone.
|
||||
|
||||
### 1. IJB-C results
|
||||
|
||||
| Datasets | 1e-05 | 1e-04 | 1e-03 | 1e-02 | 1e-01 |
|
||||
| :---: | :--- | :--- | :--- | :--- | :--- |
|
||||
| Glint360K | 95.92 | 97.30 | 98.13 | 98.78 | 99.28 |
|
||||
| MS1MV2 | 94.22 | 96.27 | 97.61 | 98.34 | 99.08 |
|
||||
|
||||
### 2. IFRT results
|
||||
|
||||
TODO
|
||||
|
||||
## Training Speed Benchmark
|
||||
### 1. Train MS1MV2
|
||||
Employ **ResNet100** as the backbone.
|
||||
| GPU | FP16 | BatchSize / it | Throughput img / sec | Time / hours |
|
||||
| :--- | :--- | :--- | :--- | :--- |
|
||||
| 8 * Tesla V100-SXM2-32GB | False | 64 | 1658 | 15 |
|
||||
| 8 * Tesla V100-SXM2-32GB | True | 64 | 2243 | 12 |
|
||||
| 8 * Tesla V100-SXM2-32GB | False | 128 | 1800 | 14 |
|
||||
| 8 * Tesla V100-SXM2-32GB | True | 128 | 3337 | 7 |
|
||||
| 8 * RTX2080Ti | False | | 1200 | |
|
||||
| 8 * RTX2080Ti | | | | |
|
||||
|
||||
|
||||
Employ **ResNet50** as the backbone.
|
||||
| GPU | FP16 | BatchSize / it | Throughput img / sec | Time / hours |
|
||||
| :--- | :--- | :--- | :--- | :--- |
|
||||
| 8 * Tesla V100-SXM2-32GB | False | 64 | 2745 | 9 |
|
||||
| 8 * Tesla V100-SXM2-32GB | True | 64 | 3770 | 7 |
|
||||
| 8 * Tesla V100-SXM2-32GB | False | 128 | 2833 | 9 |
|
||||
| 8 * Tesla V100-SXM2-32GB | True | 128 | 5102 | 5 |
|
||||
|
||||
### 2. Train millions classes
|
||||
TODO
|
||||
|
||||
## How to run
|
||||
cuda=10.1
|
||||
pytorch==1.6.0
|
||||
pip install -r requirement.txt
|
||||
|
||||
```shell
|
||||
bash run.sh
|
||||
```
|
||||
使用 `bash run.sh` 这个命令运行。
|
||||
|
||||
|
||||
## Citation
|
||||
If you find Partial-FC or Glint360K useful in your research, please consider to cite the following related paper:
|
||||
|
||||
[Partial FC](https://arxiv.org/abs/2010.05222)
|
||||
```
|
||||
@inproceedings{an2020partical_fc,
|
||||
title={Partial FC: Training 10 Million Identities on a Single Machine},
|
||||
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
|
||||
Zhang, Debing and Fu Ying},
|
||||
booktitle={Arxiv 2010.05222},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
PartialFC-Pytorch has been merged into [arcface_torch](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch).
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .iresnet import iresnet34, iresnet50, iresnet100
|
||||
@@ -1,237 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
|
||||
__all__ = ['iresnet34', 'iresnet50', 'iresnet100']
|
||||
|
||||
model_urls = {
|
||||
'iresnet34': 'https://sota.nizhib.ai/insightface/iresnet34-5b0d0e90.pth',
|
||||
'iresnet50': 'https://sota.nizhib.ai/insightface/iresnet50-7f187506.pth',
|
||||
'iresnet100': 'https://sota.nizhib.ai/insightface/iresnet100-73e07ba7.pth'
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False)
|
||||
|
||||
|
||||
class IBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1):
|
||||
super(IBasicBlock, self).__init__()
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.bn1 = nn.BatchNorm2d(
|
||||
inplanes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.conv1 = conv3x3(inplanes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(
|
||||
planes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.prelu = nn.PReLU(planes)
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn3 = nn.BatchNorm2d(
|
||||
planes,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.bn1(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.prelu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class IResNet(nn.Module):
|
||||
fc_scale = 7 * 7
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layers,
|
||||
num_features=512,
|
||||
zero_init_residual=False,
|
||||
groups=1,
|
||||
width_per_group=64,
|
||||
replace_stride_with_dilation=None):
|
||||
super(IResNet, self).__init__()
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(
|
||||
replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3,
|
||||
self.inplanes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
||||
self.prelu = nn.PReLU(self.inplanes)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
self.bn2 = nn.BatchNorm2d(
|
||||
512 * block.expansion,
|
||||
eps=1e-05,
|
||||
)
|
||||
self.dropout = nn.Dropout(p=0.4, inplace=True)
|
||||
self.fc = nn.Linear(512 * block.expansion * self.fc_scale,
|
||||
num_features)
|
||||
self.features = nn.BatchNorm1d(
|
||||
num_features,
|
||||
eps=1e-05,
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight,
|
||||
mode='fan_out',
|
||||
nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, IBasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
nn.BatchNorm2d(
|
||||
planes * block.expansion,
|
||||
eps=1e-05,
|
||||
),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.bn2(x)
|
||||
x = torch.flatten(x, 1)
|
||||
# x = self.dropout(x)
|
||||
x = self.fc(x)
|
||||
x = self.features(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = IResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def iresnet34(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet50(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet100(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
@@ -1,34 +0,0 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config = edict()
|
||||
config.dataset = "glint360k"
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.fp16 = False
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 64
|
||||
config.lr = 0.1
|
||||
config.output = "tmp_models"
|
||||
|
||||
if config.dataset == "emore":
|
||||
config.rec = "/train_tmp/faces_emore"
|
||||
config.num_classes = 85742
|
||||
config.num_epoch = 16
|
||||
|
||||
def lr_step_func(epoch):
|
||||
return ((epoch + 1) / (4 + 1)) ** 2 if epoch < -1 else 0.1 ** len(
|
||||
[m for m in [8, 14] if m - 1 <= epoch])
|
||||
config.lr_func = lr_step_func
|
||||
|
||||
elif config.dataset == "glint360k":
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.num_epoch = 17
|
||||
config.warmup_epoch = -1
|
||||
|
||||
def lr_step_func(epoch):
|
||||
return ((epoch + 1) / (4 + 1)) ** 2 if epoch < config.warmup_epoch else 0.1 ** len(
|
||||
[m for m in [6, 10, 14] if m - 1 <= epoch])
|
||||
config.lr_func = lr_step_func
|
||||
@@ -1,114 +0,0 @@
|
||||
import numbers
|
||||
import os
|
||||
import queue as Queue
|
||||
import threading
|
||||
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class BackgroundGenerator(threading.Thread):
|
||||
def __init__(self, generator, local_rank, max_prefetch=6):
|
||||
super(BackgroundGenerator, self).__init__()
|
||||
self.queue = Queue.Queue(max_prefetch)
|
||||
self.generator = generator
|
||||
self.local_rank = local_rank
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
for item in self.generator:
|
||||
self.queue.put(item)
|
||||
self.queue.put(None)
|
||||
|
||||
def next(self):
|
||||
next_item = self.queue.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __next__(self):
|
||||
return self.next()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __init__(self, local_rank, **kwargs):
|
||||
super(DataLoaderX, self).__init__(**kwargs)
|
||||
self.stream = torch.cuda.Stream(local_rank)
|
||||
self.local_rank = local_rank
|
||||
|
||||
def __iter__(self):
|
||||
self.iter = super(DataLoaderX, self).__iter__()
|
||||
self.iter = BackgroundGenerator(self.iter, self.local_rank)
|
||||
self.preload()
|
||||
return self
|
||||
|
||||
def preload(self):
|
||||
self.batch = next(self.iter, None)
|
||||
if self.batch is None:
|
||||
return None
|
||||
with torch.cuda.stream(self.stream):
|
||||
for k in range(len(self.batch)):
|
||||
self.batch[k] = self.batch[k].to(device=self.local_rank,
|
||||
non_blocking=True)
|
||||
|
||||
def __next__(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
batch = self.batch
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
self.preload()
|
||||
return batch
|
||||
|
||||
|
||||
TFS = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
|
||||
class MXFaceDataset(Dataset):
|
||||
def __init__(self, root_dir, local_rank, transform=TFS):
|
||||
super(MXFaceDataset, self).__init__()
|
||||
self.transform = transform
|
||||
self.root_dir = root_dir
|
||||
self.local_rank = local_rank
|
||||
path_imgrec = os.path.join(root_dir, 'train.rec')
|
||||
path_imgidx = os.path.join(root_dir, 'train.idx')
|
||||
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec,
|
||||
'r')
|
||||
s = self.imgrec.read_idx(0)
|
||||
header, _ = mx.recordio.unpack(s)
|
||||
if header.flag > 0:
|
||||
# print('header0 label', header.label)
|
||||
self.header0 = (int(header.label[0]), int(header.label[1]))
|
||||
self.imgidx = np.array(range(1, int(header.label[0])))
|
||||
else:
|
||||
self.imgidx = np.array(list(self.imgrec.keys))
|
||||
# print("Number of Samples:{}".format(len(self.imgidx)))
|
||||
|
||||
def __getitem__(self, index):
|
||||
# index =0
|
||||
idx = self.imgidx[index]
|
||||
s = self.imgrec.read_idx(idx)
|
||||
header, img = mx.recordio.unpack(s)
|
||||
label = header.label
|
||||
if not isinstance(label, numbers.Number):
|
||||
label = label[0]
|
||||
label = torch.tensor(label, dtype=torch.long)
|
||||
sample = mx.image.imdecode(img).asnumpy()
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
return sample, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgidx)
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
Author: {Yang Xiao, Xiang An, XuHan Zhu} in DeepGlint,
|
||||
Partial FC: Training 10 Million Identities on a Single Machine
|
||||
See the original paper:
|
||||
https://arxiv.org/abs/2010.05222
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from config import config as cfg
|
||||
|
||||
|
||||
class DistSampleClassifier(Module):
|
||||
def _forward_unimplemented(self, *input: Any) -> None:
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def __init__(self, rank, local_rank, world_size):
|
||||
super(DistSampleClassifier, self).__init__()
|
||||
self.sample_rate = cfg.sample_rate
|
||||
self.num_local = cfg.num_classes // world_size + int(
|
||||
rank < cfg.num_classes % world_size)
|
||||
self.class_start = cfg.num_classes // world_size * rank + min(
|
||||
rank, cfg.num_classes % world_size)
|
||||
self.num_sample = int(self.sample_rate * self.num_local)
|
||||
self.local_rank = local_rank
|
||||
self.world_size = world_size
|
||||
|
||||
self.weight = torch.empty(size=(self.num_local, cfg.embedding_size),
|
||||
device=local_rank)
|
||||
self.weight_mom = torch.zeros_like(self.weight)
|
||||
self.stream = torch.cuda.Stream(local_rank)
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
self.index = None
|
||||
if int(self.sample_rate) == 1:
|
||||
self.update = lambda: 0
|
||||
self.sub_weight = Parameter(self.weight)
|
||||
self.sub_weight_mom = self.weight_mom
|
||||
else:
|
||||
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, total_label):
|
||||
P = (self.class_start <=
|
||||
total_label) & (total_label < self.class_start + self.num_local)
|
||||
total_label[~P] = -1
|
||||
total_label[P] -= self.class_start
|
||||
if int(self.sample_rate) != 1:
|
||||
positive = torch.unique(total_label[P], sorted=True)
|
||||
if self.num_sample - positive.size(0) >= 0:
|
||||
perm = torch.rand(self.num_local,device=cfg.local_rank)
|
||||
perm[positive] = 2.0
|
||||
index = torch.topk(perm,k=self.num_sample)[1]
|
||||
index = index.sort()[0]
|
||||
else:
|
||||
index = positive
|
||||
self.index = index
|
||||
total_label[P] = torch.searchsorted(index, total_label[P])
|
||||
self.sub_weight = Parameter(self.weight[index])
|
||||
self.sub_weight_mom = self.weight_mom[index]
|
||||
|
||||
def forward(self, total_features, norm_weight):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
logits = F.linear(total_features, norm_weight)
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self, ):
|
||||
self.weight_mom[self.index] = self.sub_weight_mom
|
||||
self.weight[self.index] = self.sub_weight
|
||||
|
||||
def prepare(self, label, optimizer):
|
||||
with torch.cuda.stream(self.stream):
|
||||
total_label = torch.zeros(label.size()[0] * self.world_size,
|
||||
device=self.local_rank,
|
||||
dtype=torch.long)
|
||||
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)),
|
||||
label)
|
||||
self.sample(total_label)
|
||||
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
|
||||
optimizer.param_groups[-1]['params'][0] = self.sub_weight
|
||||
optimizer.state[
|
||||
self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
|
||||
norm_weight = F.normalize(self.sub_weight)
|
||||
return total_label, norm_weight
|
||||
@@ -1,229 +0,0 @@
|
||||
"""
|
||||
Author: {Yang Xiao, Xiang An, XuHan Zhu} in DeepGlint,
|
||||
Partial FC: Training 10 Million Identities on a Single Machine
|
||||
See the original paper:
|
||||
https://arxiv.org/abs/2010.05222
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data.distributed
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import backbones
|
||||
from config import config as cfg
|
||||
from dataset import MXFaceDataset, DataLoaderX
|
||||
from partial_classifier import DistSampleClassifier
|
||||
from sgd import SGD
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
class MarginSoftmax(nn.Module):
|
||||
def __init__(self, s=64.0, m=0.40):
|
||||
super(MarginSoftmax, self).__init__()
|
||||
self.s = s
|
||||
self.m = m
|
||||
|
||||
def forward(self, cosine, label):
|
||||
index = torch.where(label != -1)[0]
|
||||
m_hot = torch.zeros(index.size()[0],
|
||||
cosine.size()[1],
|
||||
device=cosine.device)
|
||||
m_hot.scatter_(1, label[index, None], self.m)
|
||||
cosine[index] -= m_hot
|
||||
ret = cosine * self.s
|
||||
return ret
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# .......
|
||||
def main(local_rank):
|
||||
dist.init_process_group(backend='nccl', init_method='env://')
|
||||
cfg.local_rank = local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
cfg.rank = dist.get_rank()
|
||||
cfg.world_size = dist.get_world_size()
|
||||
trainset = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
trainset, shuffle=True)
|
||||
train_loader = DataLoaderX(local_rank=local_rank,
|
||||
dataset=trainset,
|
||||
batch_size=cfg.batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
drop_last=False)
|
||||
|
||||
backbone = backbones.iresnet100(False).to(local_rank)
|
||||
backbone.train()
|
||||
|
||||
# Broadcast init parameters
|
||||
for ps in backbone.parameters():
|
||||
dist.broadcast(ps, 0)
|
||||
|
||||
# DDP
|
||||
backbone = torch.nn.parallel.DistributedDataParallel(
|
||||
module=backbone,
|
||||
broadcast_buffers=False,
|
||||
device_ids=[cfg.local_rank])
|
||||
backbone.train()
|
||||
|
||||
# Memory classifer
|
||||
dist_sample_classifer = DistSampleClassifier(
|
||||
rank=dist.get_rank(),
|
||||
local_rank=local_rank,
|
||||
world_size=cfg.world_size)
|
||||
|
||||
# Margin softmax
|
||||
margin_softmax = MarginSoftmax(s=64.0, m=0.4)
|
||||
|
||||
# Optimizer for backbone and classifer
|
||||
optimizer = SGD([{
|
||||
'params': backbone.parameters()
|
||||
}, {
|
||||
'params': dist_sample_classifer.parameters()
|
||||
}],
|
||||
lr=cfg.lr,
|
||||
momentum=0.9,
|
||||
weight_decay=cfg.weight_decay,
|
||||
rescale=cfg.world_size)
|
||||
|
||||
# Lr scheduler
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer=optimizer,
|
||||
lr_lambda=cfg.lr_func)
|
||||
n_epochs = cfg.num_epoch
|
||||
start_epoch = 0
|
||||
|
||||
if local_rank == 0:
|
||||
writer = SummaryWriter(log_dir='logs/shows')
|
||||
|
||||
#
|
||||
total_step = int(len(trainset) / cfg.batch_size / dist.get_world_size() * cfg.num_epoch)
|
||||
if dist.get_rank() == 0:
|
||||
print("Total Step is: %d" % total_step)
|
||||
|
||||
losses = AverageMeter()
|
||||
global_step = 0
|
||||
train_start = time.time()
|
||||
for epoch in range(start_epoch, n_epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
for step, (img, label) in enumerate(train_loader):
|
||||
total_label, norm_weight = dist_sample_classifer.prepare(
|
||||
label, optimizer)
|
||||
features = F.normalize(backbone(img))
|
||||
|
||||
# Features all-gather
|
||||
total_features = torch.zeros(features.size()[0] * cfg.world_size,
|
||||
cfg.embedding_size,
|
||||
device=local_rank)
|
||||
dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)),
|
||||
features.data)
|
||||
total_features.requires_grad = True
|
||||
|
||||
# Calculate logits
|
||||
logits = dist_sample_classifer(total_features, norm_weight)
|
||||
logits = margin_softmax(logits, total_label)
|
||||
|
||||
with torch.no_grad():
|
||||
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
|
||||
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
|
||||
|
||||
# Calculate exp(logits) and all-reduce
|
||||
logits_exp = torch.exp(logits - max_fc)
|
||||
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
|
||||
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
|
||||
|
||||
# Calculate prob
|
||||
logits_exp.div_(logits_sum_exp)
|
||||
|
||||
# Get one-hot
|
||||
grad = logits_exp
|
||||
index = torch.where(total_label != -1)[0]
|
||||
one_hot = torch.zeros(index.size()[0],
|
||||
grad.size()[1],
|
||||
device=grad.device)
|
||||
one_hot.scatter_(1, total_label[index, None], 1)
|
||||
|
||||
# Calculate loss
|
||||
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
|
||||
loss[index] = grad[index].gather(1, total_label[index, None])
|
||||
dist.all_reduce(loss, dist.ReduceOp.SUM)
|
||||
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
|
||||
|
||||
# Calculate grad
|
||||
grad[index] -= one_hot
|
||||
grad.div_(features.size()[0])
|
||||
|
||||
logits.backward(grad)
|
||||
if total_features.grad is not None:
|
||||
total_features.grad.detach_()
|
||||
x_grad = torch.zeros_like(features)
|
||||
|
||||
# Feature gradient all-reduce
|
||||
dist.reduce_scatter(
|
||||
x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
|
||||
x_grad.mul_(cfg.world_size)
|
||||
# Backward backbone
|
||||
features.backward(x_grad)
|
||||
optimizer.step()
|
||||
|
||||
# Update classifer
|
||||
dist_sample_classifer.update()
|
||||
optimizer.zero_grad()
|
||||
losses.update(loss_v, 1)
|
||||
if cfg.local_rank == 0 and step % 50 == 0:
|
||||
time_now = (time.time() - train_start) / 3600
|
||||
time_total = time_now / ((global_step + 1) / total_step)
|
||||
time_for_end = time_total - time_now
|
||||
writer.add_scalar('time_for_end', time_for_end, global_step)
|
||||
writer.add_scalar('loss', loss_v, global_step)
|
||||
print("Speed %d samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" %
|
||||
(
|
||||
(cfg.batch_size * global_step / (time.time() - train_start) * cfg.world_size),
|
||||
losses.avg,
|
||||
epoch,
|
||||
global_step,
|
||||
time_for_end
|
||||
))
|
||||
losses.reset()
|
||||
|
||||
global_step += 1
|
||||
scheduler.step()
|
||||
if dist.get_rank() == 0:
|
||||
import os
|
||||
if not os.path.exists(cfg.output):
|
||||
os.makedirs(cfg.output)
|
||||
torch.save(backbone.module.state_dict(), os.path.join(cfg.output, str(epoch) + 'backbone.pth'))
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
|
||||
args = parser.parse_args()
|
||||
main(args.local_rank)
|
||||
@@ -1,3 +0,0 @@
|
||||
tensorboard
|
||||
easydict
|
||||
mxnet==1.6.0
|
||||
@@ -1,4 +0,0 @@
|
||||
# /usr/bin/
|
||||
export OMP_NUM_THREADS=4
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 partial_fc.py | tee hist.log
|
||||
ps -ef | grep "partial_fc" | grep -v grep | awk '{print "kill -9 "$2}' | sh
|
||||
@@ -1,71 +0,0 @@
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=required,
|
||||
momentum=0,
|
||||
dampening=0,
|
||||
weight_decay=0,
|
||||
nesterov=False,
|
||||
rescale=1):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if momentum < 0.0:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay))
|
||||
|
||||
defaults = dict(lr=lr,
|
||||
momentum=momentum,
|
||||
dampening=dampening,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov)
|
||||
self.rescale = rescale
|
||||
if nesterov and (momentum <= 0 or dampening != 0):
|
||||
raise ValueError(
|
||||
"Nesterov momentum requires a momentum and zero dampening")
|
||||
super(SGD, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(SGD, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('nesterov', False)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
dampening = group['dampening']
|
||||
nesterov = group['nesterov']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p.grad.data.div_(self.rescale)
|
||||
d_p = p.grad.data
|
||||
if weight_decay != 0:
|
||||
d_p.add_(alpha=weight_decay, other=p.data)
|
||||
if momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
buf = param_state['momentum_buffer'] = torch.clone(
|
||||
d_p).detach()
|
||||
else:
|
||||
buf = param_state['momentum_buffer']
|
||||
buf.mul_(momentum).add_(other=d_p, alpha=1 - dampening)
|
||||
if nesterov:
|
||||
d_p = d_p.add(alpha=momentum, other=buf)
|
||||
else:
|
||||
d_p = buf
|
||||
|
||||
p.data.add_(other=d_p, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user