add c2c label

This commit is contained in:
Jia Guo
2018-01-30 20:13:45 +08:00
parent 24c8b372ad
commit f0cb7f6b52
2 changed files with 88 additions and 43 deletions

View File

@@ -137,6 +137,10 @@ def parse_args():
help='')
parser.add_argument('--margin-verbose', type=int, default=0,
help='')
parser.add_argument('--c2c-threshold', type=float, default=0.0,
help='')
parser.add_argument('--output-c2c', type=int, default=0,
help='')
parser.add_argument('--margin', type=int, default=4,
help='')
parser.add_argument('--beta', type=float, default=1000.,
@@ -214,7 +218,12 @@ def get_symbol(args, arg_params, aux_params):
embedding = fresnet.get_symbol(args.emb_size, args.num_layers,
version_se=args.version_se, version_input=args.version_input,
version_output=args.version_output, version_unit=args.version_unit)
gt_label = mx.symbol.Variable('softmax_label')
all_label = mx.symbol.Variable('softmax_label')
if not args.output_c2c:
gt_label = all_label
else:
gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
assert args.loss_type>=0
extra_loss = None
if args.loss_type==0: #softmax
@@ -302,46 +311,56 @@ def get_symbol(args, arg_params, aux_params):
s = args.margin_s
m = args.margin_m
assert s>0.0
assert m>0.0
assert m>=0.0
assert m<(math.pi/2)
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = math.sin(math.pi-m)*m
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
#threshold = 0.0
threshold = math.cos(math.pi-m)
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
zy = mx.sym.pick(fc7, gt_label, axis=1)
cos_t = zy/s
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(cos_t))
if args.easy_margin:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
#cond_v = cos_t - 0.4
#cond = mx.symbol.Activation(data=cond_v, act_type='relu')
if m>0.0:
cos_m = math.cos(m)
sin_m = math.sin(m)
mm = math.sin(math.pi-m)*m
#threshold = 0.0
threshold = math.cos(math.pi-m)
if args.margin_verbose>0:
margin_symbols.append(mx.symbol.mean(cos_t))
if args.easy_margin:
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
else:
cond_v = cos_t - threshold
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
body = cos_t*cos_t
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
if args.easy_margin:
zy_keep = zy
else:
zy_keep = zy - s*mm
new_zy = mx.sym.where(cond, new_zy, zy_keep)
else:
cond_v = cos_t - threshold
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
#theta = mx.sym.arccos(costheta)
#sintheta = mx.sym.sin(theta)
body = cos_t*cos_t
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
if args.easy_margin:
zy_keep = zy
else:
zy_keep = zy - s*mm
new_zy = mx.sym.where(cond, new_zy, zy_keep)
assert args.output_c2c
#set c2c as cosm^2 in data.py
cos_m = mx.sym.sqrt(c2c_label)
sin_m = 1.0-c2c_label
sin_m = mx.sym.sqrt(sin_m)
body = cos_t*cos_t
body = 1.0-body
sin_t = mx.sym.sqrt(body)
new_zy = cos_t*cos_m
b = sin_t*sin_m
new_zy = new_zy - b
new_zy = new_zy*s
if args.margin_verbose>0:
new_cos_t = new_zy/s
margin_symbols.append(mx.symbol.mean(new_cos_t))
diff = new_zy - zy
diff = mx.sym.expand_dims(diff, 1)
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
@@ -638,6 +657,7 @@ def train_net(args):
shuffle = True,
rand_mirror = args.rand_mirror,
mean = mean,
c2c_threshold = args.c2c_threshold,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,
@@ -658,6 +678,7 @@ def train_net(args):
shuffle = True,
rand_mirror = args.rand_mirror,
mean = mean,
c2c_threshold = args.c2c_threshold,
ctx_num = args.ctx_num,
images_per_identity = args.images_per_identity,
data_extra = data_extra,