mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 14:55:42 +00:00
add c2c label
This commit is contained in:
50
src/data.py
50
src/data.py
@@ -57,6 +57,7 @@ class FaceImageIter(io.DataIter):
|
||||
path_imgrec = None,
|
||||
shuffle=False, aug_list=None, mean = None,
|
||||
rand_mirror = False,
|
||||
c2c_threshold = 0.0, output_c2c = 0,
|
||||
ctx_num = 0, images_per_identity = 0, data_extra = None, hard_mining = False,
|
||||
triplet_params = None, coco_mode = False,
|
||||
mx_model = None,
|
||||
@@ -110,6 +111,8 @@ class FaceImageIter(io.DataIter):
|
||||
#self.cast_aug = mx.image.CastAug()
|
||||
#self.color_aug = mx.image.ColorJitterAug(0.4, 0.4, 0.4)
|
||||
self.ctx_num = ctx_num
|
||||
self.c2c_threshold = c2c_threshold
|
||||
self.output_c2c = output_c2c
|
||||
self.per_batch_size = int(self.batch_size/self.ctx_num)
|
||||
self.images_per_identity = images_per_identity
|
||||
if self.images_per_identity>0:
|
||||
@@ -131,7 +134,10 @@ class FaceImageIter(io.DataIter):
|
||||
self.triplet_mode = False
|
||||
self.coco_mode = coco_mode
|
||||
if len(label_name)>0:
|
||||
self.provide_label = [(label_name, (batch_size,))]
|
||||
if output_c2c:
|
||||
self.provide_label = [(label_name, (batch_size,2))]
|
||||
else:
|
||||
self.provide_label = [(label_name, (batch_size,))]
|
||||
else:
|
||||
self.provide_label = []
|
||||
if self.coco_mode:
|
||||
@@ -575,17 +581,24 @@ class FaceImageIter(io.DataIter):
|
||||
"""Helper function for reading in next sample."""
|
||||
#set total batch size, for example, 1800, and maximum size for each people, for example 45
|
||||
if self.seq is not None:
|
||||
if self.cur >= len(self.seq):
|
||||
raise StopIteration
|
||||
idx = self.seq[self.cur]
|
||||
self.cur += 1
|
||||
if self.imgrec is not None:
|
||||
s = self.imgrec.read_idx(idx)
|
||||
header, img = recordio.unpack(s)
|
||||
return header.label, img, None, None
|
||||
else:
|
||||
label, fname, bbox, landmark = self.imglist[idx]
|
||||
return label, self.read_image(fname), bbox, landmark
|
||||
while True:
|
||||
if self.cur >= len(self.seq):
|
||||
raise StopIteration
|
||||
idx = self.seq[self.cur]
|
||||
self.cur += 1
|
||||
if self.imgrec is not None:
|
||||
s = self.imgrec.read_idx(idx)
|
||||
header, img = recordio.unpack(s)
|
||||
label = header.label
|
||||
if not isinstance(header.label, numbers.Number):
|
||||
label = header.label[0]
|
||||
c = header.label[1]
|
||||
if c<self.c2c_threshold:
|
||||
continue
|
||||
return header.label, img, None, None
|
||||
else:
|
||||
label, fname, bbox, landmark = self.imglist[idx]
|
||||
return label, self.read_image(fname), bbox, landmark
|
||||
else:
|
||||
s = self.imgrec.read()
|
||||
if s is None:
|
||||
@@ -685,7 +698,18 @@ class FaceImageIter(io.DataIter):
|
||||
batch_data[i][:] = self.postprocess_data(datum)
|
||||
if self.provide_label is not None:
|
||||
if not self.coco_mode:
|
||||
batch_label[i][:] = label
|
||||
if len(batch_label.shape)==1:
|
||||
batch_label[i][:] = label
|
||||
else:
|
||||
for ll in xrange(batch_label.shape[1]):
|
||||
v = label[ll]
|
||||
if ll>0:
|
||||
v = min(0.5, max(0.25,math.log(v+1)*4-1.85))
|
||||
v = math.cos(v)
|
||||
v = v*v
|
||||
print('c2c', i,v)
|
||||
|
||||
batch_label[i][ll] = v
|
||||
else:
|
||||
batch_label[i][:] = (i%self.per_batch_size)//self.images_per_identity
|
||||
i += 1
|
||||
|
||||
@@ -137,6 +137,10 @@ def parse_args():
|
||||
help='')
|
||||
parser.add_argument('--margin-verbose', type=int, default=0,
|
||||
help='')
|
||||
parser.add_argument('--c2c-threshold', type=float, default=0.0,
|
||||
help='')
|
||||
parser.add_argument('--output-c2c', type=int, default=0,
|
||||
help='')
|
||||
parser.add_argument('--margin', type=int, default=4,
|
||||
help='')
|
||||
parser.add_argument('--beta', type=float, default=1000.,
|
||||
@@ -214,7 +218,12 @@ def get_symbol(args, arg_params, aux_params):
|
||||
embedding = fresnet.get_symbol(args.emb_size, args.num_layers,
|
||||
version_se=args.version_se, version_input=args.version_input,
|
||||
version_output=args.version_output, version_unit=args.version_unit)
|
||||
gt_label = mx.symbol.Variable('softmax_label')
|
||||
all_label = mx.symbol.Variable('softmax_label')
|
||||
if not args.output_c2c:
|
||||
gt_label = all_label
|
||||
else:
|
||||
gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
|
||||
c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
|
||||
assert args.loss_type>=0
|
||||
extra_loss = None
|
||||
if args.loss_type==0: #softmax
|
||||
@@ -302,46 +311,56 @@ def get_symbol(args, arg_params, aux_params):
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
assert s>0.0
|
||||
assert m>0.0
|
||||
assert m>=0.0
|
||||
assert m<(math.pi/2)
|
||||
cos_m = math.cos(m)
|
||||
sin_m = math.sin(m)
|
||||
mm = math.sin(math.pi-m)*m
|
||||
_weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
#threshold = 0.0
|
||||
threshold = math.cos(math.pi-m)
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if args.easy_margin:
|
||||
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
|
||||
#cond_v = cos_t - 0.4
|
||||
#cond = mx.symbol.Activation(data=cond_v, act_type='relu')
|
||||
if m>0.0:
|
||||
cos_m = math.cos(m)
|
||||
sin_m = math.sin(m)
|
||||
mm = math.sin(math.pi-m)*m
|
||||
#threshold = 0.0
|
||||
threshold = math.cos(math.pi-m)
|
||||
if args.margin_verbose>0:
|
||||
margin_symbols.append(mx.symbol.mean(cos_t))
|
||||
if args.easy_margin:
|
||||
cond = mx.symbol.Activation(data=cos_t, act_type='relu')
|
||||
else:
|
||||
cond_v = cos_t - threshold
|
||||
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
|
||||
body = cos_t*cos_t
|
||||
body = 1.0-body
|
||||
sin_t = mx.sym.sqrt(body)
|
||||
new_zy = cos_t*cos_m
|
||||
b = sin_t*sin_m
|
||||
new_zy = new_zy - b
|
||||
new_zy = new_zy*s
|
||||
if args.easy_margin:
|
||||
zy_keep = zy
|
||||
else:
|
||||
zy_keep = zy - s*mm
|
||||
new_zy = mx.sym.where(cond, new_zy, zy_keep)
|
||||
else:
|
||||
cond_v = cos_t - threshold
|
||||
cond = mx.symbol.Activation(data=cond_v, act_type='relu')
|
||||
#theta = mx.sym.arccos(costheta)
|
||||
#sintheta = mx.sym.sin(theta)
|
||||
body = cos_t*cos_t
|
||||
body = 1.0-body
|
||||
sin_t = mx.sym.sqrt(body)
|
||||
new_zy = cos_t*cos_m
|
||||
b = sin_t*sin_m
|
||||
new_zy = new_zy - b
|
||||
new_zy = new_zy*s
|
||||
if args.easy_margin:
|
||||
zy_keep = zy
|
||||
else:
|
||||
zy_keep = zy - s*mm
|
||||
new_zy = mx.sym.where(cond, new_zy, zy_keep)
|
||||
assert args.output_c2c
|
||||
#set c2c as cosm^2 in data.py
|
||||
cos_m = mx.sym.sqrt(c2c_label)
|
||||
sin_m = 1.0-c2c_label
|
||||
sin_m = mx.sym.sqrt(sin_m)
|
||||
body = cos_t*cos_t
|
||||
body = 1.0-body
|
||||
sin_t = mx.sym.sqrt(body)
|
||||
new_zy = cos_t*cos_m
|
||||
b = sin_t*sin_m
|
||||
new_zy = new_zy - b
|
||||
new_zy = new_zy*s
|
||||
|
||||
if args.margin_verbose>0:
|
||||
new_cos_t = new_zy/s
|
||||
margin_symbols.append(mx.symbol.mean(new_cos_t))
|
||||
|
||||
diff = new_zy - zy
|
||||
diff = mx.sym.expand_dims(diff, 1)
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
@@ -638,6 +657,7 @@ def train_net(args):
|
||||
shuffle = True,
|
||||
rand_mirror = args.rand_mirror,
|
||||
mean = mean,
|
||||
c2c_threshold = args.c2c_threshold,
|
||||
ctx_num = args.ctx_num,
|
||||
images_per_identity = args.images_per_identity,
|
||||
data_extra = data_extra,
|
||||
@@ -658,6 +678,7 @@ def train_net(args):
|
||||
shuffle = True,
|
||||
rand_mirror = args.rand_mirror,
|
||||
mean = mean,
|
||||
c2c_threshold = args.c2c_threshold,
|
||||
ctx_num = args.ctx_num,
|
||||
images_per_identity = args.images_per_identity,
|
||||
data_extra = data_extra,
|
||||
|
||||
Reference in New Issue
Block a user