mirror of
https://github.com/MarcosRodrigoT/ViT-Face-Recognition.git
synced 2025-12-30 08:02:29 +00:00
Preprocess mugshot images
This commit is contained in:
@@ -1,8 +1,18 @@
|
||||
import os
|
||||
import pickle
|
||||
import tensorflow as tf
|
||||
from vit_keras import vit
|
||||
|
||||
|
||||
def preprocess_image(img_path):
|
||||
img_ = tf.io.read_file(img_path)
|
||||
img_ = tf.image.decode_jpeg(img_, channels=3)
|
||||
img_ = tf.image.convert_image_dtype(img_, dtype=tf.float32)
|
||||
img_ = tf.image.resize(img_, [224, 224])
|
||||
img_ = tf.expand_dims(img_, axis=0)
|
||||
return img_
|
||||
|
||||
|
||||
"""
|
||||
CREATE DATASET
|
||||
"""
|
||||
@@ -11,19 +21,44 @@ MUGSHOT_DIR = f'{BASE_DIR}/mugshot_frontal_cropped_all'
|
||||
SURVEILLANCE_DIR = f'{BASE_DIR}/surveillance_cameras_all'
|
||||
|
||||
mugshot_data = {}
|
||||
for file in os.listdir(MUGSHOT_DIR):
|
||||
for file in sorted(os.listdir(MUGSHOT_DIR)):
|
||||
person = file.split('_')[0]
|
||||
file_path = os.path.join(MUGSHOT_DIR, file)
|
||||
mugshot_data[person] = {'file': file_path, 'embeddings': None}
|
||||
mugshot_data[person] = {
|
||||
'file': file_path,
|
||||
'embeddings': {
|
||||
'vit': None,
|
||||
'resnet': None,
|
||||
'vgg': None,
|
||||
'inception': None,
|
||||
'mobilenet': None,
|
||||
'efficientnet': None,
|
||||
}
|
||||
}
|
||||
|
||||
surveillance_data = {}
|
||||
for person in mugshot_data.keys():
|
||||
surveillance_data[person] = {'files': [], 'embeddings': []}
|
||||
for file in os.listdir(SURVEILLANCE_DIR):
|
||||
surveillance_data[person] = {
|
||||
'files': [],
|
||||
'embeddings': {
|
||||
'vit': [],
|
||||
'resnet': [],
|
||||
'vgg': [],
|
||||
'inception': [],
|
||||
'mobilenet': [],
|
||||
'efficientnet': [],
|
||||
}
|
||||
}
|
||||
for file in sorted(os.listdir(SURVEILLANCE_DIR)):
|
||||
person = file.split('_')[0]
|
||||
file_path = os.path.join(SURVEILLANCE_DIR, file)
|
||||
surveillance_data[person]['files'].append(file_path)
|
||||
surveillance_data[person]['embeddings'].append(None)
|
||||
surveillance_data[person]['embeddings']['vit'].append(None)
|
||||
surveillance_data[person]['embeddings']['resnet'].append(None)
|
||||
surveillance_data[person]['embeddings']['vgg'].append(None)
|
||||
surveillance_data[person]['embeddings']['inception'].append(None)
|
||||
surveillance_data[person]['embeddings']['mobilenet'].append(None)
|
||||
surveillance_data[person]['embeddings']['efficientnet'].append(None)
|
||||
|
||||
|
||||
"""
|
||||
@@ -124,3 +159,34 @@ efficientnetB0_model.summary()
|
||||
efficientnetB0_model.load_weights("./saved_results/Models/EfficientNet_B0/checkpoint").expect_partial() # suppresses warnings
|
||||
efficientnetB0_model = tf.keras.models.Model(inputs=efficientnetB0_model.input, outputs=efficientnetB0_model.layers[-2].output)
|
||||
efficientnetB0_model.summary()
|
||||
|
||||
|
||||
"""
|
||||
PREPROCESS IMAGES AND COMPUTE EMBEDDINGS
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
with open('./saved_results/Tests/SCface/embeddings.pickle', 'rb') as file:
|
||||
mugshot_data, surveillance_data = pickle.load(file)
|
||||
except FileNotFoundError:
|
||||
for person in mugshot_data.keys():
|
||||
print(f'##### Person {person} #####')
|
||||
img = preprocess_image(mugshot_data[person]['file'])
|
||||
|
||||
embeddings_vit = vit_model(img).numpy()
|
||||
embeddings1_resnet = resnet50_model(img).numpy()
|
||||
embeddings1_vgg16 = vgg16_model(img).numpy()
|
||||
embeddings1_inception = inception_model(img).numpy()
|
||||
embeddings1_mobilenet = mobilenet_model(img).numpy()
|
||||
embeddings1_efficientnet = efficientnetB0_model(img).numpy()
|
||||
|
||||
mugshot_data[person]['embeddings']['vit'] = embeddings_vit
|
||||
mugshot_data[person]['embeddings']['resnet'] = embeddings1_resnet
|
||||
mugshot_data[person]['embeddings']['vgg'] = embeddings1_vgg16
|
||||
mugshot_data[person]['embeddings']['inception'] = embeddings1_inception
|
||||
mugshot_data[person]['embeddings']['mobilenet'] = embeddings1_mobilenet
|
||||
mugshot_data[person]['embeddings']['efficientnet'] = embeddings1_efficientnet
|
||||
|
||||
# for person in surveillance_data.keys():
|
||||
# pass
|
||||
|
||||
Reference in New Issue
Block a user