mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 14:55:42 +00:00
add cutoff
This commit is contained in:
14
src/data.py
14
src/data.py
@@ -58,7 +58,7 @@ class FaceImageIter(io.DataIter):
|
||||
def __init__(self, batch_size, data_shape,
|
||||
path_imgrec = None,
|
||||
shuffle=False, aug_list=None, mean = None,
|
||||
rand_mirror = False,
|
||||
rand_mirror = False, cutoff = 0,
|
||||
c2c_threshold = 0.0, output_c2c = 0, c2c_mode = -10,
|
||||
ctx_num = 0, images_per_identity = 0, data_extra = None, hard_mining = False,
|
||||
triplet_params = None, coco_mode = False,
|
||||
@@ -204,6 +204,7 @@ class FaceImageIter(io.DataIter):
|
||||
self.image_size = '%d,%d'%(data_shape[1],data_shape[2])
|
||||
self.rand_mirror = rand_mirror
|
||||
print('rand_mirror', rand_mirror)
|
||||
self.cutoff = cutoff
|
||||
#self.cast_aug = mx.image.CastAug()
|
||||
#self.color_aug = mx.image.ColorJitterAug(0.4, 0.4, 0.4)
|
||||
self.ctx_num = ctx_num
|
||||
@@ -824,6 +825,17 @@ class FaceImageIter(io.DataIter):
|
||||
_data = _data.astype('float32')
|
||||
_data -= self.nd_mean
|
||||
_data *= 0.0078125
|
||||
if self.cutoff>0:
|
||||
centerh = random.randint(0, _data.shape[0]-1)
|
||||
centerw = random.randint(0, _data.shape[1]-1)
|
||||
half = self.cutoff//2
|
||||
starth = max(0, centerh-half)
|
||||
endh = min(_data.shape[0], centerh+half)
|
||||
startw = max(0, centerw-half)
|
||||
endw = min(_data.shape[1], centerw+half)
|
||||
_data = _data.astype('float32')
|
||||
#print(starth, endh, startw, endw, _data.shape)
|
||||
_data[starth:endh, startw:endw, :] = 127.5
|
||||
#_npdata = _data.asnumpy()
|
||||
#if landmark is not None:
|
||||
# _npdata = face_preprocess.preprocess(_npdata, bbox = bbox, landmark=landmark, image_size=self.image_size)
|
||||
|
||||
Reference in New Issue
Block a user