from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import random import logging import sys import numbers import math import sklearn import datetime import numpy as np import cv2 import mxnet as mx from mxnet import ndarray as nd from mxnet import io from mxnet import recordio sys.path.append(os.path.join(os.path.dirname(__file__), 'common')) import face_preprocess import multiprocessing logger = logging.getLogger() class FaceImageIter(io.DataIter): def __init__(self, batch_size, data_shape, path_imgrec = None, shuffle=False, aug_list=None, mean = None, rand_mirror = False, cutoff = 0, data_name='data', label_name='softmax_label', **kwargs): super(FaceImageIter, self).__init__() assert path_imgrec if path_imgrec: logging.info('loading recordio %s...', path_imgrec) path_imgidx = path_imgrec[0:-4]+".idx" self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type s = self.imgrec.read_idx(0) header, _ = recordio.unpack(s) self.imgidx = list(self.imgrec.keys) if shuffle: self.seq = self.imgidx self.oseq = self.imgidx print(len(self.seq)) else: self.seq = None self.mean = mean self.nd_mean = None if self.mean: self.mean = np.array(self.mean, dtype=np.float32).reshape(1,1,3) self.nd_mean = mx.nd.array(self.mean).reshape((1,1,3)) self.check_data_shape(data_shape) self.provide_data = [(data_name, (batch_size,) + data_shape)] self.batch_size = batch_size self.data_shape = data_shape self.shuffle = shuffle self.image_size = '%d,%d'%(data_shape[1],data_shape[2]) self.rand_mirror = rand_mirror print('rand_mirror', rand_mirror) self.cutoff = cutoff self.provide_label = [(label_name, (batch_size,102))] #print(self.provide_label[0][1]) self.cur = 0 self.nbatch = 0 self.is_init = False def reset(self): """Resets the iterator to the beginning of the data.""" print('call reset()') self.cur = 0 if self.shuffle: random.shuffle(self.seq) if self.seq is None and self.imgrec is not None: self.imgrec.reset() def num_samples(self): return len(self.seq) def next_sample(self): """Helper function for reading in next sample.""" #set total batch size, for example, 1800, and maximum size for each people, for example 45 if self.seq is not None: while True: if self.cur >= len(self.seq): raise StopIteration idx = self.seq[self.cur] self.cur += 1 if self.imgrec is not None: s = self.imgrec.read_idx(idx) header, img = recordio.unpack(s) label = header.label return label, img, None, None else: label, fname, bbox, landmark = self.imglist[idx] return label, self.read_image(fname), bbox, landmark else: s = self.imgrec.read() if s is None: raise StopIteration header, img = recordio.unpack(s) return header.label, img, None, None def brightness_aug(self, src, x): alpha = 1.0 + random.uniform(-x, x) src *= alpha return src def contrast_aug(self, src, x): alpha = 1.0 + random.uniform(-x, x) coef = np.array([[[0.299, 0.587, 0.114]]]) gray = src * coef gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray) src *= alpha src += gray return src def saturation_aug(self, src, x): alpha = 1.0 + random.uniform(-x, x) coef = np.array([[[0.299, 0.587, 0.114]]]) gray = src * coef gray = np.sum(gray, axis=2, keepdims=True) gray *= (1.0 - alpha) src *= alpha src += gray return src def color_aug(self, img, x): augs = [self.brightness_aug, self.contrast_aug, self.saturation_aug] random.shuffle(augs) for aug in augs: #print(img.shape) img = aug(img, x) #print(img.shape) return img def mirror_aug(self, img): _rd = random.randint(0,1) if _rd==1: for c in xrange(img.shape[2]): img[:,:,c] = np.fliplr(img[:,:,c]) return img def next(self): if not self.is_init: self.reset() self.is_init = True """Returns the next batch of data.""" #print('in next', self.cur, self.labelcur) self.nbatch+=1 batch_size = self.batch_size c, h, w = self.data_shape batch_data = nd.empty((batch_size, c, h, w)) if self.provide_label is not None: batch_label = nd.empty(self.provide_label[0][1]) i = 0 try: while i < batch_size: label, s, bbox, landmark = self.next_sample() #if label[1]>=0.0 or label[2]>=0.0: # print(label[0:10]) _data = self.imdecode(s) if self.rand_mirror: _rd = random.randint(0,1) if _rd==1: _data = mx.ndarray.flip(data=_data, axis=1) if self.nd_mean is not None: _data = _data.astype('float32') _data -= self.nd_mean _data *= 0.0078125 if self.cutoff>0: centerh = random.randint(0, _data.shape[0]-1) centerw = random.randint(0, _data.shape[1]-1) half = self.cutoff//2 starth = max(0, centerh-half) endh = min(_data.shape[0], centerh+half) startw = max(0, centerw-half) endw = min(_data.shape[1], centerw+half) _data = _data.astype('float32') #print(starth, endh, startw, endw, _data.shape) _data[starth:endh, startw:endw, :] = 127.5 data = [_data] try: self.check_valid_image(data) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue #print('aa',data[0].shape) #data = self.augmentation_transform(data) #print('bb',data[0].shape) for datum in data: assert i < batch_size, 'Batch size must be multiples of augmenter output length' #print(datum.shape) batch_data[i][:] = self.postprocess_data(datum) batch_label[i][:] = label i += 1 except StopIteration: if i>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. """ with open(os.path.join(self.path_root, fname), 'rb') as fin: img = fin.read() return img def augmentation_transform(self, data): """Transforms input data with specified augmentation.""" for aug in self.auglist: data = [ret for src in data for ret in aug(src)] return data def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" return nd.transpose(datum, axes=(2, 0, 1)) class FaceImageIterList(io.DataIter): def __init__(self, iter_list): assert len(iter_list)>0 self.provide_data = iter_list[0].provide_data self.provide_label = iter_list[0].provide_label self.iter_list = iter_list self.cur_iter = None def reset(self): self.cur_iter.reset() def next(self): self.cur_iter = random.choice(self.iter_list) while True: try: ret = self.cur_iter.next() except StopIteration: self.cur_iter.reset() continue return ret