mirror of
https://github.com/MarcosRodrigoT/ViT-Face-Recognition.git
synced 2025-12-29 23:52:28 +00:00
Added InceptionV3 model experiments
This commit is contained in:
205
inceptionV3_train.py
Normal file
205
inceptionV3_train.py
Normal file
@@ -0,0 +1,205 @@
|
||||
import datetime
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.layers import GlobalAvgPool2D, Dense
|
||||
from tensorflow.keras.initializers import GlorotUniform
|
||||
from tensorflow.keras.models import Model
|
||||
import matplotlib.pyplot as plt
|
||||
from data_generator import create_data_generators
|
||||
|
||||
|
||||
"""
|
||||
HYPERPARAMETERS
|
||||
"""
|
||||
|
||||
# Distribute training
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
|
||||
# Input
|
||||
image_size = 224
|
||||
|
||||
# Hyper-parameters
|
||||
batch_size = 128 * strategy.num_replicas_in_sync
|
||||
num_epochs = 25
|
||||
learning_rate = 0.0001
|
||||
num_classes = 8631
|
||||
|
||||
|
||||
"""
|
||||
DATASET
|
||||
"""
|
||||
|
||||
train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, batch_size=batch_size)
|
||||
|
||||
|
||||
"""
|
||||
MODEL
|
||||
"""
|
||||
|
||||
with strategy.scope():
|
||||
inception_model = tf.keras.applications.InceptionV3(
|
||||
include_top=False,
|
||||
weights="imagenet",
|
||||
input_shape=(image_size, image_size, 3),
|
||||
pooling=None,
|
||||
)
|
||||
Y = GlobalAvgPool2D()(inception_model.output)
|
||||
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform())(Y)
|
||||
inception_model = Model(inputs=inception_model.input, outputs=Y, name='InceptionV3')
|
||||
inception_model.summary(line_length=150)
|
||||
|
||||
|
||||
"""
|
||||
MODEL COMPILE
|
||||
"""
|
||||
|
||||
with strategy.scope():
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
inception_model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
||||
metrics=[
|
||||
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
|
||||
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top-5-accuracy'),
|
||||
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'),
|
||||
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
CALLBACKS
|
||||
"""
|
||||
|
||||
# checkpoint callback
|
||||
checkpoint_filepath = "./tmp/checkpoint"
|
||||
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=checkpoint_filepath,
|
||||
monitor='val_accuracy',
|
||||
verbose=1,
|
||||
save_best_only=True,
|
||||
save_weights_only=True,
|
||||
mode='max',
|
||||
save_freq='epoch',
|
||||
)
|
||||
|
||||
# csv logger callback
|
||||
csv_filepath = "./tmp/training_log.csv"
|
||||
csv_logger = tf.keras.callbacks.CSVLogger(
|
||||
csv_filepath,
|
||||
separator=',',
|
||||
append=True,
|
||||
)
|
||||
|
||||
# early stopping callback
|
||||
early_stopping = tf.keras.callbacks.EarlyStopping(
|
||||
monitor='val_loss',
|
||||
min_delta=0,
|
||||
patience=7,
|
||||
verbose=0,
|
||||
mode='auto',
|
||||
)
|
||||
|
||||
# tensorboard callback
|
||||
tb_callback = tf.keras.callbacks.TensorBoard(
|
||||
log_dir="./tmp/logs" + datetime.datetime.now().strftime("%d%m%Y-%H%M%S"),
|
||||
histogram_freq=1,
|
||||
write_graph=True,
|
||||
update_freq='epoch',
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
LOAD PRE-TRAINED MODEL WEIGHTS
|
||||
"""
|
||||
|
||||
# Load pre-trained model weights before training
|
||||
best_weights = "./saved_results/Models/Inception_V3/checkpoint"
|
||||
inception_model.load_weights(best_weights)
|
||||
|
||||
|
||||
"""
|
||||
TRAIN THE MODEL
|
||||
"""
|
||||
|
||||
history = inception_model.fit(
|
||||
train_gen,
|
||||
epochs=num_epochs,
|
||||
validation_data=val_gen,
|
||||
callbacks=[
|
||||
checkpoint_callback,
|
||||
csv_logger,
|
||||
early_stopping,
|
||||
tb_callback,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
EVALUATE THE MODEL
|
||||
"""
|
||||
|
||||
# Load best weights seen during training
|
||||
inception_model.load_weights(checkpoint_filepath)
|
||||
|
||||
# Evaluate the model
|
||||
loss, accuracy, top_five_accuracy, top_ten_accuracy, top_hundred_accuracy = inception_model.evaluate(test_gen)
|
||||
accuracy = round(accuracy * 100, 2)
|
||||
top_five_accuracy = round(top_five_accuracy * 100, 2)
|
||||
top_ten_accuracy = round(top_ten_accuracy * 100, 2)
|
||||
top_hundred_accuracy = round(top_hundred_accuracy * 100, 2)
|
||||
print(f"Accuracy on the test set: {accuracy}%.")
|
||||
print(f"Top 5 Accuracy on the test set: {top_five_accuracy}%.")
|
||||
print(f"Top 10 Accuracy on the test set: {top_ten_accuracy}%.")
|
||||
print(f"Top 100 Accuracy on the test set: {top_hundred_accuracy}%.")
|
||||
|
||||
|
||||
"""
|
||||
HISTORY FIGURES
|
||||
"""
|
||||
|
||||
# PLOTS
|
||||
# Accuracy
|
||||
plt.plot(history.history['accuracy'])
|
||||
plt.plot(history.history['val_accuracy'])
|
||||
plt.title('model accuracy')
|
||||
plt.ylabel('accuracy')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'val'], loc='upper left')
|
||||
plt.savefig('./tmp/model accuracy.png')
|
||||
plt.close()
|
||||
# Top 5 accuracy
|
||||
plt.plot(history.history['top-5-accuracy'])
|
||||
plt.plot(history.history['val_top-5-accuracy'])
|
||||
plt.title('model top 5 accuracy')
|
||||
plt.ylabel('top 5 accuracy')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'val'], loc='upper left')
|
||||
plt.savefig('./tmp/model top 5 accuracy.png')
|
||||
plt.close()
|
||||
# Top 10 accuracy
|
||||
plt.plot(history.history['top-10-accuracy'])
|
||||
plt.plot(history.history['val_top-10-accuracy'])
|
||||
plt.title('model top 10 accuracy')
|
||||
plt.ylabel('top 10 accuracy')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'val'], loc='upper left')
|
||||
plt.savefig('./tmp/model top 10 accuracy.png')
|
||||
plt.close()
|
||||
# Top 100 accuracy
|
||||
plt.plot(history.history['top-100-accuracy'])
|
||||
plt.plot(history.history['val_top-100-accuracy'])
|
||||
plt.title('model top 100 accuracy')
|
||||
plt.ylabel('top 100 accuracy')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'val'], loc='upper left')
|
||||
plt.savefig('./tmp/model top 100 accuracy.png')
|
||||
plt.close()
|
||||
# Loss
|
||||
plt.plot(history.history['loss'])
|
||||
plt.plot(history.history['val_loss'])
|
||||
plt.title('model loss')
|
||||
plt.ylabel('loss')
|
||||
plt.xlabel('epoch')
|
||||
plt.legend(['train', 'val'], loc='upper left')
|
||||
plt.savefig('./tmp/model loss.png')
|
||||
plt.close()
|
||||
763
saved_results/Models/Inception_V3/history.txt
Normal file
763
saved_results/Models/Inception_V3/history.txt
Normal file
File diff suppressed because one or more lines are too long
26
saved_results/Models/Inception_V3/training_log.csv
Normal file
26
saved_results/Models/Inception_V3/training_log.csv
Normal file
@@ -0,0 +1,26 @@
|
||||
epoch,accuracy,loss,top-10-accuracy,top-100-accuracy,top-5-accuracy,val_accuracy,val_loss,val_top-10-accuracy,val_top-100-accuracy,val_top-5-accuracy
|
||||
0,0.5044412612915039,2.828429698944092,0.7204577326774597,0.8692970275878906,0.6647382378578186,0.7837027311325073,1.0061465501785278,0.936477541923523,0.9873324036598206,0.9071956872940063
|
||||
1,0.8421162366867065,0.7245597839355469,0.9566552639007568,0.9910308718681335,0.9360066652297974,0.8924019932746887,0.4839460253715515,0.9766190648078918,0.9957222938537598,0.963849663734436
|
||||
2,0.9061304330825806,0.4162036180496216,0.978788435459137,0.9959475994110107,0.9674435257911682,0.933899462223053,0.2888556718826294,0.9882490634918213,0.9979566335678101,0.9808204174041748
|
||||
3,0.9348474144935608,0.2803215980529785,0.9875955581665039,0.9978038668632507,0.9802804589271545,0.9474136233329773,0.21838736534118652,0.9925395250320435,0.9988923668861389,0.9875679612159729
|
||||
4,0.9516826272010803,0.20140601694583893,0.99233478307724,0.9988057613372803,0.9873526096343994,0.9636204838752747,0.1476803421974182,0.995849609375,0.9994462132453918,0.992743194103241
|
||||
5,0.962658703327179,0.15122036635875702,0.9951260685920715,0.9993436336517334,0.9915807843208313,0.970030665397644,0.11526346951723099,0.9973773956298828,0.9997262954711914,0.9952257871627808
|
||||
6,0.9698299169540405,0.11761530488729477,0.996909499168396,0.999656617641449,0.9943767189979553,0.9773320555686951,0.08726964145898819,0.9984531402587891,0.9998408555984497,0.996893584728241
|
||||
7,0.9752572774887085,0.09351618587970734,0.9980387091636658,0.999820351600647,0.9961870312690735,0.9809413552284241,0.07012616842985153,0.9992042779922485,0.9999681711196899,0.9982494711875916
|
||||
8,0.9792612195014954,0.0765235498547554,0.9987742900848389,0.9999083876609802,0.9973798394203186,0.9840795993804932,0.05660514906048775,0.9994398355484009,0.9999681711196899,0.9986759424209595
|
||||
9,0.9820313453674316,0.06444480270147324,0.99919193983078,0.9999600648880005,0.9981245994567871,0.985015332698822,0.05419314652681351,0.9996180534362793,0.999993622303009,0.9989942312240601
|
||||
10,0.9844216108322144,0.054814841598272324,0.9994914531707764,0.99997878074646,0.998673141002655,0.9872050881385803,0.04437025636434555,0.9998026490211487,1.0,0.9993698000907898
|
||||
11,0.9862917065620422,0.04778851196169853,0.9996541142463684,0.9999901056289673,0.9990150928497314,0.9895476698875427,0.03730124607682228,0.9998854398727417,1.0,0.9995225667953491
|
||||
12,0.9875460863113403,0.042636893689632416,0.9997605681419373,0.9999961256980896,0.999241054058075,0.9906170964241028,0.032722555100917816,0.9999363422393799,1.0,0.9996690154075623
|
||||
13,0.9887682795524597,0.038302622735500336,0.9998284578323364,0.9999978542327881,0.9994253516197205,0.9895667433738708,0.03481917828321457,0.9999172687530518,0.999993622303009,0.9996435046195984
|
||||
14,0.9897658824920654,0.03466902673244476,0.9998719692230225,0.9999985694885254,0.9995275139808655,0.9924694895744324,0.02594899572432041,0.999961793422699,1.0,0.9997517466545105
|
||||
15,0.9904996752738953,0.03201822564005852,0.9998924732208252,0.9999996423721313,0.9996092319488525,0.990744411945343,0.030490176752209663,0.9999299645423889,1.0,0.9996880888938904
|
||||
16,0.991253674030304,0.029298337176442146,0.9999310374259949,0.9999996423721313,0.9996870160102844,0.9921830296516418,0.02582458406686783,0.9999745488166809,1.0,0.9997581243515015
|
||||
17,0.991923451423645,0.02704434096813202,0.9999434351921082,1.0,0.9997220635414124,0.9935643672943115,0.02145766094326973,0.9999363422393799,1.0,0.9998090267181396
|
||||
18,0.9924065470695496,0.02545708231627941,0.999944806098938,0.9999996423721313,0.9997464418411255,0.9940418004989624,0.01996505819261074,0.999961793422699,1.0,0.9998408555984497
|
||||
19,0.9929522275924683,0.02354452945291996,0.9999568462371826,1.0,0.9997846484184265,0.9938316941261292,0.020735710859298706,0.9999490976333618,1.0,0.9997962713241577
|
||||
20,0.9932354688644409,0.02255505509674549,0.9999586343765259,1.0,0.9998037219047546,0.9954167604446411,0.015650002285838127,0.9999809265136719,1.0,0.9999172687530518
|
||||
21,0.9937002062797546,0.02080986276268959,0.9999681711196899,1.0,0.9998359084129333,0.9952639937400818,0.015985960140824318,0.9999809265136719,1.0,0.9999490976333618
|
||||
22,0.9940460324287415,0.019749460741877556,0.999971330165863,1.0,0.9998387098312378,0.995353102684021,0.01564055308699608,0.9999681711196899,1.0,0.9999299645423889
|
||||
23,0.9942670464515686,0.018744397908449173,0.999977707862854,1.0,0.9998733997344971,0.9952448606491089,0.01479360368102789,1.0,1.0,0.9999299645423889
|
||||
24,0.9946118593215942,0.017900148406624794,0.9999762773513794,1.0,0.9998723268508911,0.9954358339309692,0.014946679584681988,0.999993622303009,1.0,0.9999363422393799
|
||||
|
Reference in New Issue
Block a user