mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-23 18:27:49 +00:00
68points face alignment
This commit is contained in:
65
alignment/benchmark.py
Normal file
65
alignment/benchmark.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import sys
|
||||
import mxnet as mx
|
||||
import datetime
|
||||
|
||||
parser = argparse.ArgumentParser(description='face model test')
|
||||
# general
|
||||
parser.add_argument('--image-size', default='128,128', help='')
|
||||
parser.add_argument('--model', default='./models/test,15', help='path to load model.')
|
||||
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
|
||||
parser.add_argument('--batch-size', default=10, type=int, help='batch size')
|
||||
parser.add_argument('--iterations', default=10, type=int, help='iterations')
|
||||
args = parser.parse_args()
|
||||
|
||||
_vec = args.image_size.split(',')
|
||||
assert len(_vec)==2
|
||||
image_size = (int(_vec[0]), int(_vec[1]))
|
||||
_vec = args.model.split(',')
|
||||
assert len(_vec)==2
|
||||
prefix = _vec[0]
|
||||
epoch = int(_vec[1])
|
||||
print('loading',prefix, epoch)
|
||||
if args.gpu>=0:
|
||||
ctx = mx.gpu(args.gpu)
|
||||
else:
|
||||
ctx = mx.cpu()
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
||||
all_layers = sym.get_internals()
|
||||
sym = all_layers['heatmap_output']
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
|
||||
#model = mx.mod.Module(symbol=sym, context=ctx)
|
||||
model.bind(for_training=False, data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))])
|
||||
#model.bind(for_training=False, data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,84,64,64))])
|
||||
model.set_params(arg_params, aux_params)
|
||||
img_path = './test.png'
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
rimg = cv2.resize(img, (image_size[1], image_size[0]))
|
||||
img = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
|
||||
img = np.transpose(img, (2,0,1)) #3*112*112, RGB
|
||||
input_blob = np.zeros( (args.batch_size, 3, image_size[1], image_size[0]),dtype=np.uint8 )
|
||||
for i in xrange(args.batch_size):
|
||||
input_blob[i] = img
|
||||
data = mx.nd.array(input_blob)
|
||||
print(data.shape)
|
||||
label = mx.nd.zeros( (args.batch_size, 84, 64, 64) )
|
||||
#db = mx.io.DataBatch(data=(data,))
|
||||
db = mx.io.DataBatch(data=(data,), label=(label,))
|
||||
stat = []
|
||||
warmup = 2
|
||||
for i in xrange(args.iterations+warmup):
|
||||
#print(i)
|
||||
time_now = datetime.datetime.now()
|
||||
model.forward(db, is_train=False)
|
||||
output = model.get_outputs()[-1].asnumpy()
|
||||
time_now2 = datetime.datetime.now()
|
||||
diff = time_now2 - time_now
|
||||
stat.append(diff.total_seconds())
|
||||
stat = stat[warmup:]
|
||||
print(np.mean(stat)/args.batch_size)
|
||||
|
||||
|
||||
545
alignment/data.py
Normal file
545
alignment/data.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# pylint: skip-file
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import sys, os
|
||||
import random
|
||||
import math
|
||||
import scipy.misc
|
||||
import cv2
|
||||
import logging
|
||||
import sklearn
|
||||
import datetime
|
||||
import img_helper
|
||||
from mxnet.io import DataIter
|
||||
from mxnet import ndarray as nd
|
||||
from mxnet import io
|
||||
from mxnet import recordio
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class FaceSegIter0(DataIter):
|
||||
def __init__(self, batch_size,
|
||||
path_imgrec = None,
|
||||
data_name = "data",
|
||||
label_name = "softmax_label"):
|
||||
self.batch_size = batch_size
|
||||
self.data_name = data_name
|
||||
self.label_name = label_name
|
||||
assert path_imgrec
|
||||
logging.info('loading recordio %s...',
|
||||
path_imgrec)
|
||||
path_imgidx = path_imgrec[0:-4]+".idx"
|
||||
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
|
||||
self.seq = list(self.imgrec.keys)
|
||||
self.cur = 0
|
||||
self.reset()
|
||||
self.num_classes = 68
|
||||
self.record_img_size = 384
|
||||
self.input_img_size = self.record_img_size
|
||||
self.data_shape = (3, self.input_img_size, self.input_img_size)
|
||||
self.label_shape = (self.num_classes, 2)
|
||||
self.provide_data = [(data_name, (batch_size,) + self.data_shape)]
|
||||
self.provide_label = [(label_name, (batch_size,) + self.label_shape)]
|
||||
|
||||
def reset(self):
|
||||
print('reset')
|
||||
"""Resets the iterator to the beginning of the data."""
|
||||
self.cur = 0
|
||||
|
||||
def next_sample(self):
|
||||
"""Helper function for reading in next sample."""
|
||||
if self.cur >= len(self.seq):
|
||||
raise StopIteration
|
||||
idx = self.seq[self.cur]
|
||||
self.cur += 1
|
||||
s = self.imgrec.read_idx(idx)
|
||||
header, img = recordio.unpack(s)
|
||||
img = mx.image.imdecode(img).asnumpy()
|
||||
#label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8)
|
||||
hlabel = np.array(header.label).reshape( (self.num_classes,2) )
|
||||
return img, hlabel
|
||||
|
||||
def next(self):
|
||||
"""Returns the next batch of data."""
|
||||
#print('next')
|
||||
batch_size = self.batch_size
|
||||
batch_data = nd.empty((batch_size,)+self.data_shape)
|
||||
batch_label = nd.empty((batch_size,)+self.label_shape)
|
||||
i = 0
|
||||
#self.cutoff = random.randint(800,1280)
|
||||
try:
|
||||
while i < batch_size:
|
||||
#print('N', i)
|
||||
data, label = self.next_sample()
|
||||
data = nd.array(data)
|
||||
data = nd.transpose(data, axes=(2, 0, 1))
|
||||
label = nd.array(label)
|
||||
#print(data.shape, label.shape)
|
||||
batch_data[i][:] = data
|
||||
batch_label[i][:] = label
|
||||
i += 1
|
||||
except StopIteration:
|
||||
if not i:
|
||||
raise StopIteration
|
||||
|
||||
return mx.io.DataBatch([batch_data], [batch_label], batch_size - i)
|
||||
|
||||
class FaceSegIter(DataIter):
|
||||
def __init__(self, batch_size,
|
||||
per_batch_size = 0,
|
||||
path_imgrec = None,
|
||||
aug_level = 0,
|
||||
force_mirror = False,
|
||||
use_coherent = 0,
|
||||
args = None,
|
||||
data_name = "data",
|
||||
label_name = "softmax_label"):
|
||||
self.aug_level = aug_level
|
||||
self.force_mirror = force_mirror
|
||||
self.use_coherent = use_coherent
|
||||
self.batch_size = batch_size
|
||||
self.per_batch_size = per_batch_size
|
||||
self.data_name = data_name
|
||||
self.label_name = label_name
|
||||
assert path_imgrec
|
||||
logging.info('loading recordio %s...',
|
||||
path_imgrec)
|
||||
path_imgidx = path_imgrec[0:-4]+".idx"
|
||||
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
|
||||
self.seq = list(self.imgrec.keys)
|
||||
print('train size', len(self.seq))
|
||||
self.cur = 0
|
||||
self.reset()
|
||||
self.num_classes = args.num_classes
|
||||
self.record_img_size = 384
|
||||
self.input_img_size = args.input_img_size
|
||||
self.data_shape = (3, self.input_img_size, self.input_img_size)
|
||||
self.label_classes = self.num_classes
|
||||
if aug_level>0:
|
||||
self.output_label_size = args.output_label_size
|
||||
self.label_shape = (self.label_classes, self.output_label_size, self.output_label_size)
|
||||
else:
|
||||
self.output_label_size = args.input_img_size
|
||||
#self.label_shape = (self.num_classes, 2)
|
||||
self.label_shape = (self.num_classes, self.output_label_size, self.output_label_size)
|
||||
self.provide_data = [(data_name, (batch_size,) + self.data_shape)]
|
||||
self.provide_label = [(label_name, (batch_size,) + self.label_shape)]
|
||||
if self.use_coherent>0:
|
||||
#self.provide_label += [("softmax_label2", (batch_size,)+self.label_shape)]
|
||||
self.provide_label += [("coherent_label", (batch_size,6))]
|
||||
self.img_num = 0
|
||||
self.invalid_num = 0
|
||||
self.mode = 1
|
||||
self.vis = 0
|
||||
self.stats = [0,0]
|
||||
self.mirror_set = [
|
||||
(22,23),
|
||||
(21,24),
|
||||
(20,25),
|
||||
(19,26),
|
||||
(18,27),
|
||||
(40,43),
|
||||
(39,44),
|
||||
(38,45),
|
||||
(37,46),
|
||||
(42,47),
|
||||
(41,48),
|
||||
(33,35),
|
||||
(32,36),
|
||||
(51,53),
|
||||
(50,54),
|
||||
(62,64),
|
||||
(61,65),
|
||||
(49,55),
|
||||
(49,55),
|
||||
(68,66),
|
||||
(60,56),
|
||||
(59,57),
|
||||
(1,17),
|
||||
(2,16),
|
||||
(3,15),
|
||||
(4,14),
|
||||
(5,13),
|
||||
(6,12),
|
||||
(7,11),
|
||||
(8,10),
|
||||
]
|
||||
|
||||
def get_data_shape(self):
|
||||
return self.data_shape
|
||||
|
||||
#def get_label_shape(self):
|
||||
# return self.label_shape
|
||||
|
||||
def get_shape_dict(self):
|
||||
D = {}
|
||||
for (k,v) in self.provide_data:
|
||||
D[k] = v
|
||||
for (k,v) in self.provide_label:
|
||||
D[k] = v
|
||||
return D
|
||||
|
||||
def get_label_names(self):
|
||||
D = []
|
||||
for (k,v) in self.provide_label:
|
||||
D.append(k)
|
||||
return D
|
||||
|
||||
def reset(self):
|
||||
print('reset')
|
||||
"""Resets the iterator to the beginning of the data."""
|
||||
if self.aug_level>0:
|
||||
random.shuffle(self.seq)
|
||||
self.cur = 0
|
||||
|
||||
def next_sample(self):
|
||||
"""Helper function for reading in next sample."""
|
||||
if self.cur >= len(self.seq):
|
||||
raise StopIteration
|
||||
idx = self.seq[self.cur]
|
||||
self.cur += 1
|
||||
s = self.imgrec.read_idx(idx)
|
||||
header, img = recordio.unpack(s)
|
||||
img = mx.image.imdecode(img).asnumpy()
|
||||
#label = np.zeros( (self.num_classes, self.record_img_size, self.record_img_size), dtype=np.uint8)
|
||||
hlabel = np.array(header.label).reshape( (self.num_classes,2) )
|
||||
annot = {}
|
||||
ul = np.array( (50000,50000), dtype=np.int32)
|
||||
br = np.array( (0,0), dtype=np.int32)
|
||||
for i in xrange(hlabel.shape[0]):
|
||||
#hlabel[i] = hlabel[i][::-1]
|
||||
h = int(hlabel[i][0])
|
||||
w = int(hlabel[i][1])
|
||||
key = np.array((h,w))
|
||||
#print(key.shape, ul.shape, br.shape)
|
||||
ul = np.minimum(key, ul)
|
||||
br = np.maximum(key, br)
|
||||
#label[h][w] = i+1
|
||||
#label[i][h][w] = 1.0
|
||||
#print(h,w,i+1)
|
||||
|
||||
if self.mode==0:
|
||||
ul_margin = np.array( (60,30) )
|
||||
br_margin = np.array( (30,30) )
|
||||
crop_ul = ul
|
||||
crop_ul = ul-ul_margin
|
||||
crop_ul = np.maximum(crop_ul, np.array( (0,0), dtype=np.int32) )
|
||||
crop_br = br
|
||||
crop_br = br+br_margin
|
||||
crop_br = np.minimum(crop_br, np.array( (img.shape[0],img.shape[1]), dtype=np.int32 ) )
|
||||
elif self.mode==1:
|
||||
#mm = (self.record_img_size - 256)//2
|
||||
#crop_ul = (mm, mm)
|
||||
#crop_br = (self.record_img_size-mm, self.record_img_size-mm)
|
||||
crop_ul = (0,0)
|
||||
crop_br = (self.record_img_size, self.record_img_size)
|
||||
annot['scale'] = 256
|
||||
#annot['center'] = np.array( (self.record_img_size/2+10, self.record_img_size/2) )
|
||||
else:
|
||||
mm = (48, 64)
|
||||
#mm = (64, 80)
|
||||
crop_ul = mm
|
||||
crop_br = (self.record_img_size-mm[0], self.record_img_size-mm[1])
|
||||
crop_ul = np.array(crop_ul)
|
||||
crop_br = np.array(crop_br)
|
||||
img = img[crop_ul[0]:crop_br[0],crop_ul[1]:crop_br[1],:]
|
||||
#print(img.shape, crop_ul, crop_br)
|
||||
invalid = False
|
||||
for i in xrange(hlabel.shape[0]):
|
||||
if (hlabel[i]<crop_ul).any() or (hlabel[i] >= crop_br).any():
|
||||
invalid = True
|
||||
hlabel[i] -= crop_ul
|
||||
|
||||
#mm = np.amin(ul)
|
||||
#mm2 = self.record_img_size - br
|
||||
#mm2 = np.amin(mm2)
|
||||
#mm = min(mm, mm2)
|
||||
#print('mm',mm, ul, br)
|
||||
#print('invalid', invalid)
|
||||
if invalid:
|
||||
self.invalid_num+=1
|
||||
|
||||
annot['invalid'] = invalid
|
||||
#annot['scale'] = (self.record_img_size - mm*2)*1.1
|
||||
#if self.mode==1:
|
||||
# annot['scale'] = 256
|
||||
#print(annot)
|
||||
|
||||
#center = ul+br
|
||||
#center /= 2.0
|
||||
#annot['center'] = center
|
||||
|
||||
#img = img[ul[0]:br[0],ul[1]:br[1],:]
|
||||
#scale = br-ul
|
||||
#scale = max(scale[0], scale[1])
|
||||
#print(img.shape)
|
||||
return img, hlabel, annot
|
||||
|
||||
def do_aug(self, data, label, annot):
|
||||
if self.vis:
|
||||
self.img_num+=1
|
||||
#if self.img_num<=self.vis:
|
||||
# filename = './vis/raw_%d.jpg' % (self.img_num)
|
||||
# print('save', filename)
|
||||
# draw = data.copy()
|
||||
# for i in xrange(label.shape[0]):
|
||||
# cv2.circle(draw, (label[i][1], label[i][0]), 1, (0, 0, 255), 2)
|
||||
# scipy.misc.imsave(filename, draw)
|
||||
|
||||
rotate = 0
|
||||
#scale = 1.0
|
||||
if 'scale' in annot:
|
||||
scale = annot['scale']
|
||||
else:
|
||||
scale = max(data.shape[0], data.shape[1])
|
||||
if 'center' in annot:
|
||||
center = annot['center']
|
||||
else:
|
||||
center = np.array( (data.shape[0]/2, data.shape[1]/2) )
|
||||
max_retry = 3
|
||||
if self.aug_level==0:
|
||||
max_retry = 6
|
||||
retry = 0
|
||||
found = False
|
||||
_scale = scale
|
||||
while retry<max_retry:
|
||||
retry+=1
|
||||
succ = True
|
||||
if self.aug_level>0:
|
||||
rotate = np.random.randint(-40, 40)
|
||||
#rotate2 = np.random.randint(-40, 40)
|
||||
rotate2 = 0
|
||||
scale_config = 0.2
|
||||
#rotate = 0
|
||||
#scale_config = 0.0
|
||||
_scale = min(1+scale_config, max(1-scale_config, (np.random.randn() * scale_config) + 1))
|
||||
_scale *= scale
|
||||
_scale = int(_scale)
|
||||
#translate = np.random.randint(-5, 5, size=(2,))
|
||||
#center += translate
|
||||
if self.mode==1:
|
||||
cropped = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate)
|
||||
if self.use_coherent==2:
|
||||
cropped2 = img_helper.crop2(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate2)
|
||||
else:
|
||||
cropped = img_helper.crop(data, center, _scale, (self.input_img_size, self.input_img_size), rot=rotate)
|
||||
#print('cropped', cropped.shape)
|
||||
label_out = np.zeros(self.label_shape, dtype=np.float32)
|
||||
label2_out = np.zeros(self.label_shape, dtype=np.float32)
|
||||
G = 0
|
||||
#if self.use_coherent:
|
||||
# G = 1
|
||||
_g = G
|
||||
if G==0:
|
||||
_g = 1
|
||||
#print('shape', label.shape, label_out.shape)
|
||||
for i in xrange(label.shape[0]):
|
||||
pt = label[i].copy()
|
||||
pt = pt[::-1]
|
||||
#print('before gaussian', label_out[i].shape, pt.shape)
|
||||
_pt = pt.copy()
|
||||
trans = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate)
|
||||
#print(trans.shape)
|
||||
if not img_helper.gaussian(label_out[i], trans, _g):
|
||||
succ = False
|
||||
break
|
||||
if self.use_coherent==2:
|
||||
_pt = pt.copy()
|
||||
trans2 = img_helper.transform(_pt, center, _scale, (self.output_label_size, self.output_label_size), rot=rotate2)
|
||||
if not img_helper.gaussian(label2_out[i], trans2, _g):
|
||||
succ = False
|
||||
break
|
||||
if not succ:
|
||||
if self.aug_level==0:
|
||||
_scale+=20
|
||||
continue
|
||||
|
||||
|
||||
if self.use_coherent==1:
|
||||
cropped2 = np.copy(cropped)
|
||||
for k in xrange(cropped2.shape[2]):
|
||||
cropped2[:,:,k] = np.fliplr(cropped2[:,:,k])
|
||||
label2_out = np.copy(label_out)
|
||||
for k in xrange(label2_out.shape[0]):
|
||||
label2_out[k,:,:] = np.fliplr(label2_out[k,:,:])
|
||||
new_label2_out = np.copy(label2_out)
|
||||
for item in self.mirror_set:
|
||||
mir = (item[0]-1, item[1]-1)
|
||||
new_label2_out[mir[1]] = label2_out[mir[0]]
|
||||
new_label2_out[mir[0]] = label2_out[mir[1]]
|
||||
label2_out = new_label2_out
|
||||
elif self.use_coherent==2:
|
||||
pass
|
||||
elif ((self.aug_level>0 and np.random.rand() < 0.5) or self.force_mirror): #flip aug
|
||||
for k in xrange(cropped.shape[2]):
|
||||
cropped[:,:,k] = np.fliplr(cropped[:,:,k])
|
||||
for k in xrange(label_out.shape[0]):
|
||||
label_out[k,:,:] = np.fliplr(label_out[k,:,:])
|
||||
new_label_out = np.copy(label_out)
|
||||
for item in self.mirror_set:
|
||||
mir = (item[0]-1, item[1]-1)
|
||||
new_label_out[mir[1]] = label_out[mir[0]]
|
||||
new_label_out[mir[0]] = label_out[mir[1]]
|
||||
label_out = new_label_out
|
||||
|
||||
if G==0:
|
||||
for k in xrange(label.shape[0]):
|
||||
ind = np.unravel_index(np.argmax(label_out[k], axis=None), label_out[k].shape)
|
||||
label_out[k,:,:] = 0.0
|
||||
label_out[k,ind[0],ind[1]] = 1.0
|
||||
if self.use_coherent:
|
||||
ind = np.unravel_index(np.argmax(label2_out[k], axis=None), label2_out[k].shape)
|
||||
label2_out[k,:,:] = 0.0
|
||||
label2_out[k,ind[0],ind[1]] = 1.0
|
||||
found = True
|
||||
break
|
||||
|
||||
|
||||
#self.stats[0]+=1
|
||||
if not found:
|
||||
#self.stats[1]+=1
|
||||
#print('find aug error', retry)
|
||||
#print(self.stats)
|
||||
return None
|
||||
|
||||
if self.vis>0 and self.img_num<=self.vis:
|
||||
print('crop', data.shape, center, _scale, rotate, cropped.shape)
|
||||
filename = './vis/cropped_%d.jpg' % (self.img_num)
|
||||
print('save', filename)
|
||||
draw = cropped.copy()
|
||||
alabel = label_out.copy()
|
||||
for i in xrange(label.shape[0]):
|
||||
a = cv2.resize(alabel[i], (self.input_img_size, self.input_img_size))
|
||||
ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
|
||||
cv2.circle(draw, (ind[1], ind[0]), 1, (0, 0, 255), 2)
|
||||
scipy.misc.imsave(filename, draw)
|
||||
if not self.use_coherent:
|
||||
return cropped, label_out
|
||||
else:
|
||||
rotate2 = 0
|
||||
r = rotate - rotate2
|
||||
#r = rotate2 - rotate
|
||||
r = math.pi*r/180
|
||||
cos_r = math.cos(r)
|
||||
sin_r = math.sin(r)
|
||||
#c = cropped2.shape[0]//2
|
||||
#M = cv2.getRotationMatrix2D( (c, c), rotate2-rotate, 1)
|
||||
M = np.array( [
|
||||
[cos_r, -1*sin_r, 0.0],
|
||||
[sin_r, cos_r, 0.0]
|
||||
] )
|
||||
#print(M)
|
||||
#M=np.array([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
||||
# 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33,
|
||||
# 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52,
|
||||
# 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 62, 61, 60, 67, 66, 65])
|
||||
return cropped, label_out, cropped2, label2_out, M.flatten()
|
||||
|
||||
def next(self):
|
||||
"""Returns the next batch of data."""
|
||||
#print('next')
|
||||
batch_size = self.batch_size
|
||||
batch_data = nd.empty((batch_size,)+self.data_shape)
|
||||
batch_label = nd.empty((batch_size,)+self.label_shape)
|
||||
if self.use_coherent:
|
||||
batch_label2 = nd.empty((batch_size,)+self.label_shape)
|
||||
batch_coherent_label = nd.empty((batch_size,6))
|
||||
i = 0
|
||||
#self.cutoff = random.randint(800,1280)
|
||||
try:
|
||||
while i < batch_size:
|
||||
#print('N', i)
|
||||
data, label, annot = self.next_sample()
|
||||
if not self.use_coherent:
|
||||
R = self.do_aug(data, label, annot)
|
||||
if R is None:
|
||||
continue
|
||||
data, label = R
|
||||
#data, label, data2, label2, M = R
|
||||
#ind = np.unravel_index(np.argmax(label[0], axis=None), label[0].shape)
|
||||
#print(label.shape, np.count_nonzero(label[0]), ind)
|
||||
#print(label[0,25:35,0:10])
|
||||
data = nd.array(data)
|
||||
data = nd.transpose(data, axes=(2, 0, 1))
|
||||
label = nd.array(label)
|
||||
#print(data.shape, label.shape)
|
||||
try:
|
||||
self.check_valid_image(data)
|
||||
except RuntimeError as e:
|
||||
logging.debug('Invalid image, skipping: %s', str(e))
|
||||
continue
|
||||
batch_data[i][:] = data
|
||||
batch_label[i][:] = label
|
||||
i += 1
|
||||
else:
|
||||
R = self.do_aug(data, label, annot)
|
||||
if R is None:
|
||||
continue
|
||||
data, label, data2, label2, M = R
|
||||
data = nd.array(data)
|
||||
data = nd.transpose(data, axes=(2, 0, 1))
|
||||
label = nd.array(label)
|
||||
data2 = nd.array(data2)
|
||||
data2 = nd.transpose(data2, axes=(2, 0, 1))
|
||||
label2 = nd.array(label2)
|
||||
M = nd.array(M)
|
||||
#print(data.shape, label.shape)
|
||||
try:
|
||||
self.check_valid_image(data)
|
||||
except RuntimeError as e:
|
||||
logging.debug('Invalid image, skipping: %s', str(e))
|
||||
continue
|
||||
batch_data[i][:] = data
|
||||
batch_label[i][:] = label
|
||||
#batch_label2[i][:] = label2
|
||||
batch_coherent_label[i][:] = M
|
||||
#i+=1
|
||||
j = i+self.per_batch_size//2
|
||||
batch_data[j][:] = data2
|
||||
batch_label[j][:] = label2
|
||||
batch_coherent_label[j][:] = M
|
||||
i += 1
|
||||
if j%self.per_batch_size==self.per_batch_size-1:
|
||||
i = j+1
|
||||
except StopIteration:
|
||||
if not i:
|
||||
raise StopIteration
|
||||
|
||||
#return {self.data_name : batch_data,
|
||||
# self.label_name : batch_label}
|
||||
#print(batch_data.shape, batch_label.shape)
|
||||
if not self.use_coherent:
|
||||
return mx.io.DataBatch([batch_data], [batch_label], batch_size - i)
|
||||
else:
|
||||
#return mx.io.DataBatch([batch_data], [batch_label, batch_label2, batch_coherent_label], batch_size - i)
|
||||
return mx.io.DataBatch([batch_data], [batch_label, batch_coherent_label], batch_size - i)
|
||||
|
||||
def check_data_shape(self, data_shape):
|
||||
"""Checks if the input data shape is valid"""
|
||||
if not len(data_shape) == 3:
|
||||
raise ValueError('data_shape should have length 3, with dimensions CxHxW')
|
||||
if not data_shape[0] == 3:
|
||||
raise ValueError('This iterator expects inputs to have 3 channels.')
|
||||
|
||||
def check_valid_image(self, data):
|
||||
"""Checks if the input data is valid"""
|
||||
if len(data[0].shape) == 0:
|
||||
raise RuntimeError('Data shape is wrong')
|
||||
|
||||
def imdecode(self, s):
|
||||
"""Decodes a string or byte string to an NDArray.
|
||||
See mx.img.imdecode for more details."""
|
||||
return imdecode(s)
|
||||
|
||||
def read_image(self, fname):
|
||||
"""Reads an input image `fname` and returns the decoded raw bytes.
|
||||
|
||||
Example usage:
|
||||
----------
|
||||
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
|
||||
"""
|
||||
with open(os.path.join(self.path_root, fname), 'rb') as fin:
|
||||
img = fin.read()
|
||||
return img
|
||||
|
||||
|
||||
70
alignment/draw.py
Normal file
70
alignment/draw.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
import skimage.draw
|
||||
|
||||
def line(img, pt1, pt2, color, width):
|
||||
# Draw a line on an image
|
||||
# Make sure dimension of color matches number of channels in img
|
||||
|
||||
# First get coordinates for corners of the line
|
||||
diff = np.array([pt1[1] - pt2[1], pt1[0] - pt2[0]], np.float)
|
||||
mag = np.linalg.norm(diff)
|
||||
if mag >= 1:
|
||||
diff *= width / (2 * mag)
|
||||
x = np.array([pt1[0] - diff[0], pt2[0] - diff[0], pt2[0] + diff[0], pt1[0] + diff[0]], int)
|
||||
y = np.array([pt1[1] + diff[1], pt2[1] + diff[1], pt2[1] - diff[1], pt1[1] - diff[1]], int)
|
||||
else:
|
||||
d = float(width) / 2
|
||||
x = np.array([pt1[0] - d, pt1[0] + d, pt1[0] + d, pt1[0] - d], int)
|
||||
y = np.array([pt1[1] - d, pt1[1] - d, pt1[1] + d, pt1[1] + d], int)
|
||||
|
||||
# noinspection PyArgumentList
|
||||
rr, cc = skimage.draw.polygon(y, x, img.shape)
|
||||
img[rr, cc] = color
|
||||
|
||||
return img
|
||||
|
||||
def limb(img, pt1, pt2, color, width):
|
||||
# Specific handling of a limb, in case the annotation isn't there for one of the joints
|
||||
if pt1[0] > 0 and pt2[0] > 0:
|
||||
line(img, pt1, pt2, color, width)
|
||||
elif pt1[0] > 0:
|
||||
circle(img, pt1, color, width)
|
||||
elif pt2[0] > 0:
|
||||
circle(img, pt2, color, width)
|
||||
|
||||
def gaussian(img, pt, sigma):
|
||||
# Draw a 2D gaussian
|
||||
|
||||
# Check that any part of the gaussian is in-bounds
|
||||
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
|
||||
br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
|
||||
if (ul[0] > img.shape[1] or ul[1] >= img.shape[0] or
|
||||
br[0] < 0 or br[1] < 0):
|
||||
# If not, just return the image as is
|
||||
return img
|
||||
|
||||
# Generate gaussian
|
||||
size = 6 * sigma + 1
|
||||
x = np.arange(0, size, 1, float)
|
||||
y = x[:, np.newaxis]
|
||||
x0 = y0 = size // 2
|
||||
# The gaussian is not normalized, we want the center value to equal 1
|
||||
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
|
||||
# Usable gaussian range
|
||||
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
|
||||
# Image range
|
||||
img_x = max(0, ul[0]), min(br[0], img.shape[1])
|
||||
img_y = max(0, ul[1]), min(br[1], img.shape[0])
|
||||
|
||||
img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
||||
return img
|
||||
|
||||
def circle(img, pt, color, radius):
|
||||
# Draw a circle
|
||||
# Mostly a convenient wrapper for skimage.draw.circle
|
||||
|
||||
rr, cc = skimage.draw.circle(pt[1], pt[0], radius, img.shape)
|
||||
img[rr, cc] = color
|
||||
return img
|
||||
853
alignment/hg2.py
Normal file
853
alignment/hg2.py
Normal file
@@ -0,0 +1,853 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
ACT_BIT = 1
|
||||
N = 4
|
||||
use_STN = False
|
||||
use_DLA = 0
|
||||
DCN = 0
|
||||
|
||||
|
||||
|
||||
def Conv(**kwargs):
|
||||
#name = kwargs.get('name')
|
||||
#_weight = mx.symbol.Variable(name+'_weight')
|
||||
#_bias = mx.symbol.Variable(name+'_bias', lr_mult=2.0, wd_mult=0.0)
|
||||
#body = mx.sym.Convolution(weight = _weight, bias = _bias, **kwargs)
|
||||
body = mx.sym.Convolution(**kwargs)
|
||||
return body
|
||||
|
||||
|
||||
def Act(data, act_type, name):
|
||||
if act_type=='prelu':
|
||||
body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name)
|
||||
else:
|
||||
body = mx.symbol.Activation(data=data, act_type=act_type, name=name)
|
||||
return body
|
||||
|
||||
def lin(data, num_filter, workspace, name, binarize, dcn):
|
||||
bn_mom = 0.9
|
||||
bit = 1
|
||||
if not binarize:
|
||||
if not dcn:
|
||||
conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv')
|
||||
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
||||
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
||||
return act1
|
||||
else:
|
||||
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
||||
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
||||
conv1_offset = mx.symbol.Convolution(name=name+'_conv_offset', data = act1,
|
||||
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
||||
conv1 = mx.contrib.symbol.DeformableConvolution(name=name+"_conv", data=act1, offset=conv1_offset,
|
||||
num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
|
||||
#conv1 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
# no_bias=False, workspace=workspace, name=name + '_conv')
|
||||
return conv1
|
||||
else:
|
||||
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
||||
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
||||
conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv', act_bit=ACT_BIT, weight_bit=bit)
|
||||
conv1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
|
||||
return conv1
|
||||
|
||||
def lin2(data, num_filter, workspace, name):
|
||||
bn_mom = 0.9
|
||||
conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv')
|
||||
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
||||
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
||||
return act1
|
||||
|
||||
def lin3(data, num_filter, workspace, name, k, g=1, d=1):
|
||||
bn_mom = 0.9
|
||||
#bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
||||
#act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
||||
#conv1 = Conv(data=act1, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=((k-1)//2,(k-1)//2), num_group=g,
|
||||
# no_bias=True, workspace=workspace, name=name + '_conv')
|
||||
#return conv1
|
||||
if k!=3:
|
||||
conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=((k-1)//2,(k-1)//2), num_group=g,
|
||||
no_bias=True, workspace=workspace, name=name + '_conv')
|
||||
else:
|
||||
conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=(d,d), num_group=g, dilate=(d, d),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv')
|
||||
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
||||
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
||||
ret = act1
|
||||
#if g>1 and k==3 and d==1:
|
||||
# body = Conv(data=ret, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), num_group=1,
|
||||
# no_bias=True, workspace=workspace, name=name + '_conv2')
|
||||
# body = mx.sym.BatchNorm(data=body, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
|
||||
# body = Act(data=body, act_type='relu', name=name + '_relu2')
|
||||
# ret = body
|
||||
return ret
|
||||
|
||||
def lin3_red(data, num_filter, workspace, name, k, g=1):
|
||||
bn_mom = 0.9
|
||||
conv1 = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(k,k), pad=(1,1), num_group=g,
|
||||
no_bias=True, workspace=workspace, name=name + '_conv', attr={'lr_mult':'1'})
|
||||
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn', attr={'lr_mult':'1'})
|
||||
act1 = Act(data=bn1, act_type='sigmoid', name=name + '_relu')
|
||||
ret = act1
|
||||
#if g>1 and k==3 and d==1:
|
||||
# body = Conv(data=ret, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), num_group=1,
|
||||
# no_bias=True, workspace=workspace, name=name + '_conv2')
|
||||
# body = mx.sym.BatchNorm(data=body, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
|
||||
# body = Act(data=body, act_type='relu', name=name + '_relu2')
|
||||
# ret = body
|
||||
return ret
|
||||
|
||||
class RES:
|
||||
def __init__(self, data, nFilters, nModules, n, workspace, name, dilate, group):
|
||||
self.data = data
|
||||
self.nFilters = nFilters
|
||||
self.nModules = nModules
|
||||
self.n = n
|
||||
self.workspace = workspace
|
||||
self.name = name
|
||||
self.dilate = dilate
|
||||
self.group = group
|
||||
self.sym_map = {}
|
||||
|
||||
def get_output(self, w, h):
|
||||
key = (w, h)
|
||||
if key in self.sym_map:
|
||||
return self.sym_map[key]
|
||||
ret = None
|
||||
if h==self.n:
|
||||
if w==self.n:
|
||||
ret = (self.data, self.nFilters)
|
||||
else:
|
||||
x = self.get_output(w+1, h)
|
||||
f = int(x[1]*0.5)
|
||||
if w!=self.n-1:
|
||||
body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, 1)
|
||||
else:
|
||||
body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, self.dilate)
|
||||
ret = (body,f)
|
||||
else:
|
||||
x = self.get_output(w+1, h+1)
|
||||
y = self.get_output(w, h+1)
|
||||
if h%2==1 and h!=w:
|
||||
xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
|
||||
#xbody = xbody+x[0]
|
||||
else:
|
||||
xbody = x[0]
|
||||
#xbody = x[0]
|
||||
#xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
|
||||
if w==0:
|
||||
ybody = lin3(y[0], y[1], self.workspace, "%s_w%d_h%d_3"%(self.name, w, h), 3, self.group)
|
||||
else:
|
||||
ybody = y[0]
|
||||
ybody = mx.sym.concat(y[0], ybody, dim=1)
|
||||
body = mx.sym.add_n(xbody,ybody, name="%s_w%d_h%d_add"%(self.name, w, h))
|
||||
body = body/2
|
||||
ret = (body, x[1])
|
||||
self.sym_map[key] = ret
|
||||
return ret
|
||||
|
||||
def get(self):
|
||||
return self.get_output(1, 1)[0]
|
||||
|
||||
def residual_unit_a(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
|
||||
bn_mom = kwargs.get('bn_mom', 0.9)
|
||||
workspace = kwargs.get('workspace', 256)
|
||||
memonger = kwargs.get('memonger', False)
|
||||
bit = 1
|
||||
#print('in unit2')
|
||||
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
|
||||
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
|
||||
if not binarize:
|
||||
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
|
||||
conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv1')
|
||||
else:
|
||||
act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True)
|
||||
conv1 = mx.sym.QConvolution(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit)
|
||||
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
|
||||
if not binarize:
|
||||
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
|
||||
conv2 = Conv(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv2')
|
||||
else:
|
||||
act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True)
|
||||
conv2 = mx.sym.QConvolution(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit)
|
||||
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
|
||||
if not binarize:
|
||||
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
|
||||
conv3 = Conv(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
|
||||
workspace=workspace, name=name + '_conv3')
|
||||
else:
|
||||
act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True)
|
||||
conv3 = mx.sym.QConvolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit)
|
||||
#if binarize:
|
||||
# conv3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
|
||||
if dim_match:
|
||||
shortcut = data
|
||||
else:
|
||||
if not binarize:
|
||||
shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
|
||||
workspace=workspace, name=name+'_sc')
|
||||
else:
|
||||
shortcut = mx.sym.QConvolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit)
|
||||
if memonger:
|
||||
shortcut._set_attr(mirror_stage='True')
|
||||
return conv3 + shortcut
|
||||
|
||||
|
||||
def residual_unit_g(data, num_filter, stride, dim_match, name, binarize, dcn, dilation, **kwargs):
|
||||
bn_mom = kwargs.get('bn_mom', 0.9)
|
||||
workspace = kwargs.get('workspace', 256)
|
||||
memonger = kwargs.get('memonger', False)
|
||||
bit = 1
|
||||
#print('in unit2')
|
||||
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
|
||||
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
|
||||
if not binarize:
|
||||
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
|
||||
if not dcn:
|
||||
conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv1')
|
||||
else:
|
||||
conv1_offset = mx.symbol.Convolution(name=name+'_conv1_offset', data = act1,
|
||||
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
||||
conv1 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv1', data=act1, offset=conv1_offset,
|
||||
num_filter=int(num_filter*0.5), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
|
||||
else:
|
||||
act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True)
|
||||
conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit)
|
||||
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
|
||||
if not binarize:
|
||||
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
|
||||
if not dcn:
|
||||
conv2 = Conv(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv2')
|
||||
else:
|
||||
conv2_offset = mx.symbol.Convolution(name=name+'_conv2_offset', data = act2,
|
||||
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
||||
conv2 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv2', data=act2, offset=conv2_offset,
|
||||
num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
|
||||
else:
|
||||
act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True)
|
||||
conv2 = mx.sym.QConvolution_v1(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit)
|
||||
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
|
||||
if not binarize:
|
||||
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
|
||||
if not dcn:
|
||||
conv3 = Conv(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv3')
|
||||
else:
|
||||
conv3_offset = mx.symbol.Convolution(name=name+'_conv3_offset', data = act3,
|
||||
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
||||
conv3 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv3', data=act3, offset=conv3_offset,
|
||||
num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
|
||||
else:
|
||||
act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True)
|
||||
conv3 = mx.sym.QConvolution_v1(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit)
|
||||
conv4 = mx.symbol.Concat(*[conv1, conv2, conv3])
|
||||
if binarize:
|
||||
conv4 = mx.sym.BatchNorm(data=conv4, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
|
||||
if dim_match:
|
||||
shortcut = data
|
||||
else:
|
||||
if not binarize:
|
||||
shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
|
||||
workspace=workspace, name=name+'_sc')
|
||||
else:
|
||||
#assert(False)
|
||||
shortcut = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0),
|
||||
no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit)
|
||||
shortcut = mx.sym.BatchNorm(data=shortcut, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_sc_bn')
|
||||
if memonger:
|
||||
shortcut._set_attr(mirror_stage='True')
|
||||
return conv4 + shortcut
|
||||
#return bn4 + shortcut
|
||||
#return act4 + shortcut
|
||||
|
||||
def ConvFactory(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), act_type="relu", mirror_attr={}, with_act=True):
|
||||
conv = mx.symbol.Convolution(
|
||||
data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
|
||||
bn = mx.symbol.BatchNorm(data=conv)
|
||||
if with_act:
|
||||
act = mx.symbol.Activation(
|
||||
data=bn, act_type=act_type, attr=mirror_attr)
|
||||
return act
|
||||
else:
|
||||
return bn
|
||||
|
||||
def block17(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
|
||||
tower_conv = ConvFactory(net, 192, (1, 1))
|
||||
tower_conv1_0 = ConvFactory(net, 129, (1, 1))
|
||||
tower_conv1_1 = ConvFactory(tower_conv1_0, 160, (1, 7), pad=(1, 2))
|
||||
tower_conv1_2 = ConvFactory(tower_conv1_1, 192, (7, 1), pad=(2, 1))
|
||||
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
|
||||
tower_out = ConvFactory(
|
||||
tower_mixed, input_num_channels, (1, 1), with_act=False)
|
||||
net = net+scale * tower_out
|
||||
if with_act:
|
||||
act = mx.symbol.Activation(
|
||||
data=net, act_type=act_type, attr=mirror_attr)
|
||||
return act
|
||||
else:
|
||||
return net
|
||||
|
||||
def block35(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}):
|
||||
M = 1.0
|
||||
tower_conv = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1))
|
||||
tower_conv1_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1))
|
||||
tower_conv1_1 = ConvFactory(tower_conv1_0, int(input_num_channels*0.25*M), (3, 3), pad=(1, 1))
|
||||
tower_conv2_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1))
|
||||
tower_conv2_1 = ConvFactory(tower_conv2_0, int(input_num_channels*0.375*M), (3, 3), pad=(1, 1))
|
||||
tower_conv2_2 = ConvFactory(tower_conv2_1, int(input_num_channels*0.5*M), (3, 3), pad=(1, 1))
|
||||
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2])
|
||||
tower_out = ConvFactory(
|
||||
tower_mixed, input_num_channels, (1, 1), with_act=False)
|
||||
|
||||
net = net+scale * tower_out
|
||||
if with_act:
|
||||
act = mx.symbol.Activation(
|
||||
data=net, act_type=act_type, attr=mirror_attr)
|
||||
return act
|
||||
else:
|
||||
return net
|
||||
|
||||
def residual_unit_i(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
|
||||
bn_mom = kwargs.get('bn_mom', 0.9)
|
||||
workspace = kwargs.get('workspace', 256)
|
||||
memonger = kwargs.get('memonger', False)
|
||||
assert not binarize
|
||||
if stride[0]>1 or not dim_match:
|
||||
return residual_unit_a(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
|
||||
conv4 = block35(data, num_filter)
|
||||
return conv4
|
||||
|
||||
def residual_unit_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
|
||||
bn_mom = kwargs.get('bn_mom', 0.9)
|
||||
workspace = kwargs.get('workspace', 256)
|
||||
memonger = kwargs.get('memonger', False)
|
||||
if stride[0]>1 or not dim_match:
|
||||
return residual_unit_g(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
|
||||
res = RES(data, num_filter, 1, 4, workspace, name, dilate, 1)
|
||||
return res.get()
|
||||
|
||||
def residual_unit(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
|
||||
#binarize = False
|
||||
#binarize = BINARIZE
|
||||
return residual_unit_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
|
||||
|
||||
def hourglass(data, nFilters, nModules, n, workspace, name, binarize, dcn):
|
||||
s = 2
|
||||
_dcn = False
|
||||
up1 = data
|
||||
for i in xrange(nModules):
|
||||
up1 = residual_unit(up1, nFilters, (1,1), True, "%s_up1_%d"%(name,i), binarize, _dcn, 1)
|
||||
low1 = mx.sym.Pooling(data=data, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
|
||||
for i in xrange(nModules):
|
||||
low1 = residual_unit(low1, nFilters, (1,1), True, "%s_low1_%d"%(name,i), binarize, _dcn, 1)
|
||||
if n>1:
|
||||
low2 = hourglass(low1, nFilters, nModules, n-1, workspace, "%s_%d"%(name, n-1), binarize, dcn)
|
||||
else:
|
||||
low2 = low1
|
||||
for i in xrange(nModules):
|
||||
low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), binarize, _dcn, 1) #TODO
|
||||
#low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), False) #TODO
|
||||
low3 = low2
|
||||
for i in xrange(nModules):
|
||||
low3 = residual_unit(low3, nFilters, (1,1), True, "%s_low3_%d"%(name,i), binarize, _dcn, 1)
|
||||
up2 = mx.symbol.Deconvolution(data=low3, num_filter=nFilters, kernel=(s,s),
|
||||
stride=(s, s),
|
||||
num_group=nFilters, no_bias=True, name='%s_upsampling_%s'%(name,n),
|
||||
attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=workspace)
|
||||
return mx.symbol.add_n(up1, up2)
|
||||
|
||||
def hourglass2(data, nFilters, nModules, n, workspace, name, binarize, dcn):
|
||||
s = 2
|
||||
_dcn = dcn
|
||||
if DCN and n==N:
|
||||
_dcn = True
|
||||
_dcn = False
|
||||
up1 = data
|
||||
dilate = 2**(4-n)
|
||||
for i in xrange(nModules):
|
||||
up1 = residual_unit(up1, nFilters, (1,1), True, "%s_up1_%d"%(name,i), binarize, _dcn, dilate)
|
||||
#low1 = mx.sym.Pooling(data=data, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
|
||||
low1 = data
|
||||
for i in xrange(nModules):
|
||||
low1 = residual_unit(low1, nFilters, (1,1), True, "%s_low1_%d"%(name,i), binarize, _dcn, dilate)
|
||||
if n>1:
|
||||
low2 = hourglass2(low1, nFilters, nModules, n-1, workspace, "%s_%d"%(name, n-1), binarize, dcn)
|
||||
else:
|
||||
low2 = low1
|
||||
for i in xrange(nModules):
|
||||
low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), binarize, _dcn, dilate) #TODO
|
||||
#low2 = residual_unit(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), False) #TODO
|
||||
low3 = low2
|
||||
for i in xrange(nModules):
|
||||
low3 = residual_unit(low3, nFilters, (1,1), True, "%s_low3_%d"%(name,i), binarize, _dcn, dilate)
|
||||
up2 = low3
|
||||
#up2 = mx.symbol.Deconvolution(data=low3, num_filter=nFilters, kernel=(s,s),
|
||||
# stride=(s, s),
|
||||
# num_group=nFilters, no_bias=True, name='%s_upsampling_%s'%(name,n),
|
||||
# attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=workspace)
|
||||
return mx.symbol.add_n(up1, up2)
|
||||
|
||||
|
||||
class DLA:
|
||||
def __init__(self, data, nFilters, nModules, n, workspace, name):
|
||||
self.data = data
|
||||
self.nFilters = nFilters
|
||||
self.nModules = nModules
|
||||
self.n = n
|
||||
self.workspace = workspace
|
||||
self.name = name
|
||||
self.sym_map = {}
|
||||
|
||||
|
||||
def residual_unit(self, data, name, dilate=1, group=1):
|
||||
res = RES(data, self.nFilters, self.nModules, 4, self.workspace, name, dilate, group)
|
||||
return res.get()
|
||||
#body = data
|
||||
#for i in xrange(self.nModules):
|
||||
# body = residual_unit(body, self.nFilters, (1,1), True, name, False, False, 1)
|
||||
#return body
|
||||
|
||||
def get_output(self, w, h):
|
||||
#print(w,h)
|
||||
assert w>=1 and w<=N+1
|
||||
assert h>=1 and h<=N+1
|
||||
s = 2
|
||||
bn_mom = 0.9
|
||||
key = (w,h)
|
||||
if key in self.sym_map:
|
||||
return self.sym_map[key]
|
||||
ret = None
|
||||
if h==self.n:
|
||||
if w==self.n:
|
||||
ret = self.data,64
|
||||
#elif w==1:
|
||||
# x = self.get_output(w+1, h)
|
||||
# body = self.residual_unit(x[0], "%s_w%d_h%d_1"%(self.name, w, h))
|
||||
# body = self.residual_unit(body, "%s_w%d_h%d_2"%(self.name, w, h), 2)
|
||||
# ret = body,x[1]
|
||||
else:
|
||||
x = self.get_output(w+1, h)
|
||||
body = self.residual_unit(x[0], "%s_w%d_h%d_1"%(self.name, w, h))
|
||||
body = mx.sym.Pooling(data=body, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
|
||||
body = self.residual_unit(body, "%s_w%d_h%d_2"%(self.name, w, h))
|
||||
ret = body, x[1]//2
|
||||
else:
|
||||
x = self.get_output(w+1, h+1)
|
||||
y = self.get_output(w, h+1)
|
||||
#xbody = Conv(data=x, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
# no_bias=True, workspace=self.workspace, name="%s_w%d_h%d_x_conv"%(self.name, w, h))
|
||||
#xbody = mx.sym.BatchNorm(data=xbody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_x_bn"%(self.name, w, h))
|
||||
#xbody = Act(data=xbody, act_type='relu', name="%s_w%d_h%d_x_act"%(self.name, w, h))
|
||||
|
||||
HC = False
|
||||
|
||||
if use_DLA<10:
|
||||
if h%2==1 and h!=w:
|
||||
xbody = lin3(x[0], self.nFilters, self.workspace, "%s_w%d_h%d_x"%(self.name, w, h), 3, self.nFilters, 1)
|
||||
HC = True
|
||||
#xbody = x[0]
|
||||
else:
|
||||
xbody = x[0]
|
||||
else:
|
||||
xbody = lin3(x[0], self.nFilters, self.workspace, "%s_w%d_h%d_x"%(self.name, w, h), 3, 1, 1)
|
||||
#xbody = x[0]
|
||||
if x[1]//y[1]==2:
|
||||
if w>1:
|
||||
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
|
||||
stride=(s, s),
|
||||
name='%s_upsampling_w%d_h%d'%(self.name,w, h),
|
||||
attr={'lr_mult': '1.0'}, workspace=self.workspace)
|
||||
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
|
||||
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
|
||||
#ybody = Conv(data=ybody, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
# no_bias=True, name="%s_w%d_h%d_y_conv2"%(self.name, w, h), workspace=self.workspace)
|
||||
#ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn2"%(self.name, w, h))
|
||||
#ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act2"%(self.name, w, h))
|
||||
else:
|
||||
if h>=1:
|
||||
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
|
||||
stride=(s, s),
|
||||
num_group=self.nFilters, no_bias=True, name='%s_upsampling_w%d_h%d'%(self.name,w, h),
|
||||
attr={'lr_mult': '0.0', 'wd_mult': '0.0'}, workspace=self.workspace)
|
||||
#ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
|
||||
import math
|
||||
#ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
|
||||
ybody = self.residual_unit(ybody, "%s_w%d_h%d_4"%(self.name, w, h))
|
||||
else:
|
||||
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
|
||||
stride=(s, s),
|
||||
name='%s_upsampling_w%d_h%d'%(self.name,w, h),
|
||||
attr={'lr_mult': '1.0'}, workspace=self.workspace)
|
||||
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
|
||||
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
|
||||
ybody = Conv(data=ybody, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, name="%s_w%d_h%d_y_conv2"%(self.name, w, h), workspace=self.workspace)
|
||||
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn2"%(self.name, w, h))
|
||||
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act2"%(self.name, w, h))
|
||||
else:
|
||||
ybody = self.residual_unit(y[0], "%s_w%d_h%d_5"%(self.name, w, h))
|
||||
#if not HC:
|
||||
if use_DLA<10:
|
||||
if use_DLA>1 and h==3 and w==2:
|
||||
z = self.get_output(w+1, h)
|
||||
zbody = z[0]
|
||||
#zbody = lin3_red(zbody, self.nFilters, self.workspace, "%s_w%d_h%d_z"%(self.name, w, h), 2, self.nFilters)
|
||||
#zbody = mx.sym.Pooling(data=zbody, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='avg')
|
||||
zbody = mx.sym.Pooling(data=zbody, kernel=(z[1], z[1]), stride=(z[1],z[1]), pad=(0,0), pool_type='avg')
|
||||
#zbody = mx.sym.Activation(data = zbody, act_type='sigmoid')
|
||||
|
||||
#body = zbody+ybody
|
||||
#body = body/2
|
||||
body = xbody+ybody
|
||||
body = body/2
|
||||
#body = body*zbody
|
||||
body = mx.sym.broadcast_mul(body, zbody)
|
||||
#body = mx.sym.add_n(*[xbody, ybody, zbody])
|
||||
#body = body/3
|
||||
else:
|
||||
body = xbody+ybody
|
||||
body = body/2
|
||||
else:
|
||||
if use_DLA==12 and h!=w:
|
||||
zbody = self.get_output(w+1, h)[0]
|
||||
zbody = lin3_red(zbody, self.nFilters, self.workspace, "%s_w%d_h%d_z"%(self.name, w, h), 2, 1)
|
||||
body = mx.sym.add_n(*[xbody, ybody, zbody])
|
||||
body = body/3
|
||||
else:
|
||||
body = xbody+ybody
|
||||
body = body/2
|
||||
ret = body, x[1]
|
||||
|
||||
assert ret is not None
|
||||
self.sym_map[key] = ret
|
||||
return ret
|
||||
|
||||
def get(self):
|
||||
return self.get_output(1, 1)[0]
|
||||
|
||||
|
||||
def l2_loss(x, y):
|
||||
loss = x-y
|
||||
loss = loss*loss
|
||||
loss = mx.symbol.mean(loss)
|
||||
return loss
|
||||
|
||||
def ce_loss(x, y):
|
||||
body = mx.sym.exp(x)
|
||||
sums = mx.sym.sum(body, axis=[2,3], keepdims=True)
|
||||
body = mx.sym.broadcast_div(body, sums)
|
||||
loss = mx.sym.log(body)
|
||||
loss = loss*y*-1.0
|
||||
loss = mx.symbol.mean(loss)
|
||||
return loss
|
||||
|
||||
def get_symbol(num_classes, **kwargs):
|
||||
global use_DLA
|
||||
global N
|
||||
global DCN
|
||||
mirror_set = [
|
||||
(22,23),
|
||||
(21,24),
|
||||
(20,25),
|
||||
(19,26),
|
||||
(18,27),
|
||||
(40,43),
|
||||
(39,44),
|
||||
(38,45),
|
||||
(37,46),
|
||||
(42,47),
|
||||
(41,48),
|
||||
(33,35),
|
||||
(32,36),
|
||||
(51,53),
|
||||
(50,54),
|
||||
(62,64),
|
||||
(61,65),
|
||||
(49,55),
|
||||
(49,55),
|
||||
(68,66),
|
||||
(60,56),
|
||||
(59,57),
|
||||
(1,17),
|
||||
(2,16),
|
||||
(3,15),
|
||||
(4,14),
|
||||
(5,13),
|
||||
(6,12),
|
||||
(7,11),
|
||||
(8,10),
|
||||
]
|
||||
mirror_map = {}
|
||||
for mm in mirror_set:
|
||||
mirror_map[mm[0]-1] = mm[1]-1
|
||||
mirror_map[mm[1]-1] = mm[0]-1
|
||||
sFilters = 64
|
||||
mFilters = 128
|
||||
nFilters = 256
|
||||
|
||||
nModules = 1
|
||||
nStacks = 2
|
||||
bn_mom = kwargs.get('bn_mom', 0.9)
|
||||
workspace = kwargs.get('workspace', 256)
|
||||
binarize = kwargs.get('binarize', False)
|
||||
input_size = kwargs.get('input_size', 128)
|
||||
label_size = kwargs.get('label_size', 64)
|
||||
use_coherent = kwargs.get('use_coherent', 0)
|
||||
use_DLA = kwargs.get('use_dla', 0)
|
||||
N = kwargs.get('use_N', 4)
|
||||
DCN = kwargs.get('use_DCN', 0)
|
||||
per_batch_size = kwargs.get('per_batch_size', 0)
|
||||
print('binarize', binarize)
|
||||
print('use_coherent', use_coherent)
|
||||
print('use_DLA', use_DLA)
|
||||
print('use_N', N)
|
||||
print('use_DCN', DCN)
|
||||
print('per_batch_size', per_batch_size)
|
||||
assert(label_size==64 or label_size==32)
|
||||
assert(input_size==128 or input_size==256)
|
||||
D = input_size // label_size
|
||||
print(input_size, label_size, D)
|
||||
dcn = False
|
||||
kwargs = {}
|
||||
data = mx.sym.Variable(name='data')
|
||||
data = data-127.5
|
||||
data = data*0.0078125
|
||||
gt_label = mx.symbol.Variable(name='softmax_label')
|
||||
losses = []
|
||||
closses = []
|
||||
if use_coherent:
|
||||
M = mx.sym.Variable(name="coherent_label")
|
||||
#gt_label2 = mx.sym.Variable(name="softmax_label2")
|
||||
coherent_weight = 0.0001
|
||||
ref_label = gt_label
|
||||
if use_STN:
|
||||
lr_mult = '0.00001'
|
||||
loc_net = Conv(data=data, num_filter=sFilters, kernel=(7, 7), stride=(2,2), pad=(3, 3),
|
||||
no_bias=True, name="stn_conv0", workspace=workspace, attr={'lr_mult': lr_mult})
|
||||
loc_net = mx.sym.BatchNorm(data=loc_net, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='stn_bn0', attr={'lr_mult': lr_mult})
|
||||
loc_net = Act(data=loc_net, act_type='relu', name='stn_relu0')
|
||||
loc_net = Conv(data=loc_net, num_filter=mFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, name="stn_conv1", workspace=workspace, attr={'lr_mult': lr_mult})
|
||||
loc_net = mx.sym.BatchNorm(data=loc_net, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='stn_bn1', attr={'lr_mult': lr_mult})
|
||||
loc_net = Act(data=loc_net, act_type='relu', name='stn_relu1')
|
||||
loc_net = mx.sym.Pooling(data=loc_net, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max')
|
||||
|
||||
loc_net = mx.sym.FullyConnected(data=loc_net, num_hidden=int(nFilters*0.5), name='loc_net_half', attr={'lr_mult': lr_mult})
|
||||
loc_net = mx.sym.Activation(data=loc_net, act_type='tanh', name='loc_net_act')
|
||||
#loc_net = mx.sym.Activation(data=loc_net, act_type='relu', name='loc_net_act')
|
||||
#loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=6, name='loc_theta', attr={'lr_mult': lr_mult})
|
||||
#loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh')
|
||||
loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=1, name='loc_theta', attr={'lr_mult': lr_mult})
|
||||
loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh')
|
||||
loc_theta = loc_theta*0.5
|
||||
sin_t = mx.sym.sin(loc_theta)
|
||||
m_sin_t = sin_t*-1.0
|
||||
cos_t = mx.sym.cos(loc_theta)
|
||||
zero_t = mx.sym.zeros_like(loc_theta)
|
||||
loc_theta = mx.sym.concat(*[cos_t, m_sin_t, zero_t, sin_t, cos_t, zero_t], dim=1)
|
||||
data = mx.sym.SpatialTransformer(data = data, loc = loc_theta, target_shape=(input_size,input_size), transform_type="affine", sampler_type="bilinear")
|
||||
ref_label = mx.sym.SpatialTransformer(data = ref_label, loc = loc_theta, target_shape=(label_size,label_size), transform_type="affine", sampler_type="bilinear")
|
||||
#data = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn_data')
|
||||
if D==4:
|
||||
body = Conv(data=data, num_filter=sFilters, kernel=(7, 7), stride=(2,2), pad=(3, 3),
|
||||
no_bias=True, name="conv0", workspace=workspace)
|
||||
else:
|
||||
body = Conv(data=data, num_filter=sFilters, kernel=(3, 3), stride=(1,1), pad=(1, 1),
|
||||
no_bias=True, name="conv0", workspace=workspace)
|
||||
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
|
||||
body = Act(data=body, act_type='relu', name='relu0')
|
||||
|
||||
body = residual_unit(body, mFilters, (1,1), sFilters==mFilters, 'res0', False, dcn, 1, **kwargs)
|
||||
#body = residual_unit(body, nFilters, (1,1), False, 'res0', binarize, **kwargs)
|
||||
|
||||
body = mx.sym.Pooling(data=body, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max')
|
||||
|
||||
body = residual_unit(body, mFilters, (1,1), True, 'res1', False, dcn, 1, **kwargs) #TODO
|
||||
#body = residual_unit(body, nFilters, (1,1), True, 'res1', binarize, **kwargs) #TODO
|
||||
body = residual_unit(body, nFilters, (1,1), mFilters==nFilters, 'res2', binarize, dcn, 1, **kwargs) #binarize=True?
|
||||
#body = residual_unit(body, nFilters, (1,1), False, 'res2', False, **kwargs) #binarize=True?
|
||||
|
||||
use_lin = True
|
||||
heatmap = None
|
||||
|
||||
for i in xrange(nStacks):
|
||||
shortcut = body
|
||||
if use_DLA>0:
|
||||
dla = DLA(body, nFilters, nModules, N+1, workspace, 'dla%d'%(i))
|
||||
body = dla.get()
|
||||
else:
|
||||
body = hourglass(body, nFilters, nModules, N, workspace, 'stack%d_hg'%(i), binarize, dcn)
|
||||
for j in xrange(nModules):
|
||||
body = residual_unit(body, nFilters, (1,1), True, 'stack%d_unit%d'%(i,j), binarize, dcn, 1, **kwargs)
|
||||
if use_lin:
|
||||
_dcn = True if DCN>=2 else False
|
||||
ll = lin(body, nFilters, workspace, name='stack%d_ll'%(i), binarize = False, dcn = _dcn) #TODO
|
||||
#ll = lin(body, nFilters, workspace, name='stack%d_ll'%(i), binarize = binarize)
|
||||
else:
|
||||
ll = body
|
||||
_name = "heatmap%d"%(i)
|
||||
if i==nStacks-1:
|
||||
_name = "heatmap"
|
||||
|
||||
_dcn = True if DCN>=2 else False
|
||||
if not _dcn:
|
||||
out = Conv(data=ll, num_filter=num_classes, kernel=(1, 1), stride=(1,1), pad=(0,0),
|
||||
name=_name, workspace=workspace)
|
||||
else:
|
||||
out_offset = mx.symbol.Convolution(name=_name+'_offset', data = ll,
|
||||
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
||||
out = mx.contrib.symbol.DeformableConvolution(name=_name, data=ll, offset=out_offset,
|
||||
num_filter=num_classes, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
|
||||
#out = Conv(data=ll, num_filter=num_classes, kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
# name=_name, workspace=workspace)
|
||||
if i==nStacks-1:
|
||||
heatmap = out
|
||||
#outs.append(out)
|
||||
if use_coherent>0:
|
||||
#px = mx.sym.slice_axis(out, axis=0, begin=0, end=b)
|
||||
#py = mx.sym.slice_axis(ref_label, axis=0, begin=0, end=b)
|
||||
px = out
|
||||
py = ref_label
|
||||
gloss = ce_loss(px, py)
|
||||
gloss = gloss/nStacks
|
||||
losses.append(gloss)
|
||||
|
||||
b = per_batch_size//2
|
||||
ux = mx.sym.slice_axis(out, axis=0, begin=0, end=b)
|
||||
dx = mx.sym.slice_axis(out, axis=0, begin=b, end=b*2)
|
||||
if use_coherent==1:
|
||||
ux = mx.sym.flip(ux, axis=3)
|
||||
ux_list = [None]*68
|
||||
for k in xrange(68):
|
||||
if k in mirror_map:
|
||||
vk = mirror_map[k]
|
||||
#print('k', k, vk)
|
||||
ux_list[vk] = mx.sym.slice_axis(ux, axis=1, begin=k, end=k+1)
|
||||
else:
|
||||
ux_list[k] = mx.sym.slice_axis(ux, axis=1, begin=k, end=k+1)
|
||||
ux = mx.sym.concat(*ux_list, dim=1)
|
||||
#dx = mx.sym.slice_axis(ref_label, axis=0, begin=b, end=b*2)
|
||||
#closs = ce_loss(ux, dx)
|
||||
closs = l2_loss(ux, dx)
|
||||
closs = closs/nStacks
|
||||
closses.append(closs)
|
||||
else:
|
||||
m = mx.sym.slice_axis(M, axis=0, begin=0, end=b)
|
||||
ux = mx.sym.SpatialTransformer(data=ux, loc=m, target_shape=(label_size, label_size), transform_type='affine', sampler_type='bilinear')
|
||||
closs = l2_loss(ux, dx)
|
||||
closs = closs/nStacks
|
||||
closses.append(closs)
|
||||
|
||||
else:
|
||||
loss = ce_loss(out, ref_label)
|
||||
loss = loss/nStacks
|
||||
losses.append(loss)
|
||||
|
||||
if i<nStacks-1:
|
||||
if use_lin:
|
||||
ll2 = Conv(data=ll, num_filter=nFilters, kernel=(1, 1), stride=(1,1), pad=(0,0),
|
||||
name="stack%d_ll2"%(i), workspace=workspace)
|
||||
else:
|
||||
ll2 = body
|
||||
out2 = Conv(data=out, num_filter=nFilters, kernel=(1, 1), stride=(1,1), pad=(0,0),
|
||||
name="stack%d_out2"%(i), workspace=workspace)
|
||||
body = mx.symbol.add_n(shortcut, ll2, out2)
|
||||
_dcn = True if (DCN==1 or DCN==3) else False
|
||||
if _dcn:
|
||||
_name = "stack%d_out3" % (i)
|
||||
out3_offset = mx.symbol.Convolution(name=_name+'_offset', data = body,
|
||||
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
||||
out3 = mx.contrib.symbol.DeformableConvolution(name=_name, data=body, offset=out3_offset,
|
||||
num_filter=nFilters, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
|
||||
body = out3
|
||||
#elif use_STN:
|
||||
# loc_net = dla.get2()
|
||||
# #loc_net = mx.sym.Pooling(data=loc_net, global_pool=True, kernel=(7, 7), pool_type='avg', name='loc_net_pool')
|
||||
# loc_net = mx.sym.FullyConnected(data=loc_net, num_hidden=int(nFilters*0.5), name='loc_net_half', attr={'lr_mult': '0.0001'})
|
||||
# loc_net = mx.sym.Activation(data=loc_net, act_type='tanh', name='loc_net_act')
|
||||
# #loc_net = mx.sym.Activation(data=loc_net, act_type='relu', name='loc_net_act')
|
||||
# loc_theta = mx.sym.FullyConnected(data=loc_net, num_hidden=6, name='loc_theta', attr={'lr_mult': '0.0001'})
|
||||
# loc_theta = mx.sym.Activation(data=loc_theta, act_type='tanh', name='loc_theta_tanh')
|
||||
# body = mx.sym.SpatialTransformer(data = body, loc = loc_theta, target_shape=(label_size,label_size), transform_type="affine", sampler_type="bilinear")
|
||||
# ref_label = mx.sym.SpatialTransformer(data = gt_label, loc = loc_theta, target_shape=(label_size,label_size), transform_type="affine", sampler_type="bilinear")
|
||||
|
||||
pred = mx.symbol.BlockGrad(heatmap)
|
||||
loss = mx.symbol.add_n(*losses)
|
||||
|
||||
loss = mx.symbol.MakeLoss(loss)
|
||||
syms = [loss]
|
||||
if len(closses)>0:
|
||||
closs = mx.symbol.add_n(*closses)
|
||||
closs = mx.symbol.MakeLoss(closs, grad_scale = coherent_weight)
|
||||
syms.append(closs)
|
||||
#syms.append(mx.symbol.BlockGrad(M))
|
||||
#syms.append(mx.symbol.BlockGrad(px))
|
||||
#syms.append(mx.symbol.BlockGrad(qx))
|
||||
#syms.append(mx.symbol.BlockGrad(m))
|
||||
#syms.append(mx.symbol.BlockGrad(closs))
|
||||
if use_coherent>1:
|
||||
syms.append(mx.symbol.BlockGrad(gt_label))
|
||||
if use_coherent>0:
|
||||
syms.append(mx.symbol.BlockGrad(M))
|
||||
syms.append(pred)
|
||||
sym = mx.symbol.Group( syms )
|
||||
return sym
|
||||
|
||||
def init_weights(sym, data_shape_dict):
|
||||
print('in hg2')
|
||||
arg_name = sym.list_arguments()
|
||||
aux_name = sym.list_auxiliary_states()
|
||||
arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
|
||||
arg_shape_dict = dict(zip(arg_name, arg_shape))
|
||||
aux_shape_dict = dict(zip(aux_name, aux_shape))
|
||||
#print(aux_shape)
|
||||
#print(aux_params)
|
||||
#print(arg_shape_dict)
|
||||
arg_params = {}
|
||||
aux_params = {}
|
||||
for k,v in arg_shape_dict.iteritems():
|
||||
#print(k,v)
|
||||
if k.endswith('offset_weight') or k.endswith('offset_bias'):
|
||||
print('initializing',k)
|
||||
arg_params[k] = mx.nd.zeros(shape = v)
|
||||
elif k.startswith('fc6_'):
|
||||
if k.endswith('_weight'):
|
||||
print('initializing',k)
|
||||
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
|
||||
elif k.endswith('_bias'):
|
||||
print('initializing',k)
|
||||
arg_params[k] = mx.nd.zeros(shape=v)
|
||||
elif k.find('upsampling')>=0:
|
||||
print('initializing upsampling_weight', k)
|
||||
arg_params[k] = mx.nd.zeros(shape=arg_shape_dict[k])
|
||||
init = mx.init.Initializer()
|
||||
init._init_bilinear(k, arg_params[k])
|
||||
elif k.find('loc_theta')>=0:
|
||||
print('initializing STN', k, v)
|
||||
if k.endswith('_weight'):
|
||||
arg_params[k] = mx.nd.zeros(shape=v)
|
||||
elif k.endswith('_bias'):
|
||||
#val = np.array([4.0, 0.0, 0.0, 0.0, 4.0, 0.0], dtype=np.float32)
|
||||
#val = np.array([0.0], dtype=np.float32)
|
||||
#arg_params[k] = mx.nd.array(val)
|
||||
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
|
||||
return arg_params, aux_params
|
||||
|
||||
145
alignment/img_helper.py
Normal file
145
alignment/img_helper.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import numpy as np
|
||||
import scipy.misc
|
||||
import scipy.signal
|
||||
import math
|
||||
|
||||
#import draw
|
||||
#import ref
|
||||
|
||||
# =============================================================================
|
||||
# General image processing functions
|
||||
# =============================================================================
|
||||
|
||||
def get_transform(center, scale, res, rot=0):
|
||||
# Generate transformation matrix
|
||||
#h = 200 * scale
|
||||
#h = 100 * scale
|
||||
h = scale
|
||||
t = np.zeros((3, 3))
|
||||
t[0, 0] = float(res[1]) / h
|
||||
t[1, 1] = float(res[0]) / h
|
||||
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
||||
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
||||
t[2, 2] = 1
|
||||
if not rot == 0:
|
||||
rot = -rot # To match direction of rotation from cropping
|
||||
rot_mat = np.zeros((3,3))
|
||||
rot_rad = rot * np.pi / 180
|
||||
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
rot_mat[0,:2] = [cs, -sn]
|
||||
rot_mat[1,:2] = [sn, cs]
|
||||
rot_mat[2,2] = 1
|
||||
# Need to rotate around center
|
||||
t_mat = np.eye(3)
|
||||
t_mat[0,2] = -res[1]/2
|
||||
t_mat[1,2] = -res[0]/2
|
||||
t_inv = t_mat.copy()
|
||||
t_inv[:2,2] *= -1
|
||||
t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
|
||||
return t
|
||||
|
||||
def transform(pt, center, scale, res, invert=0, rot=0):
|
||||
# Transform pixel location to different reference
|
||||
t = get_transform(center, scale, res, rot=rot)
|
||||
if invert:
|
||||
t = np.linalg.inv(t)
|
||||
new_pt = np.array([pt[0], pt[1], 1.]).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
#print('new_pt', new_pt.shape, new_pt)
|
||||
return new_pt[:2].astype(int)
|
||||
|
||||
def crop_center(img,crop_size):
|
||||
y,x = img.shape[0], img.shape[1]
|
||||
startx = x//2-(crop_size[1]//2)
|
||||
starty = y//2-(crop_size[0]//2)
|
||||
#print(startx, starty, crop_size)
|
||||
return img[starty:(starty+crop_size[0]),startx:(startx+crop_size[1]),:]
|
||||
|
||||
def crop(img, center, scale, res, rot=0):
|
||||
# Upper left point
|
||||
ul = np.array(transform([0, 0], center, scale, res, invert=1))
|
||||
# Bottom right point
|
||||
br = np.array(transform(res, center, scale, res, invert=1))
|
||||
|
||||
# Padding so that when rotated proper amount of context is included
|
||||
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
||||
if not rot == 0:
|
||||
ul -= pad
|
||||
br += pad
|
||||
|
||||
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
||||
if len(img.shape) > 2:
|
||||
new_shape += [img.shape[2]]
|
||||
new_img = np.zeros(new_shape)
|
||||
#print('new_img', new_img.shape)
|
||||
|
||||
# Range to fill new array
|
||||
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
||||
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
||||
# Range to sample from original image
|
||||
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
||||
old_y = max(0, ul[1]), min(len(img), br[1])
|
||||
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
|
||||
|
||||
if not rot == 0:
|
||||
# Remove padding
|
||||
#print('before rotate', new_img.shape, rot)
|
||||
new_img = scipy.misc.imrotate(new_img, rot)
|
||||
new_img = new_img[pad:-pad, pad:-pad]
|
||||
|
||||
return scipy.misc.imresize(new_img, res)
|
||||
|
||||
def crop2(img, center, scale, res, rot=0):
|
||||
# Upper left point
|
||||
rad = np.min( [center[0], img.shape[0] - center[0], center[1], img.shape[1] - center[1]] )
|
||||
new_img = img[(center[0]-rad):(center[0]+rad),(center[1]-rad):(center[1]+rad),:]
|
||||
#print('new_img', new_img.shape)
|
||||
if not rot == 0:
|
||||
new_img = scipy.misc.imrotate(new_img, rot)
|
||||
new_img = crop_center(new_img, (scale,scale))
|
||||
return scipy.misc.imresize(new_img, res)
|
||||
|
||||
def nms(img):
|
||||
# Do non-maximum suppression on a 2D array
|
||||
win_size = 3
|
||||
domain = np.ones((win_size, win_size))
|
||||
maxes = scipy.signal.order_filter(img, domain, win_size ** 2 - 1)
|
||||
diff = maxes - img
|
||||
result = img.copy()
|
||||
result[diff > 0] = 0
|
||||
return result
|
||||
|
||||
|
||||
def gaussian(img, pt, sigma):
|
||||
# Draw a 2D gaussian
|
||||
assert(sigma>0)
|
||||
|
||||
# Check that any part of the gaussian is in-bounds
|
||||
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
|
||||
br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
|
||||
if (ul[0] > img.shape[1] or ul[1] >= img.shape[0] or
|
||||
br[0] < 0 or br[1] < 0):
|
||||
# If not, just return the image as is
|
||||
#print('gaussian error')
|
||||
return False
|
||||
#return img
|
||||
|
||||
# Generate gaussian
|
||||
size = 6 * sigma + 1
|
||||
x = np.arange(0, size, 1, float)
|
||||
y = x[:, np.newaxis]
|
||||
x0 = y0 = size // 2
|
||||
# The gaussian is not normalized, we want the center value to equal 1
|
||||
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
|
||||
# Usable gaussian range
|
||||
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
|
||||
# Image range
|
||||
img_x = max(0, ul[0]), min(br[0], img.shape[1])
|
||||
img_y = max(0, ul[1]), min(br[1], img.shape[0])
|
||||
|
||||
img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
||||
return True
|
||||
#return img
|
||||
|
||||
56
alignment/infer.py
Normal file
56
alignment/infer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import argparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import sys
|
||||
import mxnet as mx
|
||||
|
||||
parser = argparse.ArgumentParser(description='face model test')
|
||||
# general
|
||||
parser.add_argument('--image-size', default='128,128', help='')
|
||||
parser.add_argument('--model', default='./models/test,15', help='path to load model.')
|
||||
parser.add_argument('--gpu', default=-1, type=int, help='gpu id')
|
||||
args = parser.parse_args()
|
||||
|
||||
_vec = args.image_size.split(',')
|
||||
assert len(_vec)==2
|
||||
image_size = (int(_vec[0]), int(_vec[1]))
|
||||
_vec = args.model.split(',')
|
||||
assert len(_vec)==2
|
||||
prefix = _vec[0]
|
||||
epoch = int(_vec[1])
|
||||
print('loading',prefix, epoch)
|
||||
if args.gpu>=0:
|
||||
ctx = mx.gpu(args.gpu)
|
||||
else:
|
||||
ctx = mx.cpu()
|
||||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
||||
all_layers = sym.get_internals()
|
||||
sym = all_layers['heatmap_output']
|
||||
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
|
||||
#model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
|
||||
model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
|
||||
model.set_params(arg_params, aux_params)
|
||||
#img_path = '/raid5data/dplearn/megaface/facescrubr/112x112/Tom_Hanks/Tom_Hanks_54745.png'
|
||||
img_path = './test.png'
|
||||
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
rimg = cv2.resize(img, (image_size[1], image_size[0]))
|
||||
img = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
|
||||
img = np.transpose(img, (2,0,1)) #3*112*112, RGB
|
||||
input_blob = np.expand_dims(img, axis=0) #1*3*112*112
|
||||
data = mx.nd.array(input_blob)
|
||||
db = mx.io.DataBatch(data=(data,))
|
||||
model.forward(db, is_train=False)
|
||||
output = model.get_outputs()[0].asnumpy()
|
||||
#print(output[0,80])
|
||||
#sys.exit(0)
|
||||
filename = "./vis/draw_%s" % img_path.split('/')[-1]
|
||||
for i in xrange(output.shape[1]):
|
||||
a = output[0,i,:,:]
|
||||
a = cv2.resize(a, (image_size[1], image_size[0]))
|
||||
ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
|
||||
cv2.circle(rimg, (ind[1], ind[0]), 1, (0, 0, 255), 2)
|
||||
print(i, ind)
|
||||
cv2.imwrite(filename, rimg)
|
||||
|
||||
355
alignment/train.py
Normal file
355
alignment/train.py
Normal file
@@ -0,0 +1,355 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
#import hg
|
||||
import hg2 as hg
|
||||
import logging
|
||||
import argparse
|
||||
from data import FaceSegIter
|
||||
import mxnet as mx
|
||||
import mxnet.optimizer as optimizer
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import random
|
||||
import cv2
|
||||
|
||||
|
||||
args = None
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class LossValueMetric(mx.metric.EvalMetric):
|
||||
def __init__(self):
|
||||
self.axis = 1
|
||||
super(LossValueMetric, self).__init__(
|
||||
'lossvalue', axis=self.axis,
|
||||
output_names=None, label_names=None)
|
||||
self.losses = []
|
||||
|
||||
def update(self, labels, preds):
|
||||
loss = preds[0].asnumpy()[0]
|
||||
self.sum_metric += loss
|
||||
self.num_inst += 1.0
|
||||
#label = preds[1].asnumpy()
|
||||
#out0 = preds[1].asnumpy()[0][0]
|
||||
#l = 10
|
||||
#out0 = out0[20:20+l, 20:20+l]
|
||||
#out1 = preds[2].asnumpy()[0][0]
|
||||
#out1 = out1[20:20+l, 20:20+l]
|
||||
|
||||
#m = preds[3].asnumpy()[0]
|
||||
#theta = np.arcsin(m[3])
|
||||
#theta = theta/math.pi*180
|
||||
#print(out0)
|
||||
#print('')
|
||||
#print(out1)
|
||||
#print('')
|
||||
#print(m, theta)
|
||||
#print(label[0])
|
||||
#for i in xrange(gt_label.shape[0]):
|
||||
# label0 = gt_label[i][0]
|
||||
# c = np.count_nonzero(label0)
|
||||
# ind = np.unravel_index(np.argmax(label0, axis=None), label0.shape)
|
||||
# print('A', i, ind, label0.shape, c)
|
||||
|
||||
|
||||
|
||||
class NMEMetric(mx.metric.EvalMetric):
|
||||
def __init__(self):
|
||||
self.axis = 1
|
||||
super(NMEMetric, self).__init__(
|
||||
'NME', axis=self.axis,
|
||||
output_names=None, label_names=None)
|
||||
#self.losses = []
|
||||
self.count = 0
|
||||
|
||||
def update(self, labels, preds):
|
||||
self.count+=1
|
||||
preds = [preds[-1]]
|
||||
for label, pred_label in zip(labels, preds):
|
||||
label = label.asnumpy()
|
||||
pred_label = pred_label.asnumpy()
|
||||
#print('acc',label.shape, pred_label.shape)
|
||||
|
||||
nme = []
|
||||
for b in xrange(pred_label.shape[0]):
|
||||
for p in xrange(pred_label.shape[1]):
|
||||
heatmap_gt = label[b][p]
|
||||
heatmap_pred = pred_label[b][p]
|
||||
heatmap_pred = cv2.resize(heatmap_pred, (label.shape[2], label.shape[3]))
|
||||
ind_gt = np.unravel_index(np.argmax(heatmap_gt, axis=None), heatmap_gt.shape)
|
||||
ind_pred = np.unravel_index(np.argmax(heatmap_pred, axis=None), heatmap_pred.shape)
|
||||
ind_gt = np.array(ind_gt)
|
||||
ind_pred = np.array(ind_pred)
|
||||
dist = np.sqrt(np.sum(np.square(ind_gt - ind_pred)))
|
||||
nme.append(dist)
|
||||
nme = np.mean(nme)
|
||||
nme /= np.sqrt(float(label.shape[2]*label.shape[3]))
|
||||
|
||||
self.sum_metric += nme
|
||||
self.num_inst += 1.0
|
||||
|
||||
class NMEMetric2(mx.metric.EvalMetric):
|
||||
def __init__(self):
|
||||
self.axis = 1
|
||||
super(NMEMetric2, self).__init__(
|
||||
'NME2', axis=self.axis,
|
||||
output_names=None, label_names=None)
|
||||
#self.losses = []
|
||||
self.count = 0
|
||||
|
||||
def update(self, labels, preds):
|
||||
self.count+=1
|
||||
preds = [preds[-1]]
|
||||
for label, pred_label in zip(labels, preds):
|
||||
label = label.asnumpy()
|
||||
pred_label = pred_label.asnumpy()
|
||||
#print('label', np.count_nonzero(label[0][36]))
|
||||
#print('acc',label.shape, pred_label.shape)
|
||||
#print(label.ndim)
|
||||
|
||||
nme = []
|
||||
for b in xrange(pred_label.shape[0]):
|
||||
record = [None]*6
|
||||
item = []
|
||||
if label.ndim==4:
|
||||
_heatmap = label[b][36]
|
||||
if np.count_nonzero(_heatmap)==0:
|
||||
continue
|
||||
else:#ndim==3
|
||||
#print(label[b])
|
||||
if np.count_nonzero(label[b])==0:
|
||||
continue
|
||||
for p in xrange(pred_label.shape[1]):
|
||||
if label.ndim==4:
|
||||
heatmap_gt = label[b][p]
|
||||
ind_gt = np.unravel_index(np.argmax(heatmap_gt, axis=None), heatmap_gt.shape)
|
||||
ind_gt = np.array(ind_gt)
|
||||
else:
|
||||
ind_gt = label[b][p]
|
||||
#ind_gt = ind_gt.astype(np.int)
|
||||
#print(ind_gt)
|
||||
heatmap_pred = pred_label[b][p]
|
||||
heatmap_pred = cv2.resize(heatmap_pred, (args.input_img_size, args.input_img_size))
|
||||
ind_pred = np.unravel_index(np.argmax(heatmap_pred, axis=None), heatmap_pred.shape)
|
||||
ind_pred = np.array(ind_pred)
|
||||
#print(ind_gt.shape)
|
||||
#print(ind_pred)
|
||||
if p==36:
|
||||
#print('b', b, p, ind_gt, np.count_nonzero(heatmap_gt))
|
||||
record[0] = ind_gt
|
||||
elif p==39:
|
||||
record[1] = ind_gt
|
||||
elif p==42:
|
||||
record[2] = ind_gt
|
||||
elif p==45:
|
||||
record[3] = ind_gt
|
||||
if record[4] is None or record[5] is None:
|
||||
record[4] = ind_gt
|
||||
record[5] = ind_gt
|
||||
else:
|
||||
record[4] = np.minimum(record[4], ind_gt)
|
||||
record[5] = np.maximum(record[5], ind_gt)
|
||||
#print(ind_gt.shape, ind_pred.shape)
|
||||
value = np.sqrt(np.sum(np.square(ind_gt - ind_pred)))
|
||||
item.append(value)
|
||||
_nme = np.mean(item)
|
||||
if args.norm_type=='2d':
|
||||
left_eye = (record[0]+record[1])/2
|
||||
right_eye = (record[2]+record[3])/2
|
||||
_dist = np.sqrt(np.sum(np.square(left_eye - right_eye)))
|
||||
#print('eye dist', _dist, left_eye, right_eye)
|
||||
_nme /= _dist
|
||||
else:
|
||||
#_dist = np.sqrt(float(label.shape[2]*label.shape[3]))
|
||||
_dist = np.sqrt(np.sum(np.square(record[5] - record[4])))
|
||||
#print(_dist)
|
||||
_nme /= _dist
|
||||
nme.append(_nme)
|
||||
#print('nme', nme)
|
||||
#nme = np.mean(nme)
|
||||
|
||||
if len(nme)>0:
|
||||
self.sum_metric += np.mean(nme)
|
||||
self.num_inst += 1.0
|
||||
|
||||
def main(args):
|
||||
_seed = 727
|
||||
random.seed(_seed)
|
||||
np.random.seed(_seed)
|
||||
mx.random.seed(_seed)
|
||||
ctx = []
|
||||
cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
|
||||
if len(cvd)>0:
|
||||
for i in xrange(len(cvd.split(','))):
|
||||
ctx.append(mx.gpu(i))
|
||||
if len(ctx)==0:
|
||||
ctx = [mx.cpu()]
|
||||
print('use cpu')
|
||||
else:
|
||||
print('gpu num:', len(ctx))
|
||||
#ctx = [mx.gpu(0)]
|
||||
args.ctx_num = len(ctx)
|
||||
|
||||
args.batch_size = args.per_batch_size*args.ctx_num
|
||||
|
||||
|
||||
print('Call with', args)
|
||||
train_iter = FaceSegIter(path_imgrec = os.path.join(args.data_dir, 'train.rec'),
|
||||
batch_size = args.batch_size,
|
||||
per_batch_size = args.per_batch_size,
|
||||
aug_level = 1,
|
||||
use_coherent = args.use_coherent,
|
||||
args = args,
|
||||
)
|
||||
targets = ['ibug', 'cofw_testset', '300W', 'AFLW2000-3D']
|
||||
|
||||
data_shape = train_iter.get_data_shape()
|
||||
#label_shape = train_iter.get_label_shape()
|
||||
sym = hg.get_symbol(num_classes=args.num_classes, binarize=args.binarize, label_size=args.output_label_size, input_size=args.input_img_size, use_coherent = args.use_coherent, use_dla = args.use_dla, use_N = args.use_N, use_DCN = args.use_DCN, per_batch_size = args.per_batch_size)
|
||||
if len(args.pretrained)==0:
|
||||
#data_shape_dict = {'data' : (args.per_batch_size,)+data_shape, 'softmax_label' : (args.per_batch_size,)+label_shape}
|
||||
data_shape_dict = train_iter.get_shape_dict()
|
||||
arg_params, aux_params = hg.init_weights(sym, data_shape_dict)
|
||||
else:
|
||||
vec = args.pretrained.split(',')
|
||||
print('loading', vec)
|
||||
_, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
|
||||
#sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
|
||||
|
||||
model = mx.mod.Module(
|
||||
context = ctx,
|
||||
symbol = sym,
|
||||
label_names = train_iter.get_label_names(),
|
||||
)
|
||||
#lr = 1.0e-3
|
||||
#lr = 2.5e-4
|
||||
lr = args.lr
|
||||
#_rescale_grad = 1.0
|
||||
_rescale_grad = 1.0/args.ctx_num
|
||||
#lr = args.lr
|
||||
#opt = optimizer.SGD(learning_rate=lr, momentum=0.9, wd=5.e-4, rescale_grad=_rescale_grad)
|
||||
#opt = optimizer.Adam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
|
||||
opt = optimizer.Nadam(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad, clip_gradient=5.0)
|
||||
#opt = optimizer.RMSProp(learning_rate=lr, wd=args.wd, rescale_grad=_rescale_grad)
|
||||
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
|
||||
_cb = mx.callback.Speedometer(args.batch_size, 10)
|
||||
_metric = LossValueMetric()
|
||||
#_metric2 = AccMetric()
|
||||
#eval_metrics = [_metric, _metric2]
|
||||
eval_metrics = [_metric]
|
||||
|
||||
#lr_steps = [40000,60000,80000]
|
||||
#lr_steps = [12000,18000,22000]
|
||||
if len(args.lr_steps)==0:
|
||||
lr_steps = [16000,24000,30000]
|
||||
#lr_steps = [14000,24000,30000]
|
||||
#lr_steps = [5000,10000]
|
||||
else:
|
||||
lr_steps = [int(x) for x in args.lr_steps.split(',')]
|
||||
_a = 40//args.batch_size
|
||||
for i in xrange(len(lr_steps)):
|
||||
lr_steps[i] *= _a
|
||||
print('lr-steps', lr_steps)
|
||||
global_step = [0]
|
||||
|
||||
def val_test():
|
||||
all_layers = sym.get_internals()
|
||||
vsym = all_layers['heatmap_output']
|
||||
vmodel = mx.mod.Module(symbol=vsym, context=ctx, label_names = None)
|
||||
#model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
|
||||
vmodel.bind(data_shapes=[('data', (args.batch_size,)+data_shape)])
|
||||
arg_params, aux_params = model.get_params()
|
||||
vmodel.set_params(arg_params, aux_params)
|
||||
for target in targets:
|
||||
_file = os.path.join(args.data_dir, '%s.rec'%target)
|
||||
if not os.path.exists(_file):
|
||||
continue
|
||||
val_iter = FaceSegIter(path_imgrec = _file,
|
||||
batch_size = args.batch_size,
|
||||
#batch_size = 4,
|
||||
aug_level = 0,
|
||||
args = args,
|
||||
)
|
||||
_metric = NMEMetric2()
|
||||
val_metric = mx.metric.create(_metric)
|
||||
val_metric.reset()
|
||||
val_iter.reset()
|
||||
for i, eval_batch in enumerate(val_iter):
|
||||
#print(eval_batch.data[0].shape, eval_batch.label[0].shape)
|
||||
batch_data = mx.io.DataBatch(eval_batch.data)
|
||||
model.forward(batch_data, is_train=False)
|
||||
model.update_metric(val_metric, eval_batch.label)
|
||||
nme_value = val_metric.get_name_value()[0][1]
|
||||
print('[%d][%s]NME: %f'%(global_step[0], target, nme_value))
|
||||
|
||||
def _batch_callback(param):
|
||||
_cb(param)
|
||||
global_step[0]+=1
|
||||
mbatch = global_step[0]
|
||||
for _lr in lr_steps:
|
||||
if mbatch==_lr:
|
||||
opt.lr *= 0.2
|
||||
print('lr change to', opt.lr)
|
||||
break
|
||||
if mbatch%1000==0:
|
||||
print('lr-batch-epoch:',opt.lr,param.nbatch,param.epoch)
|
||||
if mbatch>0 and mbatch%args.verbose==0:
|
||||
val_test()
|
||||
if args.ckpt>0:
|
||||
msave = mbatch//args.verbose
|
||||
print('saving', msave)
|
||||
arg, aux = model.get_params()
|
||||
mx.model.save_checkpoint(args.prefix, msave, model.symbol, arg, aux)
|
||||
|
||||
|
||||
model.fit(train_iter,
|
||||
begin_epoch = 0,
|
||||
num_epoch = args.end_epoch,
|
||||
#eval_data = val_iter,
|
||||
eval_data = None,
|
||||
eval_metric = eval_metrics,
|
||||
kvstore = 'device',
|
||||
optimizer = opt,
|
||||
initializer = initializer,
|
||||
arg_params = arg_params,
|
||||
aux_params = aux_params,
|
||||
allow_missing = True,
|
||||
batch_end_callback = _batch_callback,
|
||||
epoch_end_callback = None,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Train face 3d')
|
||||
# general
|
||||
parser.add_argument('--data-dir', default='./data', help='')
|
||||
parser.add_argument('--prefix', default='./models/test',
|
||||
help='directory to save model.')
|
||||
parser.add_argument('--pretrained', default='',
|
||||
help='')
|
||||
parser.add_argument('--lr-steps', default='', type=str, help='')
|
||||
parser.add_argument('--verbose', type=int, default=2000, help='')
|
||||
parser.add_argument('--retrain', action='store_true', default=False,
|
||||
help='true means continue training.')
|
||||
parser.add_argument('--binarize', action='store_true', default=False, help='')
|
||||
parser.add_argument('--end-epoch', type=int, default=87,
|
||||
help='training epoch size.')
|
||||
parser.add_argument('--per-batch-size', type=int, default=20, help='')
|
||||
parser.add_argument('--num-classes', type=int, default=68, help='')
|
||||
parser.add_argument('--input-img-size', type=int, default=128, help='')
|
||||
parser.add_argument('--output-label-size', type=int, default=64, help='')
|
||||
parser.add_argument('--lr', type=float, default=2.5e-4, help='')
|
||||
parser.add_argument('--wd', type=float, default=5e-4, help='')
|
||||
parser.add_argument('--ckpt', type=int, default=1, help='')
|
||||
parser.add_argument('--norm-type', type=str, default='2d', help='')
|
||||
parser.add_argument('--use-coherent', type=int, default=1, help='')
|
||||
parser.add_argument('--use-dla', type=int, default=1, help='')
|
||||
parser.add_argument('--use-N', type=int, default=3, help='')
|
||||
parser.add_argument('--use-DCN', type=int, default=2, help='')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user