68points face alignment

This commit is contained in:
Jia Guo
2018-05-16 00:07:30 +08:00
parent a97ba7f947
commit 78d22c8cea
7 changed files with 2089 additions and 0 deletions

65
alignment/benchmark.py Normal file
View File

@@ -0,0 +1,65 @@
import argparse
import cv2
import numpy as np
import sys
import mxnet as mx
import datetime
parser = argparse.ArgumentParser(description='face model test')
# general
parser.add_argument('--image-size', default='128,128', help='')
parser.add_argument('--model', default='./models/test,15', help='path to load model.')
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
parser.add_argument('--batch-size', default=10, type=int, help='batch size')
parser.add_argument('--iterations', default=10, type=int, help='iterations')
args = parser.parse_args()
_vec = args.image_size.split(',')
assert len(_vec)==2
image_size = (int(_vec[0]), int(_vec[1]))
_vec = args.model.split(',')
assert len(_vec)==2
prefix = _vec[0]
epoch = int(_vec[1])
print('loading',prefix, epoch)
if args.gpu>=0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['heatmap_output']
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
#model = mx.mod.Module(symbol=sym, context=ctx)
model.bind(for_training=False, data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))])
#model.bind(for_training=False, data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,84,64,64))])
model.set_params(arg_params, aux_params)
img_path = './test.png'
img = cv2.imread(img_path)
rimg = cv2.resize(img, (image_size[1], image_size[0]))
img = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
img = np.transpose(img, (2,0,1)) #3*112*112, RGB
input_blob = np.zeros( (args.batch_size, 3, image_size[1], image_size[0]),dtype=np.uint8 )
for i in xrange(args.batch_size):
input_blob[i] = img
data = mx.nd.array(input_blob)
print(data.shape)
label = mx.nd.zeros( (args.batch_size, 84, 64, 64) )
#db = mx.io.DataBatch(data=(data,))
db = mx.io.DataBatch(data=(data,), label=(label,))
stat = []
warmup = 2
for i in xrange(args.iterations+warmup):
#print(i)
time_now = datetime.datetime.now()
model.forward(db, is_train=False)
output = model.get_outputs()[-1].asnumpy()
time_now2 = datetime.datetime.now()
diff = time_now2 - time_now
stat.append(diff.total_seconds())
stat = stat[warmup:]
print(np.mean(stat)/args.batch_size)

545
alignment/data.py Normal file
View File

