From 507e644d522d472441d107d2e5b36fcbca375b3e Mon Sep 17 00:00:00 2001 From: nttstar Date: Mon, 19 Feb 2018 19:13:20 +0800 Subject: [PATCH] tiny --- src/data.py | 36 ++++++++++++++++++------------------ src/train_softmax.py | 25 ++++++++++++++----------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/src/data.py b/src/data.py index af8d0c9..b7efff1 100644 --- a/src/data.py +++ b/src/data.py @@ -108,34 +108,34 @@ class FaceImageIter(io.DataIter): self.imgidx = imgidx2 elif c2c_mode==1: imgidx2 = [] + tmp = [] for idx in self.imgidx: c = self.idx2cos[idx] f = self.idx2flag[idx] - if f==2 and c>=0.05: - continue - imgidx2.append(idx) + if f==1: + imgidx2.append(idx) + else: + tmp.append( (idx, c) ) + tmp = sorted(tmp, key = lambda x:x[1]) + tmp = tmp[250000:300000] + for _t in tmp: + imgidx2.append(_t[0]) print('idx count', len(self.imgidx), len(imgidx2)) self.imgidx = imgidx2 elif c2c_mode==2: imgidx2 = [] + tmp = [] for idx in self.imgidx: c = self.idx2cos[idx] f = self.idx2flag[idx] - if f==2 and c>=0.1: - continue - imgidx2.append(idx) - print('idx count', len(self.imgidx), len(imgidx2)) - self.imgidx = imgidx2 - elif c2c_mode==-1: - imgidx2 = [] - for idx in self.imgidx: - c = self.idx2cos[idx] - f = self.idx2flag[idx] - if f==2: - continue - if c<0.7: - continue - imgidx2.append(idx) + if f==1: + imgidx2.append(idx) + else: + tmp.append( (idx, c) ) + tmp = sorted(tmp, key = lambda x:x[1]) + tmp = tmp[200000:300000] + for _t in tmp: + imgidx2.append(_t[0]) print('idx count', len(self.imgidx), len(imgidx2)) self.imgidx = imgidx2 elif c2c_mode==-2: diff --git a/src/train_softmax.py b/src/train_softmax.py index eec7ac3..df36a02 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -445,7 +445,7 @@ def get_symbol(args, arg_params, aux_params): triplet_loss = mx.symbol.mean(triplet_loss) #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3) extra_loss = mx.symbol.MakeLoss(triplet_loss) - elif args.loss_type==13: #triplet loss with insightface margin + elif args.loss_type==13: #triplet loss with angular margin m = args.margin_m sin_m = math.sin(m) cos_m = math.cos(m) @@ -457,16 +457,19 @@ def get_symbol(args, arg_params, aux_params): an = anchor * negative ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) - #ap = mx.symbol.arccos(ap) - #an = mx.symbol.arccos(an) - #triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu') - body = ap*ap - body = 1.0-body - body = mx.symbol.sqrt(body) - body = body*sin_m - ap = ap*cos_m - ap = ap-body - triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu') + + ap = mx.symbol.arccos(ap) + an = mx.symbol.arccos(an) + triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu') + + #body = ap*ap + #body = 1.0-body + #body = mx.symbol.sqrt(body) + #body = body*sin_m + #ap = ap*cos_m + #ap = ap-body + #triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu') + triplet_loss = mx.symbol.mean(triplet_loss) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type==9: #coco loss