Add support to region of interest.

This commit is contained in:
Emad Barsoum
2017-02-20 00:23:39 -08:00
parent b3290259a9
commit 888f98d630
4 changed files with 248 additions and 13 deletions

View File

@@ -12,6 +12,7 @@ import random as rnd
from collections import namedtuple
from PIL import Image
from rect_util import Rect
import img_util as imgu
import matplotlib.pyplot as plt
@@ -124,7 +125,15 @@ class FERPlusReader(object):
targets = np.empty(shape=(current_batch_size, self.emotion_count), dtype=np.float32)
for idx in range(self.batch_start, batch_end):
index = self.indices[idx]
distorted_image = imgu.distort_img(self.data[index][1], self.width, self.height, self.max_shift, self.max_scale, self.max_angle, self.max_skew, self.do_flip)
distorted_image = imgu.distort_img(self.data[index][1],
self.data[index][3],
self.width,
self.height,
self.max_shift,
self.max_scale,
self.max_angle,
self.max_skew,
self.do_flip)
final_image = imgu.preproc_img(distorted_image, A=self.A, A_pinv=self.A_pinv)
inputs[idx-self.batch_start] = final_image
@@ -152,14 +161,18 @@ class FERPlusReader(object):
image_path = os.path.join(folder_path, row[0])
image_data = Image.open(image_path)
image_data.load()
# face rectangle
box = list(map(int, row[1][1:-1].split(',')))
face_rc = Rect(box)
emotion_raw = list(map(float, row[2:len(row)]))
emotion = self._process_data(emotion_raw, mode)
idx = np.argmax(emotion)
if idx < self.emotion_count: # not unknown or non-face
emotion = emotion[:-2]
emotion = [float(i)/sum(emotion) for i in emotion]
self.data.append((image_path, image_data, emotion))
self.data.append((image_path, image_data, emotion, face_rc))
self.per_emotion_count[idx] += 1
self.indices = np.arange(len(self.data))

View File