@@ -0,0 +1,545 @@
# pylint: skip-file
import mxnet as mx
import numpy as np
import sys, os
import random
import math
import scipy.misc
import cv2
import logging
import sklearn
import datetime
import img_helper
from mxnet.io import DataIter
from mxnet import ndarray as nd
from mxnet import io
from mxnet import recordio
from PIL import Image
class FaceSegIter0(DataIter):
def __init__(self, batch_size,
path_imgrec = None,
data_name = "data",
label_name = "softmax_label"):
self.batch_size = batch_size
self.data_name = data_name
self.label_name = label_name
assert path_imgrec
logging.info('loading recordio %s...',
path_imgrec)
path_imgidx = path_imgrec[0:-4]+".idx"
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
self.seq = list(self.imgrec.keys)
self.cur = 0
self.reset()
self.num_classes = 68
self.record_img_size = 384
self.input_img_size = self.record_img_size
self.data_shape = (3, self.input_img_size, self.input_img_size)
self.label_shape = (self.num_classes, 2)
self.provide_data = [(data_name, (batch_size,) + self.data_shape)]
self.provide_label = [(label_name, (batch_size,) + self.label_shape)]
def reset(self):
print('reset')
"""Resets the iterator to the beginning of the data."""
self.cur = 0
def next_sample(self):
"""Helper function for reading in next sample."""
if self.cur >= len(self.seq):
raise StopIteration
idx = self.seq[self.cur]
self.cur += 1
s = self.imgrec.read_idx(idx)
header, img = recordio.unpack(s)
img = mx.image.imdecode(img).asnumpy()
#label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8)
hlabel = np.array(header.label).reshape( (self.num_classes,2) )
return img, hlabel
def next(self):
"""Returns the next batch of data."""
#print('next')
batch_size = self.batch_size
batch_data = nd.empty((batch_size,)+self.data_shape)
batch_label = nd.empty((batch_size,)+self.label_shape)
i = 0
#self.cutoff = random.randint(800,1280)
try:
while i < batch_size:
#print('N', i)
data, label = self.next_sample()
data = nd.array(data)
data = nd.transpose(data, axes=(2, 0, 1))
label = nd.array(label)
#print(data.shape, label.shape)
batch_data[i][:] = data
batch_label[i][:] = label
i += 1
except StopIteration:
if not i:
raise StopIteration
return mx.io.DataBatch([batch_data], [batch_label], batch_size - i)
class FaceSegIter(DataIter):
def __init__(self, batch_size,
per_batch_size = 0,
path_imgrec = None,
aug_level = 0,
force_mirror = False,
use_coherent = 0,
args = None,
data_name = "data",
label_name = "softmax_label"):
self.aug_level = aug_level
self.force_mirror = force_mirror
self.use_coherent = use_coherent
self.batch_size = batch_size
self.per_batch_size = per_batch_size
self.data_name = data_name
self.label_name = label_name
assert path_imgrec
logging.info('loading recordio %s...',
path_imgrec)
path_imgidx = path_imgrec[0:-4]+".idx"
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
self.seq = list(self.imgrec.keys)
print('train size', len(self.seq))
self.cur = 0
self.reset()
self.num_classes = args.num_classes
self.record_img_size = 384
self.input_img_size = args.input_img_size
self.data_shape = (3, self.input_img_size, self.input_img_size)
self.label_classes = self.num_classes
if aug_level>0:
self.output_label_size = args.output_label_size
self.label_shape = (self.label_classes, self.output_label_size, self.output_label_size)
else:
self.output_label_size = args.input_img_size
#self.label_shape = (self.num_classes, 2)
self.label_shape = (self.num_classes, self.output_label_size, self.output_label_size)
self.provide_data = [(data_name, (batch_size,) + self.data_shape)]
self.provide_label = [(label_name, (batch_size,) + self.label_shape)]
if self.use_coherent>0:
#self.provide_label += [("softmax_label2", (batch_size,)+self.label_shape)]
self.provide_label += [("coherent_label", (batch_size,6))]
self.img_num = 0
self.invalid_num = 0
self.mode = 1
self.vis = 0
self.stats = [0,0]
self.mirror_set = [
(22,23),
(21,24),
(20,25),
(19,26),
(18,27),
(40,43),
(39,44),
(38,45),
(37,46),
(42,47),
(41,48),
(33,35),
(32,36),
(51,53),
(50,54),
(62,64),
(61,65),
(49,55),
(49,55),
(68,66),
(60,56),
(59,57),
(1,17),
(2,16),
(3,15),
(4,14),
(5,13),
(6,12),
(7,11),
(8,10),
]
def get_data_shape(self):
return self.data_shape
#def get_label_shape(self):
# return self.label_shape
def get_shape_dict(self):
D = {}
for (k,v) in self.provide_data:
D[k] = v
for (k,v) in self.provide_label:
D[k] = v
return D
def get_label_names(self):
D = []
for (k,v) in self.provide_label:
D.append(k)
return D
def reset(self):
print('reset')
"""Resets the iterator to the beginning of the data."""
if self.aug_level>0:
random.shuffle(self.seq)
self.cur = 0
def next_sample(self):
"""Helper function for reading in next sample."""
if self.cur >= len(self.seq):
raise StopIteration
idx = self.seq[self.cur]
self.cur += 1
s = self.imgrec.read_idx(idx)
header, img = recordio.unpack(s)
img = mx.image.imdecode(img).asnumpy()
#label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8)
hlabel = np.array(header.label).reshape( (self.num_classes,2) )
annot = {}
ul = np.array( (50000,50000), dtype=np.int32)
br = np.array( (0,0), dtype=np.int32)
for i in xrange(hlabel.shape[0]):
#hlabel[i] = hlabel[i][::-1]
h = int(hlabel[i][0])
w = int(hlabel[i][1])
key = np.array((h,w))
#print(key.shape, ul.shape, br.shape)
ul = np.minimum(key, ul)
br = np.maximum(key, br)
#label[h][w] = i+1
#label[i][h][w] = 1.0
#print(h,w,i+1)
if self.mode==0:
ul_margin = np.array( (60,30) )
br_margin = np.array( (30,30) )
crop_ul = ul
crop_ul = ul-ul_margin
crop_ul = np.maximum(crop_ul, np.array( (0,0), dtype=np.int32) )
crop_br = br
crop_br = br+br_margin
crop_br = np.minimum(crop_br, np.array( (img.shape[0],img.shape[1]), dtype=np.int32 ) )
elif self.mode==1:
#mm = (self.record_img_size - 256)//2
#crop_ul = (mm, mm)
#crop_br = (self.record_img_size-mm, self.record_img_size-mm)
crop_ul = (0,0)
crop_br = (self.record_img_size, self.record_img_size)
annot['scale'] = 256
#annot['center'] = np.array( (self.record_img_size/2+10, self.record_img_size/2) )
else:
mm = (48, 64)
#mm = (64, 80)
crop_ul = mm
crop_br = (self.record_img_size-mm[0], self.record_img_size-mm[1])
crop_ul = np.array(crop_ul)
crop_br = np.array(crop_br)
img = img[crop_ul[0]:crop_br[0],crop_ul[1]:crop_br[1],:]
#print(img.shape, crop_ul, crop_br)
invalid = False
for i in xrange(hlabel.shape[0]):
if (hlabel[i]<crop_ul).any() or (hlabel[i] >= crop_br).any():
invalid = True
hlabel[i] -= crop_ul
#mm = np.amin(ul)
#mm2 = self.record_img_size - br
#mm2 = np.amin(mm2)
#mm = min(mm, mm2)
#print('mm',mm, ul, br)
#print('invalid', invalid)
if invalid:
self.invalid_num+=1
annot['invalid'] = invalid
#annot['scale'] = (self.record_img_size - mm*2)*1.1
#if self.mode==1:
# annot['scale'] = 256
#print(annot)
#center = ul+br
#center /= 2.0
#annot['center'] = center
#img = img[ul[0]:br[0],ul[1]:br[1],:]
#scale = br-ul
#scale = max(scale[0], scale[1])
#print(img.shape)
return img, hlabel, annot
def do_aug(self, data, label, annot):
if self.vis:
self.img_num+=1
#if self.img_num<=self.vis:
# filename = './vis/raw_%d.jpg' % (self.img_num)
# print('save', filename)
# draw = data.copy()
# for i in xrange(label.shape[0]):
# cv2.circle(draw, (label[i][1], label[i][0]), 1, (0, 0, 255), 2)
# scipy.misc.imsave(filename, draw)
rotate = 0
#scale = 1.0
if 'scale' in annot:
scale = annot['scale']
else:
scale = max(data.shape[0], data.shape[1])
if 'center' in annot:
center = annot['center']
else:
center = np.array( (data.shape[0]/2, data.shape[1]/2) )
max_retry = 3
if self.aug_level==0:
max_retry = 6
retry = 0
found = False
_scale = scale
while retry<max_retry:
retry+=1
succ = True
if self.aug_level>0:
rotate = np.random.randint(-40, 40)
#rotate2 = np.random.randint(-40, 40)
rotate2 = 0
scale_config = 0.2
#rotate = 0
#scale_config = 0.0
_scale = min(1+scale_config, max(1-scale_config, (np.random.randn() * scale_config) + 1))
_scale *= scale
_scale = int(_scale)
#translate = np.random.randint(-5, 5, size=(2,))
#center += translate
if self.mode==1:
cropped = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate)
if self.use_coherent==2:
cropped2 = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate2)
else:
cropped = img_helper.crop(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate)
#print('cropped', cropped.shape)
label_out = np.zeros(self.label_shape, dtype=np.float32)
label2_out = np.zeros(self.label_shape, dtype=np.float32)
G = 0
#if self.use_coherent:
# G = 1
_g = G
if G==0:
_g = 1
#print('shape', label.shape, label_out.shape)
for i in xrange(label.shape[0]):
pt = label[i].copy()
pt = pt[::-1]
#print('before gaussian', label_out[i].shape, pt.shape)
_pt = pt.copy()
trans = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate)
#print(trans.shape)
if not img_helper.gaussian(label_out[i], trans, _g):
succ = False
break
if self.use_coherent==2:
_pt = pt.copy()
trans2 = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate2)
if not img_helper.gaussian(label2_out[i], trans2, _g):
succ = False
break
if not succ:
if self.aug_level==0:
_scale+=20
continue
if self.use_coherent==1:
cropped2 = np.copy(cropped)
for k in xrange(cropped2.shape[2]):
cropped2[:,:,k] = np.fliplr(cropped2[:,:,k])
label2_out = np.copy(label_out)
for k in xrange(label2_out.shape[0]):
label2_out[k,:,:] = np.fliplr(label2_out[k,:,:])
new_label2_out = np.copy(label2_out)
for item in self.mirror_set:
mir = (item[0]-1, item[1]-1)
new_label2_out[mir[1]] = label2_out[mir[0]]
new_label2_out[mir[0]] = label2_out[mir[1]]
label2_out = new_label2_out
elif self.use_coherent==2:
pass
elif ((self.aug_level>0 and np.random.rand() < 0.5) or self.force_mirror): #flip aug
for k in xrange(cropped.shape[2]):
cropped[:,:,k] = np.fliplr(cropped[:,:,k])
for k in xrange(label_out.shape[0]):
label_out[k,:,:] = np.fliplr(label_out[k,:,:])
new_label_out = np.copy(label_out)
for item in self.mirror_set:
mir = (item[0]-1, item[1]-1)
new_label_out[mir[1]] = label_out[mir[0]]
new_label_out[mir[0]] = label_out[mir[1]]
label_out = new_label_out
if G==0:
for k in xrange(label.shape[0]):
ind = np.unravel_index(np.argmax(label_out[k], axis=None), label_out[k].shape)
label_out[k,:,:] = 0.0
label_out[k,ind[0],ind[1]] = 1.0
if self.use_coherent:
ind = np.unravel_index(np.argmax(label2_out[k], axis=None), label2_out[k].shape)
label2_out[k,:,:] = 0.0
label2_out[k,ind[0],ind[1]] = 1.0
found = True
break
#self.stats[0]+=1
if not found:
#self.stats[1]+=1
#print('find aug error', retry)
#print(self.stats)
return None
if self.vis>0 and self.img_num<=self.vis:
print('crop', data.shape, center, _scale, rotate, cropped.shape)
filename = './vis/cropped_%d.jpg' % (self.img_num)
print('save', filename)
draw = cropped.copy()
alabel = label_out.copy()
for i in xrange(label.shape[0]):
a = cv2.resize(alabel[i], (self.input_img_size, self.input_img_size))
ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
cv2.circle(draw, (ind[1], ind[0]), 1, (0, 0, 255), 2)
scipy.misc.imsave(filename, draw)
if not self.use_coherent:
return cropped, label_out
else:
rotate2 = 0
r = rotate - rotate2
#r = rotate2 - rotate
r = math.pi*r/180
cos_r = math.cos(r)
sin_r = math.sin(r)
#c = cropped2.shape[0]//2
#M = cv2.getRotationMatrix2D( (c, c), rotate2-rotate, 1)
M = np.array( [
[cos_r, -1*sin_r, 0.0],
[sin_r, cos_r, 0.0]
] )
#print(M)
#M=np.array([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
# 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33,
# 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52,
# 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 62, 61, 60, 67, 66, 65])
return cropped, label_out, cropped2, label2_out, M.flatten()
def next(self):
"""Returns the next batch of data."""
#print('next')
batch_size = self.batch_size
batch_data = nd.empty((batch_size,)+self.data_shape)
batch_label = nd.empty((batch_size,)+self.label_shape)
if self.use_coherent:
batch_label2 = nd.empty((batch_size,)+self.label_shape)
batch_coherent_label = nd.empty((batch_size,6))
i = 0
#self.cutoff = random.randint(800,1280)
try:
while i < batch_size:
#print('N', i)
data, label, annot = self.next_sample()
if not self.use_coherent:
R = self.do_aug(data, label, annot)
if R is None:
continue
data, label = R
#data, label, data2, label2, M = R
#ind = np.unravel_index(np.argmax(label[0], axis=None), label[0].shape)
#print(label.shape, np.count_nonzero(label[0]), ind)
#print(label[0,25:35,0:10])
data = nd.array(data)
data = nd.transpose(data, axes=(2, 0, 1))
label = nd.array(label)
#print(data.shape, label.shape)
try:
self.check_valid_image(data)
except RuntimeError as e:
logging.debug('Invalid image, skipping: %s', str(e))
continue
batch_data[i][:] = data
batch_label[i][:] = label
i += 1
else:
R = self.do_aug(data, label, annot)
if R is None:
continue
data, label, data2, label2, M = R
data = nd.array(data)
data = nd.transpose(data, axes=(2, 0, 1))
label = nd.array(label)
data2 = nd.array(data2)
data2 = nd.transpose(data2, axes=(2, 0, 1))
label2 = nd.array(label2)
M = nd.array(M)
#print(data.shape, label.shape)
try:
self.check_valid_image(data)
except RuntimeError as e:
logging.debug('Invalid image, skipping: %s', str(e))
continue
batch_data[i][:] = data
batch_label[i][:] = label
#batch_label2[i][:] = label2
batch_coherent_label[i][:] = M
#i+=1
j = i+self.per_batch_size//2
batch_data[j][:] = data2
batch_label[j][:] = label2
batch_coherent_label[j][:] = M
i += 1
if j%self.per_batch_size==self.per_batch_size-1:
i = j+1
except StopIteration:
if not i:
raise StopIteration
#return {self.data_name : batch_data,
# self.label_name : batch_label}
#print(batch_data.shape, batch_label.shape)
if not self.use_coherent:
return mx.io.DataBatch([batch_data], [batch_label], batch_size - i)
else:
#return mx.io.DataBatch([batch_data], [batch_label, batch_label2, batch_coherent_label], batch_size - i)
return mx.io.DataBatch([batch_data], [batch_label, batch_coherent_label], batch_size - i)
def check_data_shape(self, data_shape):
"""Checks if the input data shape is valid"""
if not len(data_shape) == 3:
raise ValueError('data_shape should have length 3, with dimensions CxHxW')
if not data_shape[0] == 3:
raise ValueError('This iterator expects inputs to have 3 channels.')
def check_valid_image(self, data):
"""Checks if the input data is valid"""
if len(data[0].shape) == 0:
raise RuntimeError('Data shape is wrong')
def imdecode(self, s):
"""Decodes a string or byte string to an NDArray.
See mx.img.imdecode for more details."""
return imdecode(s)
def read_image(self, fname):
"""Reads an input image `fname` and returns the decoded raw bytes.
Example usage:
----------
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
"""
with open(os.path.join(self.path_root, fname), 'rb') as fin:
img = fin.read()
return img

