mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-20 08:25:14 +00:00
tiny
This commit is contained in:
36
src/data.py
36
src/data.py
@@ -108,34 +108,34 @@ class FaceImageIter(io.DataIter):
|
||||
self.imgidx = imgidx2
|
||||
elif c2c_mode==1:
|
||||
imgidx2 = []
|
||||
tmp = []
|
||||
for idx in self.imgidx:
|
||||
c = self.idx2cos[idx]
|
||||
f = self.idx2flag[idx]
|
||||
if f==2 and c>=0.05:
|
||||
continue
|
||||
imgidx2.append(idx)
|
||||
if f==1:
|
||||
imgidx2.append(idx)
|
||||
else:
|
||||
tmp.append( (idx, c) )
|
||||
tmp = sorted(tmp, key = lambda x:x[1])
|
||||
tmp = tmp[250000:300000]
|
||||
for _t in tmp:
|
||||
imgidx2.append(_t[0])
|
||||
print('idx count', len(self.imgidx), len(imgidx2))
|
||||
self.imgidx = imgidx2
|
||||
elif c2c_mode==2:
|
||||
imgidx2 = []
|
||||
tmp = []
|
||||
for idx in self.imgidx:
|
||||
c = self.idx2cos[idx]
|
||||
f = self.idx2flag[idx]
|
||||
if f==2 and c>=0.1:
|
||||
continue
|
||||
imgidx2.append(idx)
|
||||
print('idx count', len(self.imgidx), len(imgidx2))
|
||||
self.imgidx = imgidx2
|
||||
elif c2c_mode==-1:
|
||||
imgidx2 = []
|
||||
for idx in self.imgidx:
|
||||
c = self.idx2cos[idx]
|
||||
f = self.idx2flag[idx]
|
||||
if f==2:
|
||||
continue
|
||||
if c<0.7:
|
||||
continue
|
||||
imgidx2.append(idx)
|
||||
if f==1:
|
||||
imgidx2.append(idx)
|
||||
else:
|
||||
tmp.append( (idx, c) )
|
||||
tmp = sorted(tmp, key = lambda x:x[1])
|
||||
tmp = tmp[200000:300000]
|
||||
for _t in tmp:
|
||||
imgidx2.append(_t[0])
|
||||
print('idx count', len(self.imgidx), len(imgidx2))
|
||||
self.imgidx = imgidx2
|
||||
elif c2c_mode==-2:
|
||||
|
||||
@@ -445,7 +445,7 @@ def get_symbol(args, arg_params, aux_params):
|
||||
triplet_loss = mx.symbol.mean(triplet_loss)
|
||||
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
|
||||
extra_loss = mx.symbol.MakeLoss(triplet_loss)
|
||||
elif args.loss_type==13: #triplet loss with insightface margin
|
||||
elif args.loss_type==13: #triplet loss with angular margin
|
||||
m = args.margin_m
|
||||
sin_m = math.sin(m)
|
||||
cos_m = math.cos(m)
|
||||
@@ -457,16 +457,19 @@ def get_symbol(args, arg_params, aux_params):
|
||||
an = anchor * negative
|
||||
ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
|
||||
an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
|
||||
#ap = mx.symbol.arccos(ap)
|
||||
#an = mx.symbol.arccos(an)
|
||||
#triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu')
|
||||
body = ap*ap
|
||||
body = 1.0-body
|
||||
body = mx.symbol.sqrt(body)
|
||||
body = body*sin_m
|
||||
ap = ap*cos_m
|
||||
ap = ap-body
|
||||
triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu')
|
||||
|
||||
ap = mx.symbol.arccos(ap)
|
||||
an = mx.symbol.arccos(an)
|
||||
triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu')
|
||||
|
||||
#body = ap*ap
|
||||
#body = 1.0-body
|
||||
#body = mx.symbol.sqrt(body)
|
||||
#body = body*sin_m
|
||||
#ap = ap*cos_m
|
||||
#ap = ap-body
|
||||
#triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu')
|
||||
|
||||
triplet_loss = mx.symbol.mean(triplet_loss)
|
||||
extra_loss = mx.symbol.MakeLoss(triplet_loss)
|
||||
elif args.loss_type==9: #coco loss
|
||||
|
||||
Reference in New Issue
Block a user