refine test for stacked dense unet

This commit is contained in:
Jia Guo
2019-01-08 22:52:28 +08:00
parent ebc02d5391
commit f0bf4e6cc9
5 changed files with 44 additions and 242 deletions

108
alignment/test.py Executable file → Normal file
View File

@@ -1,71 +1,51 @@
import argparse
import cv2
import sys
import numpy as np
import datetime
from alignment import Alignment
sys.path.append('../SSH')
from ssh_detector import SSHDetector
import os
import mxnet as mx
#short_max = 800
scales = [1200, 1600]
t = 2
detector = SSHDetector('../SSH/model/e2ef', 0)
alignment = Alignment('./model/3d_I5', 12)
out_filename = './out.png'
class Handler:
def __init__(self, prefix, epoch, ctx_id=0):
print('loading',prefix, epoch)
ctx = mx.gpu(ctx_id)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = sym.get_internals()
sym = all_layers['heatmap_output']
image_size = (128, 128)
self.image_size = image_size
model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
#model = mx.mod.Module(symbol=sym, context=ctx)
model.bind(for_training=False, data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
model.set_params(arg_params, aux_params)
self.model = model
def get(self, img):
rimg = cv2.resize(img, (self.image_size[1], self.image_size[0]))
img = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
img = np.transpose(img, (2,0,1)) #3*112*112, RGB
input_blob = np.zeros( (1, 3, self.image_size[1], self.image_size[0]),dtype=np.uint8 )
input_blob[0] = img
data = mx.nd.array(input_blob)
db = mx.io.DataBatch(data=(data,))
self.model.forward(db, is_train=False)
alabel = self.model.get_outputs()[-1].asnumpy()[0]
ret = np.zeros( (alabel.shape[0], 2), dtype=np.float32)
for i in xrange(alabel.shape[0]):
a = cv2.resize(alabel[i], (self.image_size[1], self.image_size[0]))
ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
#ret[i] = (ind[0], ind[1]) #h, w
ret[i] = (ind[1], ind[0]) #w, h
return ret
ctx_id = 0
img_path = './test.png'
img = cv2.imread(img_path)
handler = Handler('./model/SDU', 1, ctx_id)
landmark = handler.get(img)
#visualize landmark
f = '../sample-images/t1.jpg'
if len(sys.argv)>1:
f = sys.argv[1]
img = cv2.imread(f)
im_shape = img.shape
print(im_shape)
target_size = scales[0]
max_size = scales[1]
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
if im_size_min>target_size or im_size_max>max_size:
im_scale = float(target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size:
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
print('resize to', img.shape)
for i in xrange(t-1): #warmup
faces = detector.detect(img, 0.5)
timea = datetime.datetime.now()
faces = detector.detect(img, 0.5)
timeb = datetime.datetime.now()
diff = timeb - timea
print('detection uses', diff.total_seconds(), 'seconds')
print('find', faces.shape[0], 'faces')
for face in faces:
#print(face)
cv2.rectangle(img, (face[0], face[1]), (face[2], face[3]), (255, 0, 0), 1)
w = face[2] - face[0]
h = face[3] - face[1]
wc = int( (face[2]+face[0])/2 )
hc = int( (face[3]+face[1])/2 )
size = int(max(w, h)*1.3)
scale = 100.0/max(w,h)
M = [
[scale, 0, 64-wc*scale],
[0, scale, 64-hc*scale],
]
M = np.array(M)
IM = cv2.invertAffineTransform(M)
#print(M, IM)
ebox = cv2.warpAffine(img, M, (128, 128))
#ebox = cv2.getRectSubPix(img, (size, size), (wc, hc))
landmark = alignment.get(ebox)
#print(landmark.shape)
for l in range(landmark.shape[0]):
point = np.ones( (3,), dtype=np.float32)
point[0:2] = landmark[l]
point = np.dot(IM, point)
pp = (int(point[0]), int(point[1]))
#print(pp)
cv2.circle(img, (pp[0], pp[1]), 1, (0, 0, 255), 1)
print('write to', out_filename)
cv2.imwrite(out_filename, img)