From 888f98d630778d0b2bd98f2d871854ce7677a708 Mon Sep 17 00:00:00 2001 From: Emad Barsoum Date: Mon, 20 Feb 2017 00:23:39 -0800 Subject: [PATCH] Add support to region of interest. --- src/ferplus.py | 19 +++- src/img_util.py | 19 ++-- src/rect_util.py | 222 +++++++++++++++++++++++++++++++++++++++++++++++ src/train.py | 1 + 4 files changed, 248 insertions(+), 13 deletions(-) create mode 100644 src/rect_util.py diff --git a/src/ferplus.py b/src/ferplus.py index 3f45ba2..522253e 100644 --- a/src/ferplus.py +++ b/src/ferplus.py @@ -12,6 +12,7 @@ import random as rnd from collections import namedtuple from PIL import Image +from rect_util import Rect import img_util as imgu import matplotlib.pyplot as plt @@ -124,7 +125,15 @@ class FERPlusReader(object): targets = np.empty(shape=(current_batch_size, self.emotion_count), dtype=np.float32) for idx in range(self.batch_start, batch_end): index = self.indices[idx] - distorted_image = imgu.distort_img(self.data[index][1], self.width, self.height, self.max_shift, self.max_scale, self.max_angle, self.max_skew, self.do_flip) + distorted_image = imgu.distort_img(self.data[index][1], + self.data[index][3], + self.width, + self.height, + self.max_shift, + self.max_scale, + self.max_angle, + self.max_skew, + self.do_flip) final_image = imgu.preproc_img(distorted_image, A=self.A, A_pinv=self.A_pinv) inputs[idx-self.batch_start] = final_image @@ -152,14 +161,18 @@ class FERPlusReader(object): image_path = os.path.join(folder_path, row[0]) image_data = Image.open(image_path) image_data.load() - + + # face rectangle + box = list(map(int, row[1][1:-1].split(','))) + face_rc = Rect(box) + emotion_raw = list(map(float, row[2:len(row)])) emotion = self._process_data(emotion_raw, mode) idx = np.argmax(emotion) if idx < self.emotion_count: # not unknown or non-face emotion = emotion[:-2] emotion = [float(i)/sum(emotion) for i in emotion] - self.data.append((image_path, image_data, emotion)) + self.data.append((image_path, image_data, emotion, face_rc)) self.per_emotion_count[idx] += 1 self.indices = np.arange(len(self.data)) diff --git a/src/img_util.py b/src/img_util.py index b89ace2..0dfdf4b 100644 --- a/src/img_util.py +++ b/src/img_util.py @@ -7,6 +7,7 @@ import numpy as np import random as rnd from PIL import Image from scipy import ndimage +from rect_util import Rect def compute_norm_mat(base_width, base_height): # normalization matrix used in image pre-processing @@ -39,7 +40,7 @@ def preproc_img(img, A, A_pinv): diff = diff/std return diff.reshape(img.shape) -def distort_img(img, out_width, out_height, max_shift, max_scale, max_angle, max_skew, flip=True): +def distort_img(img, roi, out_width, out_height, max_shift, max_scale, max_angle, max_skew, flip=True): shift_y = out_height*max_shift*rnd.uniform(-1.0,1.0) shift_x = out_width*max_shift*rnd.uniform(-1.0,1.0) @@ -57,21 +58,19 @@ def distort_img(img, out_width, out_height, max_shift, max_scale, max_angle, max scale_x = rnd.uniform(1.0, max_scale) if rnd.choice([True, False]): scale_x = 1.0/scale_x - T_im = crop_img(img, out_width, out_height, shift_x, shift_y, scale_x, scale_y, angle, sk_x, sk_y) + T_im = crop_img(img, roi, out_width, out_height, shift_x, shift_y, scale_x, scale_y, angle, sk_x, sk_y) if flip and rnd.choice([True, False]): T_im = np.fliplr(T_im) return T_im -def crop_img(img, crop_width, crop_height, shift_x, shift_y, scale_x, scale_y, angle, skew_x, skew_y): - width, height = img.size - center_x = width/2.0 - center_y = height/2.0 - ctr_in = np.array((center_y, center_x)) +def crop_img(img, roi, crop_width, crop_height, shift_x, shift_y, scale_x, scale_y, angle, skew_x, skew_y): + # current face center + ctr_in = np.array((roi.center().y, roi.center().x)) ctr_out = np.array((crop_height/2.0+shift_y, crop_width/2.0+shift_x)) out_shape = (crop_height, crop_width) - s_y = scale_y*(height-1)*1.0/(crop_height-1) - s_x = scale_x*(width-1)*1.0/(crop_width-1) - + s_y = scale_y*(roi.height()-1)*1.0/(crop_height-1) + s_x = scale_x*(roi.width()-1)*1.0/(crop_width-1) + # rotation and scale ang = angle*np.pi/180.0 transform = np.array([[np.cos(ang), -np.sin(ang)], [np.sin(ang), np.cos(ang)]]) diff --git a/src/rect_util.py b/src/rect_util.py new file mode 100644 index 0000000..f41a381 --- /dev/null +++ b/src/rect_util.py @@ -0,0 +1,222 @@ +# +# Copyright (c) Microsoft. All rights reserved. +# Licensed under the MIT license. See LICENSE.md file in the project root for full license information. +# + +import math + +class Point(object): + def __init__(self, x=0.0, y=0.0): + self.x = x + self.y = y + + def __add__(self, p): + """Point(x1+x2, y1+y2)""" + return Point(self.x+p.x, self.y+p.y) + + def __sub__(self, p): + """Point(x1-x2, y1-y2)""" + return Point(self.x-p.x, self.y-p.y) + + def __mul__( self, scalar ): + """Point(x1*x2, y1*y2)""" + return Point(self.x*scalar, self.y*scalar) + + def __div__(self, scalar): + """Point(x1/x2, y1/y2)""" + return Point(self.x/scalar, self.y/scalar) + + def __str__(self): + return "(%s, %s)" % (self.x, self.y) + + def length(self): + return math.sqrt(self.x**2 + self.y**2) + + def distance_to(self, p): + """Calculate the distance between two points.""" + return (self - p).length() + + def as_tuple(self): + """(x, y)""" + return (self.x, self.y) + + def clone(self): + """Return a full copy of this point.""" + return Point(self.x, self.y) + + def integerize(self): + """Convert co-ordinate values to integers.""" + self.x = int(self.x+0.5) + self.y = int(self.y+0.5) + + def floatize(self): + """Convert co-ordinate values to floats.""" + self.x = float(self.x) + self.y = float(self.y) + + def reset(self, x, y): + """Reset x & y coordinates.""" + self.x = x + self.y = y + + def shift(self, pt): + """Move to new (x+pt.x,y+pt.y).""" + self.x = self.x + pt.x + self.y = self.y + pt.y + + def shift_xy(self, dx, dy): + """Move to new (x+dx,y+dy).""" + self.x = self.x + dx + self.y = self.y + dy + + def rotate(self, rad): + """Rotate counter-clockwise by rad radians. + Positive y goes *up,* as in traditional mathematics. + The new position is returned as a new Point. + """ + s, c = [f(rad) for f in (math.sin, math.cos)] + x, y = (c*self.x - s*self.y, s*self.x + c*self.y) + return Point(x,y) + + def rotate_about(self, p, theta): + """Rotate counter-clockwise around a point, by theta degrees. + Positive y goes *up,* as in traditional mathematics. + The new position is returned as a new Point. + """ + result = self.clone() + result.shift(-p.x, -p.y) + result.rotate(theta) + result.shift(p.x, p.y) + return result + +class Rect(object): + """The rectangle stores left, top, right, and bottom values. + Coordinates are based on screen coordinates. + origin top + +-----> x increases | + | left -+- right + v | + y increases bottom + """ + + def __init__(self, box): + """Initialize a rectangle from two points.""" + self.left = box[0] + self.top = box[1] + self.right = box[2] + self.bottom = box[3] + + def as_tuple(self): + """(left, top, right, bottom)""" + return (self.left, self.top, self.right, self.bottom) + + def width(self): + """Width""" + return (self.right - self.left) + + def height(self): + """Height""" + return (self.bottom - self.top) + + def contains(self, pt): + """Return true if a point is inside the rectangle.""" + x,y = pt.as_tuple() + return (self.left <= x <= self.right and + self.top <= y <= self.bottom) + + def shift(self, pt): + """Shift by pt.x and pt.y.""" + self.left = self.left + pt.x + self.right = self.right + pt.x + self.top = self.top + pt.y + self.bottom = self.bottom + pt.y + + def shift_xy(self, dx, dy): + """Shift by dx and dy.""" + self.left = self.left + dx + self.right = self.right + dx + self.top = self.top + dy + self.bottom = self.bottom + dy + + def equal(self, other): + """Return true if a rectangle is identical to this rectangle.""" + return (self.right == other.left and self.left == other.right and + self.top == other.bottom and self.bottom == other.top) + + def overlaps(self, other): + """Return true if a rectangle overlaps this rectangle.""" + return (self.right > other.left and self.left < other.right and + self.top < other.bottom and self.bottom > other.top) + + def intersect(self, other): + """Return the intersect rectangle. + Note we don't check here whether the intersection is valid + If needed, call overlaps() first to check + """ + return Rect((max(self.left, other.left), + max(self.top, other.top), + min(self.right, other.right), + min(self.bottom, other.bottom))) + + def clamp(self, xmin, ymin, xmax, ymax): + """Return clamped rectangle based on the other rectangle. + Note we don't check here whether the output is valid + If needed, call overlaps() first to check + """ + self.left = max(self.left, xmin) + self.right = min(self.right, xmax) + self.top = max(self.top, ymin) + self.bottom = min(self.bottom, ymax) + + def top_left(self): + """Return the top-left corner as a Point.""" + return Point(self.left, self.top) + + def bottom_right(self): + """Return the bottom-right corner as a Point.""" + return Point(self.right, self.bottom) + + def center(self): + """Return the center as a Point.""" + return Point((self.left+self.right)/2.0, (self.top+self.bottom)/2.0) + + def mult(self, xmul, ymul): + """Return a rectangle with all coordinates multipled by a number.""" + return Rect((self.left*xmul, self.top*ymul, self.right*xmul, self.bottom*ymul)) + + def scale(self, scale): + """Return a scaled rectangle with identical center.""" + xctr = (self.left + self.right)/2.0 + yctr = (self.top + self.bottom)/2.0 + width = self.width()*scale + height = self.height()*scale + xstart = xctr-width/2.0 + ystart = yctr-height/2.0 + return Rect((xstart, ystart, xstart+width, ystart+height)) + + def cocenter(self, new_width, new_height): + """Return a new rectangle with identical center.""" + xctr = (self.left + self.right)/2.0 + yctr = (self.top + self.bottom)/2.0 + xstart = xctr - new_width/2.0 + ystart = yctr - new_height/2.0 + return Rect((xstart, ystart, xstart+new_width, ystart+new_height)) + + def integerize(self): + """Convert co-ordinate values to integers.""" + self.left = int(self.left+0.5) + self.right = int(self.right+0.5) + self.top = int(self.top+0.5) + self.bottom = int(self.bottom+0.5) + + def floatize(self): + """Convert co-ordinate values to floats.""" + self.left = float(self.left) + self.right = float(self.right) + self.top = float(self.top) + self.bottom = float(self.bottom) + + def __str__( self ): + return "" % (self.left,self.top, + self.right,self.bottom) + \ No newline at end of file diff --git a/src/train.py b/src/train.py index f53d0d8..d18dd18 100644 --- a/src/train.py +++ b/src/train.py @@ -29,6 +29,7 @@ emotion_table = {'neutral' : 0, 'fear' : 6, 'contempt' : 7} +# List of folders for training, validation and test. train_folders = ['FER2013Train'] valid_folders = ['FER2013Valid'] test_folders = ['FER2013Test']