Files
ViT-Face-Recognition/inceptionV3_train.py
2023-06-01 14:46:48 +02:00

206 lines
5.1 KiB
Python

import datetime
import tensorflow as tf
from tensorflow.keras.layers import GlobalAvgPool2D, Dense
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
from data_generator import create_data_generators
"""
HYPERPARAMETERS
"""
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input
image_size = 224
# Hyper-parameters
batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25
learning_rate = 0.0001
num_classes = 8631
"""
DATASET
"""
train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, batch_size=batch_size)
"""
MODEL
"""
with strategy.scope():
inception_model = tf.keras.applications.InceptionV3(
include_top=False,
weights="imagenet",
input_shape=(image_size, image_size, 3),
pooling=None,
)
Y = GlobalAvgPool2D()(inception_model.output)
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform())(Y)
inception_model = Model(inputs=inception_model.input, outputs=Y, name='InceptionV3')
inception_model.summary(line_length=150)
"""
MODEL COMPILE
"""
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
inception_model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top-5-accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'),
]
)
"""
CALLBACKS
"""
# checkpoint callback
checkpoint_filepath = "./tmp/checkpoint"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
monitor='val_accuracy',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='max',
save_freq='epoch',
)
# csv logger callback
csv_filepath = "./tmp/training_log.csv"
csv_logger = tf.keras.callbacks.CSVLogger(
csv_filepath,
separator=',',
append=True,
)
# early stopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=0,
patience=7,
verbose=0,
mode='auto',
)
# tensorboard callback
tb_callback = tf.keras.callbacks.TensorBoard(
log_dir="./tmp/logs" + datetime.datetime.now().strftime("%d%m%Y-%H%M%S"),
histogram_freq=1,
write_graph=True,
update_freq='epoch',
)
"""
LOAD PRE-TRAINED MODEL WEIGHTS
"""
# Load pre-trained model weights before training
best_weights = "./saved_results/Models/Inception_V3/checkpoint"
inception_model.load_weights(best_weights)
"""
TRAIN THE MODEL
"""
history = inception_model.fit(
train_gen,
epochs=num_epochs,
validation_data=val_gen,
callbacks=[
checkpoint_callback,
csv_logger,
early_stopping,
tb_callback,
]
)
"""
EVALUATE THE MODEL
"""
# Load best weights seen during training
inception_model.load_weights(checkpoint_filepath)
# Evaluate the model
loss, accuracy, top_five_accuracy, top_ten_accuracy, top_hundred_accuracy = inception_model.evaluate(test_gen)
accuracy = round(accuracy * 100, 2)
top_five_accuracy = round(top_five_accuracy * 100, 2)
top_ten_accuracy = round(top_ten_accuracy * 100, 2)
top_hundred_accuracy = round(top_hundred_accuracy * 100, 2)
print(f"Accuracy on the test set: {accuracy}%.")
print(f"Top 5 Accuracy on the test set: {top_five_accuracy}%.")
print(f"Top 10 Accuracy on the test set: {top_ten_accuracy}%.")
print(f"Top 100 Accuracy on the test set: {top_hundred_accuracy}%.")
"""
HISTORY FIGURES
"""
# PLOTS
# Accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model accuracy.png')
plt.close()
# Top 5 accuracy
plt.plot(history.history['top-5-accuracy'])
plt.plot(history.history['val_top-5-accuracy'])
plt.title('model top 5 accuracy')
plt.ylabel('top 5 accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model top 5 accuracy.png')
plt.close()
# Top 10 accuracy
plt.plot(history.history['top-10-accuracy'])
plt.plot(history.history['val_top-10-accuracy'])
plt.title('model top 10 accuracy')
plt.ylabel('top 10 accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model top 10 accuracy.png')
plt.close()
# Top 100 accuracy
plt.plot(history.history['top-100-accuracy'])
plt.plot(history.history['val_top-100-accuracy'])
plt.title('model top 100 accuracy')
plt.ylabel('top 100 accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model top 100 accuracy.png')
plt.close()
# Loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model loss.png')
plt.close()