From 2199039ba822a48932a8b79ffbd9cd9c47c0e784 Mon Sep 17 00:00:00 2001 From: Jia Guo Date: Wed, 29 Nov 2017 17:04:05 +0800 Subject: [PATCH] refine --- src/common/face_image.py | 2 +- src/common/face_preprocess.py | 31 +- src/data.py | 577 +++++++++++++++++++++++++++------- src/{ => eval}/lfw.py | 89 +++++- src/marginalnet.py | 29 +- src/train_softmax.py | 157 ++------- 6 files changed, 611 insertions(+), 274 deletions(-) rename src/{ => eval}/lfw.py (71%) diff --git a/src/common/face_image.py b/src/common/face_image.py index 035d13a..e169d50 100644 --- a/src/common/face_image.py +++ b/src/common/face_image.py @@ -114,7 +114,7 @@ def get_dataset_common(input_dir): return ret def get_dataset(name, input_dir): - if name=='webface' or name=='lfw': + if name=='webface' or name=='lfw' or name=='vgg': return get_dataset_common(input_dir) if name=='celeb': return get_dataset_celeb(input_dir) diff --git a/src/common/face_preprocess.py b/src/common/face_preprocess.py index 5271877..1c9525b 100644 --- a/src/common/face_preprocess.py +++ b/src/common/face_preprocess.py @@ -49,13 +49,25 @@ def preprocess(img, bbox=None, landmark=None, **kwargs): if isinstance(img, str): img = read_image(img, **kwargs) M = None + image_size = [] + str_image_size = kwargs.get('image_size', '') + if len(str_image_size)>0: + image_size = [int(x) for x in str_image_size.split(',')] + if len(image_size)==1: + image_size = [image_size[0], image_size[0]] + assert len(image_size)==2 + assert image_size[0]==112 + assert image_size[0]==112 or image_size[1]==96 if landmark is not None: + assert len(image_size)==2 src = np.array([ [30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041] ], dtype=np.float32 ) + if image_size[1]==112: + src[:,0] += 8.0 dst = landmark.astype(np.float32) tform = trans.SimilarityTransform() @@ -78,16 +90,12 @@ def preprocess(img, bbox=None, landmark=None, **kwargs): bb[1] = np.maximum(det[1]-margin/2, 0) bb[2] = np.minimum(det[2]+margin/2, img.shape[1]) bb[3] = np.minimum(det[3]+margin/2, img.shape[0]) - cropped = img[bb[1]:bb[3],bb[0]:bb[2],:] - str_image_size = kwargs.get('image_size', '') - if len(str_image_size)>0: - image_size = [int(x) for x in str_image_size.split(',')] - if len(image_size)==1: - image_size = [image_size[0], image_size[0]] - assert len(image_size)==2 - scaled = cv2.resize(cropped, (image_size[1], image_size[0])) - return scaled + ret = img[bb[1]:bb[3],bb[0]:bb[2],:] + if len(image_size)>0: + ret = cv2.resize(ret, (image_size[1], image_size[0])) + return ret else: #do align using landmark + assert len(image_size)==2 #src = src[0:3,:] #dst = dst[0:3,:] @@ -96,11 +104,8 @@ def preprocess(img, bbox=None, landmark=None, **kwargs): #print(src.shape, dst.shape) #print(src) #print(dst) - _shape = [int(x) for x in kwargs.get('image_size').split(',')] - #print(_shape) - #M = cv2.getAffineTransform(src,dst) #print(M) - warped = cv2.warpAffine(img,M,(_shape[1],_shape[0]), borderValue = 0.0) + warped = cv2.warpAffine(img,M,(image_size[1],image_size[0]), borderValue = 0.0) #tform3 = trans.ProjectiveTransform() #tform3.estimate(src, dst) diff --git a/src/data.py b/src/data.py index 0e64edd..469e3c4 100644 --- a/src/data.py +++ b/src/data.py @@ -27,118 +27,140 @@ import face_preprocess logger = logging.getLogger() -#modification on ImageIter class FaceImageIter(io.DataIter): - def __init__(self, batch_size, data_shape, images_per_person, margin = 44, path_imglist=None, path_root=None, - shuffle=False, aug_list=None, + def __init__(self, batch_size, data_shape, + path_imgrec = None, + shuffle=False, aug_list=None, mean = None, + rand_mirror = False, data_name='data', label_name='softmax_label', **kwargs): super(FaceImageIter, self).__init__() - assert path_imglist - self.label2key = {} - self.labelkeys = [] - print('loading image list...') - with open(path_imglist) as fin: - imglist = {} - imgkeys = [] - key = 0 - for line in iter(fin.readline, ''): - line = line.strip().split('\t') - if len(line)<17: - continue #skip no detected face image - label = nd.array([float(line[2])]) - ilabel = int(line[2]) - if ilabel not in self.label2key: - self.label2key[ilabel] = [key] - self.labelkeys.append(ilabel) - #self.labelcur[ilabel] = 0 - else: - self.label2key[ilabel].append(key) - #label = nd.array([float(i) for i in line[1:-1]]) - bbox = np.array([int(i) for i in line[3:7]]) - #key = int(line[0]) - imglist[key] = (label, line[1], bbox) - imgkeys.append(key) - key+=1 - self.imglist = imglist - print('image list size', len(self.imglist)) + assert path_imgrec + if path_imgrec: + logging.info('loading recordio %s...', + path_imgrec) + path_imgidx = path_imgrec[0:-4]+".idx" + self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type + s = self.imgrec.read_idx(0) + header, _ = recordio.unpack(s) + if header.flag>0: + print('header0 label', header.label) + #assert(header.flag==1) + self.imgidx = range(1, int(header.label[0])) + self.idx2range = {} + self.seq_identity = range(int(header.label[0]), int(header.label[1])) + for identity in self.seq_identity: + s = self.imgrec.read_idx(identity) + header, _ = recordio.unpack(s) + #print('flag', header.flag) + #print(header.label) + #assert(header.flag==2) + self.idx2range[identity] = (int(header.label[0]), int(header.label[1])) + print('idx2range', len(self.idx2range)) + else: + self.imgidx = list(self.imgrec.keys) + if shuffle: + self.seq = self.imgidx + else: + self.seq = None - self.path_root = path_root - self.margin = margin + self.mean = mean + self.nd_mean = None + if self.mean: + self.mean = np.array(self.mean, dtype=np.float32).reshape(1,1,3) + self.nd_mean = mx.nd.array(self.mean).reshape((1,1,3)) self.check_data_shape(data_shape) self.provide_data = [(data_name, (batch_size,) + data_shape)] self.provide_label = [(label_name, (batch_size,))] self.batch_size = batch_size self.data_shape = data_shape - self.images_per_person = images_per_person - #self.label_width = label_width - self.imgkeys = imgkeys self.shuffle = shuffle + self.image_size = '%d,%d'%(data_shape[1],data_shape[2]) + self.rand_mirror = rand_mirror + #self.cast_aug = mx.image.CastAug() + #self.color_aug = mx.image.ColorJitterAug(0.4, 0.4, 0.4) - if aug_list is None: - self.auglist = mx.image.CreateAugmenter(data_shape, **kwargs) - else: - self.auglist = aug_list - print('aug size:', len(self.auglist)) - #for aug in self.auglist: - # print(aug.__name__) self.cur = 0 - self.labelcur = 0 self.reset() def reset(self): """Resets the iterator to the beginning of the data.""" + print('call reset()') if self.shuffle: - #random.shuffle(self.imgkeys) - random.shuffle(self.labelkeys) + random.shuffle(self.seq) + if self.imgrec is not None: + self.imgrec.reset() self.cur = 0 - self.labelcur = 0 - #for k in self.label2key: - # random.shuffle(self.label2key[k]) - def _next_sample(self): - """Helper function for reading in next sample.""" - #set total batch size, for example, 1800, and maximum size for each people, for example 45 - while True: - if self.cur >= len(self.labelkeys): - raise StopIteration - ilabel = self.labelkeys[self.cur] - if self.labelcur>=min(len(self.label2key[ilabel]), self.images_per_person): - self.labelcur=0 - self.cur+=1 - else: - idx = self.label2key[ilabel][self.labelcur] - self.labelcur += 1 - label, fname, bbox = self.imglist[idx] - return label, self.read_image(fname), bbox + def num_samples(self): + return len(self.seq) def next_sample(self): """Helper function for reading in next sample.""" #set total batch size, for example, 1800, and maximum size for each people, for example 45 - while True: - if self.cur >= len(self.labelkeys): - raise StopIteration - ilabel = self.labelkeys[self.cur] - if self.labelcur>=min(len(self.label2key[ilabel]), self.images_per_person): - self.labelcur=0 - self.cur+=1 + if self.seq is not None: + if self.cur >= len(self.seq): + raise StopIteration + idx = self.seq[self.cur] + self.cur += 1 + if self.imgrec is not None: + s = self.imgrec.read_idx(idx) + header, img = recordio.unpack(s) + return header.label, img, None, None else: - #print('in next_sample', self.cur, self.labelcur) - if self.labelcur==0 and self.shuffle: - #print('shuffling') - random.shuffle(self.label2key[ilabel]) - idx = self.label2key[ilabel][self.labelcur] - self.labelcur += 1 - label, fname, bbox = self.imglist[idx] - return label, self.read_image(fname), bbox + label, fname, bbox, landmark = self.imglist[idx] + return label, self.read_image(fname), bbox, landmark + else: + s = self.imgrec.read() + if s is None: + raise StopIteration + header, img = recordio.unpack(s) + return header.label, img, None, None + + def brightness_aug(self, src, x): + alpha = 1.0 + random.uniform(-x, x) + src *= alpha + return src + + def contrast_aug(self, src, x): + alpha = 1.0 + random.uniform(-x, x) + coef = np.array([[[0.299, 0.587, 0.114]]]) + gray = src * coef + gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray) + src *= alpha + src += gray + return src + + def saturation_aug(self, src, x): + alpha = 1.0 + random.uniform(-x, x) + coef = np.array([[[0.299, 0.587, 0.114]]]) + gray = src * coef + gray = np.sum(gray, axis=2, keepdims=True) + gray *= (1.0 - alpha) + src *= alpha + src += gray + return src + + def color_aug(self, img, x): + augs = [self.brightness_aug, self.contrast_aug, self.saturation_aug] + random.shuffle(augs) + for aug in augs: + #print(img.shape) + img = aug(img, x) + #print(img.shape) + return img + + def mirror_aug(self, img): + _rd = random.randint(0,1) + if _rd==1: + for c in xrange(img.shape[2]): + img[:,:,c] = np.fliplr(img[:,:,c]) + return img + def next(self): """Returns the next batch of data.""" - if self.shuffle: - random.shuffle(self.labelkeys) - self.cur = 0 - self.labelcur = 0 #print('in next', self.cur, self.labelcur) batch_size = self.batch_size c, h, w = self.data_shape @@ -147,23 +169,48 @@ class FaceImageIter(io.DataIter): i = 0 try: while i < batch_size: - label, s, bbox = self.next_sample() - data = [self.imdecode(s, bbox)] + label, s, bbox, landmark = self.next_sample() + _data = self.imdecode(s) + if self.rand_mirror: + _rd = random.randint(0,1) + if _rd==1: + _data = mx.ndarray.flip(data=_data, axis=1) + if self.nd_mean is not None: + _data = _data.astype('float32') + _data -= self.nd_mean + _data *= 0.0078125 + #_npdata = _data.asnumpy() + #if landmark is not None: + # _npdata = face_preprocess.preprocess(_npdata, bbox = bbox, landmark=landmark, image_size=self.image_size) + #if self.rand_mirror: + # _npdata = self.mirror_aug(_npdata) + #if self.mean is not None: + # _npdata = _npdata.astype(np.float32) + # _npdata -= self.mean + # _npdata *= 0.0078125 + #nimg = np.zeros(_npdata.shape, dtype=np.float32) + #nimg[self.patch[1]:self.patch[3],self.patch[0]:self.patch[2],:] = _npdata[self.patch[1]:self.patch[3], self.patch[0]:self.patch[2], :] + #_data = mx.nd.array(nimg) + data = [_data] try: self.check_valid_image(data) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue - data = self.augmentation_transform(data) + #print('aa',data[0].shape) + #data = self.augmentation_transform(data) + #print('bb',data[0].shape) for datum in data: assert i < batch_size, 'Batch size must be multiples of augmenter output length' + #print(datum.shape) batch_data[i][:] = self.postprocess_data(datum) batch_label[i][:] = label i += 1 except StopIteration: - if not i: + if i= len(self.seq_sim_identity): + raise StopIteration + identity = self.seq_sim_identity[self.cur[0]] + if self.cur[1]>=self.images_per_identity: + self.cur[0]+=1 + self.cur[1]=0 + s = self.imgrec.read_idx(identity) + header, _ = recordio.unpack(s) + self.idx_range = range(int(header.label[0]), int(header.label[1])) + continue + if self.shuffle and self.cur[1]==0: + random.shuffle(self.idx_range) + idx = self.idx_range[self.cur[1]] + self.cur[1] += 1 + s = self.imgrec.read_idx(idx) + header, img = recordio.unpack(s) + return header.label, img, None, None + + + def brightness_aug(self, src, x): + alpha = 1.0 + random.uniform(-x, x) + src *= alpha + return src + + def contrast_aug(self, src, x): + alpha = 1.0 + random.uniform(-x, x) + coef = np.array([[[0.299, 0.587, 0.114]]]) + gray = src * coef + gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray) + src *= alpha + src += gray + return src + + def saturation_aug(self, src, x): + alpha = 1.0 + random.uniform(-x, x) + coef = np.array([[[0.299, 0.587, 0.114]]]) + gray = src * coef + gray = np.sum(gray, axis=2, keepdims=True) + gray *= (1.0 - alpha) + src *= alpha + src += gray + return src + + def color_aug(self, img, x): + augs = [self.brightness_aug, self.contrast_aug, self.saturation_aug] + random.shuffle(augs) + for aug in augs: + #print(img.shape) + img = aug(img, x) + #print(img.shape) + return img + + def mirror_aug(self, img): + _rd = random.randint(0,1) + if _rd==1: + for c in xrange(img.shape[2]): + img[:,:,c] = np.fliplr(img[:,:,c]) + return img + + + def next(self): + if not self.inited: + self.reset() + self.inited = True + """Returns the next batch of data.""" + #print('in next', self.cur, self.labelcur) + batch_size = self.batch_size + c, h, w = self.data_shape + batch_data = nd.empty((batch_size, c, h, w)) + batch_label = nd.empty(self.provide_label[0][1]) + i = 0 + try: + while i < batch_size: + label, s, bbox, landmark = self.next_sample() + _data = self.imdecode(s) + if self.rand_mirror: + _rd = random.randint(0,1) + if _rd==1: + _data = mx.ndarray.flip(data=_data, axis=1) + if self.nd_mean is not None: + _data = _data.astype('float32') + _data -= self.nd_mean + _data *= 0.0078125 + data = [_data] + try: + self.check_valid_image(data) + except RuntimeError as e: + logging.debug('Invalid image, skipping: %s', str(e)) + continue + #print('aa',data[0].shape) + #data = self.augmentation_transform(data) + #print('bb',data[0].shape) + for datum in data: + assert i < batch_size, 'Batch size must be multiples of augmenter output length' + #print(datum.shape) + batch_data[i][:] = self.postprocess_data(datum) + batch_label[i][:] = label + i += 1 + except StopIteration: + if i>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. + """ + with open(os.path.join(self.path_root, fname), 'rb') as fin: + img = fin.read() + return img + + def augmentation_transform(self, data): + """Transforms input data with specified augmentation.""" + for aug in self.auglist: + data = [ret for src in data for ret in aug(src)] + return data + + def postprocess_data(self, datum): + """Final postprocessing step before image is loaded into the batch.""" + return nd.transpose(datum, axes=(2, 0, 1)) class FaceImageIter4(io.DataIter): @@ -705,15 +1055,6 @@ class FaceImageIter4(io.DataIter): per_batch_size = int(batch_size/ctx_num) self.provide_label = [(label_name, (batch_size,))] self.batch_size = batch_size - self.ctx_num = ctx_num - self.images_per_identity = images_per_identity - self.identities = int(per_batch_size/self.images_per_identity) - self.min_per_identity = 10 - if self.images_per_identity<=10: - self.min_per_identity = self.images_per_identity - self.min_per_identity = 1 - assert self.min_per_identity<=self.images_per_identity - print(self.images_per_identity, self.identities, self.min_per_identity) self.data_shape = data_shape self.shuffle = shuffle self.image_size = '%d,%d'%(data_shape[1],data_shape[2]) @@ -748,6 +1089,15 @@ class FaceImageIter4(io.DataIter): print(self.extra) else: self.provide_data = [(data_name, (batch_size,) + data_shape)] + self.ctx_num = ctx_num + self.images_per_identity = images_per_identity + self.identities = int(per_batch_size/self.images_per_identity) + self.min_per_identity = 10 + if self.images_per_identity<=10: + self.min_per_identity = self.images_per_identity + self.min_per_identity = 1 + assert self.min_per_identity<=self.images_per_identity + print(self.images_per_identity, self.identities, self.min_per_identity) if aug_list is None: self.auglist = mx.image.CreateAugmenter(data_shape, **kwargs) @@ -893,25 +1243,14 @@ class FaceImageIter4(io.DataIter): while i < batch_size: label, s, bbox, landmark = self.next_sample() _data = self.imdecode(s) - #_data = self.augmentation_transform([_data])[0] - _npdata = _data.asnumpy() - if landmark is not None: - _npdata = face_preprocess.preprocess(_npdata, bbox = bbox, landmark=landmark, image_size=self.image_size) if self.rand_mirror: _rd = random.randint(0,1) if _rd==1: - for c in xrange(_npdata.shape[2]): - _npdata[:,:,c] = np.fliplr(_npdata[:,:,c]) - if self.mean is not None: - _npdata = _npdata.astype(np.float32) - _npdata -= self.mean - _npdata *= 0.0078125 - nimg = np.zeros(_npdata.shape, dtype=np.float32) - nimg[self.patch[1]:self.patch[3],self.patch[0]:self.patch[2],:] = _npdata[self.patch[1]:self.patch[3], self.patch[0]:self.patch[2], :] - #print(_npdata.shape) - #print(_npdata) - _data = mx.nd.array(nimg) - #print(_data.shape) + _data = mx.ndarray.flip(data=_data, axis=1) + if self.nd_mean is not None: + _data = _data.astype('float32') + _data -= self.nd_mean + _data *= 0.0078125 data = [_data] try: self.check_valid_image(data) diff --git a/src/lfw.py b/src/eval/lfw.py similarity index 71% rename from src/lfw.py rename to src/eval/lfw.py index 0ea902b..7400769 100644 --- a/src/lfw.py +++ b/src/eval/lfw.py @@ -34,7 +34,10 @@ from sklearn.model_selection import KFold from scipy import interpolate import sklearn from sklearn.decomposition import PCA -#import facenet +import mxnet as mx +from mxnet import ndarray as nd + + def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, pca = 0): assert(embeddings1.shape[0] == embeddings2.shape[0]) @@ -186,5 +189,89 @@ def read_pairs(pairs_filename): pairs.append(pair) return np.array(pairs) +def load_dataset(lfw_dir, image_size): + lfw_pairs = read_pairs(os.path.join(lfw_dir, 'pairs.txt')) + lfw_paths, issame_list = get_paths(lfw_dir, lfw_pairs, 'jpg') + lfw_data_list = [] + for flip in [0,1]: + lfw_data = nd.empty((len(lfw_paths), 3, image_size[0], image_size[1])) + lfw_data_list.append(lfw_data) + i = 0 + for path in lfw_paths: + with open(path, 'rb') as fin: + _bin = fin.read() + img = mx.image.imdecode(_bin) + img = nd.transpose(img, axes=(2, 0, 1)) + for flip in [0,1]: + if flip==1: + img = mx.ndarray.flip(data=img, axis=2) + lfw_data_list[flip][i][:] = img + i+=1 + if i%1000==0: + print('loading lfw', i) + print(lfw_data_list[0].shape) + print(lfw_data_list[1].shape) + return (lfw_data_list, issame_list) +def test(lfw_set, mx_model, batch_size): + print('testing lfw..') + lfw_data_list = lfw_set[0] + issame_list = lfw_set[1] + model = mx_model + embeddings_list = [] + for i in xrange( len(lfw_data_list) ): + lfw_data = lfw_data_list[i] + embeddings = None + ba = 0 + while ba