diff --git a/alignment/benchmark.py b/alignment/benchmark.py new file mode 100644 index 0000000..003675e --- /dev/null +++ b/alignment/benchmark.py @@ -0,0 +1,65 @@ +import argparse +import cv2 +import numpy as np +import sys +import mxnet as mx +import datetime + +parser = argparse.ArgumentParser(description='face model test') +# general +parser.add_argument('--image-size', default='128,128', help='') +parser.add_argument('--model', default='./models/test,15', help='path to load model.') +parser.add_argument('--gpu', default=0, type=int, help='gpu id') +parser.add_argument('--batch-size', default=10, type=int, help='batch size') +parser.add_argument('--iterations', default=10, type=int, help='iterations') +args = parser.parse_args() + +_vec = args.image_size.split(',') +assert len(_vec)==2 +image_size = (int(_vec[0]), int(_vec[1])) +_vec = args.model.split(',') +assert len(_vec)==2 +prefix = _vec[0] +epoch = int(_vec[1]) +print('loading',prefix, epoch) +if args.gpu>=0: + ctx = mx.gpu(args.gpu) +else: + ctx = mx.cpu() +sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +all_layers = sym.get_internals() +sym = all_layers['heatmap_output'] +model = mx.mod.Module(symbol=sym, context=ctx, label_names = None) +#model = mx.mod.Module(symbol=sym, context=ctx) +model.bind(for_training=False, data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))]) +#model.bind(for_training=False, data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,84,64,64))]) +model.set_params(arg_params, aux_params) +img_path = './test.png' + +img = cv2.imread(img_path) + +rimg = cv2.resize(img, (image_size[1], image_size[0])) +img = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB) +img = np.transpose(img, (2,0,1)) #3*112*112, RGB +input_blob = np.zeros( (args.batch_size, 3, image_size[1], image_size[0]),dtype=np.uint8 ) +for i in xrange(args.batch_size): + input_blob[i] = img +data = mx.nd.array(input_blob) +print(data.shape) +label = mx.nd.zeros( (args.batch_size, 84, 64, 64) ) +#db = mx.io.DataBatch(data=(data,)) +db = mx.io.DataBatch(data=(data,), label=(label,)) +stat = [] +warmup = 2 +for i in xrange(args.iterations+warmup): + #print(i) + time_now = datetime.datetime.now() + model.forward(db, is_train=False) + output = model.get_outputs()[-1].asnumpy() + time_now2 = datetime.datetime.now() + diff = time_now2 - time_now + stat.append(diff.total_seconds()) +stat = stat[warmup:] +print(np.mean(stat)/args.batch_size) + + diff --git a/alignment/data.py b/alignment/data.py new file mode 100644 index 0000000..59b41ad --- /dev/null +++ b/alignment/data.py @@ -0,0 +1,545 @@ +# pylint: skip-file +import mxnet as mx +import numpy as np +import sys, os +import random +import math +import scipy.misc +import cv2 +import logging +import sklearn +import datetime +import img_helper +from mxnet.io import DataIter +from mxnet import ndarray as nd +from mxnet import io +from mxnet import recordio +from PIL import Image + + +class FaceSegIter0(DataIter): + def __init__(self, batch_size, + path_imgrec = None, + data_name = "data", + label_name = "softmax_label"): + self.batch_size = batch_size + self.data_name = data_name + self.label_name = label_name + assert path_imgrec + logging.info('loading recordio %s...', + path_imgrec) + path_imgidx = path_imgrec[0:-4]+".idx" + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type + self.seq = list(self.imgrec.keys) + self.cur = 0 + self.reset() + self.num_classes = 68 + self.record_img_size = 384 + self.input_img_size = self.record_img_size + self.data_shape = (3, self.input_img_size, self.input_img_size) + self.label_shape = (self.num_classes, 2) + self.provide_data = [(data_name, (batch_size,) + self.data_shape)] + self.provide_label = [(label_name, (batch_size,) + self.label_shape)] + + def reset(self): + print('reset') + """Resets the iterator to the beginning of the data.""" + self.cur = 0 + + def next_sample(self): + """Helper function for reading in next sample.""" + if self.cur >= len(self.seq): + raise StopIteration + idx = self.seq[self.cur] + self.cur += 1 + s = self.imgrec.read_idx(idx) + header, img = recordio.unpack(s) + img = mx.image.imdecode(img).asnumpy() + #label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8) + hlabel = np.array(header.label).reshape( (self.num_classes,2) ) + return img, hlabel + + def next(self): + """Returns the next batch of data.""" + #print('next') + batch_size = self.batch_size + batch_data = nd.empty((batch_size,)+self.data_shape) + batch_label = nd.empty((batch_size,)+self.label_shape) + i = 0 + #self.cutoff = random.randint(800,1280) + try: + while i < batch_size: + #print('N', i) + data, label = self.next_sample() + data = nd.array(data) + data = nd.transpose(data, axes=(2, 0, 1)) + label = nd.array(label) + #print(data.shape, label.shape) + batch_data[i][:] = data + batch_label[i][:] = label + i += 1 + except StopIteration: + if not i: + raise StopIteration + + return mx.io.DataBatch([batch_data], [batch_label], batch_size - i) + +class FaceSegIter(DataIter): + def __init__(self, batch_size, + per_batch_size = 0, + path_imgrec = None, + aug_level = 0, + force_mirror = False, + use_coherent = 0, + args = None, + data_name = "data", + label_name = "softmax_label"): + self.aug_level = aug_level + self.force_mirror = force_mirror + self.use_coherent = use_coherent + self.batch_size = batch_size + self.per_batch_size = per_batch_size + self.data_name = data_name + self.label_name = label_name + assert path_imgrec + logging.info('loading recordio %s...', + path_imgrec) + path_imgidx = path_imgrec[0:-4]+".idx" + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type + self.seq = list(self.imgrec.keys) + print('train size', len(self.seq)) + self.cur = 0 + self.reset() + self.num_classes = args.num_classes + self.record_img_size = 384 + self.input_img_size = args.input_img_size + self.data_shape = (3, self.input_img_size, self.input_img_size) + self.label_classes = self.num_classes + if aug_level>0: + self.output_label_size = args.output_label_size + self.label_shape = (self.label_classes, self.output_label_size, self.output_label_size) + else: + self.output_label_size = args.input_img_size + #self.label_shape = (self.num_classes, 2) + self.label_shape = (self.num_classes, self.output_label_size, self.output_label_size) + self.provide_data = [(data_name, (batch_size,) + self.data_shape)] + self.provide_label = [(label_name, (batch_size,) + self.label_shape)] + if self.use_coherent>0: + #self.provide_label += [("softmax_label2", (batch_size,)+self.label_shape)] + self.provide_label += [("coherent_label", (batch_size,6))] + self.img_num = 0 + self.invalid_num = 0 + self.mode = 1 + self.vis = 0 + self.stats = [0,0] + self.mirror_set = [ + (22,23), + (21,24), + (20,25), + (19,26), + (18,27), + (40,43), + (39,44), + (38,45), + (37,46), + (42,47), + (41,48), + (33,35), + (32,36), + (51,53), + (50,54), + (62,64), + (61,65), + (49,55), + (49,55), + (68,66), + (60,56), + (59,57), + (1,17), + (2,16), + (3,15), + (4,14), + (5,13), + (6,12), + (7,11), + (8,10), + ] + + def get_data_shape(self): + return self.data_shape + + #def get_label_shape(self): + # return self.label_shape + + def get_shape_dict(self): + D = {} + for (k,v) in self.provide_data: + D[k] = v + for (k,v) in self.provide_label: + D[k] = v + return D + + def get_label_names(self): + D = [] + for (k,v) in self.provide_label: + D.append(k) + return D + + def reset(self): + print('reset') + """Resets the iterator to the beginning of the data.""" + if self.aug_level>0: + random.shuffle(self.seq) + self.cur = 0 + + def next_sample(self): + """Helper function for reading in next sample.""" + if self.cur >= len(self.seq): + raise StopIteration + idx = self.seq[self.cur] + self.cur += 1 + s = self.imgrec.read_idx(idx) + header, img = recordio.unpack(s) + img = mx.image.imdecode(img).asnumpy() + #label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8) + hlabel = np.array(header.label).reshape( (self.num_classes,2) ) + annot = {} + ul = np.array( (50000,50000), dtype=np.int32) + br = np.array( (0,0), dtype=np.int32) + for i in xrange(hlabel.shape[0]): + #hlabel[i] = hlabel[i][::-1] + h = int(hlabel[i][0]) + w = int(hlabel[i][1]) + key = np.array((h,w)) + #print(key.shape, ul.shape, br.shape) + ul = np.minimum(key, ul) + br = np.maximum(key, br) + #label[h][w] = i+1 + #label[i][h][w] = 1.0 + #print(h,w,i+1) + + if self.mode==0: + ul_margin = np.array( (60,30) ) + br_margin = np.array( (30,30) ) + crop_ul = ul + crop_ul = ul-ul_margin + crop_ul = np.maximum(crop_ul, np.array( (0,0), dtype=np.int32) ) + crop_br = br + crop_br = br+br_margin + crop_br = np.minimum(crop_br, np.array( (img.shape[0],img.shape[1]), dtype=np.int32 ) ) + elif self.mode==1: + #mm = (self.record_img_size - 256)//2 + #crop_ul = (mm, mm) + #crop_br = (self.record_img_size-mm, self.record_img_size-mm) + crop_ul = (0,0) + crop_br = (self.record_img_size, self.record_img_size) + annot['scale'] = 256 + #annot['center'] = np.array( (self.record_img_size/2+10, self.record_img_size/2) ) + else: + mm = (48, 64) + #mm = (64, 80) + crop_ul = mm + crop_br = (self.record_img_size-mm[0], self.record_img_size-mm[1]) + crop_ul = np.array(crop_ul) + crop_br = np.array(crop_br) + img = img[crop_ul[0]:crop_br[0],crop_ul[1]:crop_br[1],:] + #print(img.shape, crop_ul, crop_br) + invalid = False + for i in xrange(hlabel.shape[0]): + if (hlabel[i]= crop_br).any(): + invalid = True + hlabel[i] -= crop_ul + + #mm = np.amin(ul) + #mm2 = self.record_img_size - br + #mm2 = np.amin(mm2) + #mm = min(mm, mm2) + #print('mm',mm, ul, br) + #print('invalid', invalid) + if invalid: + self.invalid_num+=1 + + annot['invalid'] = invalid + #annot['scale'] = (self.record_img_size - mm*2)*1.1 + #if self.mode==1: + # annot['scale'] = 256 + #print(annot) + + #center = ul+br + #center /= 2.0 + #annot['center'] = center + + #img = img[ul[0]:br[0],ul[1]:br[1],:] + #scale = br-ul + #scale = max(scale[0], scale[1]) + #print(img.shape) + return img, hlabel, annot + + def do_aug(self, data, label, annot): + if self.vis: + self.img_num+=1 + #if self.img_num<=self.vis: + # filename = './vis/raw_%d.jpg' % (self.img_num) + # print('save', filename) + # draw = data.copy() + # for i in xrange(label.shape[0]): + # cv2.circle(draw, (label[i][1], label[i][0]), 1, (0, 0, 255), 2) + # scipy.misc.imsave(filename, draw) + + rotate = 0 + #scale = 1.0 + if 'scale' in annot: + scale = annot['scale'] + else: + scale = max(data.shape[0], data.shape[1]) + if 'center' in annot: + center = annot['center'] + else: + center = np.array( (data.shape[0]/2, data.shape[1]/2) ) + max_retry = 3 + if self.aug_level==0: + max_retry = 6 + retry = 0 + found = False + _scale = scale + while retry0: + rotate = np.random.randint(-40, 40) + #rotate2 = np.random.randint(-40, 40) + rotate2 = 0 + scale_config = 0.2 + #rotate = 0 + #scale_config = 0.0 + _scale = min(1+scale_config, max(1-scale_config, (np.random.randn() * scale_config) + 1)) + _scale *= scale + _scale = int(_scale) + #translate = np.random.randint(-5, 5, size=(2,)) + #center += translate + if self.mode==1: + cropped = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate) + if self.use_coherent==2: + cropped2 = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate2) + else: + cropped = img_helper.crop(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate) + #print('cropped', cropped.shape) + label_out = np.zeros(self.label_shape, dtype=np.float32) + label2_out = np.zeros(self.label_shape, dtype=np.float32) + G = 0 + #if self.use_coherent: + # G = 1 + _g = G + if G==0: + _g = 1 + #print('shape', label.shape, label_out.shape) + for i in xrange(label.shape[0]): + pt = label[i].copy() + pt = pt[::-1] + #print('before gaussian', label_out[i].shape, pt.shape) + _pt = pt.copy() + trans = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate) + #print(trans.shape) + if not img_helper.gaussian(label_out[i], trans, _g): + succ = False + break + if self.use_coherent==2: + _pt = pt.copy() + trans2 = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate2) + if not img_helper.gaussian(label2_out[i], trans2, _g): + succ = False + break + if not succ: + if self.aug_level==0: + _scale+=20 + continue + + + if self.use_coherent==1: + cropped2 = np.copy(cropped) + for k in xrange(cropped2.shape[2]): + cropped2[:,:,k] = np.fliplr(cropped2[:,:,k]) + label2_out = np.copy(label_out) + for k in xrange(label2_out.shape[0]): + label2_out[k,:,:] = np.fliplr(label2_out[k,:,:]) + new_label2_out = np.copy(label2_out) + for item in self.mirror_set: + mir = (item[0]-1, item[1]-1) + new_label2_out[mir[1]] = label2_out[mir[0]] + new_label2_out[mir[0]] = label2_out[mir[1]] + label2_out = new_label2_out + elif self.use_coherent==2: + pass + elif ((self.aug_level>0 and np.random.rand() < 0.5) or self.force_mirror): #flip aug + for k in xrange(cropped.shape[2]): + cropped[:,:,k] = np.fliplr(cropped[:,:,k]) + for k in xrange(label_out.shape[0]): + label_out[k,:,:] = np.fliplr(label_out[k,:,:]) + new_label_out = np.copy(label_out) + for item in self.mirror_set: + mir = (item[0]-1, item[1]-1) + new_label_out[mir[1]] = label_out[mir[0]] + new_label_out[mir[0]] = label_out[mir[1]] + label_out = new_label_out + + if G==0: + for k in xrange(label.shape[0]): + ind = np.unravel_index(np.argmax(label_out[k], axis=None), label_out[k].shape) + label_out[k,:,:] = 0.0 + label_out[k,ind[0],ind[1]] = 1.0 + if self.use_coherent: + ind = np.unravel_index(np.argmax(label2_out[k], axis=None), label2_out[k].shape) + label2_out[k,:,:] = 0.0 + label2_out[k,ind[0],ind[1]] = 1.0 + found = True + break + + + #self.stats[0]+=1 + if not found: + #self.stats[1]+=1 + #print('find aug error', retry) + #print(self.stats) + return None + + if self.vis>0 and self.img_num<=self.vis: + print('crop', data.shape, center, _scale, rotate, cropped.shape) + filename = './vis/cropped_%d.jpg' % (self.img_num) + print('save', filename) + draw = cropped.copy() + alabel = label_out.copy() + for i in xrange(label.shape[0]): + a = cv2.resize(alabel[i], (self.input_img_size, self.input_img_size)) + ind = np.unravel_index(np.argmax(a, axis=None), a.shape) + cv2.circle(draw, (ind[1], ind[0]), 1, (0, 0, 255), 2) + scipy.misc.imsave(filename, draw) + if not self.use_coherent: + return cropped, label_out + else: + rotate2 = 0 + r = rotate - rotate2 + #r = rotate2 - rotate + r = math.pi*r/180 + cos_r = math.cos(r) + sin_r = math.sin(r) + #c = cropped2.shape[0]//2 + #M = cv2.getRotationMatrix2D( (c, c), rotate2-rotate, 1) + M = np.array( [ + [cos_r, -1*sin_r, 0.0], + [sin_r, cos_r, 0.0] + ] ) + #print(M) + #M=np.array([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + # 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33, + # 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52, + # 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 62, 61, 60, 67, 66, 65]) + return cropped, label_out, cropped2, label2_out, M.flatten() + + def next(self): + """Returns the next batch of data.""" + #print('next') + batch_size = self.batch_size + batch_data = nd.empty((batch_size,)+self.data_shape) + batch_label = nd.empty((batch_size,)+self.label_shape) + if self.use_coherent: + batch_label2 = nd.empty((batch_size,)+self.label_shape) + batch_coherent_label = nd.empty((batch_size,6)) + i = 0 + #self.cutoff = random.randint(800,1280) + try: + while i < batch_size: + #print('N', i) + data, label, annot = self.next_sample() + if not self.use_coherent: + R = self.do_aug(data, label, annot) + if R is None: + continue + data, label = R + #data, label, data2, label2, M = R + #ind = np.unravel_index(np.argmax(label[0], axis=None), label[0].shape) + #print(label.shape, np.count_nonzero(label[0]), ind) + #print(label[0,25:35,0:10]) + data = nd.array(data) + data = nd.transpose(data, axes=(2, 0, 1)) + label = nd.array(label) + #print(data.shape, label.shape) + try: + self.check_valid_image(data) + except RuntimeError as e: + logging.debug('Invalid image, skipping: %s', str(e)) + continue + batch_data[i][:] = data + batch_label[i][:] = label + i += 1 + else: + R = self.do_aug(data, label, annot) + if R is None: + continue + data, label, data2, label2, M = R + data = nd.array(data) + data = nd.transpose(data, axes=(2, 0, 1)) + label = nd.array(label) + data2 = nd.array(data2) + data2 = nd.transpose(data2, axes=(2, 0, 1)) + label2 = nd.array(label2) + M = nd.array(M) + #print(data.shape, label.shape) + try: + self.check_valid_image(data) + except RuntimeError as e: + logging.debug('Invalid image, skipping: %s', str(e)) + continue + batch_data[i][:] = data + batch_label[i][:] = label + #batch_label2[i][:] = label2 + batch_coherent_label[i][:] = M + #i+=1 + j = i+self.per_batch_size//2 + batch_data[j][:] = data2 + batch_label[j][:] = label2 + batch_coherent_label[j][:] = M + i += 1 + if j%self.per_batch_size==self.per_batch_size-1: + i = j+1 + except StopIteration: + if not i: + raise StopIteration + + #return {self.data_name : batch_data, + # self.label_name : batch_label} + #print(batch_data.shape, batch_label.shape) + if not self.use_coherent: + return mx.io.DataBatch([batch_data], [batch_label], batch_size - i) + else: + #return mx.io.DataBatch([batch_data], [batch_label, batch_label2, batch_coherent_label], batch_size - i) + return mx.io.DataBatch([batch_data], [batch_label, batch_coherent_label], batch_size - i) + + def check_data_shape(self, data_shape): + """Checks if the input data shape is valid""" + if not len(data_shape) == 3: + raise ValueError('data_shape should have length 3, with dimensions CxHxW') + if not data_shape[0] == 3: + raise ValueError('This iterator expects inputs to have 3 channels.') + + def check_valid_image(self, data): + """Checks if the input data is valid""" + if len(data[0].shape) == 0: + raise RuntimeError('Data shape is wrong') + + def imdecode(self, s): + """Decodes a string or byte string to an NDArray. + See mx.img.imdecode for more details.""" + return imdecode(s) + + def read_image(self, fname): + """Reads an input image `fname` and returns the decoded raw bytes. + + Example usage: + ---------- + >>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. + """ + with open(os.path.join(self.path_root, fname), 'rb') as fin: + img = fin.read() + return img + + diff --git a/alignment/draw.py b/alignment/draw.py new file mode 100644 index 0000000..fa689f7 --- /dev/null +++ b/alignment/draw.py @@ -0,0 +1,70 @@ +import numpy as np +import skimage.draw + +def line(img, pt1, pt2, color, width): + # Draw a line on an image + # Make sure dimension of color matches number of channels in img + + # First get coordinates for corners of the line + diff = np.array([pt1[1] - pt2[1], pt1[0] - pt2[0]], np.float) + mag = np.linalg.norm(diff) + if mag >= 1: + diff *= width / (2 * mag) + x = np.array([pt1[0] - diff[0], pt2[0] - diff[0], pt2[0] + diff[0], pt1[0] + diff[0]], int) + y = np.array([pt1[1] + diff[1], pt2[1] + diff[1], pt2[1] - diff[1], pt1[1] - diff[1]], int) + else: + d = float(width) / 2 + x = np.array([pt1[0] - d, pt1[0] + d, pt1[0] + d, pt1[0] - d], int) + y = np.array([pt1[1] - d, pt1[1] - d, pt1[1] + d, pt1[1] + d], int) + + # noinspection PyArgumentList + rr, cc = skimage.draw.polygon(y, x, img.shape) + img[rr, cc] = color + + return img + +def limb(img, pt1, pt2, color, width): + # Specific handling of a limb, in case the annotation isn't there for one of the joints + if pt1[0] > 0 and pt2[0] > 0: + line(img, pt1, pt2, color, width) + elif pt1[0] > 0: + circle(img, pt1, color, width) + elif pt2[0] > 0: + circle(img, pt2, color, width) + +def gaussian(img, pt, sigma): + # Draw a 2D gaussian + + # Check that any part of the gaussian is in-bounds + ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] + br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] + if (ul[0] > img.shape[1] or ul[1] >= img.shape[0] or + br[0] < 0 or br[1] < 0): + # If not, just return the image as is + return img + + # Generate gaussian + size = 6 * sigma + 1 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0 = y0 = size // 2 + # The gaussian is not normalized, we want the center value to equal 1 + g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], img.shape[1]) + img_y = max(0, ul[1]), min(br[1], img.shape[0]) + + img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] + return img + +def circle(img, pt, color, radius): + # Draw a circle + # Mostly a convenient wrapper for skimage.draw.circle + + rr, cc = skimage.draw.circle(pt[1], pt[0], radius, img.shape) + img[rr, cc] = color + return img diff --git a/alignment/hg2.py b/alignment/hg2.py new file mode 100644 index 0000000..d89bcb6 --- /dev/null +++ b/alignment/hg2.py @@ -0,0 +1,853 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import mxnet as mx +import numpy as np + + +ACT_BIT = 1 +N = 4 +use_STN = False +use_DLA = 0 +DCN = 0 + + + +def Conv(**kwargs): + #name = kwargs.get('name') + #_weight = mx.symbol.Variable(name+'_weight') + #_bias = mx.symbol.Variable(name+'_bias', lr_mult=2.0, wd_mult=0.0) + #body = mx.sym.Convolution(weight = _weight, bias = _bias, **kwargs) + body = mx.sym.Convolution(**kwargs) + return body + + +def Act(data, act_type, name): + if act_type=='prelu': + body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name) + else: + body = mx.symbol.Activation(data=data, act_type=act_type, name=name) + return body + +def lin(data, num_filter, workspace, name, binarize, dcn): + bn_mom = 0.9 + bit = 1 + if not binarize: + if not dcn: + conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv') + bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn') + act1 = Act(data=bn1, act_type='relu', name=name + '_relu') + return act1 + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn') + act1 = Act(data=bn1, act_type='relu', name=name + '_relu') + conv1_offset = mx.symbol.Convolution(name=name+'_conv_offset', data = act1, + num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1)) + conv1 = mx.contrib.symbol.DeformableConvolution(name=name+"_conv", data=act1, offset=conv1_offset, + num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False) + #conv1 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1), + # no_bias=False, workspace=workspace, name=name + '_conv') + return conv1 + else: + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn') + act1 = Act(data=bn1, act_type='relu', name=name + '_relu') + conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv', act_bit=ACT_BIT, weight_bit=bit) + conv1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2') + return conv1 + +def lin2(data, num_filter, workspace, name): + bn_mom = 0.9 + conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv') + bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn') + act1 = Act(data=bn1, act_type='relu', name=name + '_relu') + return act1 + +def lin3(data, num_filter, workspace, name, k, g=1, d=1): + bn_mom = 0.9 + #bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn') + #act1 = Act(data=bn1, act_type='relu', name=name + '_relu') + #conv1 = Conv(data=act1, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=((k-1)//2,(k-1)//2), num_group=g, + # no_bias=True, workspace=workspace, name=name + '_conv') + #return conv1 + if k!=3: + conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=((k-1)//2,(k-1)//2), num_group=g, + no_bias=True, workspace=workspace, name=name + '_conv') + else: + conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=(d,d), num_group=g, dilate=(d, d), + no_bias=True, workspace=workspace, name=name + '_conv') + bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn') + act1 = Act(data=bn1, act_type='relu', name=name + '_relu') + ret = act1 + #if g>1 and k==3 and d==1: + # body = Conv(data=ret, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), num_group=1, + # no_bias=True, workspace=workspace, name=name + '_conv2') + # body = mx.sym.BatchNorm(data=body, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2') + # body = Act(data=body, act_type='relu', name=name + '_relu2') + # ret = body + return ret + +def lin3_red(data, num_filter, workspace, name, k, g=1): + bn_mom = 0.9 + conv1 = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(k,k), pad=(1,1), num_group=g, + no_bias=True, workspace=workspace, name=name + '_conv', attr={'lr_mult':'1'}) + bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn', attr={'lr_mult':'1'}) + act1 = Act(data=bn1, act_type='sigmoid', name=name + '_relu') + ret = act1 + #if g>1 and k==3 and d==1: + # body = Conv(data=ret, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), num_group=1, + # no_bias=True, workspace=workspace, name=name + '_conv2') + # body = mx.sym.BatchNorm(data=body, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2') + # body = Act(data=body, act_type='relu', name=name + '_relu2') + # ret = body + return ret + +class RES: + def __init__(self, data, nFilters, nModules, n, workspace, name, dilate, group): + self.data = data + self.nFilters = nFilters + self.nModules = nModules + self.n = n + self.workspace = workspace + self.name = name + self.dilate = dilate + self.group = group + self.sym_map = {} + + def get_output(self, w, h): + key = (w, h) + if key in self.sym_map: + return self.sym_map[key] + ret = None + if h==self.n: + if w==self.n: + ret = (self.data, self.nFilters) + else: + x = self.get_output(w+1, h) + f = int(x[1]*0.5) + if w!=self.n-1: + body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, 1) + else: + body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, self.dilate) + ret = (body,f) + else: + x = self.get_output(w+1, h+1) + y = self.get_output(w, h+1) + if h%2==1 and h!=w: + xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1]) + #xbody = xbody+x[0] + else: + xbody = x[0] + #xbody = x[0] + #xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1]) + if w==0: + ybody = lin3(y[0], y[1], self.workspace, "%s_w%d_h%d_3"%(self.name, w, h), 3, self.group) + else: + ybody = y[0] + ybody = mx.sym.concat(y[0], ybody, dim=1) + body = mx.sym.add_n(xbody,ybody, name="%s_w%d_h%d_add"%(self.name, w, h)) + body = body/2 + ret = (body, x[1]) + self.sym_map[key] = ret + return ret + + def get(self): + return self.get_output(1, 1)[0] + +def residual_unit_a(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs): + bn_mom = kwargs.get('bn_mom', 0.9) + workspace = kwargs.get('workspace', 256) + memonger = kwargs.get('memonger', False) + bit = 1 + #print('in unit2') + # the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') + if not binarize: + act1 = Act(data=bn1, act_type='relu', name=name + '_relu1') + conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv1') + else: + act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True) + conv1 = mx.sym.QConvolution(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit) + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') + if not binarize: + act2 = Act(data=bn2, act_type='relu', name=name + '_relu2') + conv2 = Conv(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2') + else: + act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True) + conv2 = mx.sym.QConvolution(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit) + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') + if not binarize: + act3 = Act(data=bn3, act_type='relu', name=name + '_relu3') + conv3 = Conv(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True, + workspace=workspace, name=name + '_conv3') + else: + act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True) + conv3 = mx.sym.QConvolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit) + #if binarize: + # conv3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4') + if dim_match: + shortcut = data + else: + if not binarize: + shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + else: + shortcut = mx.sym.QConvolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit) + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv3 + shortcut + + +def residual_unit_g(data, num_filter, stride, dim_match, name, binarize, dcn, dilation, **kwargs): + bn_mom = kwargs.get('bn_mom', 0.9) + workspace = kwargs.get('workspace', 256) + memonger = kwargs.get('memonger', False) + bit = 1 + #print('in unit2') + # the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper + bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') + if not binarize: + act1 = Act(data=bn1, act_type='relu', name=name + '_relu1') + if not dcn: + conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation), + no_bias=True, workspace=workspace, name=name + '_conv1') + else: + conv1_offset = mx.symbol.Convolution(name=name+'_conv1_offset', data = act1, + num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1)) + conv1 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv1', data=act1, offset=conv1_offset, + num_filter=int(num_filter*0.5), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True) + else: + act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True) + conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit) + bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') + if not binarize: + act2 = Act(data=bn2, act_type='relu', name=name + '_relu2') + if not dcn: + conv2 = Conv(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation), + no_bias=True, workspace=workspace, name=name + '_conv2') + else: + conv2_offset = mx.symbol.Convolution(name=name+'_conv2_offset', data = act2, + num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1)) + conv2 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv2', data=act2, offset=conv2_offset, + num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True) + else: + act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True) + conv2 = mx.sym.QConvolution_v1(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit) + bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') + if not binarize: + act3 = Act(data=bn3, act_type='relu', name=name + '_relu3') + if not dcn: + conv3 = Conv(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation), + no_bias=True, workspace=workspace, name=name + '_conv3') + else: + conv3_offset = mx.symbol.Convolution(name=name+'_conv3_offset', data = act3, + num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1)) + conv3 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv3', data=act3, offset=conv3_offset, + num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True) + else: + act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True) + conv3 = mx.sym.QConvolution_v1(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit) + conv4 = mx.symbol.Concat(*[conv1, conv2, conv3]) + if binarize: + conv4 = mx.sym.BatchNorm(data=conv4, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4') + if dim_match: + shortcut = data + else: + if not binarize: + shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True, + workspace=workspace, name=name+'_sc') + else: + #assert(False) + shortcut = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0), + no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit) + shortcut = mx.sym.BatchNorm(data=shortcut, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_sc_bn') + if memonger: + shortcut._set_attr(mirror_stage='True') + return conv4 + shortcut + #return bn4 + shortcut + #return act4 + shortcut + +def ConvFactory(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), act_type="relu", mirror_attr={}, with_act=True): + conv = mx.symbol.Convolution( + data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad) + bn = mx.symbol.BatchNorm(data=conv) + if with_act: + act = mx.symbol.Activation( + data=bn, act_type=act_type, attr=mirror_attr) + return act + else: + return bn + +def block17(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}): + tower_conv = ConvFactory(net, 192, (1, 1)) + tower_conv1_0 = ConvFactory(net, 129, (1, 1)) + tower_conv1_1 = ConvFactory(tower_conv1_0, 160, (1, 7), pad=(1, 2)) + tower_conv1_2 = ConvFactory(tower_conv1_1, 192, (7, 1), pad=(2, 1)) + tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2]) + tower_out = ConvFactory( + tower_mixed, input_num_channels, (1, 1), with_act=False) + net = net+scale * tower_out + if with_act: + act = mx.symbol.Activation( + data=net, act_type=act_type, attr=mirror_attr) + return act + else: + return net + +def block35(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}): + M = 1.0 + tower_conv = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1)) + tower_conv1_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1)) + tower_conv1_1 = ConvFactory(tower_conv1_0, int(input_num_channels*0.25*M), (3, 3), pad=(1, 1)) + tower_conv2_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1)) + tower_conv2_1 = ConvFactory(tower_conv2_0, int(input_num_channels*0.375*M), (3, 3), pad=(1, 1)) + tower_conv2_2 = ConvFactory(tower_conv2_1, int(input_num_channels*0.5*M), (3, 3), pad=(1, 1)) + tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2]) + tower_out = ConvFactory( + tower_mixed, input_num_channels, (1, 1), with_act=False) + + net = net+scale * tower_out + if with_act: + act = mx.symbol.Activation( + data=net, act_type=act_type, attr=mirror_attr) + return act + else: + return net + +def residual_unit_i(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs): + bn_mom = kwargs.get('bn_mom', 0.9) + workspace = kwargs.get('workspace', 256) + memonger = kwargs.get('memonger', False) + assert not binarize + if stride[0]>1 or not dim_match: + return residual_unit_a(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs) + conv4 = block35(data, num_filter) + return conv4 + +def residual_unit_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs): + bn_mom = kwargs.get('bn_mom', 0.9) + workspace = kwargs.get('workspace', 256) + memonger = kwargs.get('memonger', False) + if stride[0]>1 or not dim_match: + return residual_unit_g(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs) + res = RES(data, num_filter, 1, 4, workspace, name, dilate, 1) + return res.get() + +def residual_unit(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs): + #binarize = False + #binarize = BINARIZE + return residual_unit_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs) + +def hourglass(data, nFilters, nModules, n, workspace, name, binarize, dcn): + s = 2 + _dcn = False + up1 = data + for i in xrange(nModules): + up1 = residual_unit(up1, nFilters, (1,1), True, "%s_up1_%d"%(name,i), binarize, _dcn, 1) + low1 = mx.sym.Pooling(data=data, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max') + for i in xrange(nModules): + low1 = residual_unit(low1, nFilters, (1,1), True, "%s_low1_%d"%(name,i), binarize, _dcn, 1) + if n>1: + low2 = hourglass(low1, nFilters, nModules, n-1, workspace, "%s_%d"%(name, n-1), binarize, dcn) + else: + low2 = low1 + for i in xrange(nModules): + low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), binarize, _dcn, 1) #TODO + #low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), False) #TODO + low3 = low2 + for i in xrange(nModules): + low3 = residual_unit(low3, nFilters, (1,1), True, "%s_low3_%d"%(name,i), binarize, _dcn, 1) + up2 = mx.symbol.Deconvolution(data=low3, num_filter=nFilters, kernel=(s,s), + stride=(s, s), + num_group=nFilters, no_bias=True, name='%s_upsampling_%s'%(name,n), + attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=workspace) + return mx.symbol.add_n(up1, up2) + +def hourglass2(data, nFilters, nModules, n, workspace, name, binarize, dcn): + s = 2 + _dcn = dcn + if DCN and n==N: + _dcn = True + _dcn = False + up1 = data + dilate = 2**(4-n) + for i in xrange(nModules): + up1 = residual_unit(up1, nFilters, (1,1), True, "%s_up1_%d"%(name,i), binarize, _dcn, dilate) + #low1 = mx.sym.Pooling(data=data, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max') + low1 = data + for i in xrange(nModules): + low1 = residual_unit(low1, nFilters, (1,1), True, "%s_low1_%d"%(name,i), binarize, _dcn, dilate) + if n>1: + low2 = hourglass2(low1, nFilters, nModules, n-1, workspace, "%s_%d"%(name, n-1), binarize, dcn) + else: + low2 = low1 + for i in xrange(nModules): + low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), binarize, _dcn, dilate) #TODO + #low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), False) #TODO + low3 = low2 + for i in xrange(nModules): + low3 = residual_unit(low3, nFilters, (1,1), True, "%s_low3_%d"%(name,i), binarize, _dcn, dilate) + up2 = low3 + #up2 = mx.symbol.Deconvolution(data=low3, num_filter=nFilters, kernel=(s,s), + # stride=(s, s), + # num_group=nFilters, no_bias=True, name='%s_upsampling_%s'%(name,n), + # attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=workspace) + return mx.symbol.add_n(up1, up2) + + +class DLA: + def __init__(self, data, nFilters, nModules, n, workspace, name): + self.data = data + self.nFilters = nFilters + self.nModules = nModules + self.n = n + self.workspace = workspace + self.name = name + self.sym_map = {} + + + def residual_unit(self, data, name, dilate=1, group=1): + res = RES(data, self.nFilters, self.nModules, 4, self.workspace, name, dilate, group) + return res.get() + #body = data + #for i in xrange(self.nModules): + # body = residual_unit(body, self.nFilters, (1,1), True, name, False, False, 1) + #return body + + def get_output(self, w, h): + #print(w,h) + assert w>=1 and w<=N+1 + assert h>=1 and h<=N+1 + s = 2 + bn_mom = 0.9 + key = (w,h) + if key in self.sym_map: + return self.sym_map[key] + ret = None + if h==self.n: + if w==self.n: + ret = self.data,64 + #elif w==1: + # x = self.get_output(w+1, h) + # body = self.residual_unit(x[0], "%s_w%d_h%d_1"%(self.name, w, h)) + # body = self.residual_unit(body, "%s_w%d_h%d_2"%(self.name, w, h), 2) + # ret = body,x[1] + else: + x = self.get_output(w+1, h) + body = self.residual_unit(x[0], "%s_w%d_h%d_1"%(self.name, w, h)) + body = mx.sym.Pooling(data=body, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max') + body = self.residual_unit(body, "%s_w%d_h%d_2"%(self.name, w, h)) + ret = body, x[1]//2 + else: + x = self.get_output(w+1, h+1) + y = self.get_output(w, h+1) + #xbody = Conv(data=x, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1), + # no_bias=True, workspace=self.workspace, name="%s_w%d_h%d_x_conv"%(self.name, w, h)) + #xbody = mx.sym.BatchNorm(data=xbody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_x_bn"%(self.name, w, h)) + #xbody = Act(data=xbody, act_type='relu', name="%s_w%d_h%d_x_act"%(self.name, w, h)) + + HC = False + + if use_DLA<10: + if h%2==1 and h!=w: + xbody = lin3(x[0], self.nFilters, self.workspace, "%s_w%d_h%d_x"%(self.name, w, h), 3, self.nFilters, 1) + HC = True + #xbody = x[0] + else: + xbody = x[0] + else: + xbody = lin3(x[0], self.nFilters, self.workspace, "%s_w%d_h%d_x"%(self.name, w, h), 3, 1, 1) + #xbody = x[0] + if x[1]//y[1]==2: + if w>1: + ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s), + stride=(s, s), + name='%s_upsampling_w%d_h%d'%(self.name,w, h), + attr={'lr_mult': '1.0'}, workspace=self.workspace) + ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h)) + ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h)) + #ybody = Conv(data=ybody, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1), + # no_bias=True, name="%s_w%d_h%d_y_conv2"%(self.name, w, h), workspace=self.workspace) + #ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn2"%(self.name, w, h)) + #ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act2"%(self.name, w, h)) + else: + if h>=1: + ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s), + stride=(s, s), + num_group=self.nFilters, no_bias=True, name='%s_upsampling_w%d_h%d'%(self.name,w, h), + attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=self.workspace) + #ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h)) + import math + #ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h)) + ybody = self.residual_unit(ybody, "%s_w%d_h%d_4"%(self.name, w, h)) + else: + ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s), + stride=(s, s), + name='%s_upsampling_w%d_h%d'%(self.name,w, h), + attr={'lr_mult': '1.0'}, workspace=self.workspace) + ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h)) + ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h)) + ybody = Conv(data=ybody, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, name="%s_w%d_h%d_y_conv2"%(self.name, w, h), workspace=self.workspace) + ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn2"%(self.name, w, h)) + ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act2"%(self.name, w, h)) + else: + ybody = self.residual_unit(y[0], "%s_w%d_h%d_5"%(self.name, w, h)) + #if not HC: + if use_DLA<10: + if use_DLA>1 and h==3 and w==2: + z = self.get_output(w+1, h) + zbody = z[0] + #zbody = lin3_red(zbody, self.nFilters, self.workspace, "%s_w%d_h%d_z"%(self.name, w, h), 2, self.nFilters) + #zbody = mx.sym.Pooling(data=zbody, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='avg') + zbody = mx.sym.Pooling(data=zbody, kernel=(z[1], z[1]), stride=(z[1],z[1]), pad=(0,0), pool_type='avg') + #zbody = mx.sym.Activation(data = zbody, act_type='sigmoid') + + #body = zbody+ybody + #body = body/2 + body = xbody+ybody + body = body/2 + #body = body*zbody + body = mx.sym.broadcast_mul(body, zbody) + #body = mx.sym.add_n(*[xbody, ybody, zbody]) + #body = body/3 + else: + body = xbody+ybody + body = body/2 + else: + if use_DLA==12 and h!=w: + zbody = self.get_output(w+1, h)[0] + zbody = lin3_red(zbody, self.nFilters, self.workspace, "%s_w%d_h%d_z"%(self.name, w, h), 2, 1) + body = mx.sym.add_n(*[xbody, ybody, zbody]) + body = body/3 + else: + body = xbody+ybody + body = body/2 + ret = body, x[1] + + assert ret is not None + self.sym_map[key] = ret + return ret + + def get(self): + return self.get_output(1, 1)[0] + + +def l2_loss(x, y): + loss = x-y + loss = loss*loss + loss = mx.symbol.mean(loss) + return loss + +def ce_loss(x, y): + body = mx.sym.exp(x) + sums = mx.sym.sum(body, axis=[2,3], keepdims=True) + body = mx.sym.broadcast_div(body, sums) + loss = mx.sym.log(body) + loss = loss*y*-1.0 + loss = mx.symbol.mean(loss) + return loss + +def get_symbol(num_classes, **kwargs): + global use_DLA + global N + global DCN + mirror_set = [ + (22,23), + (21,24), + (20,25), + (19,26), + (18,27), + (40,43), + (39,44), + (38,45), + (37,46), + (42,47), + (41,48), + (33,35), + (32,36), + (51,53), + (50,54), + (62,64), + (61,65), + (49,55), + (49,55), + (68,66), + (60,56), + (59,57), + (1,17), + (2,16), + (3,15), + (4,14), + (5,13), + (6,12), + (7,11), + (8,10), + ] + mirror_map = {} + for mm in mirror_set: + mirror_map[mm[0]-1] = mm[1]-1 + mirror_map[mm[1]-1] = mm[0]-1 + sFilters = 64 + mFilters = 128 + nFilters = 256 + + nModules = 1 + nStacks = 2 + bn_mom = kwargs.get('bn_mom', 0.9) + workspace = kwargs.get('workspace', 256) + binarize = kwargs.get('binarize', False) + input_size = kwargs.get('input_size', 128) + label_size = kwargs.get('label_size', 64) + use_coherent = kwargs.get('use_coherent', 0) + use_DLA = kwargs.get('use_dla', 0) + N = kwargs.get('use_N', 4) + DCN = kwargs.get('use_DCN', 0) + per_batch_size = kwargs.get('per_batch_size', 0) + print('binarize', binarize) + print('use_coherent', use_coherent) + print('use_DLA', use_DLA) + print('use_N', N) + print('use_DCN', DCN) + print('per_batch_size', per_batch_size) + assert(label_size==64 or label_size==32) + assert(input_size==128 or input_size==256) + D = input_size // label_size + print(input_size, label_size, D) + dcn = False + kwargs = {} + data = mx.sym.Variable(name='data') + data = data-127.5 + data = data*0.0078125 + gt_label = mx.symbol.Variable(name='softmax_label') + losses = [] + closses = [] + if use_coherent: + M = mx.sym.Variable(name="coherent_label") + #gt_label2 = mx.sym.Variable(name="softmax_label2") + coherent_weight = 0.0001 + ref_label = gt_label + if use_STN: + lr_mult = '0.00001' + loc_net = Conv(data=data, num_filter=sFilters, kernel=(7, 7), stride=(2,2), pad=(3, 3), + no_bias=True, name="stn_conv0", workspace=workspace, attr={'lr_mult': lr_mult}) + loc_net = mx.sym.BatchNorm(data=loc_net, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='stn_bn0', attr={'lr_mult': lr_mult}) + loc_net = Act(data=loc_net, act_type='relu', name='stn_relu0') + loc_net = Conv(data=loc_net, num_filter=mFilters, kernel=(3,3), stride=(1,1), pad=(1,1), + no_bias=True, name="stn_conv1", workspace=workspace, attr={'lr_mult': lr_mult}) + loc_net = mx.sym.BatchNorm(data=loc_net, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='stn_bn1', attr={'lr_mult': lr_mult}) + loc_net = Act(data=loc_net, act_type='relu', name='stn_relu1') + loc_net = mx.sym.Pooling(data=loc_net, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max') + + loc_net = mx.sym.FullyConnected(data=loc_net, num_hidden=int(nFilters*0.5), name='loc_net_half', attr={'lr_mult': lr_mult}) + loc_net = mx.sym.Activation(data=loc_net, act_type='tanh', name='loc_net_act') + #loc_net = mx.sym.Activation(data=loc_net, act_type='relu', name='loc_net_act') + #loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=6, name='loc_theta', attr={'lr_mult': lr_mult}) + #loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh') + loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=1, name='loc_theta', attr={'lr_mult': lr_mult}) + loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh') + loc_theta = loc_theta*0.5 + sin_t = mx.sym.sin(loc_theta) + m_sin_t = sin_t*-1.0 + cos_t = mx.sym.cos(loc_theta) + zero_t = mx.sym.zeros_like(loc_theta) + loc_theta = mx.sym.concat(*[cos_t, m_sin_t, zero_t, sin_t, cos_t, zero_t], dim=1) + data = mx.sym.SpatialTransformer(data = data, loc = loc_theta, target_shape=(input_size,input_size), transform_type="affine", sampler_type="bilinear") + ref_label = mx.sym.SpatialTransformer(data = ref_label, loc = loc_theta, target_shape=(label_size,label_size), transform_type="affine", sampler_type="bilinear") + #data = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn_data') + if D==4: + body = Conv(data=data, num_filter=sFilters, kernel=(7, 7), stride=(2,2), pad=(3, 3), + no_bias=True, name="conv0", workspace=workspace) + else: + body = Conv(data=data, num_filter=sFilters, kernel=(3, 3), stride=(1,1), pad=(1, 1), + no_bias=True, name="conv0", workspace=workspace) + body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0') + body = Act(data=body, act_type='relu', name='relu0') + + body = residual_unit(body, mFilters, (1,1), sFilters==mFilters, 'res0', False, dcn, 1, **kwargs) + #body = residual_unit(body, nFilters, (1,1), False, 'res0', binarize, **kwargs) + + body = mx.sym.Pooling(data=body, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max') + + body = residual_unit(body, mFilters, (1,1), True, 'res1', False, dcn, 1, **kwargs) #TODO + #body = residual_unit(body, nFilters, (1,1), True, 'res1', binarize, **kwargs) #TODO + body = residual_unit(body, nFilters, (1,1), mFilters==nFilters, 'res2', binarize, dcn, 1, **kwargs) #binarize=True? + #body = residual_unit(body, nFilters, (1,1), False, 'res2', False, **kwargs) #binarize=True? + + use_lin = True + heatmap = None + + for i in xrange(nStacks): + shortcut = body + if use_DLA>0: + dla = DLA(body, nFilters, nModules, N+1, workspace, 'dla%d'%(i)) + body = dla.get() + else: + body = hourglass(body, nFilters, nModules, N, workspace, 'stack%d_hg'%(i), binarize, dcn) + for j in xrange(nModules): + body = residual_unit(body, nFilters, (1,1), True, 'stack%d_unit%d'%(i,j), binarize, dcn, 1, **kwargs) + if use_lin: + _dcn = True if DCN>=2 else False + ll = lin(body, nFilters, workspace, name='stack%d_ll'%(i), binarize = False, dcn = _dcn) #TODO + #ll = lin(body, nFilters, workspace, name='stack%d_ll'%(i), binarize = binarize) + else: + ll = body + _name = "heatmap%d"%(i) + if i==nStacks-1: + _name = "heatmap" + + _dcn = True if DCN>=2 else False + if not _dcn: + out = Conv(data=ll, num_filter=num_classes, kernel=(1, 1), stride=(1,1), pad=(0,0), + name=_name, workspace=workspace) + else: + out_offset = mx.symbol.Convolution(name=_name+'_offset', data = ll, + num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1)) + out = mx.contrib.symbol.DeformableConvolution(name=_name, data=ll, offset=out_offset, + num_filter=num_classes, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False) + #out = Conv(data=ll, num_filter=num_classes, kernel=(3,3), stride=(1,1), pad=(1,1), + # name=_name, workspace=workspace) + if i==nStacks-1: + heatmap = out + #outs.append(out) + if use_coherent>0: + #px = mx.sym.slice_axis(out, axis=0, begin=0, end=b) + #py = mx.sym.slice_axis(ref_label, axis=0, begin=0, end=b) + px = out + py = ref_label + gloss = ce_loss(px, py) + gloss = gloss/nStacks + losses.append(gloss) + + b = per_batch_size//2 + ux = mx.sym.slice_axis(out, axis=0, begin=0, end=b) + dx = mx.sym.slice_axis(out, axis=0, begin=b, end=b*2) + if use_coherent==1: + ux = mx.sym.flip(ux, axis=3) + ux_list = [None]*68 + for k in xrange(68): + if k in mirror_map: + vk = mirror_map[k] + #print('k', k, vk) + ux_list[vk] = mx.sym.slice_axis(ux, axis=1, begin=k, end=k+1) + else: + ux_list[k] = mx.sym.slice_axis(ux, axis=1, begin=k, end=k+1) + ux = mx.sym.concat(*ux_list, dim=1) + #dx = mx.sym.slice_axis(ref_label, axis=0, begin=b, end=b*2) + #closs = ce_loss(ux, dx) + closs = l2_loss(ux, dx) + closs = closs/nStacks + closses.append(closs) + else: + m = mx.sym.slice_axis(M, axis=0, begin=0, end=b) + ux = mx.sym.SpatialTransformer(data=ux, loc=m, target_shape=(label_size, label_size), transform_type='affine', sampler_type='bilinear') + closs = l2_loss(ux, dx) + closs = closs/nStacks + closses.append(closs) + + else: + loss = ce_loss(out, ref_label) + loss = loss/nStacks + losses.append(loss) + + if i0: + closs = mx.symbol.add_n(*closses) + closs = mx.symbol.MakeLoss(closs, grad_scale = coherent_weight) + syms.append(closs) + #syms.append(mx.symbol.BlockGrad(M)) + #syms.append(mx.symbol.BlockGrad(px)) + #syms.append(mx.symbol.BlockGrad(qx)) + #syms.append(mx.symbol.BlockGrad(m)) + #syms.append(mx.symbol.BlockGrad(closs)) + if use_coherent>1: + syms.append(mx.symbol.BlockGrad(gt_label)) + if use_coherent>0: + syms.append(mx.symbol.BlockGrad(M)) + syms.append(pred) + sym = mx.symbol.Group( syms ) + return sym + +def init_weights(sym, data_shape_dict): + print('in hg2') + arg_name = sym.list_arguments() + aux_name = sym.list_auxiliary_states() + arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict) + arg_shape_dict = dict(zip(arg_name, arg_shape)) + aux_shape_dict = dict(zip(aux_name, aux_shape)) + #print(aux_shape) + #print(aux_params) + #print(arg_shape_dict) + arg_params = {} + aux_params = {} + for k,v in arg_shape_dict.iteritems(): + #print(k,v) + if k.endswith('offset_weight') or k.endswith('offset_bias'): + print('initializing',k) + arg_params[k] = mx.nd.zeros(shape = v) + elif k.startswith('fc6_'): + if k.endswith('_weight'): + print('initializing',k) + arg_params[k] = mx.random.normal(0, 0.01, shape=v) + elif k.endswith('_bias'): + print('initializing',k) + arg_params[k] = mx.nd.zeros(shape=v) + elif k.find('upsampling')>=0: + print('initializing upsampling_weight', k) + arg_params[k] = mx.nd.zeros(shape=arg_shape_dict[k]) + init = mx.init.Initializer() + init._init_bilinear(k, arg_params[k]) + elif k.find('loc_theta')>=0: + print('initializing STN', k, v) + if k.endswith('_weight'): + arg_params[k] = mx.nd.zeros(shape=v) + elif k.endswith('_bias'): + #val = np.array([4.0, 0.0, 0.0, 0.0, 4.0, 0.0], dtype=np.float32) + #val = np.array([0.0], dtype=np.float32) + #arg_params[k] = mx.nd.array(val) + arg_params[k] = mx.random.normal(0, 0.01, shape=v) + return arg_params, aux_params + diff --git a/alignment/img_helper.py b/alignment/img_helper.py new file mode 100644 index 0000000..394fdc8 --- /dev/null +++ b/alignment/img_helper.py @@ -0,0 +1,145 @@ +import numpy as np +import scipy.misc +import scipy.signal +import math + +#import draw +#import ref + +# ============================================================================= +# General image processing functions +# ============================================================================= + +def get_transform(center, scale, res, rot=0): + # Generate transformation matrix + #h = 200 * scale + #h = 100 * scale + h = scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3,3)) + rot_rad = rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + rot_mat[2,2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0,2] = -res[1]/2 + t_mat[1,2] = -res[0]/2 + t_inv = t_mat.copy() + t_inv[:2,2] *= -1 + t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) + return t + +def transform(pt, center, scale, res, invert=0, rot=0): + # Transform pixel location to different reference + t = get_transform(center, scale, res, rot=rot) + if invert: + t = np.linalg.inv(t) + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + #print('new_pt', new_pt.shape, new_pt) + return new_pt[:2].astype(int) + +def crop_center(img,crop_size): + y,x = img.shape[0], img.shape[1] + startx = x//2-(crop_size[1]//2) + starty = y//2-(crop_size[0]//2) + #print(startx, starty, crop_size) + return img[starty:(starty+crop_size[0]),startx:(startx+crop_size[1]),:] + +def crop(img, center, scale, res, rot=0): + # Upper left point + ul = np.array(transform([0, 0], center, scale, res, invert=1)) + # Bottom right point + br = np.array(transform(res, center, scale, res, invert=1)) + + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + #print('new_img', new_img.shape) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] + + if not rot == 0: + # Remove padding + #print('before rotate', new_img.shape, rot) + new_img = scipy.misc.imrotate(new_img, rot) + new_img = new_img[pad:-pad, pad:-pad] + + return scipy.misc.imresize(new_img, res) + +def crop2(img, center, scale, res, rot=0): + # Upper left point + rad = np.min( [center[0], img.shape[0] - center[0], center[1], img.shape[1] - center[1]] ) + new_img = img[(center[0]-rad):(center[0]+rad),(center[1]-rad):(center[1]+rad),:] + #print('new_img', new_img.shape) + if not rot == 0: + new_img = scipy.misc.imrotate(new_img, rot) + new_img = crop_center(new_img, (scale,scale)) + return scipy.misc.imresize(new_img, res) + +def nms(img): + # Do non-maximum suppression on a 2D array + win_size = 3 + domain = np.ones((win_size, win_size)) + maxes = scipy.signal.order_filter(img, domain, win_size ** 2 - 1) + diff = maxes - img + result = img.copy() + result[diff > 0] = 0 + return result + + +def gaussian(img, pt, sigma): + # Draw a 2D gaussian + assert(sigma>0) + + # Check that any part of the gaussian is in-bounds + ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] + br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] + if (ul[0] > img.shape[1] or ul[1] >= img.shape[0] or + br[0] < 0 or br[1] < 0): + # If not, just return the image as is + #print('gaussian error') + return False + #return img + + # Generate gaussian + size = 6 * sigma + 1 + x = np.arange(0, size, 1, float) + y = x[:, np.newaxis] + x0 = y0 = size // 2 + # The gaussian is not normalized, we want the center value to equal 1 + g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + + # Usable gaussian range + g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] + g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] + # Image range + img_x = max(0, ul[0]), min(br[0], img.shape[1]) + img_y = max(0, ul[1]), min(br[1], img.shape[0]) + + img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] + return True + #return img + diff --git a/alignment/infer.py b/alignment/infer.py new file mode 100644 index 0000000..5f02063 --- /dev/null +++ b/alignment/infer.py @@ -0,0 +1,56 @@ +import argparse +import cv2 +import numpy as np +import sys +import mxnet as mx + +parser = argparse.ArgumentParser(description='face model test') +# general +parser.add_argument('--image-size', default='128,128', help='') +parser.add_argument('--model', default='./models/test,15', help='path to load model.') +parser.add_argument('--gpu', default=-1, type=int, help='gpu id') +args = parser.parse_args() + +_vec = args.image_size.split(',') +assert len(_vec)==2 +image_size = (int(_vec[0]), int(_vec[1])) +_vec = args.model.split(',') +assert len(_vec)==2 +prefix = _vec[0] +epoch = int(_vec[1]) +print('loading',prefix, epoch) +if args.gpu>=0: + ctx = mx.gpu(args.gpu) +else: + ctx = mx.cpu() +sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +all_layers = sym.get_internals() +sym = all_layers['heatmap_output'] +model = mx.mod.Module(symbol=sym, context=ctx, label_names = None) +#model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) +model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))]) +model.set_params(arg_params, aux_params) +#img_path = '/raid5data/dplearn/megaface/facescrubr/112x112/Tom_Hanks/Tom_Hanks_54745.png' +img_path = './test.png' + +img = cv2.imread(img_path) + +rimg = cv2.resize(img, (image_size[1], image_size[0])) +img = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB) +img = np.transpose(img, (2,0,1)) #3*112*112, RGB +input_blob = np.expand_dims(img, axis=0) #1*3*112*112 +data = mx.nd.array(input_blob) +db = mx.io.DataBatch(data=(data,)) +model.forward(db, is_train=False) +output = model.get_outputs()[0].asnumpy() +#print(output[0,80]) +#sys.exit(0) +filename = "./vis/draw_%s" % img_path.split('/')[-1] +for i in xrange(output.shape[1]): + a = output[0,i,:,:] + a = cv2.resize(a, (image_size[1], image_size[0])) + ind = np.unravel_index(np.argmax(a, axis=None), a.shape) + cv2.circle(rimg, (ind[1], ind[0]), 1, (0, 0, 255), 2) + print(i, ind) +cv2.imwrite(filename, rimg) + diff --git a/alignment/train.py b/alignment/train.py new file mode 100644 index 0000000..bc1cd0b --- /dev/null +++ b/alignment/train.py @@ -0,0 +1,355 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +#import hg +import hg2 as hg +import logging +import argparse +from data import FaceSegIter +import mxnet as mx +import mxnet.optimizer as optimizer +import numpy as np +import os +import sys +import math +import random +import cv2 + + +args = None + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +class LossValueMetric(mx.metric.EvalMetric): + def __init__(self): + self.axis = 1 + super(LossValueMetric, self).__init__( + 'lossvalue', axis=self.axis, + output_names=None, label_names=None) + self.losses = [] + + def update(self, labels, preds): + loss = preds[0].asnumpy()[0] + self.sum_metric += loss + self.num_inst += 1.0 + #label = preds[1].asnumpy() + #out0 = preds[1].asnumpy()[0][0] + #l = 10 + #out0 = out0[20:20+l, 20:20+l] + #out1 = preds[2].asnumpy()[0][0] + #out1 = out1[20:20+l, 20:20+l] + + #m = preds[3].asnumpy()[0] + #theta = np.arcsin(m[3]) + #theta = theta/math.pi*180 + #print(out0) + #print('') + #print(out1) + #print('') + #print(m, theta) + #print(label[0]) + #for i in xrange(gt_label.shape[0]): + # label0 = gt_label[i][0] + # c = np.count_nonzero(label0) + # ind = np.unravel_index(np.argmax(label0, axis=None), label0.shape) + # print('A', i, ind, label0.shape, c) + + + +class NMEMetric(mx.metric.EvalMetric): + def __init__(self): + self.axis = 1 + super(NMEMetric, self).__init__( + 'NME', axis=self.axis, + output_names=None, label_names=None) + #self.losses = [] + self.count = 0 + + def update(self, labels, preds): + self.count+=1 + preds = [preds[-1]] + for label, pred_label in zip(labels, preds): + label = label.asnumpy() + pred_label = pred_label.asnumpy() + #print('acc',label.shape, pred_label.shape) + + nme = [] + for b in xrange(pred_label.shape[0]): + for p in xrange(pred_label.shape[1]): + heatmap_gt = label[b][p] + heatmap_pred = pred_label[b][p] + heatmap_pred = cv2.resize(heatmap_pred, (label.shape[2], label.shape[3])) + ind_gt = np.unravel_index(np.argmax(heatmap_gt, axis=None), heatmap_gt.shape) + ind_pred = np.unravel_index(np.argmax(heatmap_pred, axis=None), heatmap_pred.shape) + ind_gt = np.array(ind_gt) + ind_pred = np.array(ind_pred) + dist = np.sqrt(np.sum(np.square(ind_gt - ind_pred))) + nme.append(dist) + nme = np.mean(nme) + nme /= np.sqrt(float(label.shape[2]*label.shape[3])) + + self.sum_metric += nme + self.num_inst += 1.0 + +class NMEMetric2(mx.metric.EvalMetric): + def __init__(self): + self.axis = 1 + super(NMEMetric2, self).__init__( + 'NME2', axis=self.axis, + output_names=None, label_names=None) + #self.losses = [] + self.count = 0 + + def update(self, labels, preds): + self.count+=1 + preds = [preds[-1]] + for label, pred_label in zip(labels, preds): + label = label.asnumpy() + pred_label = pred_label.asnumpy() + #print('label', np.count_nonzero(label[0][36])) + #print('acc',label.shape, pred_label.shape) + #print(label.ndim) + + nme = [] + for b in xrange(pred_label.shape[0]): + record = [None]*6 + item = [] + if label.ndim==4: + _heatmap = label[b][36] + if np.count_nonzero(_heatmap)==0: + continue + else:#ndim==3 + #print(label[b]) + if np.count_nonzero(label[b])==0: + continue + for p in xrange(pred_label.shape[1]): + if label.ndim==4: + heatmap_gt = label[b][p] + ind_gt = np.unravel_index(np.argmax(heatmap_gt, axis=None), heatmap_gt.shape) + ind_gt = np.array(ind_gt) + else: + ind_gt = label[b][p] + #ind_gt = ind_gt.astype(np.int) + #print(ind_gt) + heatmap_pred = pred_label[b][p] + heatmap_pred = cv2.resize(heatmap_pred, (args.input_img_size, args.input_img_size)) + ind_pred = np.unravel_index(np.argmax(heatmap_pred, axis=None), heatmap_pred.shape) + ind_pred = np.array(ind_pred) + #print(ind_gt.shape) + #print(ind_pred) + if p==36: + #print('b', b, p, ind_gt, np.count_nonzero(heatmap_gt)) + record[0] = ind_gt + elif p==39: + record[1] = ind_gt + elif p==42: + record[2] = ind_gt + elif p==45: + record[3] = ind_gt + if record[4] is None or record[5] is None: + record[4] = ind_gt + record[5] = ind_gt + else: + record[4] = np.minimum(record[4], ind_gt) + record[5] = np.maximum(record[5], ind_gt) + #print(ind_gt.shape, ind_pred.shape) + value = np.sqrt(np.sum(np.square(ind_gt - ind_pred))) + item.append(value) + _nme = np.mean(item) + if args.norm_type=='2d': + left_eye = (record[0]+record[1])/2 + right_eye = (record[2]+record[3])/2 + _dist = np.sqrt(np.sum(np.square(left_eye - right_eye))) + #print('eye dist', _dist, left_eye, right_eye) + _nme /= _dist + else: + #_dist = np.sqrt(float(label.shape[2]*label.shape[3])) + _dist = np.sqrt(np.sum(np.square(record[5] - record[4]))) + #print(_dist) + _nme /= _dist + nme.append(_nme) + #print('nme', nme) + #nme = np.mean(nme) + + if len(nme)>0: + self.sum_metric += np.mean(nme) + self.num_inst += 1.0 + +def main(args): + _seed = 727 + random.seed(_seed) + np.random.seed(_seed) + mx.random.seed(_seed) + ctx = [] + cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip() + if len(cvd)>0: + for i in xrange(len(cvd.split(','))): + ctx.append(mx.gpu(i)) + if len(ctx)==0: + ctx = [mx.cpu()] + print('use cpu') + else: + print('gpu num:', len(ctx)) + #ctx = [mx.gpu(0)] + args.ctx_num = len(ctx) + + args.batch_size = args.per_batch_size*args.ctx_num + + + print('Call with', args) + train_iter = FaceSegIter(path_imgrec = os.path.join(args.data_dir, 'train.rec'), + batch_size = args.batch_size, + per_batch_size = args.per_batch_size, + aug_level = 1, + use_coherent = args.use_coherent, + args = args, + ) + targets = ['ibug', 'cofw_testset', '300W', 'AFLW2000-3D'] + + data_shape = train_iter.get_data_shape() + #label_shape = train_iter.get_label_shape() + sym = hg.get_symbol(num_classes=args.num_classes, binarize=args.binarize, label_size=args.output_label_size, input_size=args.input_img_size, use_coherent = args.use_coherent, use_dla = args.use_dla, use_N = args.use_N, use_DCN = args.use_DCN, per_batch_size = args.per_batch_size) + if len(args.pretrained)==0: + #data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape} + data_shape_dict = train_iter.get_shape_dict() + arg_params, aux_params = hg.init_weights(sym, data_shape_dict) + else: + vec = args.pretrained.split(',') + print('loading', vec) + _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1])) + #sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params) + + model = mx.mod.Module( + context = ctx, + symbol = sym, + label_names = train_iter.get_label_names(), + ) + #lr = 1.0e-3 + #lr = 2.5e-4 + lr = args.lr + #_rescale_grad = 1.0 + _rescale_grad = 1.0/args.ctx_num + #lr = args.lr + #opt = optimizer.SGD(learning_rate=lr, momentum=0.9, wd=5.e-4, rescale_grad=_rescale_grad) + #opt = optimizer.Adam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad) + opt = optimizer.Nadam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0) + #opt = optimizer.RMSProp(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad) + initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) + _cb = mx.callback.Speedometer(args.batch_size, 10) + _metric = LossValueMetric() + #_metric2 = AccMetric() + #eval_metrics = [_metric, _metric2] + eval_metrics = [_metric] + + #lr_steps = [40000,60000,80000] + #lr_steps = [12000,18000,22000] + if len(args.lr_steps)==0: + lr_steps = [16000,24000,30000] + #lr_steps = [14000,24000,30000] + #lr_steps = [5000,10000] + else: + lr_steps = [int(x) for x in args.lr_steps.split(',')] + _a = 40//args.batch_size + for i in xrange(len(lr_steps)): + lr_steps[i] *= _a + print('lr-steps', lr_steps) + global_step = [0] + + def val_test(): + all_layers = sym.get_internals() + vsym = all_layers['heatmap_output'] + vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names = None) + #model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) + vmodel.bind(data_shapes=[('data', (args.batch_size,)+data_shape)]) + arg_params, aux_params = model.get_params() + vmodel.set_params(arg_params, aux_params) + for target in targets: + _file = os.path.join(args.data_dir, '%s.rec'%target) + if not os.path.exists(_file): + continue + val_iter = FaceSegIter(path_imgrec = _file, + batch_size = args.batch_size, + #batch_size = 4, + aug_level = 0, + args = args, + ) + _metric = NMEMetric2() + val_metric = mx.metric.create(_metric) + val_metric.reset() + val_iter.reset() + for i, eval_batch in enumerate(val_iter): + #print(eval_batch.data[0].shape, eval_batch.label[0].shape) + batch_data = mx.io.DataBatch(eval_batch.data) + model.forward(batch_data, is_train=False) + model.update_metric(val_metric, eval_batch.label) + nme_value = val_metric.get_name_value()[0][1] + print('[%d][%s]NME: %f'%(global_step[0], target, nme_value)) + + def _batch_callback(param): + _cb(param) + global_step[0]+=1 + mbatch = global_step[0] + for _lr in lr_steps: + if mbatch==_lr: + opt.lr *= 0.2 + print('lr change to', opt.lr) + break + if mbatch%1000==0: + print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch) + if mbatch>0 and mbatch%args.verbose==0: + val_test() + if args.ckpt>0: + msave = mbatch//args.verbose + print('saving', msave) + arg, aux = model.get_params() + mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg, aux) + + + model.fit(train_iter, + begin_epoch = 0, + num_epoch = args.end_epoch, + #eval_data = val_iter, + eval_data = None, + eval_metric = eval_metrics, + kvstore = 'device', + optimizer = opt, + initializer = initializer, + arg_params = arg_params, + aux_params = aux_params, + allow_missing = True, + batch_end_callback = _batch_callback, + epoch_end_callback = None, + ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Train face 3d') + # general + parser.add_argument('--data-dir', default='./data', help='') + parser.add_argument('--prefix', default='./models/test', + help='directory to save model.') + parser.add_argument('--pretrained', default='', + help='') + parser.add_argument('--lr-steps', default='', type=str, help='') + parser.add_argument('--verbose', type=int, default=2000, help='') + parser.add_argument('--retrain', action='store_true', default=False, + help='true means continue training.') + parser.add_argument('--binarize', action='store_true', default=False, help='') + parser.add_argument('--end-epoch', type=int, default=87, + help='training epoch size.') + parser.add_argument('--per-batch-size', type=int, default=20, help='') + parser.add_argument('--num-classes', type=int, default=68, help='') + parser.add_argument('--input-img-size', type=int, default=128, help='') + parser.add_argument('--output-label-size', type=int, default=64, help='') + parser.add_argument('--lr', type=float, default=2.5e-4, help='') + parser.add_argument('--wd', type=float, default=5e-4, help='') + parser.add_argument('--ckpt', type=int, default=1, help='') + parser.add_argument('--norm-type', type=str, default='2d', help='') + parser.add_argument('--use-coherent', type=int, default=1, help='') + parser.add_argument('--use-dla', type=int, default=1, help='') + parser.add_argument('--use-N', type=int, default=3, help='') + parser.add_argument('--use-DCN', type=int, default=2, help='') + args = parser.parse_args() + main(args) +