mirror of
https://gitcode.com/gh_mirrors/fe/FERPlus.git
synced 2025-12-30 05:22:26 +00:00
Add support to region of interest.
This commit is contained in:
@@ -12,6 +12,7 @@ import random as rnd
|
||||
from collections import namedtuple
|
||||
|
||||
from PIL import Image
|
||||
from rect_util import Rect
|
||||
import img_util as imgu
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -124,7 +125,15 @@ class FERPlusReader(object):
|
||||
targets = np.empty(shape=(current_batch_size, self.emotion_count), dtype=np.float32)
|
||||
for idx in range(self.batch_start, batch_end):
|
||||
index = self.indices[idx]
|
||||
distorted_image = imgu.distort_img(self.data[index][1], self.width, self.height, self.max_shift, self.max_scale, self.max_angle, self.max_skew, self.do_flip)
|
||||
distorted_image = imgu.distort_img(self.data[index][1],
|
||||
self.data[index][3],
|
||||
self.width,
|
||||
self.height,
|
||||
self.max_shift,
|
||||
self.max_scale,
|
||||
self.max_angle,
|
||||
self.max_skew,
|
||||
self.do_flip)
|
||||
final_image = imgu.preproc_img(distorted_image, A=self.A, A_pinv=self.A_pinv)
|
||||
|
||||
inputs[idx-self.batch_start] = final_image
|
||||
@@ -152,14 +161,18 @@ class FERPlusReader(object):
|
||||
image_path = os.path.join(folder_path, row[0])
|
||||
image_data = Image.open(image_path)
|
||||
image_data.load()
|
||||
|
||||
|
||||
# face rectangle
|
||||
box = list(map(int, row[1][1:-1].split(',')))
|
||||
face_rc = Rect(box)
|
||||
|
||||
emotion_raw = list(map(float, row[2:len(row)]))
|
||||
emotion = self._process_data(emotion_raw, mode)
|
||||
idx = np.argmax(emotion)
|
||||
if idx < self.emotion_count: # not unknown or non-face
|
||||
emotion = emotion[:-2]
|
||||
emotion = [float(i)/sum(emotion) for i in emotion]
|
||||
self.data.append((image_path, image_data, emotion))
|
||||
self.data.append((image_path, image_data, emotion, face_rc))
|
||||
self.per_emotion_count[idx] += 1
|
||||
|
||||
self.indices = np.arange(len(self.data))
|
||||
|
||||
@@ -7,6 +7,7 @@ import numpy as np
|
||||
import random as rnd
|
||||
from PIL import Image
|
||||
from scipy import ndimage
|
||||
from rect_util import Rect
|
||||
|
||||
def compute_norm_mat(base_width, base_height):
|
||||
# normalization matrix used in image pre-processing
|
||||
@@ -39,7 +40,7 @@ def preproc_img(img, A, A_pinv):
|
||||
diff = diff/std
|
||||
return diff.reshape(img.shape)
|
||||
|
||||
def distort_img(img, out_width, out_height, max_shift, max_scale, max_angle, max_skew, flip=True):
|
||||
def distort_img(img, roi, out_width, out_height, max_shift, max_scale, max_angle, max_skew, flip=True):
|
||||
shift_y = out_height*max_shift*rnd.uniform(-1.0,1.0)
|
||||
shift_x = out_width*max_shift*rnd.uniform(-1.0,1.0)
|
||||
|
||||
@@ -57,21 +58,19 @@ def distort_img(img, out_width, out_height, max_shift, max_scale, max_angle, max
|
||||
scale_x = rnd.uniform(1.0, max_scale)
|
||||
if rnd.choice([True, False]):
|
||||
scale_x = 1.0/scale_x
|
||||
T_im = crop_img(img, out_width, out_height, shift_x, shift_y, scale_x, scale_y, angle, sk_x, sk_y)
|
||||
T_im = crop_img(img, roi, out_width, out_height, shift_x, shift_y, scale_x, scale_y, angle, sk_x, sk_y)
|
||||
if flip and rnd.choice([True, False]):
|
||||
T_im = np.fliplr(T_im)
|
||||
return T_im
|
||||
|
||||
def crop_img(img, crop_width, crop_height, shift_x, shift_y, scale_x, scale_y, angle, skew_x, skew_y):
|
||||
width, height = img.size
|
||||
center_x = width/2.0
|
||||
center_y = height/2.0
|
||||
ctr_in = np.array((center_y, center_x))
|
||||
def crop_img(img, roi, crop_width, crop_height, shift_x, shift_y, scale_x, scale_y, angle, skew_x, skew_y):
|
||||
# current face center
|
||||
ctr_in = np.array((roi.center().y, roi.center().x))
|
||||
ctr_out = np.array((crop_height/2.0+shift_y, crop_width/2.0+shift_x))
|
||||
out_shape = (crop_height, crop_width)
|
||||
s_y = scale_y*(height-1)*1.0/(crop_height-1)
|
||||
s_x = scale_x*(width-1)*1.0/(crop_width-1)
|
||||
|
||||
s_y = scale_y*(roi.height()-1)*1.0/(crop_height-1)
|
||||
s_x = scale_x*(roi.width()-1)*1.0/(crop_width-1)
|
||||
|
||||
# rotation and scale
|
||||
ang = angle*np.pi/180.0
|
||||
transform = np.array([[np.cos(ang), -np.sin(ang)], [np.sin(ang), np.cos(ang)]])
|
||||
|
||||
222
src/rect_util.py
Normal file
222
src/rect_util.py
Normal file
@@ -0,0 +1,222 @@
|
||||
#
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
|
||||
#
|
||||
|
||||
import math
|
||||
|
||||
class Point(object):
|
||||
def __init__(self, x=0.0, y=0.0):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __add__(self, p):
|
||||
"""Point(x1+x2, y1+y2)"""
|
||||
return Point(self.x+p.x, self.y+p.y)
|
||||
|
||||
def __sub__(self, p):
|
||||
"""Point(x1-x2, y1-y2)"""
|
||||
return Point(self.x-p.x, self.y-p.y)
|
||||
|
||||
def __mul__( self, scalar ):
|
||||
"""Point(x1*x2, y1*y2)"""
|
||||
return Point(self.x*scalar, self.y*scalar)
|
||||
|
||||
def __div__(self, scalar):
|
||||
"""Point(x1/x2, y1/y2)"""
|
||||
return Point(self.x/scalar, self.y/scalar)
|
||||
|
||||
def __str__(self):
|
||||
return "(%s, %s)" % (self.x, self.y)
|
||||
|
||||
def length(self):
|
||||
return math.sqrt(self.x**2 + self.y**2)
|
||||
|
||||
def distance_to(self, p):
|
||||
"""Calculate the distance between two points."""
|
||||
return (self - p).length()
|
||||
|
||||
def as_tuple(self):
|
||||
"""(x, y)"""
|
||||
return (self.x, self.y)
|
||||
|
||||
def clone(self):
|
||||
"""Return a full copy of this point."""
|
||||
return Point(self.x, self.y)
|
||||
|
||||
def integerize(self):
|
||||
"""Convert co-ordinate values to integers."""
|
||||
self.x = int(self.x+0.5)
|
||||
self.y = int(self.y+0.5)
|
||||
|
||||
def floatize(self):
|
||||
"""Convert co-ordinate values to floats."""
|
||||
self.x = float(self.x)
|
||||
self.y = float(self.y)
|
||||
|
||||
def reset(self, x, y):
|
||||
"""Reset x & y coordinates."""
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def shift(self, pt):
|
||||
"""Move to new (x+pt.x,y+pt.y)."""
|
||||
self.x = self.x + pt.x
|
||||
self.y = self.y + pt.y
|
||||
|
||||
def shift_xy(self, dx, dy):
|
||||
"""Move to new (x+dx,y+dy)."""
|
||||
self.x = self.x + dx
|
||||
self.y = self.y + dy
|
||||
|
||||
def rotate(self, rad):
|
||||
"""Rotate counter-clockwise by rad radians.
|
||||
Positive y goes *up,* as in traditional mathematics.
|
||||
The new position is returned as a new Point.
|
||||
"""
|
||||
s, c = [f(rad) for f in (math.sin, math.cos)]
|
||||
x, y = (c*self.x - s*self.y, s*self.x + c*self.y)
|
||||
return Point(x,y)
|
||||
|
||||
def rotate_about(self, p, theta):
|
||||
"""Rotate counter-clockwise around a point, by theta degrees.
|
||||
Positive y goes *up,* as in traditional mathematics.
|
||||
The new position is returned as a new Point.
|
||||
"""
|
||||
result = self.clone()
|
||||
result.shift(-p.x, -p.y)
|
||||
result.rotate(theta)
|
||||
result.shift(p.x, p.y)
|
||||
return result
|
||||
|
||||
class Rect(object):
|
||||
"""The rectangle stores left, top, right, and bottom values.
|
||||
Coordinates are based on screen coordinates.
|
||||
origin top
|
||||
+-----> x increases |
|
||||
| left -+- right
|
||||
v |
|
||||
y increases bottom
|
||||
"""
|
||||
|
||||
def __init__(self, box):
|
||||
"""Initialize a rectangle from two points."""
|
||||
self.left = box[0]
|
||||
self.top = box[1]
|
||||
self.right = box[2]
|
||||
self.bottom = box[3]
|
||||
|
||||
def as_tuple(self):
|
||||
"""(left, top, right, bottom)"""
|
||||
return (self.left, self.top, self.right, self.bottom)
|
||||
|
||||
def width(self):
|
||||
"""Width"""
|
||||
return (self.right - self.left)
|
||||
|
||||
def height(self):
|
||||
"""Height"""
|
||||
return (self.bottom - self.top)
|
||||
|
||||
def contains(self, pt):
|
||||
"""Return true if a point is inside the rectangle."""
|
||||
x,y = pt.as_tuple()
|
||||
return (self.left <= x <= self.right and
|
||||
self.top <= y <= self.bottom)
|
||||
|
||||
def shift(self, pt):
|
||||
"""Shift by pt.x and pt.y."""
|
||||
self.left = self.left + pt.x
|
||||
self.right = self.right + pt.x
|
||||
self.top = self.top + pt.y
|
||||
self.bottom = self.bottom + pt.y
|
||||
|
||||
def shift_xy(self, dx, dy):
|
||||
"""Shift by dx and dy."""
|
||||
self.left = self.left + dx
|
||||
self.right = self.right + dx
|
||||
self.top = self.top + dy
|
||||
self.bottom = self.bottom + dy
|
||||
|
||||
def equal(self, other):
|
||||
"""Return true if a rectangle is identical to this rectangle."""
|
||||
return (self.right == other.left and self.left == other.right and
|
||||
self.top == other.bottom and self.bottom == other.top)
|
||||
|
||||
def overlaps(self, other):
|
||||
"""Return true if a rectangle overlaps this rectangle."""
|
||||
return (self.right > other.left and self.left < other.right and
|
||||
self.top < other.bottom and self.bottom > other.top)
|
||||
|
||||
def intersect(self, other):
|
||||
"""Return the intersect rectangle.
|
||||
Note we don't check here whether the intersection is valid
|
||||
If needed, call overlaps() first to check
|
||||
"""
|
||||
return Rect((max(self.left, other.left),
|
||||
max(self.top, other.top),
|
||||
min(self.right, other.right),
|
||||
min(self.bottom, other.bottom)))
|
||||
|
||||
def clamp(self, xmin, ymin, xmax, ymax):
|
||||
"""Return clamped rectangle based on the other rectangle.
|
||||
Note we don't check here whether the output is valid
|
||||
If needed, call overlaps() first to check
|
||||
"""
|
||||
self.left = max(self.left, xmin)
|
||||
self.right = min(self.right, xmax)
|
||||
self.top = max(self.top, ymin)
|
||||
self.bottom = min(self.bottom, ymax)
|
||||
|
||||
def top_left(self):
|
||||
"""Return the top-left corner as a Point."""
|
||||
return Point(self.left, self.top)
|
||||
|
||||
def bottom_right(self):
|
||||
"""Return the bottom-right corner as a Point."""
|
||||
return Point(self.right, self.bottom)
|
||||
|
||||
def center(self):
|
||||
"""Return the center as a Point."""
|
||||
return Point((self.left+self.right)/2.0, (self.top+self.bottom)/2.0)
|
||||
|
||||
def mult(self, xmul, ymul):
|
||||
"""Return a rectangle with all coordinates multipled by a number."""
|
||||
return Rect((self.left*xmul, self.top*ymul, self.right*xmul, self.bottom*ymul))
|
||||
|
||||
def scale(self, scale):
|
||||
"""Return a scaled rectangle with identical center."""
|
||||
xctr = (self.left + self.right)/2.0
|
||||
yctr = (self.top + self.bottom)/2.0
|
||||
width = self.width()*scale
|
||||
height = self.height()*scale
|
||||
xstart = xctr-width/2.0
|
||||
ystart = yctr-height/2.0
|
||||
return Rect((xstart, ystart, xstart+width, ystart+height))
|
||||
|
||||
def cocenter(self, new_width, new_height):
|
||||
"""Return a new rectangle with identical center."""
|
||||
xctr = (self.left + self.right)/2.0
|
||||
yctr = (self.top + self.bottom)/2.0
|
||||
xstart = xctr - new_width/2.0
|
||||
ystart = yctr - new_height/2.0
|
||||
return Rect((xstart, ystart, xstart+new_width, ystart+new_height))
|
||||
|
||||
def integerize(self):
|
||||
"""Convert co-ordinate values to integers."""
|
||||
self.left = int(self.left+0.5)
|
||||
self.right = int(self.right+0.5)
|
||||
self.top = int(self.top+0.5)
|
||||
self.bottom = int(self.bottom+0.5)
|
||||
|
||||
def floatize(self):
|
||||
"""Convert co-ordinate values to floats."""
|
||||
self.left = float(self.left)
|
||||
self.right = float(self.right)
|
||||
self.top = float(self.top)
|
||||
self.bottom = float(self.bottom)
|
||||
|
||||
def __str__( self ):
|
||||
return "<Rect (%s,%s)-(%s,%s)>" % (self.left,self.top,
|
||||
self.right,self.bottom)
|
||||
|
||||
@@ -29,6 +29,7 @@ emotion_table = {'neutral' : 0,
|
||||
'fear' : 6,
|
||||
'contempt' : 7}
|
||||
|
||||
# List of folders for training, validation and test.
|
||||
train_folders = ['FER2013Train']
|
||||
valid_folders = ['FER2013Valid']
|
||||
test_folders = ['FER2013Test']
|
||||
|
||||
Reference in New Issue
Block a user