70
alignment/draw.py Normal file
View File

@@ -0,0 +1,70 @@
import numpy as np
import skimage.draw
def line(img, pt1, pt2, color, width):
# Draw a line on an image
# Make sure dimension of color matches number of channels in img
# First get coordinates for corners of the line
diff = np.array([pt1[1] - pt2[1], pt1[0] - pt2[0]], np.float)
mag = np.linalg.norm(diff)
if mag >= 1:
diff *= width / (2 * mag)
x = np.array([pt1[0] - diff[0], pt2[0] - diff[0], pt2[0] + diff[0], pt1[0] + diff[0]], int)
y = np.array([pt1[1] + diff[1], pt2[1] + diff[1], pt2[1] - diff[1], pt1[1] - diff[1]], int)
else:
d = float(width) / 2
x = np.array([pt1[0] - d, pt1[0] + d, pt1[0] + d, pt1[0] - d], int)
y = np.array([pt1[1] - d, pt1[1] - d, pt1[1] + d, pt1[1] + d], int)
# noinspection PyArgumentList
rr, cc = skimage.draw.polygon(y, x, img.shape)
img[rr, cc] = color
return img
def limb(img, pt1, pt2, color, width):
# Specific handling of a limb, in case the annotation isn't there for one of the joints
if pt1[0] > 0 and pt2[0] > 0:
line(img, pt1, pt2, color, width)
elif pt1[0] > 0:
circle(img, pt1, color, width)
elif pt2[0] > 0:
circle(img, pt2, color, width)
def gaussian(img, pt, sigma):
# Draw a 2D gaussian
# Check that any part of the gaussian is in-bounds
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
if (ul[0] > img.shape[1] or ul[1] >= img.shape[0] or
br[0] < 0 or br[1] < 0):
# If not, just return the image as is
return img
# Generate gaussian
size = 6 * sigma + 1
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], img.shape[1])
img_y = max(0, ul[1]), min(br[1], img.shape[0])
img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return img
def circle(img, pt, color, radius):
# Draw a circle
# Mostly a convenient wrapper for skimage.draw.circle
rr, cc = skimage.draw.circle(pt[1], pt[0], radius, img.shape)
img[rr, cc] = color
return img

853
alignment/hg2.py Normal file
View File

