mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 14:55:42 +00:00
add intra/inter loss
This commit is contained in:
@@ -32,6 +32,7 @@ class FaceImageIter(io.DataIter):
|
||||
path_imgrec = None,
|
||||
shuffle=False, aug_list=None, mean = None,
|
||||
rand_mirror = False, cutoff = 0, color_jittering = 0,
|
||||
images_filter = 0,
|
||||
data_name='data', label_name='softmax_label', **kwargs):
|
||||
super(FaceImageIter, self).__init__()
|
||||
assert path_imgrec
|
||||
@@ -46,15 +47,19 @@ class FaceImageIter(io.DataIter):
|
||||
print('header0 label', header.label)
|
||||
self.header0 = (int(header.label[0]), int(header.label[1]))
|
||||
#assert(header.flag==1)
|
||||
self.imgidx = range(1, int(header.label[0]))
|
||||
#self.imgidx = range(1, int(header.label[0]))
|
||||
self.imgidx = []
|
||||
self.id2range = {}
|
||||
self.seq_identity = range(int(header.label[0]), int(header.label[1]))
|
||||
for identity in self.seq_identity:
|
||||
s = self.imgrec.read_idx(identity)
|
||||
header, _ = recordio.unpack(s)
|
||||
a,b = int(header.label[0]), int(header.label[1])
|
||||
self.id2range[identity] = (a,b)
|
||||
count = b-a
|
||||
if count<images_filter:
|
||||
continue
|
||||
self.id2range[identity] = (a,b)
|
||||
self.imgidx += range(a, b)
|
||||
print('id2range', len(self.id2range))
|
||||
else:
|
||||
self.imgidx = list(self.imgrec.keys)
|
||||
|
||||
@@ -125,6 +125,7 @@ def parse_args():
|
||||
parser.add_argument('--rand-mirror', type=int, default=1, help='if do random mirror in training')
|
||||
parser.add_argument('--cutoff', type=int, default=0, help='cut off aug')
|
||||
parser.add_argument('--color', type=int, default=0, help='color jittering aug')
|
||||
parser.add_argument('--images-filter', type=int, default=0, help='minimum images per identity filter')
|
||||
parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30', help='verification targets')
|
||||
parser.add_argument('--ce-loss', default=False, action='store_true', help='if output ce loss')
|
||||
args = parser.parse_args()
|
||||
@@ -272,9 +273,71 @@ def get_symbol(args, arg_params, aux_params):
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
body = mx.sym.broadcast_mul(gt_one_hot, diff)
|
||||
fc7 = fc7+body
|
||||
elif args.loss_type==6:
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
assert s>0.0
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
t = mx.sym.arccos(cos_t)
|
||||
intra_loss = t/np.pi
|
||||
intra_loss = mx.sym.mean(intra_loss)
|
||||
#intra_loss = mx.sym.exp(cos_t*-1.0)
|
||||
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 1)
|
||||
if m>0.0:
|
||||
t = t+m
|
||||
body = mx.sym.cos(t)
|
||||
new_zy = body*s
|
||||
diff = new_zy - zy
|
||||
diff = mx.sym.expand_dims(diff, 1)
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
body = mx.sym.broadcast_mul(gt_one_hot, diff)
|
||||
fc7 = fc7+body
|
||||
elif args.loss_type==7:
|
||||
s = args.margin_s
|
||||
m = args.margin_m
|
||||
assert s>0.0
|
||||
_weight = mx.symbol.L2Normalization(_weight, mode='instance')
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s
|
||||
fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7')
|
||||
zy = mx.sym.pick(fc7, gt_label, axis=1)
|
||||
cos_t = zy/s
|
||||
t = mx.sym.arccos(cos_t)
|
||||
intra_loss = t/np.pi
|
||||
intra_loss = mx.sym.MakeLoss(intra_loss, name='intra_loss', grad_scale = 0.1)
|
||||
#idx = mx.sym.arange(0, args.num_classes)
|
||||
#idx = mx.sym.random.shuffle(idx)
|
||||
#idx = mx.sym.slice(data=idx, begin=0, end=100)
|
||||
#counter_weight = mx.sym.take(_weight, indices=idx, axis=1)
|
||||
#counter_weight = mx.sym.pick(_weight, gt_label, axis=1)
|
||||
counter_weight = mx.sym.take(_weight, gt_label, axis=1)
|
||||
counter_cos = mx.sym.dot(counter_weight, _weight, transpose_a=True)
|
||||
counter_angle = mx.sym.arccos(counter_cos)
|
||||
inter_loss = counter_angle/np.pi #[0,1]
|
||||
inter_loss = inter_loss*-1.0 # [-1,0]
|
||||
inter_loss = inter_loss+1.0 # [0,1]
|
||||
inter_loss = mx.sym.mean(inter_loss)
|
||||
inter_loss = mx.sym.MakeLoss(inter_loss, name='inter_loss', grad_scale = 0.1)
|
||||
if m>0.0:
|
||||
t = t+m
|
||||
body = mx.sym.cos(t)
|
||||
new_zy = body*s
|
||||
diff = new_zy - zy
|
||||
diff = mx.sym.expand_dims(diff, 1)
|
||||
gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
|
||||
body = mx.sym.broadcast_mul(gt_one_hot, diff)
|
||||
fc7 = fc7+body
|
||||
out_list = [mx.symbol.BlockGrad(embedding)]
|
||||
softmax = mx.symbol.SoftmaxOutput(data=fc7, label = gt_label, name='softmax', normalization='valid')
|
||||
out_list.append(softmax)
|
||||
if args.loss_type==6:
|
||||
out_list.append(intra_loss)
|
||||
if args.loss_type==7:
|
||||
out_list.append(inter_loss)
|
||||
#out_list.append(intra_loss)
|
||||
if args.ce_loss:
|
||||
#ce_loss = mx.symbol.softmax_cross_entropy(data=fc7, label = gt_label, name='ce_loss')/args.per_batch_size
|
||||
body = mx.symbol.SoftmaxActivation(data=fc7)
|
||||
@@ -346,6 +409,9 @@ def train_net(args):
|
||||
arg_params = None
|
||||
aux_params = None
|
||||
sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
|
||||
if args.network[0]=='s':
|
||||
data_shape_dict = {'data' : (args.per_batch_size,)+data_shape}
|
||||
spherenet.init_weights(sym, data_shape_dict, args.num_layers)
|
||||
else:
|
||||
vec = args.pretrained.split(',')
|
||||
print('loading', vec)
|
||||
@@ -369,6 +435,7 @@ def train_net(args):
|
||||
mean = mean,
|
||||
cutoff = args.cutoff,
|
||||
color_jittering = args.color,
|
||||
images_filter = args.images_filter,
|
||||
)
|
||||
|
||||
metric1 = AccMetric()
|
||||
@@ -446,20 +513,33 @@ def train_net(args):
|
||||
save_step[0]+=1
|
||||
msave = save_step[0]
|
||||
do_save = False
|
||||
is_highest = False
|
||||
if len(acc_list)>0:
|
||||
lfw_score = acc_list[0]
|
||||
if lfw_score>highest_acc[0]:
|
||||
highest_acc[0] = lfw_score
|
||||
if lfw_score>=0.998:
|
||||
do_save = True
|
||||
#lfw_score = acc_list[0]
|
||||
#if lfw_score>highest_acc[0]:
|
||||
# highest_acc[0] = lfw_score
|
||||
# if lfw_score>=0.998:
|
||||
# do_save = True
|
||||
score = sum(acc_list)
|
||||
if acc_list[-1]>=highest_acc[-1]:
|
||||
if acc_list[-1]>highest_acc[-1]:
|
||||
is_highest = True
|
||||
else:
|
||||
if score>=highest_acc[0]:
|
||||
is_highest = True
|
||||
highest_acc[0] = score
|
||||
highest_acc[-1] = acc_list[-1]
|
||||
if lfw_score>=0.99:
|
||||
do_save = True
|
||||
#if lfw_score>=0.99:
|
||||
# do_save = True
|
||||
if is_highest:
|
||||
do_save = True
|
||||
if args.ckpt==0:
|
||||
do_save = False
|
||||
elif args.ckpt>1:
|
||||
elif args.ckpt==2:
|
||||
do_save = True
|
||||
elif args.ckpt==3:
|
||||
msave = 1
|
||||
|
||||
if do_save:
|
||||
print('saving', msave)
|
||||
arg, aux = model.get_params()
|
||||
|
||||
Reference in New Issue
Block a user