mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-15 04:37:50 +00:00
203 lines
6.0 KiB
Python
203 lines
6.0 KiB
Python
import os
|
|
import shutil
|
|
import datetime
|
|
import sys
|
|
from mxnet import ndarray as nd
|
|
import mxnet as mx
|
|
import random
|
|
import argparse
|
|
import numbers
|
|
import cv2
|
|
import time
|
|
import pickle
|
|
import sklearn
|
|
import sklearn.preprocessing
|
|
from easydict import EasyDict as edict
|
|
import numpy as np
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'common'))
|
|
from rec_builder import *
|
|
|
|
def get_embedding(args, imgrec, a, b, image_size, model):
|
|
ocontents = []
|
|
for idx in range(a, b):
|
|
s = imgrec.read_idx(idx)
|
|
ocontents.append(s)
|
|
embeddings = None
|
|
#print(len(ocontents))
|
|
ba = 0
|
|
rlabel = -1
|
|
imgs = []
|
|
contents = []
|
|
while True:
|
|
bb = min(ba+args.batch_size, len(ocontents))
|
|
if ba>=bb:
|
|
break
|
|
_batch_size = bb-ba
|
|
#_batch_size2 = max(_batch_size, args.ctx_num)
|
|
_batch_size2 = _batch_size
|
|
if _batch_size%args.ctx_num!=0:
|
|
_batch_size2 = ((_batch_size//args.ctx_num)+1) * args.ctx_num
|
|
data = np.zeros( (_batch_size2,3, image_size[0], image_size[1]) )
|
|
count = bb-ba
|
|
ii=0
|
|
for i in range(ba, bb):
|
|
header, img = mx.recordio.unpack(ocontents[i])
|
|
contents.append(img)
|
|
label = header.label
|
|
if not isinstance(label, numbers.Number):
|
|
label = label[0]
|
|
if rlabel<0:
|
|
rlabel = int(label)
|
|
|
|
img = mx.image.imdecode(img)
|
|
rgb = img.asnumpy()
|
|
bgr = rgb[:,:,::-1]
|
|
imgs.append(bgr)
|
|
img = rgb.transpose( (2,0,1) )
|
|
data[ii] = img
|
|
ii+=1
|
|
while ii<_batch_size2:
|
|
data[ii] = data[0]
|
|
ii+=1
|
|
nddata = nd.array(data)
|
|
db = mx.io.DataBatch(data=(nddata,))
|
|
model.forward(db, is_train=False)
|
|
net_out = model.get_outputs()
|
|
net_out = net_out[0].asnumpy()
|
|
if embeddings is None:
|
|
embeddings = np.zeros( (len(ocontents), net_out.shape[1]))
|
|
embeddings[ba:bb,:] = net_out[0:_batch_size,:]
|
|
ba = bb
|
|
embeddings = sklearn.preprocessing.normalize(embeddings)
|
|
return embeddings, rlabel, contents
|
|
|
|
def main(args):
|
|
print(args)
|
|
image_size = (112,112)
|
|
print('image_size', image_size)
|
|
vec = args.model.split(',')
|
|
prefix = vec[0]
|
|
epoch = int(vec[1])
|
|
print('loading',prefix, epoch)
|
|
ctx = []
|
|
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
|
|
if len(cvd)>0:
|
|
for i in range(len(cvd.split(','))):
|
|
ctx.append(mx.gpu(i))
|
|
if len(ctx)==0:
|
|
ctx = [mx.cpu()]
|
|
print('use cpu')
|
|
else:
|
|
print('gpu num:', len(ctx))
|
|
args.ctx_num = len(ctx)
|
|
args.batch_size *= args.ctx_num
|
|
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
|
W = None
|
|
i = 0
|
|
while True:
|
|
key = 'fc7_%d_weight'%i
|
|
i+=1
|
|
if key not in arg_params:
|
|
break
|
|
_W = arg_params[key].asnumpy()
|
|
#_W = _W.reshape( (-1, 10, 512) )
|
|
if W is None:
|
|
W = _W
|
|
else:
|
|
W = np.concatenate( (W, _W), axis=0 )
|
|
K = args.k
|
|
W = sklearn.preprocessing.normalize(W)
|
|
W = W.reshape( (-1, K, 512) )
|
|
all_layers = sym.get_internals()
|
|
sym = all_layers['fc1_output']
|
|
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
|
|
model.bind(data_shapes=[('data', (args.ctx_num, 3, image_size[0], image_size[1]))])
|
|
model.set_params(arg_params, aux_params)
|
|
print('W:',W.shape)
|
|
path_imgrec = os.path.join(args.data, 'train.rec')
|
|
path_imgidx = os.path.join(args.data, 'train.idx')
|
|
imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
|
|
id_list = []
|
|
s = imgrec.read_idx(0)
|
|
header, _ = mx.recordio.unpack(s)
|
|
assert header.flag>0
|
|
print('header0 label', header.label)
|
|
header0 = (int(header.label[0]), int(header.label[1]))
|
|
#assert(header.flag==1)
|
|
imgidx = range(1, int(header.label[0]))
|
|
id2range = {}
|
|
a, b = int(header.label[0]), int(header.label[1])
|
|
seq_identity = range(a,b)
|
|
print(len(seq_identity))
|
|
image_count = 0
|
|
pp=0
|
|
for wid, identity in enumerate(seq_identity):
|
|
pp+=1
|
|
s = imgrec.read_idx(identity)
|
|
header, _ = mx.recordio.unpack(s)
|
|
contents = []
|
|
a,b = int(header.label[0]), int(header.label[1])
|
|
_count = b-a
|
|
id_list.append( (wid, a, b, _count) )
|
|
image_count += _count
|
|
pp = 0
|
|
if not os.path.exists(args.output):
|
|
os.makedirs(args.output)
|
|
ret = np.zeros( (image_count, K+1), dtype=np.float32 )
|
|
output_dir = args.output
|
|
builder = SeqRecBuilder(output_dir)
|
|
print(ret.shape)
|
|
imid = 0
|
|
da = datetime.datetime.now()
|
|
label = 0
|
|
num_images = 0
|
|
cos_thresh = np.cos(np.pi*args.threshold / 180.0)
|
|
for id_item in id_list:
|
|
wid = id_item[0]
|
|
pp+=1
|
|
if pp%40==0:
|
|
db = datetime.datetime.now()
|
|
print('processing id', pp, (db-da).total_seconds())
|
|
da = db
|
|
x, _, contents = get_embedding(args, imgrec, id_item[1], id_item[2], image_size, model)
|
|
subcenters = W[wid]
|
|
K_stat = np.zeros( (K, ), dtype=np.int)
|
|
for i in range(x.shape[0]):
|
|
_x = x[i]
|
|
sim = np.dot(subcenters, _x) # len(sim)==K
|
|
mc = np.argmax(sim)
|
|
K_stat[mc] += 1
|
|
dominant_index = np.argmax(K_stat)
|
|
dominant_center = subcenters[dominant_index]
|
|
sim = np.dot(x, dominant_center)
|
|
idx = np.where(sim>cos_thresh)[0]
|
|
num_drop = x.shape[0] - len(idx)
|
|
if len(idx)==0:
|
|
continue
|
|
#print("labelid %d dropped %d, from %d to %d"% (wid, num_drop, x.shape[0], len(idx)))
|
|
num_images += len(idx)
|
|
for _idx in idx:
|
|
c = contents[_idx]
|
|
builder.add(label, c, is_image=False)
|
|
label+=1
|
|
builder.close()
|
|
|
|
print('total:', num_images)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='')
|
|
# general
|
|
parser.add_argument('--data', default='/bigdata/faces_ms1m_full', type=str, help='')
|
|
parser.add_argument('--output', default='/bigdata/ms1m_full_k3drop075', type=str, help='')
|
|
parser.add_argument('--model', default='../Evaluation/IJB/pretrained_models/r50-arcfacesc-msf-k3z/model,2', help='path to load model.')
|
|
parser.add_argument('--batch-size', default=16, type=int, help='')
|
|
parser.add_argument('--threshold', default=75, type=float, help='')
|
|
parser.add_argument('--k', default=3, type=int, help='')
|
|
args = parser.parse_args()
|
|
main(args)
|
|
|