mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
349 lines
12 KiB
Python
349 lines
12 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import random
|
|
import logging
|
|
import sys
|
|
import numbers
|
|
import math
|
|
import datetime
|
|
import numpy as np
|
|
import cv2
|
|
|
|
import mxnet as mx
|
|
from mxnet import ndarray as nd
|
|
from mxnet import io
|
|
from mxnet import recordio
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class FaceImageIter(io.DataIter):
|
|
def __init__(self,
|
|
batch_size,
|
|
data_shape,
|
|
path_imgrec=None,
|
|
shuffle=False,
|
|
aug_list=None,
|
|
mean=None,
|
|
rand_mirror=False,
|
|
cutoff=0,
|
|
color_jittering=0,
|
|
images_filter=0,
|
|
data_name='data',
|
|
label_name='softmax_label',
|
|
context=0,
|
|
context_num=1,
|
|
**kwargs):
|
|
super(FaceImageIter, self).__init__()
|
|
assert path_imgrec
|
|
self.context = context
|
|
self.context_num = context_num
|
|
if path_imgrec:
|
|
logging.info('loading recordio %s...', path_imgrec)
|
|
path_imgidx = path_imgrec[0:-4] + ".idx"
|
|
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec,
|
|
'r')
|
|
s = self.imgrec.read_idx(0)
|
|
header, _ = recordio.unpack(s)
|
|
if header.flag > 0:
|
|
self.header0 = (int(header.label[0]), int(header.label[1]))
|
|
self.imgidx = []
|
|
self.id2range = {}
|
|
self.seq_identity = range(int(header.label[0]),
|
|
int(header.label[1]))
|
|
for identity in self.seq_identity:
|
|
s = self.imgrec.read_idx(identity)
|
|
header, _ = recordio.unpack(s)
|
|
a, b = int(header.label[0]), int(header.label[1])
|
|
count = b - a
|
|
if count < images_filter:
|
|
continue
|
|
self.id2range[identity] = (a, b)
|
|
self.imgidx += range(a, b)
|
|
self_data_lenth = len(self.imgidx)
|
|
else:
|
|
self.imgidx = list(self.imgrec.keys)
|
|
if shuffle:
|
|
self.seq = self.imgidx
|
|
self.oseq = self.imgidx
|
|
else:
|
|
self.seq = None
|
|
|
|
self.mean = mean
|
|
self.nd_mean = None
|
|
self.epoch = 0
|
|
|
|
if self.mean:
|
|
self.mean = np.array(self.mean, dtype=np.float32).reshape(1, 1, 3)
|
|
self.nd_mean = mx.nd.array(self.mean).reshape((1, 1, 3))
|
|
|
|
self.check_data_shape(data_shape)
|
|
self.provide_data = [(data_name, (batch_size, ) + data_shape)]
|
|
self.batch_size = batch_size
|
|
self.data_shape = data_shape
|
|
self.shuffle = shuffle
|
|
self.image_size = '%d,%d' % (data_shape[1], data_shape[2])
|
|
self.rand_mirror = rand_mirror
|
|
|
|
self.cutoff = cutoff
|
|
self.color_jittering = color_jittering
|
|
self.CJA = mx.image.ColorJitterAug(0.125, 0.125, 0.125)
|
|
self.provide_label = [(label_name, (batch_size, ))]
|
|
|
|
self.cur = 0
|
|
self.nbatch = 0
|
|
self.is_init = False
|
|
self.num_samples_per_gpu = int(
|
|
math.floor(len(self.seq) * 1.0 / self.context_num))
|
|
|
|
def reset(self):
|
|
"""Resets the iterator to the beginning of the data."""
|
|
self.epoch += 1
|
|
self.cur = 0
|
|
if self.shuffle:
|
|
random.seed(self.epoch)
|
|
random.shuffle(self.seq)
|
|
if self.seq is None and self.imgrec is not None:
|
|
self.imgrec.reset()
|
|
|
|
def num_samples(self):
|
|
return len(self.seq)
|
|
|
|
def next_sample(self):
|
|
if self.seq is not None:
|
|
while True:
|
|
if self.cur >= self.num_samples_per_gpu:
|
|
raise StopIteration
|
|
idx = self.seq[self.num_samples_per_gpu * self.context +
|
|
self.cur]
|
|
self.cur += 1
|
|
if self.imgrec is not None:
|
|
s = self.imgrec.read_idx(idx)
|
|
header, img = recordio.unpack(s)
|
|
label = header.label
|
|
if not isinstance(label, numbers.Number):
|
|
label = label[0]
|
|
return int(label), img, None, None
|
|
else:
|
|
label, fname, bbox, landmark = self.imglist[idx]
|
|
return label, self.read_image(fname), bbox, landmark
|
|
else:
|
|
s = self.imgrec.read()
|
|
if s is None:
|
|
raise StopIteration
|
|
header, img = recordio.unpack(s)
|
|
return header.label, img, None, None
|
|
|
|
def brightness_aug(self, src, x):
|
|
alpha = 1.0 + random.uniform(-x, x)
|
|
src *= alpha
|
|
return src
|
|
|
|
def contrast_aug(self, src, x):
|
|
alpha = 1.0 + random.uniform(-x, x)
|
|
coef = nd.array([[[0.299, 0.587, 0.114]]])
|
|
gray = src * coef
|
|
gray = (3.0 * (1.0 - alpha) / gray.size) * nd.sum(gray)
|
|
src *= alpha
|
|
src += gray
|
|
return src
|
|
|
|
def saturation_aug(self, src, x):
|
|
alpha = 1.0 + random.uniform(-x, x)
|
|
coef = nd.array([[[0.299, 0.587, 0.114]]])
|
|
gray = src * coef
|
|
gray = nd.sum(gray, axis=2, keepdims=True)
|
|
gray *= (1.0 - alpha)
|
|
src *= alpha
|
|
src += gray
|
|
return src
|
|
|
|
def color_aug(self, img, x):
|
|
return self.CJA(img)
|
|
|
|
def mirror_aug(self, img):
|
|
_rd = random.randint(0, 1)
|
|
if _rd == 1:
|
|
for c in range(img.shape[2]):
|
|
img[:, :, c] = np.fliplr(img[:, :, c])
|
|
return img
|
|
|
|
def compress_aug(self, img):
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
buf = BytesIO()
|
|
img = Image.fromarray(img.asnumpy(), 'RGB')
|
|
q = random.randint(2, 20)
|
|
img.save(buf, format='JPEG', quality=q)
|
|
buf = buf.getvalue()
|
|
img = Image.open(BytesIO(buf))
|
|
return nd.array(np.asarray(img, 'float32'))
|
|
|
|
def next(self):
|
|
if not self.is_init:
|
|
self.reset()
|
|
self.is_init = True
|
|
"""Returns the next batch of data."""
|
|
# print('in next', self.cur, self.labelcur)
|
|
self.nbatch += 1
|
|
|
|
batch_size = self.batch_size
|
|
c, h, w = self.data_shape
|
|
batch_data = nd.empty((batch_size, c, h, w))
|
|
if self.provide_label is not None:
|
|
batch_label = nd.empty(self.provide_label[0][1])
|
|
i = 0
|
|
try:
|
|
while i < batch_size:
|
|
label, s, bbox, landmark = self.next_sample()
|
|
_data = self.imdecode(s)
|
|
if _data.shape[0] != self.data_shape[1]:
|
|
_data = mx.image.resize_short(_data, self.data_shape[1])
|
|
if self.rand_mirror:
|
|
_rd = random.randint(0, 1)
|
|
if _rd == 1:
|
|
_data = mx.ndarray.flip(data=_data, axis=1)
|
|
if self.color_jittering > 0:
|
|
if self.color_jittering > 1:
|
|
_rd = random.randint(0, 1)
|
|
if _rd == 1:
|
|
_data = self.compress_aug(_data)
|
|
# print('do color aug')
|
|
_data = _data.astype('float32', copy=False)
|
|
# print(_data.__class__)
|
|
_data = self.color_aug(_data, 0.125)
|
|
if self.nd_mean is not None:
|
|
_data = _data.astype('float32', copy=False)
|
|
_data -= self.nd_mean
|
|
_data *= 0.0078125
|
|
if self.cutoff > 0:
|
|
_rd = random.randint(0, 1)
|
|
if _rd == 1:
|
|
# print('do cutoff aug', self.cutoff)
|
|
centerh = random.randint(0, _data.shape[0] - 1)
|
|
centerw = random.randint(0, _data.shape[1] - 1)
|
|
half = self.cutoff // 2
|
|
starth = max(0, centerh - half)
|
|
endh = min(_data.shape[0], centerh + half)
|
|
startw = max(0, centerw - half)
|
|
endw = min(_data.shape[1], centerw + half)
|
|
# print(starth, endh, startw, endw, _data.shape)
|
|
_data[starth:endh, startw:endw, :] = 128
|
|
data = [_data]
|
|
try:
|
|
self.check_valid_image(data)
|
|
except RuntimeError as e:
|
|
logging.debug('Invalid image, skipping: %s', str(e))
|
|
continue
|
|
# print('aa',data[0].shape)
|
|
# data = self.augmentation_transform(data)
|
|
# print('bb',data[0].shape)
|
|
for datum in data:
|
|
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
|
|
# print(datum.shape)
|
|
batch_data[i][:] = self.postprocess_data(datum)
|
|
batch_label[i][:] = label
|
|
i += 1
|
|
except StopIteration:
|
|
if i < batch_size:
|
|
raise StopIteration
|
|
|
|
return io.DataBatch([batch_data], [batch_label], batch_size - i)
|
|
|
|
def check_data_shape(self, data_shape):
|
|
"""Checks if the input data shape is valid"""
|
|
if not len(data_shape) == 3:
|
|
raise ValueError(
|
|
'data_shape should have length 3, with dimensions CxHxW')
|
|
if not data_shape[0] == 3:
|
|
raise ValueError(
|
|
'This iterator expects inputs to have 3 channels.')
|
|
|
|
def check_valid_image(self, data):
|
|
"""Checks if the input data is valid"""
|
|
if len(data[0].shape) == 0:
|
|
raise RuntimeError('Data shape is wrong')
|
|
|
|
def imdecode(self, s):
|
|
"""Decodes a string or byte string to an NDArray.
|
|
See mx.img.imdecode for more details."""
|
|
img = mx.image.imdecode(s) # mx.ndarray
|
|
return img
|
|
|
|
def read_image(self, fname):
|
|
"""Reads an input image `fname` and returns the decoded raw bytes.
|
|
|
|
Example usage:
|
|
----------
|
|
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
|
|
"""
|
|
with open(os.path.join(self.path_root, fname), 'rb') as fin:
|
|
img = fin.read()
|
|
return img
|
|
|
|
def augmentation_transform(self, data):
|
|
"""Transforms input data with specified augmentation."""
|
|
for aug in self.auglist:
|
|
data = [ret for src in data for ret in aug(src)]
|
|
return data
|
|
|
|
def postprocess_data(self, datum):
|
|
"""Final postprocessing step before image is loaded into the batch."""
|
|
return nd.transpose(datum, axes=(2, 0, 1))
|
|
|
|
|
|
class FaceImageIterList(io.DataIter):
|
|
def __init__(self, iter_list):
|
|
assert len(iter_list) > 0
|
|
self.provide_data = iter_list[0].provide_data
|
|
self.provide_label = iter_list[0].provide_label
|
|
self.iter_list = iter_list
|
|
self.cur_iter = None
|
|
|
|
def reset(self):
|
|
self.cur_iter.reset()
|
|
|
|
def next(self):
|
|
self.cur_iter = random.choice(self.iter_list)
|
|
while True:
|
|
try:
|
|
ret = self.cur_iter.next()
|
|
except StopIteration:
|
|
self.cur_iter.reset()
|
|
continue
|
|
return ret
|
|
|
|
|
|
# dummy
|
|
class DummyIter(mx.io.DataIter):
|
|
def __init__(self,
|
|
batch_size,
|
|
data_shape,
|
|
batches=1000,
|
|
mode='',
|
|
dtype='float32'):
|
|
super(DummyIter, self).__init__(batch_size)
|
|
self.data_shape = (batch_size, ) + data_shape
|
|
self.label_shape = (batch_size, )
|
|
self.provide_data = [('data', self.data_shape)]
|
|
self.provide_label = [('softmax_label', self.label_shape)]
|
|
# self.provide_label = [('label', self.label_shape)]
|
|
# if mode == 'perseus':
|
|
# self.provide_label = []
|
|
self.batch = mx.io.DataBatch(
|
|
data=[mx.nd.zeros(self.data_shape, dtype=dtype)],
|
|
label=[mx.nd.zeros(self.label_shape, dtype=dtype)])
|
|
self._batches = 0
|
|
self.batches = batches
|
|
|
|
def next(self):
|
|
if self._batches < self.batches:
|
|
self._batches += 1
|
|
return self.batch
|
|
else:
|
|
self._batches = 0
|
|
raise StopIteration
|