This commit is contained in:
nttstar
2018-02-19 19:13:20 +08:00
parent ef24eecc05
commit 507e644d52
2 changed files with 32 additions and 29 deletions

View File

@@ -108,34 +108,34 @@ class FaceImageIter(io.DataIter):
self.imgidx = imgidx2
elif c2c_mode==1:
imgidx2 = []
tmp = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f==2 and c>=0.05:
continue
imgidx2.append(idx)
if f==1:
imgidx2.append(idx)
else:
tmp.append( (idx, c) )
tmp = sorted(tmp, key = lambda x:x[1])
tmp = tmp[250000:300000]
for _t in tmp:
imgidx2.append(_t[0])
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_mode==2:
imgidx2 = []
tmp = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f==2 and c>=0.1:
continue
imgidx2.append(idx)
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_mode==-1:
imgidx2 = []
for idx in self.imgidx:
c = self.idx2cos[idx]
f = self.idx2flag[idx]
if f==2:
continue
if c<0.7:
continue
imgidx2.append(idx)
if f==1:
imgidx2.append(idx)
else:
tmp.append( (idx, c) )
tmp = sorted(tmp, key = lambda x:x[1])
tmp = tmp[200000:300000]
for _t in tmp:
imgidx2.append(_t[0])
print('idx count', len(self.imgidx), len(imgidx2))
self.imgidx = imgidx2
elif c2c_mode==-2:

View File

@@ -445,7 +445,7 @@ def get_symbol(args, arg_params, aux_params):
triplet_loss = mx.symbol.mean(triplet_loss)
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
extra_loss = mx.symbol.MakeLoss(triplet_loss)
elif args.loss_type==13: #triplet loss with insightface margin
elif args.loss_type==13: #triplet loss with angular margin
m = args.margin_m
sin_m = math.sin(m)
cos_m = math.cos(m)
@@ -457,16 +457,19 @@ def get_symbol(args, arg_params, aux_params):
an = anchor * negative
ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
#ap = mx.symbol.arccos(ap)
#an = mx.symbol.arccos(an)
#triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu')
body = ap*ap
body = 1.0-body
body = mx.symbol.sqrt(body)
body = body*sin_m
ap = ap*cos_m
ap = ap-body
triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu')
ap = mx.symbol.arccos(ap)
an = mx.symbol.arccos(an)
triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu')
#body = ap*ap
#body = 1.0-body
#body = mx.symbol.sqrt(body)
#body = body*sin_m
#ap = ap*cos_m
#ap = ap-body
#triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu')
triplet_loss = mx.symbol.mean(triplet_loss)
extra_loss = mx.symbol.MakeLoss(triplet_loss)
elif args.loss_type==9: #coco loss