mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 05:27:56 +00:00
add coco loss
This commit is contained in:
@@ -272,21 +272,18 @@ def get_symbol(args, arg_params, aux_params):
|
||||
triplet_loss = mx.symbol.mean(triplet_loss)
|
||||
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
|
||||
extra_loss = mx.symbol.MakeLoss(triplet_loss)
|
||||
elif args.loss_type==13: #triplet loss II
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
|
||||
anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)
|
||||
positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)
|
||||
negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)
|
||||
ap = anchor - positive
|
||||
an = anchor - negative
|
||||
ap = ap*ap
|
||||
an = an*an
|
||||
ap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)
|
||||
an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)
|
||||
triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')
|
||||
triplet_loss = mx.symbol.mean(triplet_loss)
|
||||
#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
|
||||
extra_loss = mx.symbol.MakeLoss(triplet_loss)
|
||||
elif args.loss_type==9: #coco loss
|
||||
centroids = []
|
||||
for i in xrange(args.per_identities):
|
||||
xs = mx.symbol.slice_axis(embedding, axis=0, begin=i*args.images_per_identity, end=(i+1)*args.images_per_identity)
|
||||
mean = mx.symbol.mean(xs, axis=0, keepdims=True)
|
||||
mean = mx.symbol.L2Normalization(mean, mode='instance')
|
||||
centroids.append(mean)
|
||||
centroids = mx.symbol.concat(*centroids, dim=0)
|
||||
nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*args.coco_scale
|
||||
fc7 = mx.symbol.dot(nembedding, centroids, transpose_b = True) #(batchsize, per_identities)
|
||||
#extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
|
||||
#extra_loss = mx.symbol.BlockGrad(extra_loss)
|
||||
else:
|
||||
#embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
|
||||
embedding = embedding * 5
|
||||
@@ -375,6 +372,7 @@ def train_net(args):
|
||||
|
||||
assert(args.num_classes>0)
|
||||
print('num_classes', args.num_classes)
|
||||
args.coco_scale = 0.5*math.log(float(args.num_classes-1))+3
|
||||
|
||||
#path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
|
||||
path_imgrec = os.path.join(data_dir, "train.rec")
|
||||
@@ -394,13 +392,13 @@ def train_net(args):
|
||||
args.beta_freeze = 5000
|
||||
args.gamma = 0.06
|
||||
|
||||
if args.loss_type<10:
|
||||
if args.loss_type<9:
|
||||
assert args.images_per_identity==0
|
||||
else:
|
||||
if args.images_per_identity==0:
|
||||
if args.loss_type==11:
|
||||
args.images_per_identity = 2
|
||||
elif args.loss_type==10:
|
||||
elif args.loss_type==10 or args.loss_type==9:
|
||||
args.images_per_identity = 16
|
||||
elif args.loss_type==12:
|
||||
args.images_per_identity = 5
|
||||
@@ -445,6 +443,7 @@ def train_net(args):
|
||||
data_extra = None
|
||||
hard_mining = False
|
||||
triplet_params = None
|
||||
coco_mode = False
|
||||
if args.loss_type==10:
|
||||
hard_mining = True
|
||||
_shape = (args.batch_size, args.per_batch_size)
|
||||
@@ -467,6 +466,8 @@ def train_net(args):
|
||||
c+=args.per_batch_size
|
||||
elif args.loss_type==12:
|
||||
triplet_params = [args.triplet_bag_size, args.triplet_alpha]
|
||||
elif args.loss_type==9:
|
||||
coco_mode = True
|
||||
|
||||
label_name = 'softmax_label'
|
||||
if data_extra is None:
|
||||
@@ -497,6 +498,7 @@ def train_net(args):
|
||||
data_extra = data_extra,
|
||||
hard_mining = hard_mining,
|
||||
triplet_params = triplet_params,
|
||||
coco_mode = coco_mode,
|
||||
mx_model = model,
|
||||
label_name = label_name,
|
||||
)
|
||||
@@ -516,6 +518,7 @@ def train_net(args):
|
||||
data_extra = data_extra,
|
||||
hard_mining = hard_mining,
|
||||
triplet_params = triplet_params,
|
||||
coco_mode = coco_mode,
|
||||
mx_model = model,
|
||||
label_name = label_name,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user