@@ -7,6 +7,7 @@ import numpy as np
import random as rnd
from PIL import Image
from scipy import ndimage
from rect_util import Rect
def compute_norm_mat(base_width, base_height):
# normalization matrix used in image pre-processing
@@ -39,7 +40,7 @@ def preproc_img(img, A, A_pinv):
diff = diff/std
return diff.reshape(img.shape)
def distort_img(img, out_width, out_height, max_shift, max_scale, max_angle, max_skew, flip=True):
def distort_img(img, roi, out_width, out_height, max_shift, max_scale, max_angle, max_skew, flip=True):
shift_y = out_height*max_shift*rnd.uniform(-1.0,1.0)
shift_x = out_width*max_shift*rnd.uniform(-1.0,1.0)
@@ -57,21 +58,19 @@ def distort_img(img, out_width, out_height, max_shift, max_scale, max_angle, max
scale_x = rnd.uniform(1.0, max_scale)
if rnd.choice([True, False]):
scale_x = 1.0/scale_x
T_im = crop_img(img, out_width, out_height, shift_x, shift_y, scale_x, scale_y, angle, sk_x, sk_y)
T_im = crop_img(img, roi, out_width, out_height, shift_x, shift_y, scale_x, scale_y, angle, sk_x, sk_y)
if flip and rnd.choice([True, False]):
T_im = np.fliplr(T_im)
return T_im
def crop_img(img, crop_width, crop_height, shift_x, shift_y, scale_x, scale_y, angle, skew_x, skew_y):
width, height = img.size
center_x = width/2.0
center_y = height/2.0
ctr_in = np.array((center_y, center_x))
def crop_img(img, roi, crop_width, crop_height, shift_x, shift_y, scale_x, scale_y, angle, skew_x, skew_y):
# current face center
ctr_in = np.array((roi.center().y, roi.center().x))
ctr_out = np.array((crop_height/2.0+shift_y, crop_width/2.0+shift_x))
out_shape = (crop_height, crop_width)
s_y = scale_y*(height-1)*1.0/(crop_height-1)
s_x = scale_x*(width-1)*1.0/(crop_width-1)
s_y = scale_y*(roi.height()-1)*1.0/(crop_height-1)
s_x = scale_x*(roi.width()-1)*1.0/(crop_width-1)
# rotation and scale
ang = angle*np.pi/180.0
transform = np.array([[np.cos(ang), -np.sin(ang)], [np.sin(ang), np.cos(ang)]])

222
src/rect_util.py Normal file
View File

@@ -0,0 +1,222 @@
#
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
#
import math
class Point(object):
def __init__(self, x=0.0, y=0.0):
self.x = x
self.y = y
def __add__(self, p):
"""Point(x1+x2, y1+y2)"""
return Point(self.x+p.x, self.y+p.y)
def __sub__(self, p):
"""Point(x1-x2, y1-y2)"""
return Point(self.x-p.x, self.y-p.y)
def __mul__( self, scalar ):
"""Point(x1*x2, y1*y2)"""
return Point(self.x*scalar, self.y*scalar)
def __div__(self, scalar):
"""Point(x1/x2, y1/y2)"""
return Point(self.x/scalar, self.y/scalar)
def __str__(self):
return "(%s, %s)" % (self.x, self.y)
def length(self):
return math.sqrt(self.x**2 + self.y**2)
def distance_to(self, p):
"""Calculate the distance between two points."""
return (self - p).length()
def as_tuple(self):
"""(x, y)"""
return (self.x, self.y)
def clone(self):
"""Return a full copy of this point."""
return Point(self.x, self.y)
def integerize(self):
"""Convert co-ordinate values to integers."""
self.x = int(self.x+0.5)
self.y = int(self.y+0.5)
def floatize(self):
"""Convert co-ordinate values to floats."""
self.x = float(self.x)
self.y = float(self.y)
def reset(self, x, y):
"""Reset x & y coordinates."""
self.x = x
self.y = y
def shift(self, pt):
"""Move to new (x+pt.x,y+pt.y)."""
self.x = self.x + pt.x
self.y = self.y + pt.y
def shift_xy(self, dx, dy):
"""Move to new (x+dx,y+dy)."""
self.x = self.x + dx
self.y = self.y + dy
def rotate(self, rad):
"""Rotate counter-clockwise by rad radians.
Positive y goes *up,* as in traditional mathematics.
The new position is returned as a new Point.
"""
s, c = [f(rad) for f in (math.sin, math.cos)]
x, y = (c*self.x - s*self.y, s*self.x + c*self.y)
return Point(x,y)
def rotate_about(self, p, theta):
"""Rotate counter-clockwise around a point, by theta degrees.
Positive y goes *up,* as in traditional mathematics.
The new position is returned as a new Point.
"""
result = self.clone()
result.shift(-p.x, -p.y)
result.rotate(theta)
result.shift(p.x, p.y)
return result
class Rect(object):
"""The rectangle stores left, top, right, and bottom values.
Coordinates are based on screen coordinates.
origin top
+-----> x increases |
| left -+- right
v |
y increases bottom
"""
def __init__(self, box):
"""Initialize a rectangle from two points."""
self.left = box[0]
self.top = box[1]
self.right = box[2]
self.bottom = box[3]
def as_tuple(self):
"""(left, top, right, bottom)"""
return (self.left, self.top, self.right, self.bottom)
def width(self):
"""Width"""
return (self.right - self.left)
def height(self):
"""Height"""
return (self.bottom - self.top)
def contains(self, pt):
"""Return true if a point is inside the rectangle."""
x,y = pt.as_tuple()
return (self.left <= x <= self.right and
self.top <= y <= self.bottom)
def shift(self, pt):
"""Shift by pt.x and pt.y."""
self.left = self.left + pt.x
self.right = self.right + pt.x
self.top = self.top + pt.y
self.bottom = self.bottom + pt.y
def shift_xy(self, dx, dy):
"""Shift by dx and dy."""
self.left = self.left + dx
self.right = self.right + dx
self.top = self.top + dy
self.bottom = self.bottom + dy
def equal(self, other):
"""Return true if a rectangle is identical to this rectangle."""
return (self.right == other.left and self.left == other.right and
self.top == other.bottom and self.bottom == other.top)
def overlaps(self, other):
"""Return true if a rectangle overlaps this rectangle."""
return (self.right > other.left and self.left < other.right and
self.top < other.bottom and self.bottom > other.top)
def intersect(self, other):
"""Return the intersect rectangle.
Note we don't check here whether the intersection is valid
If needed, call overlaps() first to check
"""
return Rect((max(self.left, other.left),
max(self.top, other.top),
min(self.right, other.right),
min(self.bottom, other.bottom)))
def clamp(self, xmin, ymin, xmax, ymax):
"""Return clamped rectangle based on the other rectangle.
Note we don't check here whether the output is valid
If needed, call overlaps() first to check
"""
self.left = max(self.left, xmin)
self.right = min(self.right, xmax)
self.top = max(self.top, ymin)
self.bottom = min(self.bottom, ymax)
def top_left(self):
"""Return the top-left corner as a Point."""
return Point(self.left, self.top)
def bottom_right(self):
"""Return the bottom-right corner as a Point."""
return Point(self.right, self.bottom)
def center(self):
"""Return the center as a Point."""
return Point((self.left+self.right)/2.0, (self.top+self.bottom)/2.0)
def mult(self, xmul, ymul):
"""Return a rectangle with all coordinates multipled by a number."""
return Rect((self.left*xmul, self.top*ymul, self.right*xmul, self.bottom*ymul))
def scale(self, scale):
"""Return a scaled rectangle with identical center."""
xctr = (self.left + self.right)/2.0
yctr = (self.top + self.bottom)/2.0
width = self.width()*scale
height = self.height()*scale
xstart = xctr-width/2.0
ystart = yctr-height/2.0
return Rect((xstart, ystart, xstart+width, ystart+height))
def cocenter(self, new_width, new_height):
"""Return a new rectangle with identical center."""
xctr = (self.left + self.right)/2.0
yctr = (self.top + self.bottom)/2.0
xstart = xctr - new_width/2.0
ystart = yctr - new_height/2.0
return Rect((xstart, ystart, xstart+new_width, ystart+new_height))
def integerize(self):
"""Convert co-ordinate values to integers."""
self.left = int(self.left+0.5)
self.right = int(self.right+0.5)
self.top = int(self.top+0.5)
self.bottom = int(self.bottom+0.5)
def floatize(self):
"""Convert co-ordinate values to floats."""
self.left = float(self.left)
self.right = float(self.right)
self.top = float(self.top)
self.bottom = float(self.bottom)
def __str__( self ):
return "<Rect (%s,%s)-(%s,%s)>" % (self.left,self.top,
self.right,self.bottom)

View File

@@ -29,6 +29,7 @@ emotion_table = {'neutral' : 0,
'fear' : 6,
'contempt' : 7}
# List of folders for training, validation and test.
train_folders = ['FER2013Train']
valid_folders = ['FER2013Valid']
test_folders = ['FER2013Test']