# pylint: skip-file import mxnet as mx import numpy as np import sys, os import random import math import scipy.misc import cv2 import logging import sklearn import datetime import img_helper from mxnet.io import DataIter from mxnet import ndarray as nd from mxnet import io from mxnet import recordio from PIL import Image class FaceSegIter0(DataIter): def __init__(self, batch_size, path_imgrec = None, data_name = "data", label_name = "softmax_label"): self.batch_size = batch_size self.data_name = data_name self.label_name = label_name assert path_imgrec logging.info('loading recordio %s...', path_imgrec) path_imgidx = path_imgrec[0:-4]+".idx" self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type self.seq = list(self.imgrec.keys) self.cur = 0 self.reset() self.num_classes = 68 self.record_img_size = 384 self.input_img_size = self.record_img_size self.data_shape = (3, self.input_img_size, self.input_img_size) self.label_shape = (self.num_classes, 2) self.provide_data = [(data_name, (batch_size,) + self.data_shape)] self.provide_label = [(label_name, (batch_size,) + self.label_shape)] def reset(self): print('reset') """Resets the iterator to the beginning of the data.""" self.cur = 0 def next_sample(self): """Helper function for reading in next sample.""" if self.cur >= len(self.seq): raise StopIteration idx = self.seq[self.cur] self.cur += 1 s = self.imgrec.read_idx(idx) header, img = recordio.unpack(s) img = mx.image.imdecode(img).asnumpy() #label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8) hlabel = np.array(header.label).reshape( (self.num_classes,2) ) return img, hlabel def next(self): """Returns the next batch of data.""" #print('next') batch_size = self.batch_size batch_data = nd.empty((batch_size,)+self.data_shape) batch_label = nd.empty((batch_size,)+self.label_shape) i = 0 #self.cutoff = random.randint(800,1280) try: while i < batch_size: #print('N', i) data, label = self.next_sample() data = nd.array(data) data = nd.transpose(data, axes=(2, 0, 1)) label = nd.array(label) #print(data.shape, label.shape) batch_data[i][:] = data batch_label[i][:] = label i += 1 except StopIteration: if not i: raise StopIteration return mx.io.DataBatch([batch_data], [batch_label], batch_size - i) class FaceSegIter(DataIter): def __init__(self, batch_size, per_batch_size = 0, path_imgrec = None, aug_level = 0, force_mirror = False, use_coherent = 0, args = None, data_name = "data", label_name = "softmax_label"): self.aug_level = aug_level self.force_mirror = force_mirror self.use_coherent = use_coherent self.batch_size = batch_size self.per_batch_size = per_batch_size self.data_name = data_name self.label_name = label_name assert path_imgrec logging.info('loading recordio %s...', path_imgrec) path_imgidx = path_imgrec[0:-4]+".idx" self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type self.seq = list(self.imgrec.keys) print('train size', len(self.seq)) self.cur = 0 self.reset() self.num_classes = args.num_classes self.record_img_size = 384 self.input_img_size = args.input_img_size self.data_shape = (3, self.input_img_size, self.input_img_size) self.label_classes = self.num_classes if aug_level>0: self.output_label_size = args.output_label_size self.label_shape = (self.label_classes, self.output_label_size, self.output_label_size) else: self.output_label_size = args.input_img_size #self.label_shape = (self.num_classes, 2) self.label_shape = (self.num_classes, self.output_label_size, self.output_label_size) self.provide_data = [(data_name, (batch_size,) + self.data_shape)] self.provide_label = [(label_name, (batch_size,) + self.label_shape)] if self.use_coherent>0: #self.provide_label += [("softmax_label2", (batch_size,)+self.label_shape)] self.provide_label += [("coherent_label", (batch_size,6))] self.img_num = 0 self.invalid_num = 0 self.mode = 1 self.vis = 0 self.stats = [0,0] self.mirror_set = [ (22,23), (21,24), (20,25), (19,26), (18,27), (40,43), (39,44), (38,45), (37,46), (42,47), (41,48), (33,35), (32,36), (51,53), (50,54), (62,64), (61,65), (49,55), (49,55), (68,66), (60,56), (59,57), (1,17), (2,16), (3,15), (4,14), (5,13), (6,12), (7,11), (8,10), ] def get_data_shape(self): return self.data_shape #def get_label_shape(self): # return self.label_shape def get_shape_dict(self): D = {} for (k,v) in self.provide_data: D[k] = v for (k,v) in self.provide_label: D[k] = v return D def get_label_names(self): D = [] for (k,v) in self.provide_label: D.append(k) return D def reset(self): print('reset') """Resets the iterator to the beginning of the data.""" if self.aug_level>0: random.shuffle(self.seq) self.cur = 0 def next_sample(self): """Helper function for reading in next sample.""" if self.cur >= len(self.seq): raise StopIteration idx = self.seq[self.cur] self.cur += 1 s = self.imgrec.read_idx(idx) header, img = recordio.unpack(s) img = mx.image.imdecode(img).asnumpy() #label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8) hlabel = np.array(header.label).reshape( (self.num_classes,2) ) annot = {} ul = np.array( (50000,50000), dtype=np.int32) br = np.array( (0,0), dtype=np.int32) for i in xrange(hlabel.shape[0]): #hlabel[i] = hlabel[i][::-1] h = int(hlabel[i][0]) w = int(hlabel[i][1]) key = np.array((h,w)) #print(key.shape, ul.shape, br.shape) ul = np.minimum(key, ul) br = np.maximum(key, br) #label[h][w] = i+1 #label[i][h][w] = 1.0 #print(h,w,i+1) if self.mode==0: ul_margin = np.array( (60,30) ) br_margin = np.array( (30,30) ) crop_ul = ul crop_ul = ul-ul_margin crop_ul = np.maximum(crop_ul, np.array( (0,0), dtype=np.int32) ) crop_br = br crop_br = br+br_margin crop_br = np.minimum(crop_br, np.array( (img.shape[0],img.shape[1]), dtype=np.int32 ) ) elif self.mode==1: #mm = (self.record_img_size - 256)//2 #crop_ul = (mm, mm) #crop_br = (self.record_img_size-mm, self.record_img_size-mm) crop_ul = (0,0) crop_br = (self.record_img_size, self.record_img_size) annot['scale'] = 256 #annot['center'] = np.array( (self.record_img_size/2+10, self.record_img_size/2) ) else: mm = (48, 64) #mm = (64, 80) crop_ul = mm crop_br = (self.record_img_size-mm[0], self.record_img_size-mm[1]) crop_ul = np.array(crop_ul) crop_br = np.array(crop_br) img = img[crop_ul[0]:crop_br[0],crop_ul[1]:crop_br[1],:] #print(img.shape, crop_ul, crop_br) invalid = False for i in xrange(hlabel.shape[0]): if (hlabel[i]= crop_br).any(): invalid = True hlabel[i] -= crop_ul #mm = np.amin(ul) #mm2 = self.record_img_size - br #mm2 = np.amin(mm2) #mm = min(mm, mm2) #print('mm',mm, ul, br) #print('invalid', invalid) if invalid: self.invalid_num+=1 annot['invalid'] = invalid #annot['scale'] = (self.record_img_size - mm*2)*1.1 #if self.mode==1: # annot['scale'] = 256 #print(annot) #center = ul+br #center /= 2.0 #annot['center'] = center #img = img[ul[0]:br[0],ul[1]:br[1],:] #scale = br-ul #scale = max(scale[0], scale[1]) #print(img.shape) return img, hlabel, annot def do_aug(self, data, label, annot): if self.vis: self.img_num+=1 #if self.img_num<=self.vis: # filename = './vis/raw_%d.jpg' % (self.img_num) # print('save', filename) # draw = data.copy() # for i in xrange(label.shape[0]): # cv2.circle(draw, (label[i][1], label[i][0]), 1, (0, 0, 255), 2) # scipy.misc.imsave(filename, draw) rotate = 0 #scale = 1.0 if 'scale' in annot: scale = annot['scale'] else: scale = max(data.shape[0], data.shape[1]) if 'center' in annot: center = annot['center'] else: center = np.array( (data.shape[0]/2, data.shape[1]/2) ) max_retry = 3 if self.aug_level==0: max_retry = 6 retry = 0 found = False _scale = scale while retry0: rotate = np.random.randint(-40, 40) #rotate2 = np.random.randint(-40, 40) rotate2 = 0 scale_config = 0.2 #rotate = 0 #scale_config = 0.0 _scale = min(1+scale_config, max(1-scale_config, (np.random.randn() * scale_config) + 1)) _scale *= scale _scale = int(_scale) #translate = np.random.randint(-5, 5, size=(2,)) #center += translate if self.mode==1: cropped = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate) if self.use_coherent==2: cropped2 = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate2) else: cropped = img_helper.crop(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate) #print('cropped', cropped.shape) label_out = np.zeros(self.label_shape, dtype=np.float32) label2_out = np.zeros(self.label_shape, dtype=np.float32) G = 0 #if self.use_coherent: # G = 1 _g = G if G==0: _g = 1 #print('shape', label.shape, label_out.shape) for i in xrange(label.shape[0]): pt = label[i].copy() pt = pt[::-1] #print('before gaussian', label_out[i].shape, pt.shape) _pt = pt.copy() trans = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate) #print(trans.shape) if not img_helper.gaussian(label_out[i], trans, _g): succ = False break if self.use_coherent==2: _pt = pt.copy() trans2 = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate2) if not img_helper.gaussian(label2_out[i], trans2, _g): succ = False break if not succ: if self.aug_level==0: _scale+=20 continue if self.use_coherent==1: cropped2 = np.copy(cropped) for k in xrange(cropped2.shape[2]): cropped2[:,:,k] = np.fliplr(cropped2[:,:,k]) label2_out = np.copy(label_out) for k in xrange(label2_out.shape[0]): label2_out[k,:,:] = np.fliplr(label2_out[k,:,:]) new_label2_out = np.copy(label2_out) for item in self.mirror_set: mir = (item[0]-1, item[1]-1) new_label2_out[mir[1]] = label2_out[mir[0]] new_label2_out[mir[0]] = label2_out[mir[1]] label2_out = new_label2_out elif self.use_coherent==2: pass elif ((self.aug_level>0 and np.random.rand() < 0.5) or self.force_mirror): #flip aug for k in xrange(cropped.shape[2]): cropped[:,:,k] = np.fliplr(cropped[:,:,k]) for k in xrange(label_out.shape[0]): label_out[k,:,:] = np.fliplr(label_out[k,:,:]) new_label_out = np.copy(label_out) for item in self.mirror_set: mir = (item[0]-1, item[1]-1) new_label_out[mir[1]] = label_out[mir[0]] new_label_out[mir[0]] = label_out[mir[1]] label_out = new_label_out if G==0: for k in xrange(label.shape[0]): ind = np.unravel_index(np.argmax(label_out[k], axis=None), label_out[k].shape) label_out[k,:,:] = 0.0 label_out[k,ind[0],ind[1]] = 1.0 if self.use_coherent: ind = np.unravel_index(np.argmax(label2_out[k], axis=None), label2_out[k].shape) label2_out[k,:,:] = 0.0 label2_out[k,ind[0],ind[1]] = 1.0 found = True break #self.stats[0]+=1 if not found: #self.stats[1]+=1 #print('find aug error', retry) #print(self.stats) return None if self.vis>0 and self.img_num<=self.vis: print('crop', data.shape, center, _scale, rotate, cropped.shape) filename = './vis/cropped_%d.jpg' % (self.img_num) print('save', filename) draw = cropped.copy() alabel = label_out.copy() for i in xrange(label.shape[0]): a = cv2.resize(alabel[i], (self.input_img_size, self.input_img_size)) ind = np.unravel_index(np.argmax(a, axis=None), a.shape) cv2.circle(draw, (ind[1], ind[0]), 1, (0, 0, 255), 2) scipy.misc.imsave(filename, draw) if not self.use_coherent: return cropped, label_out else: rotate2 = 0 r = rotate - rotate2 #r = rotate2 - rotate r = math.pi*r/180 cos_r = math.cos(r) sin_r = math.sin(r) #c = cropped2.shape[0]//2 #M = cv2.getRotationMatrix2D( (c, c), rotate2-rotate, 1) M = np.array( [ [cos_r, -1*sin_r, 0.0], [sin_r, cos_r, 0.0] ] ) #print(M) #M=np.array([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, # 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33, # 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52, # 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 62, 61, 60, 67, 66, 65]) return cropped, label_out, cropped2, label2_out, M.flatten() def next(self): """Returns the next batch of data.""" #print('next') batch_size = self.batch_size batch_data = nd.empty((batch_size,)+self.data_shape) batch_label = nd.empty((batch_size,)+self.label_shape) if self.use_coherent: batch_label2 = nd.empty((batch_size,)+self.label_shape) batch_coherent_label = nd.empty((batch_size,6)) i = 0 #self.cutoff = random.randint(800,1280) try: while i < batch_size: #print('N', i) data, label, annot = self.next_sample() if not self.use_coherent: R = self.do_aug(data, label, annot) if R is None: continue data, label = R #data, label, data2, label2, M = R #ind = np.unravel_index(np.argmax(label[0], axis=None), label[0].shape) #print(label.shape, np.count_nonzero(label[0]), ind) #print(label[0,25:35,0:10]) data = nd.array(data) data = nd.transpose(data, axes=(2, 0, 1)) label = nd.array(label) #print(data.shape, label.shape) try: self.check_valid_image(data) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue batch_data[i][:] = data batch_label[i][:] = label i += 1 else: R = self.do_aug(data, label, annot) if R is None: continue data, label, data2, label2, M = R data = nd.array(data) data = nd.transpose(data, axes=(2, 0, 1)) label = nd.array(label) data2 = nd.array(data2) data2 = nd.transpose(data2, axes=(2, 0, 1)) label2 = nd.array(label2) M = nd.array(M) #print(data.shape, label.shape) try: self.check_valid_image(data) except RuntimeError as e: logging.debug('Invalid image, skipping: %s', str(e)) continue batch_data[i][:] = data batch_label[i][:] = label #batch_label2[i][:] = label2 batch_coherent_label[i][:] = M #i+=1 j = i+self.per_batch_size//2 batch_data[j][:] = data2 batch_label[j][:] = label2 batch_coherent_label[j][:] = M i += 1 if j%self.per_batch_size==self.per_batch_size-1: i = j+1 except StopIteration: if not i: raise StopIteration #return {self.data_name : batch_data, # self.label_name : batch_label} #print(batch_data.shape, batch_label.shape) if not self.use_coherent: return mx.io.DataBatch([batch_data], [batch_label], batch_size - i) else: #return mx.io.DataBatch([batch_data], [batch_label, batch_label2, batch_coherent_label], batch_size - i) return mx.io.DataBatch([batch_data], [batch_label, batch_coherent_label], batch_size - i) def check_data_shape(self, data_shape): """Checks if the input data shape is valid""" if not len(data_shape) == 3: raise ValueError('data_shape should have length 3, with dimensions CxHxW') if not data_shape[0] == 3: raise ValueError('This iterator expects inputs to have 3 channels.') def check_valid_image(self, data): """Checks if the input data is valid""" if len(data[0].shape) == 0: raise RuntimeError('Data shape is wrong') def imdecode(self, s): """Decodes a string or byte string to an NDArray. See mx.img.imdecode for more details.""" return imdecode(s) def read_image(self, fname): """Reads an input image `fname` and returns the decoded raw bytes. Example usage: ---------- >>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. """ with open(os.path.join(self.path_root, fname), 'rb') as fin: img = fin.read() return img