diff --git a/src/data.py b/src/data.py index d8f273f..1466d8f 100644 --- a/src/data.py +++ b/src/data.py @@ -108,34 +108,34 @@ class FaceImageIter(io.DataIter): self.imgidx = imgidx2 elif c2c_mode==1: imgidx2 = [] + tmp = [] for idx in self.imgidx: c = self.idx2cos[idx] f = self.idx2flag[idx] - if f==2 and c>=0.05: - continue - imgidx2.append(idx) + if f==1: + imgidx2.append(idx) + else: + tmp.append( (idx, c) ) + tmp = sorted(tmp, key = lambda x:x[1]) + tmp = tmp[250000:300000] + for _t in tmp: + imgidx2.append(_t[0]) print('idx count', len(self.imgidx), len(imgidx2)) self.imgidx = imgidx2 elif c2c_mode==2: imgidx2 = [] + tmp = [] for idx in self.imgidx: c = self.idx2cos[idx] f = self.idx2flag[idx] - if f==2 and c>=0.1: - continue - imgidx2.append(idx) - print('idx count', len(self.imgidx), len(imgidx2)) - self.imgidx = imgidx2 - elif c2c_mode==-1: - imgidx2 = [] - for idx in self.imgidx: - c = self.idx2cos[idx] - f = self.idx2flag[idx] - if f==2: - continue - if c<0.7: - continue - imgidx2.append(idx) + if f==1: + imgidx2.append(idx) + else: + tmp.append( (idx, c) ) + tmp = sorted(tmp, key = lambda x:x[1]) + tmp = tmp[200000:300000] + for _t in tmp: + imgidx2.append(_t[0]) print('idx count', len(self.imgidx), len(imgidx2)) self.imgidx = imgidx2 elif c2c_mode==-2: diff --git a/src/train_softmax.py b/src/train_softmax.py index c42d251..d3eda2d 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -54,7 +54,7 @@ class AccMetric(mx.metric.EvalMetric): def update(self, labels, preds): self.count+=1 - if args.loss_type>=2 and args.loss_type<=5 and args.margin_verbose>0: + if args.loss_type>=2 and args.loss_type<=7 and args.margin_verbose>0: if self.count%args.ctx_num==0: mbatch = self.count//args.ctx_num if mbatch==1 or mbatch%args.margin_verbose==0: @@ -137,6 +137,8 @@ def parse_args(): help='') parser.add_argument('--margin-s', type=float, default=64.0, help='') + parser.add_argument('--margin-a', type=float, default=0.0, + help='') parser.add_argument('--easy-margin', type=int, default=0, help='') parser.add_argument('--margin-verbose', type=int, default=0, @@ -399,13 +401,93 @@ def get_symbol(args, arg_params, aux_params): assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') - nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') - cos_a = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') - theta_a = mx.sym.arccos(cos_a) - gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = m, off_value = 0.0) - theta_a = theta_a+gt_one_hot - fc7 = math.pi/2 - theta_a - fc7 = fc7*s + #nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') + #cos_a = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') + #theta_a = mx.sym.arccos(cos_a) + #gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = m, off_value = 0.0) + #theta_a = theta_a+gt_one_hot + #fc7 = math.pi/2 - theta_a + #fc7 = fc7*s + nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s + fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') + zy = mx.sym.pick(fc7, gt_label, axis=1) + cos_t = zy/s + if args.margin_verbose>0: + margin_symbols.append(mx.symbol.mean(cos_t)) + if m>0.0: + #m = m*1.1 + #m_min = 0.3 + #var_m = m + #cos_ta = mx.symbol.Activation(data=cos_t, act_type='relu') + #cos_ta = cos_t + 1.001 + cos_ta = cos_t - 0.7 + cos_ta = mx.symbol.Activation(data=cos_ta, act_type='relu') + #cos_t_max = mx.symbol.max(cos_ta) + #cos_t_min = mx.symbol.min(cos_ta) + #cos_t_gap = cos_t_max-cos_t_min + #cos_t_max = cos_t_max + 1.0e-6 + #r = mx.symbol.broadcast_div(cos_ta,cos_t_max) + #r = cos_ta / 1.7 + r = cos_ta+0.7 + var_m = r*m + + t = mx.sym.arccos(cos_t) + t = t+var_m + body = mx.sym.cos(t) + new_zy = body*s + if args.margin_verbose>0: + new_cos_t = new_zy/s + margin_symbols.append(mx.symbol.mean(new_cos_t)) + #margin_symbols.append(mx.symbol.mean(var_m)) + diff = new_zy - zy + diff = mx.sym.expand_dims(diff, 1) + gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) + body = mx.sym.broadcast_mul(gt_one_hot, diff) + fc7 = fc7+body + elif args.loss_type==6: + s = args.margin_s + m = args.margin_m + assert s>0.0 + assert m>=0.0 + assert m<(math.pi/2) + _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) + _weight = mx.symbol.L2Normalization(_weight, mode='instance') + nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s + fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') + zy = mx.sym.pick(fc7, gt_label, axis=1) + cos_t = zy/s + t = mx.sym.arccos(cos_t) + if args.margin_verbose>0: + margin_symbols.append(mx.symbol.mean(t)) + t_min = mx.sym.min(t) + ta = mx.sym.broadcast_div(t_min, t) + + a1 = args.margin_a + r1 = ta-a1 + r1 = mx.symbol.Activation(data=r1, act_type='relu') + r1 = r1+a1 + + a2 = 1.0 + r2 = ta-a2 + r2 = mx.symbol.Activation(data=r2, act_type='relu') + r2 = r2+a2 + + cond = t-1.0 + cond = mx.symbol.Activation(data=cond, act_type='relu') + r = mx.sym.where(cond, r2, r1) + var_m = r*m + t = t+var_m + body = mx.sym.cos(t) + new_zy = body*s + if args.margin_verbose>0: + #new_cos_t = new_zy/s + #margin_symbols.append(mx.symbol.mean(new_cos_t)) + margin_symbols.append(mx.symbol.mean(t)) + diff = new_zy - zy + diff = mx.sym.expand_dims(diff, 1) + gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) + body = mx.sym.broadcast_mul(gt_one_hot, diff) + fc7 = fc7+body elif args.loss_type==10: #marginal loss nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') params = [1.2, 0.3, 1.0] @@ -475,7 +557,7 @@ def get_symbol(args, arg_params, aux_params): triplet_loss = mx.symbol.mean(triplet_loss) #triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3) extra_loss = mx.symbol.MakeLoss(triplet_loss) - elif args.loss_type==13: #triplet loss with insightface margin + elif args.loss_type==13: #triplet loss with angular margin m = args.margin_m sin_m = math.sin(m) cos_m = math.cos(m) @@ -487,16 +569,19 @@ def get_symbol(args, arg_params, aux_params): an = anchor * negative ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1) an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1) - #ap = mx.symbol.arccos(ap) - #an = mx.symbol.arccos(an) - #triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu') - body = ap*ap - body = 1.0-body - body = mx.symbol.sqrt(body) - body = body*sin_m - ap = ap*cos_m - ap = ap-body - triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu') + + ap = mx.symbol.arccos(ap) + an = mx.symbol.arccos(an) + triplet_loss = mx.symbol.Activation(data = (ap-an+args.margin_m), act_type='relu') + + #body = ap*ap + #body = 1.0-body + #body = mx.symbol.sqrt(body) + #body = body*sin_m + #ap = ap*cos_m + #ap = ap-body + #triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu') + triplet_loss = mx.symbol.mean(triplet_loss) extra_loss = mx.symbol.MakeLoss(triplet_loss) elif args.loss_type==9: #coco loss @@ -833,7 +918,7 @@ def train_net(args): save_step = [0] if len(args.lr_steps)==0: lr_steps = [40000, 60000, 80000] - if args.loss_type>=1 and args.loss_type<=5: + if args.loss_type>=1 and args.loss_type<=7: lr_steps = [100000, 140000, 160000] p = 512.0/args.batch_size for l in xrange(len(lr_steps)):