diff --git a/src/data.py b/src/data.py index d1b1789..7ff0438 100644 --- a/src/data.py +++ b/src/data.py @@ -769,6 +769,14 @@ class FaceImageIter(io.DataIter): v = numpy.random.normal(mu, sigma) v = math.abs(v)*-1.0+mrange[1] v = max(v, mrange[0]) + elif self.output_c2c==5: + v = np.random.uniform(0.41, 0.51) + if count>=175: + v = np.random.uniform(0.37, 0.47) + elif self.output_c2c==6: + v = np.random.uniform(0.41, 0.51) + if count>=175: + v = np.random.uniform(0.38, 0.48) else: assert False diff --git a/src/train_softmax.py b/src/train_softmax.py index dd50245..02dd32d 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -133,12 +133,14 @@ def parse_args(): help='') parser.add_argument('--per-batch-size', type=int, default=128, help='') - parser.add_argument('--margin-m', type=float, default=0.35, + parser.add_argument('--margin-m', type=float, default=0.5, help='') parser.add_argument('--margin-s', type=float, default=64.0, help='') parser.add_argument('--margin-a', type=float, default=0.0, help='') + parser.add_argument('--margin-b', type=float, default=0.0, + help='') parser.add_argument('--easy-margin', type=int, default=0, help='') parser.add_argument('--margin-verbose', type=int, default=0, @@ -172,7 +174,7 @@ def parse_args(): parser.add_argument('--triplet-alpha', type=float, default=0.3, help='') parser.add_argument('--triplet-max-ap', type=float, default=0.0, help='') parser.add_argument('--verbose', type=int, default=2000, help='') - parser.add_argument('--loss-type', type=int, default=1, + parser.add_argument('--loss-type', type=int, default=4, help='') parser.add_argument('--incay', type=float, default=0.0, help='feature incay') @@ -396,48 +398,69 @@ def get_symbol(args, arg_params, aux_params): body = mx.sym.broadcast_mul(gt_one_hot, diff) fc7 = fc7+body elif args.loss_type==5: + #s = args.margin_s + #m = args.margin_m + #assert s>0.0 + #assert m>=0.0 + #assert m<(math.pi/2) + #_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) + #_weight = mx.symbol.L2Normalization(_weight, mode='instance') + #nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s + #fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') + #zy = mx.sym.pick(fc7, gt_label, axis=1) + #cos_t = zy/s + #if args.margin_verbose>0: + # margin_symbols.append(mx.symbol.mean(cos_t)) + #if m>0.0: + # a1 = args.margin_a + # r1 = ta-a1 + # r1 = mx.symbol.Activation(data=r1, act_type='relu') + # r1 = r1+a1 + # t = mx.sym.arccos(cos_t) + # cond = t-1.0 + # cond = mx.symbol.Activation(data=cond, act_type='relu') + # r = mx.sym.where(cond, r2, r1) + # t = t+var_m + # body = mx.sym.cos(t) + # new_zy = body*s + # if args.margin_verbose>0: + # new_cos_t = new_zy/s + # margin_symbols.append(mx.symbol.mean(new_cos_t)) + # #margin_symbols.append(mx.symbol.mean(var_m)) + # diff = new_zy - zy + # diff = mx.sym.expand_dims(diff, 1) + # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) + # body = mx.sym.broadcast_mul(gt_one_hot, diff) + # fc7 = fc7+body s = args.margin_s m = args.margin_m assert s>0.0 - assert m>=0.0 - assert m<(math.pi/2) + #assert m>=0.0 + #assert m<(math.pi/2) _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') - #nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') - #cos_a = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') - #theta_a = mx.sym.arccos(cos_a) - #gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = m, off_value = 0.0) - #theta_a = theta_a+gt_one_hot - #fc7 = math.pi/2 - theta_a - #fc7 = fc7*s nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s + t = mx.sym.arccos(cos_t) if args.margin_verbose>0: - margin_symbols.append(mx.symbol.mean(cos_t)) - if m>0.0: - - a1 = args.margin_a - r1 = ta-a1 - r1 = mx.symbol.Activation(data=r1, act_type='relu') - r1 = r1+a1 - t = mx.sym.arccos(cos_t) - cond = t-1.0 - cond = mx.symbol.Activation(data=cond, act_type='relu') - r = mx.sym.where(cond, r2, r1) - t = t+var_m - body = mx.sym.cos(t) - new_zy = body*s - if args.margin_verbose>0: - new_cos_t = new_zy/s - margin_symbols.append(mx.symbol.mean(new_cos_t)) - #margin_symbols.append(mx.symbol.mean(var_m)) - diff = new_zy - zy - diff = mx.sym.expand_dims(diff, 1) - gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) - body = mx.sym.broadcast_mul(gt_one_hot, diff) - fc7 = fc7+body + margin_symbols.append(mx.symbol.mean(t)) + if args.margin_a>0.0: + t = t*args.margin_a + if args.margin_m>0.0: + t = t+args.margin_m + body = mx.sym.cos(t) + if args.margin_b>0.0: + body = body - args.margin_b + new_zy = body*s + if args.margin_verbose>0: + margin_symbols.append(mx.symbol.mean(t)) + diff = new_zy - zy + diff = mx.sym.expand_dims(diff, 1) + gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) + body = mx.sym.broadcast_mul(gt_one_hot, diff) + fc7 = fc7+body elif args.loss_type==6: s = args.margin_s m = args.margin_m