mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 05:27:56 +00:00
add coco loss
This commit is contained in:
153
src/data.py
153
src/data.py
@@ -20,9 +20,37 @@ from mxnet import io
|
||||
from mxnet import recordio
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), 'common'))
|
||||
import face_preprocess
|
||||
import multiprocessing
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
def pick_triplets_impl(q_in, q_out):
|
||||
more = True
|
||||
while more:
|
||||
deq = q_in.get()
|
||||
if deq is None:
|
||||
more = False
|
||||
else:
|
||||
embeddings, emb_start_idx, nrof_images, alpha = deq
|
||||
print('running', emb_start_idx, nrof_images, os.getpid())
|
||||
for j in xrange(1,nrof_images):
|
||||
a_idx = emb_start_idx + j - 1
|
||||
neg_dists_sqr = np.sum(np.square(embeddings[a_idx] - embeddings), 1)
|
||||
for pair in xrange(j, nrof_images): # For every possible positive pair.
|
||||
p_idx = emb_start_idx + pair
|
||||
pos_dist_sqr = np.sum(np.square(embeddings[a_idx]-embeddings[p_idx]))
|
||||
neg_dists_sqr[emb_start_idx:emb_start_idx+nrof_images] = np.NaN
|
||||
all_neg = np.where(np.logical_and(neg_dists_sqr-pos_dist_sqr<alpha, pos_dist_sqr<neg_dists_sqr))[0] # FaceNet selection
|
||||
#all_neg = np.where(neg_dists_sqr-pos_dist_sqr<alpha)[0] # VGG Face selecction
|
||||
nrof_random_negs = all_neg.shape[0]
|
||||
if nrof_random_negs>0:
|
||||
rnd_idx = np.random.randint(nrof_random_negs)
|
||||
n_idx = all_neg[rnd_idx]
|
||||
#triplets.append( (a_idx, p_idx, n_idx) )
|
||||
q_out.put( (a_idx, p_idx, n_idx) )
|
||||
#emb_start_idx += nrof_images
|
||||
print('exit',os.getpid())
|
||||
|
||||
class FaceImageIter(io.DataIter):
|
||||
|
||||
def __init__(self, batch_size, data_shape,
|
||||
@@ -30,7 +58,7 @@ class FaceImageIter(io.DataIter):
|
||||
shuffle=False, aug_list=None, mean = None,
|
||||
rand_mirror = False,
|
||||
ctx_num = 0, images_per_identity = 0, data_extra = None, hard_mining = False,
|
||||
triplet_params = None,
|
||||
triplet_params = None, coco_mode = False,
|
||||
mx_model = None,
|
||||
data_name='data', label_name='softmax_label', **kwargs):
|
||||
super(FaceImageIter, self).__init__()
|
||||
@@ -73,10 +101,6 @@ class FaceImageIter(io.DataIter):
|
||||
|
||||
self.check_data_shape(data_shape)
|
||||
self.provide_data = [(data_name, (batch_size,) + data_shape)]
|
||||
if len(label_name)>0:
|
||||
self.provide_label = [(label_name, (batch_size,))]
|
||||
else:
|
||||
self.provide_label = []
|
||||
self.batch_size = batch_size
|
||||
self.data_shape = data_shape
|
||||
self.shuffle = shuffle
|
||||
@@ -104,6 +128,14 @@ class FaceImageIter(io.DataIter):
|
||||
assert self.mx_model is not None
|
||||
self.triplet_params = triplet_params
|
||||
self.triplet_mode = False
|
||||
self.coco_mode = coco_mode
|
||||
if len(label_name)>0:
|
||||
self.provide_label = [(label_name, (batch_size,))]
|
||||
else:
|
||||
self.provide_label = []
|
||||
if self.coco_mode:
|
||||
assert self.triplet_params is None
|
||||
assert self.images_per_identity>0
|
||||
if self.triplet_params is not None:
|
||||
assert self.images_per_identity>0
|
||||
assert self.mx_model is not None
|
||||
@@ -121,12 +153,17 @@ class FaceImageIter(io.DataIter):
|
||||
self.times = [0.0, 0.0, 0.0]
|
||||
#self.reset()
|
||||
|
||||
def pick_triplets(self, embeddings, nrof_images_per_class):
|
||||
trip_idx = 0
|
||||
|
||||
def ____pick_triplets(self, embeddings, nrof_images_per_class):
|
||||
emb_start_idx = 0
|
||||
num_trips = 0
|
||||
triplets = []
|
||||
people_per_batch = len(nrof_images_per_class)
|
||||
nrof_threads = 8
|
||||
q_in = multiprocessing.Queue()
|
||||
q_out = multiprocessing.Queue()
|
||||
processes = [multiprocessing.Process(target=pick_triplets_impl, args=(q_in, q_out)) \
|
||||
for i in range(nrof_threads)]
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
# VGG Face: Choosing good triplets is crucial and should strike a balance between
|
||||
# selecting informative (i.e. challenging) examples and swamping training with examples that
|
||||
@@ -135,6 +172,35 @@ class FaceImageIter(io.DataIter):
|
||||
# latter is a form of hard-negative mining, but it is not as aggressive (and much cheaper) than
|
||||
# choosing the maximally violating example, as often done in structured output learning.
|
||||
|
||||
for i in xrange(people_per_batch):
|
||||
nrof_images = int(nrof_images_per_class[i])
|
||||
job = (embeddings, emb_start_idx, nrof_images, self.triplet_alpha)
|
||||
emb_start_idx+=nrof_images
|
||||
q_in.put(job)
|
||||
for i in xrange(nrof_threads):
|
||||
q_in.put(None)
|
||||
print('joining')
|
||||
for p in processes:
|
||||
p.join()
|
||||
print('joined')
|
||||
q_out.put(None)
|
||||
|
||||
triplets = []
|
||||
more = True
|
||||
while more:
|
||||
triplet = q_out.get()
|
||||
if triplet is None:
|
||||
more = False
|
||||
else:
|
||||
triplets.append(triplets)
|
||||
np.random.shuffle(triplets)
|
||||
return triplets
|
||||
|
||||
def pick_triplets(self, embeddings, nrof_images_per_class):
|
||||
emb_start_idx = 0
|
||||
triplets = []
|
||||
people_per_batch = len(nrof_images_per_class)
|
||||
|
||||
for i in xrange(people_per_batch):
|
||||
nrof_images = int(nrof_images_per_class[i])
|
||||
for j in xrange(1,nrof_images):
|
||||
@@ -150,17 +216,63 @@ class FaceImageIter(io.DataIter):
|
||||
if nrof_random_negs>0:
|
||||
rnd_idx = np.random.randint(nrof_random_negs)
|
||||
n_idx = all_neg[rnd_idx]
|
||||
#triplets.append((image_paths[a_idx], image_paths[p_idx], image_paths[n_idx]))
|
||||
triplets.append( (a_idx, p_idx, n_idx) )
|
||||
#triplets.append((image_paths[a_idx], image_paths[p_idx], image_paths[n_idx]))
|
||||
#print('Triplet %d: (%d, %d, %d), pos_dist=%2.6f, neg_dist=%2.6f (%d, %d, %d, %d, %d)' %
|
||||
# (trip_idx, a_idx, p_idx, n_idx, pos_dist_sqr, neg_dists_sqr[n_idx], nrof_random_negs, rnd_idx, i, j, emb_start_idx))
|
||||
trip_idx += 1
|
||||
|
||||
num_trips += 1
|
||||
|
||||
emb_start_idx += nrof_images
|
||||
np.random.shuffle(triplets)
|
||||
return triplets
|
||||
|
||||
def __pick_triplets(self, embeddings, nrof_images_per_class):
|
||||
emb_start_idx = 0
|
||||
triplets = []
|
||||
people_per_batch = len(nrof_images_per_class)
|
||||
|
||||
for i in xrange(people_per_batch):
|
||||
nrof_images = int(nrof_images_per_class[i])
|
||||
if nrof_images<2:
|
||||
continue
|
||||
for j in xrange(1,nrof_images):
|
||||
a_idx = emb_start_idx + j - 1
|
||||
pcount = nrof_images-1
|
||||
dists_a2all = np.sum(np.square(embeddings[a_idx] - embeddings), 1) #(N,)
|
||||
#print(a_idx, dists_a2all.shape)
|
||||
ba = emb_start_idx
|
||||
bb = emb_start_idx+nrof_images
|
||||
sorted_idx = np.argsort(dists_a2all)
|
||||
#print('assert', sorted_idx[0], a_idx)
|
||||
#assert sorted_idx[0]==a_idx
|
||||
#for idx in sorted_idx:
|
||||
# print(idx, dists_a2all[idx])
|
||||
p2n_map = {}
|
||||
pfound = 0
|
||||
for idx in sorted_idx:
|
||||
if idx==a_idx: #is anchor
|
||||
continue
|
||||
if idx<bb and idx>=ba: #is pos
|
||||
p2n_map[idx] = [dists_a2all[idx], []] #ap, [neg_list]
|
||||
pfound+=1
|
||||
else: # is neg
|
||||
an = dists_a2all[idx]
|
||||
if pfound==pcount and len(p2n_map)==0:
|
||||
break
|
||||
to_del = []
|
||||
for p_idx in p2n_map:
|
||||
v = p2n_map[p_idx]
|
||||
an_ap = an - v[0]
|
||||
if an_ap<self.triplet_alpha:
|
||||
v[1].append(idx)
|
||||
else:
|
||||
#output
|
||||
if len(v[1])>0:
|
||||
n_idx = random.choice(v[1])
|
||||
triplets.append( (a_idx, p_idx, n_idx) )
|
||||
to_del.append(p_idx)
|
||||
for _del in to_del:
|
||||
del p2n_map[_del]
|
||||
for p_idx,v in p2n_map.iteritems():
|
||||
if len(v[1])>0:
|
||||
n_idx = random.choice(v[1])
|
||||
triplets.append( (a_idx, p_idx, n_idx) )
|
||||
emb_start_idx += nrof_images
|
||||
np.random.shuffle(triplets)
|
||||
return triplets
|
||||
|
||||
@@ -202,10 +314,10 @@ class FaceImageIter(io.DataIter):
|
||||
tag = []
|
||||
#idx = np.zeros( (bag_size,) )
|
||||
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
|
||||
print('triplet time stat', self.times)
|
||||
if self.triplet_oseq_cur+bag_size>len(self.oseq):
|
||||
self.triplet_oseq_reset()
|
||||
print('eval %d images..'%bag_size, self.triplet_oseq_cur)
|
||||
print('triplet time stat', self.times)
|
||||
self.times[0] += self.time_elapsed()
|
||||
self.time_reset()
|
||||
#print(data.shape)
|
||||
@@ -520,7 +632,10 @@ class FaceImageIter(io.DataIter):
|
||||
#print(datum.shape)
|
||||
batch_data[i][:] = self.postprocess_data(datum)
|
||||
if self.provide_label is not None:
|
||||
batch_label[i][:] = label
|
||||
if not self.coco_mode:
|
||||
batch_label[i][:] = label
|
||||
else:
|
||||
batch_label[i][:] = (i%self.per_batch_size)//self.images_per_identity
|
||||
i += 1
|
||||
except StopIteration:
|
||||
if i<batch_size:
|
||||
|
||||
@@ -272,21 +272,18 @@ def get_symbol(args, arg_params, aux_params):
|
||||
triplet_loss = mx.symbol.mean(triplet_loss)
|
||||
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
|
||||
extra_loss = mx.symbol.MakeLoss(triplet_loss)
|
||||
elif args.loss_type==13: #triplet loss II
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
|
||||
anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
|
||||
positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
|
||||
negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
|
||||
ap = anchor - positive
|
||||
an = anchor - negative
|
||||
ap = ap*ap
|
||||
an = an*an
|
||||
ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
|
||||
an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
|
||||
triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')
|
||||
triplet_loss = mx.symbol.mean(triplet_loss)
|
||||
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
|
||||
extra_loss = mx.symbol.MakeLoss(triplet_loss)
|
||||
elif args.loss_type==9: #coco loss
|
||||
centroids = []
|
||||
for i in xrange(args.per_identities):
|
||||
xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity)
|
||||
mean = mx.symbol.mean(xs, axis=0, keepdims=True)
|
||||
mean = mx.symbol.L2Normalization(mean, mode='instance')
|
||||
centroids.append(mean)
|
||||
centroids = mx.symbol.concat(*centroids, dim=0)
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale
|
||||
fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities)
|
||||
#extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
|
||||
#extra_loss = mx.symbol.BlockGrad(extra_loss)
|
||||
else:
|
||||
#embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
|
||||
embedding = embedding * 5
|
||||
@@ -375,6 +372,7 @@ def train_net(args):
|
||||
|
||||
assert(args.num_classes>0)
|
||||
print('num_classes', args.num_classes)
|
||||
args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3
|
||||
|
||||
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
|
||||
path_imgrec = os.path.join(data_dir, "train.rec")
|
||||
@@ -394,13 +392,13 @@ def train_net(args):
|
||||
args.beta_freeze = 5000
|
||||
args.gamma = 0.06
|
||||
|
||||
if args.loss_type<10:
|
||||
if args.loss_type<9:
|
||||
assert args.images_per_identity==0
|
||||
else:
|
||||
if args.images_per_identity==0:
|
||||
if args.loss_type==11:
|
||||
args.images_per_identity = 2
|
||||
elif args.loss_type==10:
|
||||
elif args.loss_type==10 or args.loss_type==9:
|
||||
args.images_per_identity = 16
|
||||
elif args.loss_type==12:
|
||||
args.images_per_identity = 5
|
||||
@@ -445,6 +443,7 @@ def train_net(args):
|
||||
data_extra = None
|
||||
hard_mining = False
|
||||
triplet_params = None
|
||||
coco_mode = False
|
||||
if args.loss_type==10:
|
||||
hard_mining = True
|
||||
_shape = (args.batch_size, args.per_batch_size)
|
||||
@@ -467,6 +466,8 @@ def train_net(args):
|
||||
c+=args.per_batch_size
|
||||
elif args.loss_type==12:
|
||||
triplet_params = [args.triplet_bag_size, args.triplet_alpha]
|
||||
elif args.loss_type==9:
|
||||
coco_mode = True
|
||||
|
||||
label_name = 'softmax_label'
|
||||
if data_extra is None:
|
||||
@@ -497,6 +498,7 @@ def train_net(args):
|
||||
data_extra = data_extra,
|
||||
hard_mining = hard_mining,
|
||||
triplet_params = triplet_params,
|
||||
coco_mode = coco_mode,
|
||||
mx_model = model,
|
||||
label_name = label_name,
|
||||
)
|
||||
@@ -516,6 +518,7 @@ def train_net(args):
|
||||
data_extra = data_extra,
|
||||
hard_mining = hard_mining,
|
||||
triplet_params = triplet_params,
|
||||
coco_mode = coco_mode,
|
||||
mx_model = model,
|
||||
label_name = label_name,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user