Added InceptionV3 model experiments

This commit is contained in:
mrt
2023-06-01 14:46:48 +02:00
parent cfea6cd690
commit db46102e4a
3 changed files with 994 additions and 0 deletions

205
inceptionV3_train.py Normal file
View File

@@ -0,0 +1,205 @@
import datetime
import tensorflow as tf
from tensorflow.keras.layers import GlobalAvgPool2D, Dense
from tensorflow.keras.initializers import GlorotUniform
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
from data_generator import create_data_generators
"""
HYPERPARAMETERS
"""
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input
image_size = 224
# Hyper-parameters
batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25
learning_rate = 0.0001
num_classes = 8631
"""
DATASET
"""
train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, batch_size=batch_size)
"""
MODEL
"""
with strategy.scope():
inception_model = tf.keras.applications.InceptionV3(
include_top=False,
weights="imagenet",
input_shape=(image_size, image_size, 3),
pooling=None,
)
Y = GlobalAvgPool2D()(inception_model.output)
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform())(Y)
inception_model = Model(inputs=inception_model.input, outputs=Y, name='InceptionV3')
inception_model.summary(line_length=150)
"""
MODEL COMPILE
"""
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
inception_model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name='top-5-accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'),
]
)
"""
CALLBACKS
"""
# checkpoint callback
checkpoint_filepath = "./tmp/checkpoint"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
monitor='val_accuracy',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='max',
save_freq='epoch',
)
# csv logger callback
csv_filepath = "./tmp/training_log.csv"
csv_logger = tf.keras.callbacks.CSVLogger(
csv_filepath,
separator=',',
append=True,
)
# early stopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=0,
patience=7,
verbose=0,
mode='auto',
)
# tensorboard callback
tb_callback = tf.keras.callbacks.TensorBoard(
log_dir="./tmp/logs" + datetime.datetime.now().strftime("%d%m%Y-%H%M%S"),
histogram_freq=1,
write_graph=True,
update_freq='epoch',
)
"""
LOAD PRE-TRAINED MODEL WEIGHTS
"""
# Load pre-trained model weights before training
best_weights = "./saved_results/Models/Inception_V3/checkpoint"
inception_model.load_weights(best_weights)
"""
TRAIN THE MODEL
"""
history = inception_model.fit(
train_gen,
epochs=num_epochs,
validation_data=val_gen,
callbacks=[
checkpoint_callback,
csv_logger,
early_stopping,
tb_callback,
]
)
"""
EVALUATE THE MODEL
"""
# Load best weights seen during training
inception_model.load_weights(checkpoint_filepath)
# Evaluate the model
loss, accuracy, top_five_accuracy, top_ten_accuracy, top_hundred_accuracy = inception_model.evaluate(test_gen)
accuracy = round(accuracy * 100, 2)
top_five_accuracy = round(top_five_accuracy * 100, 2)
top_ten_accuracy = round(top_ten_accuracy * 100, 2)
top_hundred_accuracy = round(top_hundred_accuracy * 100, 2)
print(f"Accuracy on the test set: {accuracy}%.")
print(f"Top 5 Accuracy on the test set: {top_five_accuracy}%.")
print(f"Top 10 Accuracy on the test set: {top_ten_accuracy}%.")
print(f"Top 100 Accuracy on the test set: {top_hundred_accuracy}%.")
"""
HISTORY FIGURES
"""
# PLOTS
# Accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model accuracy.png')
plt.close()
# Top 5 accuracy
plt.plot(history.history['top-5-accuracy'])
plt.plot(history.history['val_top-5-accuracy'])
plt.title('model top 5 accuracy')
plt.ylabel('top 5 accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model top 5 accuracy.png')
plt.close()
# Top 10 accuracy
plt.plot(history.history['top-10-accuracy'])
plt.plot(history.history['val_top-10-accuracy'])
plt.title('model top 10 accuracy')
plt.ylabel('top 10 accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model top 10 accuracy.png')
plt.close()
# Top 100 accuracy
plt.plot(history.history['top-100-accuracy'])
plt.plot(history.history['val_top-100-accuracy'])
plt.title('model top 100 accuracy')
plt.ylabel('top 100 accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model top 100 accuracy.png')
plt.close()
# Loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./tmp/model loss.png')
plt.close()

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,26 @@
epoch,accuracy,loss,top-10-accuracy,top-100-accuracy,top-5-accuracy,val_accuracy,val_loss,val_top-10-accuracy,val_top-100-accuracy,val_top-5-accuracy
0,0.5044412612915039,2.828429698944092,0.7204577326774597,0.8692970275878906,0.6647382378578186,0.7837027311325073,1.0061465501785278,0.936477541923523,0.9873324036598206,0.9071956872940063
1,0.8421162366867065,0.7245597839355469,0.9566552639007568,0.9910308718681335,0.9360066652297974,0.8924019932746887,0.4839460253715515,0.9766190648078918,0.9957222938537598,0.963849663734436
2,0.9061304330825806,0.4162036180496216,0.978788435459137,0.9959475994110107,0.9674435257911682,0.933899462223053,0.2888556718826294,0.9882490634918213,0.9979566335678101,0.9808204174041748
3,0.9348474144935608,0.2803215980529785,0.9875955581665039,0.9978038668632507,0.9802804589271545,0.9474136233329773,0.21838736534118652,0.9925395250320435,0.9988923668861389,0.9875679612159729
4,0.9516826272010803,0.20140601694583893,0.99233478307724,0.9988057613372803,0.9873526096343994,0.9636204838752747,0.1476803421974182,0.995849609375,0.9994462132453918,0.992743194103241
5,0.962658703327179,0.15122036635875702,0.9951260685920715,0.9993436336517334,0.9915807843208313,0.970030665397644,0.11526346951723099,0.9973773956298828,0.9997262954711914,0.9952257871627808
6,0.9698299169540405,0.11761530488729477,0.996909499168396,0.999656617641449,0.9943767189979553,0.9773320555686951,0.08726964145898819,0.9984531402587891,0.9998408555984497,0.996893584728241
7,0.9752572774887085,0.09351618587970734,0.9980387091636658,0.999820351600647,0.9961870312690735,0.9809413552284241,0.07012616842985153,0.9992042779922485,0.9999681711196899,0.9982494711875916
8,0.9792612195014954,0.0765235498547554,0.9987742900848389,0.9999083876609802,0.9973798394203186,0.9840795993804932,0.05660514906048775,0.9994398355484009,0.9999681711196899,0.9986759424209595
9,0.9820313453674316,0.06444480270147324,0.99919193983078,0.9999600648880005,0.9981245994567871,0.985015332698822,0.05419314652681351,0.9996180534362793,0.999993622303009,0.9989942312240601
10,0.9844216108322144,0.054814841598272324,0.9994914531707764,0.99997878074646,0.998673141002655,0.9872050881385803,0.04437025636434555,0.9998026490211487,1.0,0.9993698000907898
11,0.9862917065620422,0.04778851196169853,0.9996541142463684,0.9999901056289673,0.9990150928497314,0.9895476698875427,0.03730124607682228,0.9998854398727417,1.0,0.9995225667953491
12,0.9875460863113403,0.042636893689632416,0.9997605681419373,0.9999961256980896,0.999241054058075,0.9906170964241028,0.032722555100917816,0.9999363422393799,1.0,0.9996690154075623
13,0.9887682795524597,0.038302622735500336,0.9998284578323364,0.9999978542327881,0.9994253516197205,0.9895667433738708,0.03481917828321457,0.9999172687530518,0.999993622303009,0.9996435046195984
14,0.9897658824920654,0.03466902673244476,0.9998719692230225,0.9999985694885254,0.9995275139808655,0.9924694895744324,0.02594899572432041,0.999961793422699,1.0,0.9997517466545105
15,0.9904996752738953,0.03201822564005852,0.9998924732208252,0.9999996423721313,0.9996092319488525,0.990744411945343,0.030490176752209663,0.9999299645423889,1.0,0.9996880888938904
16,0.991253674030304,0.029298337176442146,0.9999310374259949,0.9999996423721313,0.9996870160102844,0.9921830296516418,0.02582458406686783,0.9999745488166809,1.0,0.9997581243515015
17,0.991923451423645,0.02704434096813202,0.9999434351921082,1.0,0.9997220635414124,0.9935643672943115,0.02145766094326973,0.9999363422393799,1.0,0.9998090267181396
18,0.9924065470695496,0.02545708231627941,0.999944806098938,0.9999996423721313,0.9997464418411255,0.9940418004989624,0.01996505819261074,0.999961793422699,1.0,0.9998408555984497
19,0.9929522275924683,0.02354452945291996,0.9999568462371826,1.0,0.9997846484184265,0.9938316941261292,0.020735710859298706,0.9999490976333618,1.0,0.9997962713241577
20,0.9932354688644409,0.02255505509674549,0.9999586343765259,1.0,0.9998037219047546,0.9954167604446411,0.015650002285838127,0.9999809265136719,1.0,0.9999172687530518
21,0.9937002062797546,0.02080986276268959,0.9999681711196899,1.0,0.9998359084129333,0.9952639937400818,0.015985960140824318,0.9999809265136719,1.0,0.9999490976333618
22,0.9940460324287415,0.019749460741877556,0.999971330165863,1.0,0.9998387098312378,0.995353102684021,0.01564055308699608,0.9999681711196899,1.0,0.9999299645423889
23,0.9942670464515686,0.018744397908449173,0.999977707862854,1.0,0.9998733997344971,0.9952448606491089,0.01479360368102789,1.0,1.0,0.9999299645423889
24,0.9946118593215942,0.017900148406624794,0.9999762773513794,1.0,0.9998723268508911,0.9954358339309692,0.014946679584681988,0.999993622303009,1.0,0.9999363422393799
1 epoch accuracy loss top-10-accuracy top-100-accuracy top-5-accuracy val_accuracy val_loss val_top-10-accuracy val_top-100-accuracy val_top-5-accuracy
2 0 0.5044412612915039 2.828429698944092 0.7204577326774597 0.8692970275878906 0.6647382378578186 0.7837027311325073 1.0061465501785278 0.936477541923523 0.9873324036598206 0.9071956872940063
3 1 0.8421162366867065 0.7245597839355469 0.9566552639007568 0.9910308718681335 0.9360066652297974 0.8924019932746887 0.4839460253715515 0.9766190648078918 0.9957222938537598 0.963849663734436
4 2 0.9061304330825806 0.4162036180496216 0.978788435459137 0.9959475994110107 0.9674435257911682 0.933899462223053 0.2888556718826294 0.9882490634918213 0.9979566335678101 0.9808204174041748
5 3 0.9348474144935608 0.2803215980529785 0.9875955581665039 0.9978038668632507 0.9802804589271545 0.9474136233329773 0.21838736534118652 0.9925395250320435 0.9988923668861389 0.9875679612159729
6 4 0.9516826272010803 0.20140601694583893 0.99233478307724 0.9988057613372803 0.9873526096343994 0.9636204838752747 0.1476803421974182 0.995849609375 0.9994462132453918 0.992743194103241
7 5 0.962658703327179 0.15122036635875702 0.9951260685920715 0.9993436336517334 0.9915807843208313 0.970030665397644 0.11526346951723099 0.9973773956298828 0.9997262954711914 0.9952257871627808
8 6 0.9698299169540405 0.11761530488729477 0.996909499168396 0.999656617641449 0.9943767189979553 0.9773320555686951 0.08726964145898819 0.9984531402587891 0.9998408555984497 0.996893584728241
9 7 0.9752572774887085 0.09351618587970734 0.9980387091636658 0.999820351600647 0.9961870312690735 0.9809413552284241 0.07012616842985153 0.9992042779922485 0.9999681711196899 0.9982494711875916
10 8 0.9792612195014954 0.0765235498547554 0.9987742900848389 0.9999083876609802 0.9973798394203186 0.9840795993804932 0.05660514906048775 0.9994398355484009 0.9999681711196899 0.9986759424209595
11 9 0.9820313453674316 0.06444480270147324 0.99919193983078 0.9999600648880005 0.9981245994567871 0.985015332698822 0.05419314652681351 0.9996180534362793 0.999993622303009 0.9989942312240601
12 10 0.9844216108322144 0.054814841598272324 0.9994914531707764 0.99997878074646 0.998673141002655 0.9872050881385803 0.04437025636434555 0.9998026490211487 1.0 0.9993698000907898
13 11 0.9862917065620422 0.04778851196169853 0.9996541142463684 0.9999901056289673 0.9990150928497314 0.9895476698875427 0.03730124607682228 0.9998854398727417 1.0 0.9995225667953491
14 12 0.9875460863113403 0.042636893689632416 0.9997605681419373 0.9999961256980896 0.999241054058075 0.9906170964241028 0.032722555100917816 0.9999363422393799 1.0 0.9996690154075623
15 13 0.9887682795524597 0.038302622735500336 0.9998284578323364 0.9999978542327881 0.9994253516197205 0.9895667433738708 0.03481917828321457 0.9999172687530518 0.999993622303009 0.9996435046195984
16 14 0.9897658824920654 0.03466902673244476 0.9998719692230225 0.9999985694885254 0.9995275139808655 0.9924694895744324 0.02594899572432041 0.999961793422699 1.0 0.9997517466545105
17 15 0.9904996752738953 0.03201822564005852 0.9998924732208252 0.9999996423721313 0.9996092319488525 0.990744411945343 0.030490176752209663 0.9999299645423889 1.0 0.9996880888938904
18 16 0.991253674030304 0.029298337176442146 0.9999310374259949 0.9999996423721313 0.9996870160102844 0.9921830296516418 0.02582458406686783 0.9999745488166809 1.0 0.9997581243515015
19 17 0.991923451423645 0.02704434096813202 0.9999434351921082 1.0 0.9997220635414124 0.9935643672943115 0.02145766094326973 0.9999363422393799 1.0 0.9998090267181396
20 18 0.9924065470695496 0.02545708231627941 0.999944806098938 0.9999996423721313 0.9997464418411255 0.9940418004989624 0.01996505819261074 0.999961793422699 1.0 0.9998408555984497
21 19 0.9929522275924683 0.02354452945291996 0.9999568462371826 1.0 0.9997846484184265 0.9938316941261292 0.020735710859298706 0.9999490976333618 1.0 0.9997962713241577
22 20 0.9932354688644409 0.02255505509674549 0.9999586343765259 1.0 0.9998037219047546 0.9954167604446411 0.015650002285838127 0.9999809265136719 1.0 0.9999172687530518
23 21 0.9937002062797546 0.02080986276268959 0.9999681711196899 1.0 0.9998359084129333 0.9952639937400818 0.015985960140824318 0.9999809265136719 1.0 0.9999490976333618
24 22 0.9940460324287415 0.019749460741877556 0.999971330165863 1.0 0.9998387098312378 0.995353102684021 0.01564055308699608 0.9999681711196899 1.0 0.9999299645423889
25 23 0.9942670464515686 0.018744397908449173 0.999977707862854 1.0 0.9998733997344971 0.9952448606491089 0.01479360368102789 1.0 1.0 0.9999299645423889
26 24 0.9946118593215942 0.017900148406624794 0.9999762773513794 1.0 0.9998723268508911 0.9954358339309692 0.014946679584681988 0.999993622303009 1.0 0.9999363422393799