From f0cb7f6b52853bc555eff2a3f8d306dbd6a685b4 Mon Sep 17 00:00:00 2001 From: Jia Guo Date: Tue, 30 Jan 2018 20:13:45 +0800 Subject: [PATCH] add c2c label --- src/data.py | 50 ++++++++++++++++++++------- src/train_softmax.py | 81 ++++++++++++++++++++++++++++---------------- 2 files changed, 88 insertions(+), 43 deletions(-) diff --git a/src/data.py b/src/data.py index 4a215a0..d7258bc 100644 --- a/src/data.py +++ b/src/data.py @@ -57,6 +57,7 @@ class FaceImageIter(io.DataIter): path_imgrec = None, shuffle=False, aug_list=None, mean = None, rand_mirror = False, + c2c_threshold = 0.0, output_c2c = 0, ctx_num = 0, images_per_identity = 0, data_extra = None, hard_mining = False, triplet_params = None, coco_mode = False, mx_model = None, @@ -110,6 +111,8 @@ class FaceImageIter(io.DataIter): #self.cast_aug = mx.image.CastAug() #self.color_aug = mx.image.ColorJitterAug(0.4, 0.4, 0.4) self.ctx_num = ctx_num + self.c2c_threshold = c2c_threshold + self.output_c2c = output_c2c self.per_batch_size = int(self.batch_size/self.ctx_num) self.images_per_identity = images_per_identity if self.images_per_identity>0: @@ -131,7 +134,10 @@ class FaceImageIter(io.DataIter): self.triplet_mode = False self.coco_mode = coco_mode if len(label_name)>0: - self.provide_label = [(label_name, (batch_size,))] + if output_c2c: + self.provide_label = [(label_name, (batch_size,2))] + else: + self.provide_label = [(label_name, (batch_size,))] else: self.provide_label = [] if self.coco_mode: @@ -575,17 +581,24 @@ class FaceImageIter(io.DataIter): """Helper function for reading in next sample.""" #set total batch size, for example, 1800, and maximum size for each people, for example 45 if self.seq is not None: - if self.cur >= len(self.seq): - raise StopIteration - idx = self.seq[self.cur] - self.cur += 1 - if self.imgrec is not None: - s = self.imgrec.read_idx(idx) - header, img = recordio.unpack(s) - return header.label, img, None, None - else: - label, fname, bbox, landmark = self.imglist[idx] - return label, self.read_image(fname), bbox, landmark + while True: + if self.cur >= len(self.seq): + raise StopIteration + idx = self.seq[self.cur] + self.cur += 1 + if self.imgrec is not None: + s = self.imgrec.read_idx(idx) + header, img = recordio.unpack(s) + label = header.label + if not isinstance(header.label, numbers.Number): + label = header.label[0] + c = header.label[1] + if c0: + v = min(0.5, max(0.25,math.log(v+1)*4-1.85)) + v = math.cos(v) + v = v*v + print('c2c', i,v) + + batch_label[i][ll] = v else: batch_label[i][:] = (i%self.per_batch_size)//self.images_per_identity i += 1 diff --git a/src/train_softmax.py b/src/train_softmax.py index 94c3de4..0024852 100644 --- a/src/train_softmax.py +++ b/src/train_softmax.py @@ -137,6 +137,10 @@ def parse_args(): help='') parser.add_argument('--margin-verbose', type=int, default=0, help='') + parser.add_argument('--c2c-threshold', type=float, default=0.0, + help='') + parser.add_argument('--output-c2c', type=int, default=0, + help='') parser.add_argument('--margin', type=int, default=4, help='') parser.add_argument('--beta', type=float, default=1000., @@ -214,7 +218,12 @@ def get_symbol(args, arg_params, aux_params): embedding = fresnet.get_symbol(args.emb_size, args.num_layers, version_se=args.version_se, version_input=args.version_input, version_output=args.version_output, version_unit=args.version_unit) - gt_label = mx.symbol.Variable('softmax_label') + all_label = mx.symbol.Variable('softmax_label') + if not args.output_c2c: + gt_label = all_label + else: + gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1) + c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2) assert args.loss_type>=0 extra_loss = None if args.loss_type==0: #softmax @@ -302,46 +311,56 @@ def get_symbol(args, arg_params, aux_params): s = args.margin_s m = args.margin_m assert s>0.0 - assert m>0.0 + assert m>=0.0 assert m<(math.pi/2) - cos_m = math.cos(m) - sin_m = math.sin(m) - mm = math.sin(math.pi-m)*m _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) _weight = mx.symbol.L2Normalization(_weight, mode='instance') - #threshold = 0.0 - threshold = math.cos(math.pi-m) nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') zy = mx.sym.pick(fc7, gt_label, axis=1) cos_t = zy/s - if args.margin_verbose>0: - margin_symbols.append(mx.symbol.mean(cos_t)) - if args.easy_margin: - cond = mx.symbol.Activation(data=cos_t, act_type='relu') - #cond_v = cos_t - 0.4 - #cond = mx.symbol.Activation(data=cond_v, act_type='relu') + if m>0.0: + cos_m = math.cos(m) + sin_m = math.sin(m) + mm = math.sin(math.pi-m)*m + #threshold = 0.0 + threshold = math.cos(math.pi-m) + if args.margin_verbose>0: + margin_symbols.append(mx.symbol.mean(cos_t)) + if args.easy_margin: + cond = mx.symbol.Activation(data=cos_t, act_type='relu') + else: + cond_v = cos_t - threshold + cond = mx.symbol.Activation(data=cond_v, act_type='relu') + body = cos_t*cos_t + body = 1.0-body + sin_t = mx.sym.sqrt(body) + new_zy = cos_t*cos_m + b = sin_t*sin_m + new_zy = new_zy - b + new_zy = new_zy*s + if args.easy_margin: + zy_keep = zy + else: + zy_keep = zy - s*mm + new_zy = mx.sym.where(cond, new_zy, zy_keep) else: - cond_v = cos_t - threshold - cond = mx.symbol.Activation(data=cond_v, act_type='relu') - #theta = mx.sym.arccos(costheta) - #sintheta = mx.sym.sin(theta) - body = cos_t*cos_t - body = 1.0-body - sin_t = mx.sym.sqrt(body) - new_zy = cos_t*cos_m - b = sin_t*sin_m - new_zy = new_zy - b - new_zy = new_zy*s - if args.easy_margin: - zy_keep = zy - else: - zy_keep = zy - s*mm - new_zy = mx.sym.where(cond, new_zy, zy_keep) + assert args.output_c2c + #set c2c as cosm^2 in data.py + cos_m = mx.sym.sqrt(c2c_label) + sin_m = 1.0-c2c_label + sin_m = mx.sym.sqrt(sin_m) + body = cos_t*cos_t + body = 1.0-body + sin_t = mx.sym.sqrt(body) + new_zy = cos_t*cos_m + b = sin_t*sin_m + new_zy = new_zy - b + new_zy = new_zy*s + if args.margin_verbose>0: new_cos_t = new_zy/s margin_symbols.append(mx.symbol.mean(new_cos_t)) - diff = new_zy - zy diff = mx.sym.expand_dims(diff, 1) gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0) @@ -638,6 +657,7 @@ def train_net(args): shuffle = True, rand_mirror = args.rand_mirror, mean = mean, + c2c_threshold = args.c2c_threshold, ctx_num = args.ctx_num, images_per_identity = args.images_per_identity, data_extra = data_extra, @@ -658,6 +678,7 @@ def train_net(args): shuffle = True, rand_mirror = args.rand_mirror, mean = mean, + c2c_threshold = args.c2c_threshold, ctx_num = args.ctx_num, images_per_identity = args.images_per_identity, data_extra = data_extra,