@@ -0,0 +1,853 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import mxnet as mx
import numpy as np
ACT_BIT = 1
N = 4
use_STN = False
use_DLA = 0
DCN = 0
def Conv(**kwargs):
#name = kwargs.get('name')
#_weight = mx.symbol.Variable(name+'_weight')
#_bias = mx.symbol.Variable(name+'_bias', lr_mult=2.0, wd_mult=0.0)
#body = mx.sym.Convolution(weight = _weight, bias = _bias, **kwargs)
body = mx.sym.Convolution(**kwargs)
return body
def Act(data, act_type, name):
if act_type=='prelu':
body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name)
else:
body = mx.symbol.Activation(data=data, act_type=act_type, name=name)
return body
def lin(data, num_filter, workspace, name, binarize, dcn):
bn_mom = 0.9
bit = 1
if not binarize:
if not dcn:
conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv')
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
return act1
else:
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
conv1_offset = mx.symbol.Convolution(name=name+'_conv_offset', data = act1,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv1 = mx.contrib.symbol.DeformableConvolution(name=name+"_conv", data=act1, offset=conv1_offset,
num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
#conv1 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
# no_bias=False, workspace=workspace, name=name + '_conv')
return conv1
else:
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv', act_bit=ACT_BIT, weight_bit=bit)
conv1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
return conv1
def lin2(data, num_filter, workspace, name):
bn_mom = 0.9
conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv')
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
return act1
def lin3(data, num_filter, workspace, name, k, g=1, d=1):
bn_mom = 0.9
#bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
#act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
#conv1 = Conv(data=act1, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=((k-1)//2,(k-1)//2), num_group=g,
# no_bias=True, workspace=workspace, name=name + '_conv')
#return conv1
if k!=3:
conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=((k-1)//2,(k-1)//2), num_group=g,
no_bias=True, workspace=workspace, name=name + '_conv')
else:
conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=(d,d), num_group=g, dilate=(d, d),
no_bias=True, workspace=workspace, name=name + '_conv')
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
ret = act1
#if g>1 and k==3 and d==1:
# body = Conv(data=ret, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), num_group=1,
# no_bias=True, workspace=workspace, name=name + '_conv2')
# body = mx.sym.BatchNorm(data=body, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
# body = Act(data=body, act_type='relu', name=name + '_relu2')
# ret = body
return ret
def lin3_red(data, num_filter, workspace, name, k, g=1):
bn_mom = 0.9
conv1 = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(k,k), pad=(1,1), num_group=g,
no_bias=True, workspace=workspace, name=name + '_conv', attr={'lr_mult':'1'})
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn', attr={'lr_mult':'1'})
act1 = Act(data=bn1, act_type='sigmoid', name=name + '_relu')
ret = act1
#if g>1 and k==3 and d==1:
# body = Conv(data=ret, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), num_group=1,
# no_bias=True, workspace=workspace, name=name + '_conv2')
# body = mx.sym.BatchNorm(data=body, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
# body = Act(data=body, act_type='relu', name=name + '_relu2')
# ret = body
return ret
class RES:
def __init__(self, data, nFilters, nModules, n, workspace, name, dilate, group):
self.data = data
self.nFilters = nFilters
self.nModules = nModules
self.n = n
self.workspace = workspace
self.name = name
self.dilate = dilate
self.group = group
self.sym_map = {}
def get_output(self, w, h):
key = (w, h)
if key in self.sym_map:
return self.sym_map[key]
ret = None
if h==self.n:
if w==self.n:
ret = (self.data, self.nFilters)
else:
x = self.get_output(w+1, h)
f = int(x[1]*0.5)
if w!=self.n-1:
body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, 1)
else:
body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, self.dilate)
ret = (body,f)
else:
x = self.get_output(w+1, h+1)
y = self.get_output(w, h+1)
if h%2==1 and h!=w:
xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
#xbody = xbody+x[0]
else:
xbody = x[0]
#xbody = x[0]
#xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
if w==0:
ybody = lin3(y[0], y[1], self.workspace, "%s_w%d_h%d_3"%(self.name, w, h), 3, self.group)
else:
ybody = y[0]
ybody = mx.sym.concat(y[0], ybody, dim=1)
body = mx.sym.add_n(xbody,ybody, name="%s_w%d_h%d_add"%(self.name, w, h))
body = body/2
ret = (body, x[1])
self.sym_map[key] = ret
return ret
def get(self):
return self.get_output(1, 1)[0]
def residual_unit_a(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
bn_mom = kwargs.get('bn_mom', 0.9)
workspace = kwargs.get('workspace', 256)
memonger = kwargs.get('memonger', False)
bit = 1
#print('in unit2')
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
if not binarize:
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1')
else:
act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True)
conv1 = mx.sym.QConvolution(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit)
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
if not binarize:
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = Conv(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
else:
act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True)
conv2 = mx.sym.QConvolution(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit)
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
if not binarize:
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
conv3 = Conv(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
workspace=workspace, name=name + '_conv3')
else:
act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True)
conv3 = mx.sym.QConvolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit)
#if binarize:
# conv3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
if dim_match:
shortcut = data
else:
if not binarize:
shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
else:
shortcut = mx.sym.QConvolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit)
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv3 + shortcut
def residual_unit_g(data, num_filter, stride, dim_match, name, binarize, dcn, dilation, **kwargs):
bn_mom = kwargs.get('bn_mom', 0.9)
workspace = kwargs.get('workspace', 256)
memonger = kwargs.get('memonger', False)
bit = 1
#print('in unit2')
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
if not binarize:
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
if not dcn:
conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
no_bias=True, workspace=workspace, name=name + '_conv1')
else:
conv1_offset = mx.symbol.Convolution(name=name+'_conv1_offset', data = act1,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv1 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv1', data=act1, offset=conv1_offset,
num_filter=int(num_filter*0.5), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
else:
act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True)
conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit)
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
if not binarize:
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
if not dcn:
conv2 = Conv(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
no_bias=True, workspace=workspace, name=name + '_conv2')
else:
conv2_offset = mx.symbol.Convolution(name=name+'_conv2_offset', data = act2,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv2 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv2', data=act2, offset=conv2_offset,
num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
else:
act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True)
conv2 = mx.sym.QConvolution_v1(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit)
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
if not binarize:
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
if not dcn:
conv3 = Conv(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
no_bias=True, workspace=workspace, name=name + '_conv3')
else:
conv3_offset = mx.symbol.Convolution(name=name+'_conv3_offset', data = act3,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv3 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv3', data=act3, offset=conv3_offset,
num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
else:
act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True)
conv3 = mx.sym.QConvolution_v1(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit)
conv4 = mx.symbol.Concat(*[conv1, conv2, conv3])
if binarize:
conv4 = mx.sym.BatchNorm(data=conv4, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
if dim_match:
shortcut = data
else:
if not binarize:
shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
else:
#assert(False)
shortcut = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit)
shortcut = mx.sym.BatchNorm(data=shortcut, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_sc_bn')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv4 + shortcut
#return bn4 + shortcut
#return act4 + shortcut
def ConvFactory(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), act_type="relu", mirror_attr={}, with_act=True):
conv = mx.symbol.Convolution(
data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
bn = mx.symbol.BatchNorm(data=conv)
if with_act:
act = mx.symbol.Activation(
data=bn, act_type=act_type, attr=mirror_attr)
return act
else:
return bn
def block17(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
tower_conv = ConvFactory(net, 192, (1, 1))
tower_conv1_0 = ConvFactory(net, 129, (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1_0, 160, (1, 7), pad=(1, 2))
tower_conv1_2 = ConvFactory(tower_conv1_1, 192, (7, 1), pad=(2, 1))
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
tower_out = ConvFactory(
tower_mixed, input_num_channels, (1, 1), with_act=False)
net = net+scale * tower_out
if with_act:
act = mx.symbol.Activation(
data=net, act_type=act_type, attr=mirror_attr)
return act
else:
return net
def block35(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
M = 1.0
tower_conv = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1))
tower_conv1_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1))
tower_conv1_1 = ConvFactory(tower_conv1_0, int(input_num_channels*0.25*M), (3, 3), pad=(1, 1))
tower_conv2_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1))
tower_conv2_1 = ConvFactory(tower_conv2_0, int(input_num_channels*0.375*M), (3, 3), pad=(1, 1))
tower_conv2_2 = ConvFactory(tower_conv2_1, int(input_num_channels*0.5*M), (3, 3), pad=(1, 1))
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2])
tower_out = ConvFactory(
tower_mixed, input_num_channels, (1, 1), with_act=False)
net = net+scale * tower_out
if with_act:
act = mx.symbol.Activation(
data=net, act_type=act_type, attr=mirror_attr)
return act
else:
return net
def residual_unit_i(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
bn_mom = kwargs.get('bn_mom', 0.9)
workspace = kwargs.get('workspace', 256)
memonger = kwargs.get('memonger', False)
assert not binarize
if stride[0]>1 or not dim_match:
return residual_unit_a(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
conv4 = block35(data, num_filter)
return conv4
def residual_unit_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
bn_mom = kwargs.get('bn_mom', 0.9)
workspace = kwargs.get('workspace', 256)
memonger = kwargs.get('memonger', False)
if stride[0]>1 or not dim_match:
return residual_unit_g(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
res = RES(data, num_filter, 1, 4, workspace, name, dilate, 1)
return res.get()
def residual_unit(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
#binarize = False
#binarize = BINARIZE
return residual_unit_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
def hourglass(data, nFilters, nModules, n, workspace, name, binarize, dcn):
s = 2
_dcn = False
up1 = data
for i in xrange(nModules):
up1 = residual_unit(up1, nFilters, (1,1), True, "%s_up1_%d"%(name,i), binarize, _dcn, 1)
low1 = mx.sym.Pooling(data=data, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
for i in xrange(nModules):
low1 = residual_unit(low1, nFilters, (1,1), True, "%s_low1_%d"%(name,i), binarize, _dcn, 1)
if n>1:
low2 = hourglass(low1, nFilters, nModules, n-1, workspace, "%s_%d"%(name, n-1), binarize, dcn)
else:
low2 = low1
for i in xrange(nModules):
low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), binarize, _dcn, 1) #TODO
#low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), False) #TODO
low3 = low2
for i in xrange(nModules):
low3 = residual_unit(low3, nFilters, (1,1), True, "%s_low3_%d"%(name,i), binarize, _dcn, 1)
up2 = mx.symbol.Deconvolution(data=low3, num_filter=nFilters, kernel=(s,s),
stride=(s, s),
num_group=nFilters, no_bias=True, name='%s_upsampling_%s'%(name,n),
attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=workspace)
return mx.symbol.add_n(up1, up2)
def hourglass2(data, nFilters, nModules, n, workspace, name, binarize, dcn):
s = 2
_dcn = dcn
if DCN and n==N:
_dcn = True
_dcn = False
up1 = data
dilate = 2**(4-n)
for i in xrange(nModules):
up1 = residual_unit(up1, nFilters, (1,1), True, "%s_up1_%d"%(name,i), binarize, _dcn, dilate)
#low1 = mx.sym.Pooling(data=data, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
low1 = data
for i in xrange(nModules):
low1 = residual_unit(low1, nFilters, (1,1), True, "%s_low1_%d"%(name,i), binarize, _dcn, dilate)
if n>1:
low2 = hourglass2(low1, nFilters, nModules, n-1, workspace, "%s_%d"%(name, n-1), binarize, dcn)
else:
low2 = low1
for i in xrange(nModules):
low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), binarize, _dcn, dilate) #TODO
#low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), False) #TODO
low3 = low2
for i in xrange(nModules):
low3 = residual_unit(low3, nFilters, (1,1), True, "%s_low3_%d"%(name,i), binarize, _dcn, dilate)
up2 = low3
#up2 = mx.symbol.Deconvolution(data=low3, num_filter=nFilters, kernel=(s,s),
# stride=(s, s),
# num_group=nFilters, no_bias=True, name='%s_upsampling_%s'%(name,n),
# attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=workspace)
return mx.symbol.add_n(up1, up2)
class DLA:
def __init__(self, data, nFilters, nModules, n, workspace, name):
self.data = data
self.nFilters = nFilters
self.nModules = nModules
self.n = n
self.workspace = workspace
self.name = name
self.sym_map = {}
def residual_unit(self, data, name, dilate=1, group=1):
res = RES(data, self.nFilters, self.nModules, 4, self.workspace, name, dilate, group)
return res.get()
#body = data
#for i in xrange(self.nModules):
# body = residual_unit(body, self.nFilters, (1,1), True, name, False, False, 1)
#return body
def get_output(self, w, h):
#print(w,h)
assert w>=1 and w<=N+1
assert h>=1 and h<=N+1
s = 2
bn_mom = 0.9
key = (w,h)
if key in self.sym_map:
return self.sym_map[key]
ret = None
if h==self.n:
if w==self.n:
ret = self.data,64
#elif w==1:
# x = self.get_output(w+1, h)
# body = self.residual_unit(x[0], "%s_w%d_h%d_1"%(self.name, w, h))
# body = self.residual_unit(body, "%s_w%d_h%d_2"%(self.name, w, h), 2)
# ret = body,x[1]
else:
x = self.get_output(w+1, h)
body = self.residual_unit(x[0], "%s_w%d_h%d_1"%(self.name, w, h))
body = mx.sym.Pooling(data=body, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
body = self.residual_unit(body, "%s_w%d_h%d_2"%(self.name, w, h))
ret = body, x[1]//2
else:
x = self.get_output(w+1, h+1)
y = self.get_output(w, h+1)
#xbody = Conv(data=x, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
# no_bias=True, workspace=self.workspace, name="%s_w%d_h%d_x_conv"%(self.name, w, h))
#xbody = mx.sym.BatchNorm(data=xbody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_x_bn"%(self.name, w, h))
#xbody = Act(data=xbody, act_type='relu', name="%s_w%d_h%d_x_act"%(self.name, w, h))
HC = False
if use_DLA<10:
if h%2==1 and h!=w:
xbody = lin3(x[0], self.nFilters, self.workspace, "%s_w%d_h%d_x"%(self.name, w, h), 3, self.nFilters, 1)
HC = True
#xbody = x[0]
else:
xbody = x[0]
else:
xbody = lin3(x[0], self.nFilters, self.workspace, "%s_w%d_h%d_x"%(self.name, w, h), 3, 1, 1)
#xbody = x[0]
if x[1]//y[1]==2:
if w>1:
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
stride=(s, s),
name='%s_upsampling_w%d_h%d'%(self.name,w, h),
attr={'lr_mult': '1.0'}, workspace=self.workspace)
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
#ybody = Conv(data=ybody, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
# no_bias=True, name="%s_w%d_h%d_y_conv2"%(self.name, w, h), workspace=self.workspace)
#ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn2"%(self.name, w, h))
#ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act2"%(self.name, w, h))
else:
if h>=1:
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
stride=(s, s),
num_group=self.nFilters, no_bias=True, name='%s_upsampling_w%d_h%d'%(self.name,w, h),
attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=self.workspace)
#ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
import math
#ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
ybody = self.residual_unit(ybody, "%s_w%d_h%d_4"%(self.name, w, h))
else:
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
stride=(s, s),
name='%s_upsampling_w%d_h%d'%(self.name,w, h),
attr={'lr_mult': '1.0'}, workspace=self.workspace)
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
ybody = Conv(data=ybody, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, name="%s_w%d_h%d_y_conv2"%(self.name, w, h), workspace=self.workspace)
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn2"%(self.name, w, h))
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act2"%(self.name, w, h))
else:
ybody = self.residual_unit(y[0], "%s_w%d_h%d_5"%(self.name, w, h))
#if not HC:
if use_DLA<10:
if use_DLA>1 and h==3 and w==2:
z = self.get_output(w+1, h)
zbody = z[0]
#zbody = lin3_red(zbody, self.nFilters, self.workspace, "%s_w%d_h%d_z"%(self.name, w, h), 2, self.nFilters)
#zbody = mx.sym.Pooling(data=zbody, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='avg')
zbody = mx.sym.Pooling(data=zbody, kernel=(z[1], z[1]), stride=(z[1],z[1]), pad=(0,0), pool_type='avg')
#zbody = mx.sym.Activation(data = zbody, act_type='sigmoid')
#body = zbody+ybody
#body = body/2
body = xbody+ybody
body = body/2
#body = body*zbody
body = mx.sym.broadcast_mul(body, zbody)
#body = mx.sym.add_n(*[xbody, ybody, zbody])
#body = body/3
else:
body = xbody+ybody
body = body/2
else:
if use_DLA==12 and h!=w:
zbody = self.get_output(w+1, h)[0]
zbody = lin3_red(zbody, self.nFilters, self.workspace, "%s_w%d_h%d_z"%(self.name, w, h), 2, 1)
body = mx.sym.add_n(*[xbody, ybody, zbody])
body = body/3
else:
body = xbody+ybody
body = body/2
ret = body, x[1]
assert ret is not None
self.sym_map[key] = ret
return ret
def get(self):
return self.get_output(1, 1)[0]
def l2_loss(x, y):
loss = x-y
loss = loss*loss
loss = mx.symbol.mean(loss)
return loss
def ce_loss(x, y):
body = mx.sym.exp(x)
sums = mx.sym.sum(body, axis=[2,3], keepdims=True)
body = mx.sym.broadcast_div(body, sums)
loss = mx.sym.log(body)
loss = loss*y*-1.0
loss = mx.symbol.mean(loss)
return loss
def get_symbol(num_classes, **kwargs):
global use_DLA
global N
global DCN
mirror_set = [
(22,23),
(21,24),
(20,25),
(19,26),
(18,27),
(40,43),
(39,44),
(38,45),
(37,46),
(42,47),
(41,48),
(33,35),
(32,36),
(51,53),
(50,54),
(62,64),
(61,65),
(49,55),
(49,55),
(68,66),
(60,56),
(59,57),
(1,17),
(2,16),
(3,15),
(4,14),
(5,13),
(6,12),
(7,11),
(8,10),
]
mirror_map = {}
for mm in mirror_set:
mirror_map[mm[0]-1] = mm[1]-1
mirror_map[mm[1]-1] = mm[0]-1
sFilters = 64
mFilters = 128
nFilters = 256
nModules = 1
nStacks = 2
bn_mom = kwargs.get('bn_mom', 0.9)
workspace = kwargs.get('workspace', 256)
binarize = kwargs.get('binarize', False)
input_size = kwargs.get('input_size', 128)
label_size = kwargs.get('label_size', 64)
use_coherent = kwargs.get('use_coherent', 0)
use_DLA = kwargs.get('use_dla', 0)
N = kwargs.get('use_N', 4)
DCN = kwargs.get('use_DCN', 0)
per_batch_size = kwargs.get('per_batch_size', 0)
print('binarize', binarize)
print('use_coherent', use_coherent)
print('use_DLA', use_DLA)
print('use_N', N)
print('use_DCN', DCN)
print('per_batch_size', per_batch_size)
assert(label_size==64 or label_size==32)
assert(input_size==128 or input_size==256)
D = input_size // label_size
print(input_size, label_size, D)
dcn = False
kwargs = {}
data = mx.sym.Variable(name='data')
data = data-127.5
data = data*0.0078125
gt_label = mx.symbol.Variable(name='softmax_label')
losses = []
closses = []
if use_coherent:
M = mx.sym.Variable(name="coherent_label")
#gt_label2 = mx.sym.Variable(name="softmax_label2")
coherent_weight = 0.0001
ref_label = gt_label
if use_STN:
lr_mult = '0.00001'
loc_net = Conv(data=data, num_filter=sFilters, kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="stn_conv0", workspace=workspace, attr={'lr_mult': lr_mult})
loc_net = mx.sym.BatchNorm(data=loc_net, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='stn_bn0', attr={'lr_mult': lr_mult})
loc_net = Act(data=loc_net, act_type='relu', name='stn_relu0')
loc_net = Conv(data=loc_net, num_filter=mFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, name="stn_conv1", workspace=workspace, attr={'lr_mult': lr_mult})
loc_net = mx.sym.BatchNorm(data=loc_net, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='stn_bn1', attr={'lr_mult': lr_mult})
loc_net = Act(data=loc_net, act_type='relu', name='stn_relu1')
loc_net = mx.sym.Pooling(data=loc_net, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max')
loc_net = mx.sym.FullyConnected(data=loc_net, num_hidden=int(nFilters*0.5), name='loc_net_half', attr={'lr_mult': lr_mult})
loc_net = mx.sym.Activation(data=loc_net, act_type='tanh', name='loc_net_act')
#loc_net = mx.sym.Activation(data=loc_net, act_type='relu', name='loc_net_act')
#loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=6, name='loc_theta', attr={'lr_mult': lr_mult})
#loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh')
loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=1, name='loc_theta', attr={'lr_mult': lr_mult})
loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh')
loc_theta = loc_theta*0.5
sin_t = mx.sym.sin(loc_theta)
m_sin_t = sin_t*-1.0
cos_t = mx.sym.cos(loc_theta)
zero_t = mx.sym.zeros_like(loc_theta)
loc_theta = mx.sym.concat(*[cos_t, m_sin_t, zero_t, sin_t, cos_t, zero_t], dim=1)
data = mx.sym.SpatialTransformer(data = data, loc = loc_theta, target_shape=(input_size,input_size), transform_type="affine", sampler_type="bilinear")
ref_label = mx.sym.SpatialTransformer(data = ref_label, loc = loc_theta, target_shape=(label_size,label_size), transform_type="affine", sampler_type="bilinear")
#data = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn_data')
if D==4:
body = Conv(data=data, num_filter=sFilters, kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="conv0", workspace=workspace)
else:
body = Conv(data=data, num_filter=sFilters, kernel=(3, 3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = Act(data=body, act_type='relu', name='relu0')
body = residual_unit(body, mFilters, (1,1), sFilters==mFilters, 'res0', False, dcn, 1, **kwargs)
#body = residual_unit(body, nFilters, (1,1), False, 'res0', binarize, **kwargs)
body = mx.sym.Pooling(data=body, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max')
body = residual_unit(body, mFilters, (1,1), True, 'res1', False, dcn, 1, **kwargs) #TODO
#body = residual_unit(body, nFilters, (1,1), True, 'res1', binarize, **kwargs) #TODO
body = residual_unit(body, nFilters, (1,1), mFilters==nFilters, 'res2', binarize, dcn, 1, **kwargs) #binarize=True?
#body = residual_unit(body, nFilters, (1,1), False, 'res2', False, **kwargs) #binarize=True?
use_lin = True
heatmap = None
for i in xrange(nStacks):
shortcut = body
if use_DLA>0:
dla = DLA(body, nFilters, nModules, N+1, workspace, 'dla%d'%(i))
body = dla.get()
else:
body = hourglass(body, nFilters, nModules, N, workspace, 'stack%d_hg'%(i), binarize, dcn)
for j in xrange(nModules):
body = residual_unit(body, nFilters, (1,1), True, 'stack%d_unit%d'%(i,j), binarize, dcn, 1, **kwargs)
if use_lin:
_dcn = True if DCN>=2 else False
ll = lin(body, nFilters, workspace, name='stack%d_ll'%(i), binarize = False, dcn = _dcn) #TODO
#ll = lin(body, nFilters, workspace, name='stack%d_ll'%(i), binarize = binarize)
else:
ll = body
_name = "heatmap%d"%(i)
if i==nStacks-1:
_name = "heatmap"
_dcn = True if DCN>=2 else False
if not _dcn:
out = Conv(data=ll, num_filter=num_classes, kernel=(1, 1), stride=(1,1), pad=(0,0),
name=_name, workspace=workspace)
else:
out_offset = mx.symbol.Convolution(name=_name+'_offset', data = ll,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
out = mx.contrib.symbol.DeformableConvolution(name=_name, data=ll, offset=out_offset,
num_filter=num_classes, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
#out = Conv(data=ll, num_filter=num_classes, kernel=(3,3), stride=(1,1), pad=(1,1),
# name=_name, workspace=workspace)
if i==nStacks-1:
heatmap = out
#outs.append(out)
if use_coherent>0:
#px = mx.sym.slice_axis(out, axis=0, begin=0, end=b)
#py = mx.sym.slice_axis(ref_label, axis=0, begin=0, end=b)
px = out
py = ref_label
gloss = ce_loss(px, py)
gloss = gloss/nStacks
losses.append(gloss)
b = per_batch_size//2
ux = mx.sym.slice_axis(out, axis=0, begin=0, end=b)
dx = mx.sym.slice_axis(out, axis=0, begin=b, end=b*2)
if use_coherent==1:
ux = mx.sym.flip(ux, axis=3)
ux_list = [None]*68
for k in xrange(68):
if k in mirror_map:
vk = mirror_map[k]
#print('k', k, vk)
ux_list[vk] = mx.sym.slice_axis(ux, axis=1, begin=k, end=k+1)
else:
ux_list[k] = mx.sym.slice_axis(ux, axis=1, begin=k, end=k+1)
ux = mx.sym.concat(*ux_list, dim=1)
#dx = mx.sym.slice_axis(ref_label, axis=0, begin=b, end=b*2)
#closs = ce_loss(ux, dx)
closs = l2_loss(ux, dx)
closs = closs/nStacks
closses.append(closs)
else:
m = mx.sym.slice_axis(M, axis=0, begin=0, end=b)
ux = mx.sym.SpatialTransformer(data=ux, loc=m, target_shape=(label_size, label_size), transform_type='affine', sampler_type='bilinear')
closs = l2_loss(ux, dx)
closs = closs/nStacks
closses.append(closs)
else:
loss = ce_loss(out, ref_label)
loss = loss/nStacks
losses.append(loss)
if i<nStacks-1:
if use_lin:
ll2 = Conv(data=ll, num_filter=nFilters, kernel=(1, 1), stride=(1,1), pad=(0,0),
name="stack%d_ll2"%(i), workspace=workspace)
else:
ll2 = body
out2 = Conv(data=out, num_filter=nFilters, kernel=(1, 1), stride=(1,1), pad=(0,0),
name="stack%d_out2"%(i), workspace=workspace)
body = mx.symbol.add_n(shortcut, ll2, out2)
_dcn = True if (DCN==1 or DCN==3) else False
if _dcn:
_name = "stack%d_out3" % (i)
out3_offset = mx.symbol.Convolution(name=_name+'_offset', data = body,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
out3 = mx.contrib.symbol.DeformableConvolution(name=_name, data=body, offset=out3_offset,
num_filter=nFilters, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
body = out3
#elif use_STN:
# loc_net = dla.get2()
# #loc_net = mx.sym.Pooling(data=loc_net, global_pool=True, kernel=(7, 7), pool_type='avg', name='loc_net_pool')
# loc_net = mx.sym.FullyConnected(data=loc_net, num_hidden=int(nFilters*0.5), name='loc_net_half', attr={'lr_mult': '0.0001'})
# loc_net = mx.sym.Activation(data=loc_net, act_type='tanh', name='loc_net_act')
# #loc_net = mx.sym.Activation(data=loc_net, act_type='relu', name='loc_net_act')
# loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=6, name='loc_theta', attr={'lr_mult': '0.0001'})
# loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh')
# body = mx.sym.SpatialTransformer(data = body, loc = loc_theta, target_shape=(label_size,label_size), transform_type="affine", sampler_type="bilinear")
# ref_label = mx.sym.SpatialTransformer(data = gt_label, loc = loc_theta, target_shape=(label_size,label_size), transform_type="affine", sampler_type="bilinear")
pred = mx.symbol.BlockGrad(heatmap)
loss = mx.symbol.add_n(*losses)
loss = mx.symbol.MakeLoss(loss)
syms = [loss]
if len(closses)>0:
closs = mx.symbol.add_n(*closses)
closs = mx.symbol.MakeLoss(closs, grad_scale = coherent_weight)
syms.append(closs)
#syms.append(mx.symbol.BlockGrad(M))
#syms.append(mx.symbol.BlockGrad(px))
#syms.append(mx.symbol.BlockGrad(qx))
#syms.append(mx.symbol.BlockGrad(m))
#syms.append(mx.symbol.BlockGrad(closs))
if use_coherent>1:
syms.append(mx.symbol.BlockGrad(gt_label))
if use_coherent>0:
syms.append(mx.symbol.BlockGrad(M))
syms.append(pred)
sym = mx.symbol.Group( syms )
return sym
def init_weights(sym, data_shape_dict):
print('in hg2')
arg_name = sym.list_arguments()
aux_name = sym.list_auxiliary_states()
arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
arg_shape_dict = dict(zip(arg_name, arg_shape))
aux_shape_dict = dict(zip(aux_name, aux_shape))
#print(aux_shape)
#print(aux_params)
#print(arg_shape_dict)
arg_params = {}
aux_params = {}
for k,v in arg_shape_dict.iteritems():
#print(k,v)
if k.endswith('offset_weight') or k.endswith('offset_bias'):
print('initializing',k)
arg_params[k] = mx.nd.zeros(shape = v)
elif k.startswith('fc6_'):
if k.endswith('_weight'):
print('initializing',k)
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
elif k.endswith('_bias'):
print('initializing',k)
arg_params[k] = mx.nd.zeros(shape=v)
elif k.find('upsampling')>=0:
print('initializing upsampling_weight', k)
arg_params[k] = mx.nd.zeros(shape=arg_shape_dict[k])
init = mx.init.Initializer()
init._init_bilinear(k, arg_params[k])
elif k.find('loc_theta')>=0:
print('initializing STN', k, v)
if k.endswith('_weight'):
arg_params[k] = mx.nd.zeros(shape=v)
elif k.endswith('_bias'):
#val = np.array([4.0, 0.0, 0.0, 0.0, 4.0, 0.0], dtype=np.float32)
#val = np.array([0.0], dtype=np.float32)
#arg_params[k] = mx.nd.array(val)
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
return arg_params, aux_params

145
alignment/img_helper.py Normal file
View File

@@ -0,0 +1,145 @@
import numpy as np
import scipy.misc
import scipy.signal
import math
#import draw
#import ref
# =============================================================================
# General image processing functions
# =============================================================================
def get_transform(center, scale, res, rot=0):
# Generate transformation matrix
#h = 200 * scale
#h = 100 * scale
h = scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3,3))
rot_rad = rot * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
rot_mat[2,2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0,2] = -res[1]/2
t_mat[1,2] = -res[0]/2
t_inv = t_mat.copy()
t_inv[:2,2] *= -1
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
# Transform pixel location to different reference
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
#print('new_pt', new_pt.shape, new_pt)
return new_pt[:2].astype(int)
def crop_center(img,crop_size):
y,x = img.shape[0], img.shape[1]
startx = x//2-(crop_size[1]//2)
starty = y//2-(crop_size[0]//2)
#print(startx, starty, crop_size)
return img[starty:(starty+crop_size[0]),startx:(startx+crop_size[1]),:]
def crop(img, center, scale, res, rot=0):
# Upper left point
ul = np.array(transform([0, 0], center, scale, res, invert=1))
# Bottom right point
br = np.array(transform(res, center, scale, res, invert=1))
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
#print('new_img', new_img.shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
if not rot == 0:
# Remove padding
#print('before rotate', new_img.shape, rot)
new_img = scipy.misc.imrotate(new_img, rot)
new_img = new_img[pad:-pad, pad:-pad]
return scipy.misc.imresize(new_img, res)
def crop2(img, center, scale, res, rot=0):
# Upper left point
rad = np.min( [center[0], img.shape[0] - center[0], center[1], img.shape[1] - center[1]] )
new_img = img[(center[0]-rad):(center[0]+rad),(center[1]-rad):(center[1]+rad),:]
#print('new_img', new_img.shape)
if not rot == 0:
new_img = scipy.misc.imrotate(new_img, rot)
new_img = crop_center(new_img, (scale,scale))
return scipy.misc.imresize(new_img, res)
def nms(img):
# Do non-maximum suppression on a 2D array
win_size = 3
domain = np.ones((win_size, win_size))
maxes = scipy.signal.order_filter(img, domain, win_size ** 2 - 1)
diff = maxes - img
result = img.copy()
result[diff > 0] = 0
return result
def gaussian(img, pt, sigma):
# Draw a 2D gaussian
assert(sigma>0)
# Check that any part of the gaussian is in-bounds
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
if (ul[0] > img.shape[1] or ul[1] >= img.shape[0] or
br[0] < 0 or br[1] < 0):
# If not, just return the image as is
#print('gaussian error')
return False
#return img
# Generate gaussian
size = 6 * sigma + 1
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], img.shape[1])
img_y = max(0, ul[1]), min(br[1], img.shape[0])
img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return True
#return img

56
alignment/infer.py Normal file
View File

@@ -0,0 +1,56 @@
import argparse
import cv2
import numpy as np
import sys
import mxnet as mx
parser = argparse.ArgumentParser(description='face model test')
# general
parser.add_argument('--image-size', default='128,128', help='')
parser.add_argument('--model', default='./models/test,15', help='path to load model.')
parser.add_argument('--gpu', default=-1, type=int, help='gpu id')
args = parser.parse_args()
_vec = args.image_size.split(',')
assert len(_vec)==2
image_size = (int(_vec[0]), int(_vec[1]))
_vec = args.model.split(',')
assert len(_vec)==2
prefix = _vec[0]
epoch = int(_vec[1])
print('loading',prefix, epoch)
if args.gpu>=0:
ctx = mx.gpu(args.gpu)
else:
ctx = mx.cpu()
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['heatmap_output']
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
#model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
#img_path = '/raid5data/dplearn/megaface/facescrubr/112x112/Tom_Hanks/Tom_Hanks_54745.png'
img_path = './test.png'
img = cv2.imread(img_path)
rimg = cv2.resize(img, (image_size[1], image_size[0]))
img = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
img = np.transpose(img, (2,0,1)) #3*112*112, RGB
input_blob = np.expand_dims(img, axis=0) #1*3*112*112
data = mx.nd.array(input_blob)
db = mx.io.DataBatch(data=(data,))
model.forward(db, is_train=False)
output = model.get_outputs()[0].asnumpy()
#print(output[0,80])
#sys.exit(0)
filename = "./vis/draw_%s" % img_path.split('/')[-1]
for i in xrange(output.shape[1]):
a = output[0,i,:,:]
a = cv2.resize(a, (image_size[1], image_size[0]))
ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
cv2.circle(rimg, (ind[1], ind[0]), 1, (0, 0, 255), 2)
print(i, ind)
cv2.imwrite(filename, rimg)

355
alignment/train.py Normal file
View File

@@ -0,0 +1,355 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#import hg
import hg2 as hg
import logging
import argparse
from data import FaceSegIter
import mxnet as mx
import mxnet.optimizer as optimizer
import numpy as np
import os
import sys
import math
import random
import cv2
args = None
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__(
'lossvalue', axis=self.axis,
output_names=None, label_names=None)
self.losses = []
def update(self, labels, preds):
loss = preds[0].asnumpy()[0]
self.sum_metric += loss
self.num_inst += 1.0
#label = preds[1].asnumpy()
#out0 = preds[1].asnumpy()[0][0]
#l = 10
#out0 = out0[20:20+l, 20:20+l]
#out1 = preds[2].asnumpy()[0][0]
#out1 = out1[20:20+l, 20:20+l]
#m = preds[3].asnumpy()[0]
#theta = np.arcsin(m[3])
#theta = theta/math.pi*180
#print(out0)
#print('')
#print(out1)
#print('')
#print(m, theta)
#print(label[0])
#for i in xrange(gt_label.shape[0]):
# label0 = gt_label[i][0]
# c = np.count_nonzero(label0)
# ind = np.unravel_index(np.argmax(label0, axis=None), label0.shape)
# print('A', i, ind, label0.shape, c)
class NMEMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(NMEMetric, self).__init__(
'NME', axis=self.axis,
output_names=None, label_names=None)
#self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
preds = [preds[-1]]
for label, pred_label in zip(labels, preds):
label = label.asnumpy()
pred_label = pred_label.asnumpy()
#print('acc',label.shape, pred_label.shape)
nme = []
for b in xrange(pred_label.shape[0]):
for p in xrange(pred_label.shape[1]):
heatmap_gt = label[b][p]
heatmap_pred = pred_label[b][p]
heatmap_pred = cv2.resize(heatmap_pred, (label.shape[2], label.shape[3]))
ind_gt = np.unravel_index(np.argmax(heatmap_gt, axis=None), heatmap_gt.shape)
ind_pred = np.unravel_index(np.argmax(heatmap_pred, axis=None), heatmap_pred.shape)
ind_gt = np.array(ind_gt)
ind_pred = np.array(ind_pred)
dist = np.sqrt(np.sum(np.square(ind_gt - ind_pred)))
nme.append(dist)
nme = np.mean(nme)
nme /= np.sqrt(float(label.shape[2]*label.shape[3]))
self.sum_metric += nme
self.num_inst += 1.0
class NMEMetric2(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(NMEMetric2, self).__init__(
'NME2', axis=self.axis,
output_names=None, label_names=None)
#self.losses = []
self.count = 0
def update(self, labels, preds):
self.count+=1
preds = [preds[-1]]
for label, pred_label in zip(labels, preds):
label = label.asnumpy()
pred_label = pred_label.asnumpy()
#print('label', np.count_nonzero(label[0][36]))
#print('acc',label.shape, pred_label.shape)
#print(label.ndim)
nme = []
for b in xrange(pred_label.shape[0]):
record = [None]*6
item = []
if label.ndim==4:
_heatmap = label[b][36]
if np.count_nonzero(_heatmap)==0:
continue
else:#ndim==3
#print(label[b])
if np.count_nonzero(label[b])==0:
continue
for p in xrange(pred_label.shape[1]):
if label.ndim==4:
heatmap_gt = label[b][p]
ind_gt = np.unravel_index(np.argmax(heatmap_gt, axis=None), heatmap_gt.shape)
ind_gt = np.array(ind_gt)
else:
ind_gt = label[b][p]
#ind_gt = ind_gt.astype(np.int)
#print(ind_gt)
heatmap_pred = pred_label[b][p]
heatmap_pred = cv2.resize(heatmap_pred, (args.input_img_size, args.input_img_size))
ind_pred = np.unravel_index(np.argmax(heatmap_pred, axis=None), heatmap_pred.shape)
ind_pred = np.array(ind_pred)
#print(ind_gt.shape)
#print(ind_pred)
if p==36:
#print('b', b, p, ind_gt, np.count_nonzero(heatmap_gt))
record[0] = ind_gt
elif p==39:
record[1] = ind_gt
elif p==42:
record[2] = ind_gt
elif p==45:
record[3] = ind_gt
if record[4] is None or record[5] is None:
record[4] = ind_gt
record[5] = ind_gt
else:
record[4] = np.minimum(record[4], ind_gt)
record[5] = np.maximum(record[5], ind_gt)
#print(ind_gt.shape, ind_pred.shape)
value = np.sqrt(np.sum(np.square(ind_gt - ind_pred)))
item.append(value)
_nme = np.mean(item)
if args.norm_type=='2d':
left_eye = (record[0]+record[1])/2
right_eye = (record[2]+record[3])/2
_dist = np.sqrt(np.sum(np.square(left_eye - right_eye)))
#print('eye dist', _dist, left_eye, right_eye)
_nme /= _dist
else:
#_dist = np.sqrt(float(label.shape[2]*label.shape[3]))
_dist = np.sqrt(np.sum(np.square(record[5] - record[4])))
#print(_dist)
_nme /= _dist
nme.append(_nme)
#print('nme', nme)
#nme = np.mean(nme)
if len(nme)>0:
self.sum_metric += np.mean(nme)
self.num_inst += 1.0
def main(args):
_seed = 727
random.seed(_seed)
np.random.seed(_seed)
mx.random.seed(_seed)
ctx = []
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
if len(cvd)>0:
for i in xrange(len(cvd.split(','))):
ctx.append(mx.gpu(i))
if len(ctx)==0:
ctx = [mx.cpu()]
print('use cpu')
else:
print('gpu num:', len(ctx))
#ctx = [mx.gpu(0)]
args.ctx_num = len(ctx)
args.batch_size = args.per_batch_size*args.ctx_num
print('Call with', args)
train_iter = FaceSegIter(path_imgrec = os.path.join(args.data_dir, 'train.rec'),
batch_size = args.batch_size,
per_batch_size = args.per_batch_size,
aug_level = 1,
use_coherent = args.use_coherent,
args = args,
)
targets = ['ibug', 'cofw_testset', '300W', 'AFLW2000-3D']
data_shape = train_iter.get_data_shape()
#label_shape = train_iter.get_label_shape()
sym = hg.get_symbol(num_classes=args.num_classes, binarize=args.binarize, label_size=args.output_label_size, input_size=args.input_img_size, use_coherent = args.use_coherent, use_dla = args.use_dla, use_N = args.use_N, use_DCN = args.use_DCN, per_batch_size = args.per_batch_size)
if len(args.pretrained)==0:
#data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
data_shape_dict = train_iter.get_shape_dict()
arg_params, aux_params = hg.init_weights(sym, data_shape_dict)
else:
vec = args.pretrained.split(',')
print('loading', vec)
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
#sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
model = mx.mod.Module(
context = ctx,
symbol = sym,
label_names = train_iter.get_label_names(),
)
#lr = 1.0e-3
#lr = 2.5e-4
lr = args.lr
#_rescale_grad = 1.0
_rescale_grad = 1.0/args.ctx_num
#lr = args.lr
#opt = optimizer.SGD(learning_rate=lr, momentum=0.9, wd=5.e-4, rescale_grad=_rescale_grad)
#opt = optimizer.Adam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
opt = optimizer.Nadam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
#opt = optimizer.RMSProp(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
_cb = mx.callback.Speedometer(args.batch_size, 10)
_metric = LossValueMetric()
#_metric2 = AccMetric()
#eval_metrics = [_metric, _metric2]
eval_metrics = [_metric]
#lr_steps = [40000,60000,80000]
#lr_steps = [12000,18000,22000]
if len(args.lr_steps)==0:
lr_steps = [16000,24000,30000]
#lr_steps = [14000,24000,30000]
#lr_steps = [5000,10000]
else:
lr_steps = [int(x) for x in args.lr_steps.split(',')]
_a = 40//args.batch_size
for i in xrange(len(lr_steps)):
lr_steps[i] *= _a
print('lr-steps', lr_steps)
global_step = [0]
def val_test():
all_layers = sym.get_internals()
vsym = all_layers['heatmap_output']
vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names = None)
#model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
vmodel.bind(data_shapes=[('data', (args.batch_size,)+data_shape)])
arg_params, aux_params = model.get_params()
vmodel.set_params(arg_params, aux_params)
for target in targets:
_file = os.path.join(args.data_dir, '%s.rec'%target)
if not os.path.exists(_file):
continue
val_iter = FaceSegIter(path_imgrec = _file,
batch_size = args.batch_size,
#batch_size = 4,
aug_level = 0,
args = args,
)
_metric = NMEMetric2()
val_metric = mx.metric.create(_metric)
val_metric.reset()
val_iter.reset()
for i, eval_batch in enumerate(val_iter):
#print(eval_batch.data[0].shape, eval_batch.label[0].shape)
batch_data = mx.io.DataBatch(eval_batch.data)
model.forward(batch_data, is_train=False)
model.update_metric(val_metric, eval_batch.label)
nme_value = val_metric.get_name_value()[0][1]
print('[%d][%s]NME: %f'%(global_step[0], target, nme_value))
def _batch_callback(param):
_cb(param)
global_step[0]+=1
mbatch = global_step[0]
for _lr in lr_steps:
if mbatch==_lr:
opt.lr *= 0.2
print('lr change to', opt.lr)
break
if mbatch%1000==0:
print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
if mbatch>0 and mbatch%args.verbose==0:
val_test()
if args.ckpt>0:
msave = mbatch//args.verbose
print('saving', msave)
arg, aux = model.get_params()
mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg, aux)
model.fit(train_iter,
begin_epoch = 0,
num_epoch = args.end_epoch,
#eval_data = val_iter,
eval_data = None,
eval_metric = eval_metrics,
kvstore = 'device',
optimizer = opt,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
allow_missing = True,
batch_end_callback = _batch_callback,
epoch_end_callback = None,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train face 3d')
# general
parser.add_argument('--data-dir', default='./data', help='')
parser.add_argument('--prefix', default='./models/test',
help='directory to save model.')
parser.add_argument('--pretrained', default='',
help='')
parser.add_argument('--lr-steps', default='', type=str, help='')
parser.add_argument('--verbose', type=int, default=2000, help='')
parser.add_argument('--retrain', action='store_true', default=False,
help='true means continue training.')
parser.add_argument('--binarize', action='store_true', default=False, help='')
parser.add_argument('--end-epoch', type=int, default=87,
help='training epoch size.')
parser.add_argument('--per-batch-size', type=int, default=20, help='')
parser.add_argument('--num-classes', type=int, default=68, help='')
parser.add_argument('--input-img-size', type=int, default=128, help='')
parser.add_argument('--output-label-size', type=int, default=64, help='')
parser.add_argument('--lr', type=float, default=2.5e-4, help='')
parser.add_argument('--wd', type=float, default=5e-4, help='')
parser.add_argument('--ckpt', type=int, default=1, help='')
parser.add_argument('--norm-type', type=str, default='2d', help='')
parser.add_argument('--use-coherent', type=int, default=1, help='')
parser.add_argument('--use-dla', type=int, default=1, help='')
parser.add_argument('--use-N', type=int, default=3, help='')
parser.add_argument('--use-DCN', type=int, default=2, help='')
args = parser.parse_args()
main(args)