mirror of
https://github.com/MarcosRodrigoT/ViT-Face-Recognition.git
synced 2025-12-30 08:02:29 +00:00
Added multi-gpu capabilities
This commit is contained in:
@@ -11,11 +11,14 @@ from data_generator import create_data_generators
|
||||
HYPERPARAMETERS
|
||||
"""
|
||||
|
||||
# Distribute training
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
|
||||
# Input
|
||||
image_size = 224
|
||||
|
||||
# Hyper-parameters
|
||||
batch_size = 128
|
||||
batch_size = 128 * strategy.num_replicas_in_sync
|
||||
num_epochs = 25
|
||||
learning_rate = 0.0001
|
||||
num_classes = 8631
|
||||
@@ -32,15 +35,16 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
|
||||
MODEL
|
||||
"""
|
||||
|
||||
resnet_model = tf.keras.applications.ResNet50(
|
||||
with strategy.scope():
|
||||
resnet_model = tf.keras.applications.ResNet50(
|
||||
include_top=False,
|
||||
weights="imagenet",
|
||||
input_shape=(image_size, image_size, 3),
|
||||
pooling=None,
|
||||
)
|
||||
Y = GlobalAvgPool2D()(resnet_model.output)
|
||||
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform())(Y)
|
||||
resnet_model = Model(inputs=resnet_model.input, outputs=Y, name='ResNet50')
|
||||
)
|
||||
Y = GlobalAvgPool2D()(resnet_model.output)
|
||||
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform())(Y)
|
||||
resnet_model = Model(inputs=resnet_model.input, outputs=Y, name='ResNet50')
|
||||
resnet_model.summary()
|
||||
|
||||
|
||||
@@ -48,8 +52,9 @@ resnet_model.summary()
|
||||
MODEL COMPILE
|
||||
"""
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
resnet_model.compile(
|
||||
with strategy.scope():
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
resnet_model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
||||
metrics=[
|
||||
@@ -58,7 +63,7 @@ resnet_model.compile(
|
||||
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'),
|
||||
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -11,11 +11,14 @@ from data_generator import create_data_generators
|
||||
HYPERPARAMETERS
|
||||
"""
|
||||
|
||||
# Distribute training
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
|
||||
# Input
|
||||
image_size = 224
|
||||
|
||||
# Hyper-parameters
|
||||
batch_size = 128
|
||||
batch_size = 128 * strategy.num_replicas_in_sync
|
||||
num_epochs = 25
|
||||
learning_rate = 0.0001
|
||||
num_classes = 8631
|
||||
@@ -32,15 +35,16 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
|
||||
MODEL
|
||||
"""
|
||||
|
||||
vgg_model = tf.keras.applications.VGG16(
|
||||
with strategy.scope():
|
||||
vgg_model = tf.keras.applications.VGG16(
|
||||
include_top=True,
|
||||
weights="imagenet",
|
||||
input_shape=(image_size, image_size, 3),
|
||||
pooling=None,
|
||||
)
|
||||
Y = vgg_model.layers[-2].output
|
||||
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform)(Y)
|
||||
vgg_model = Model(inputs=vgg_model.input, outputs=Y, name='VGG16')
|
||||
)
|
||||
Y = vgg_model.layers[-2].output
|
||||
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform)(Y)
|
||||
vgg_model = Model(inputs=vgg_model.input, outputs=Y, name='VGG16')
|
||||
vgg_model.summary()
|
||||
|
||||
|
||||
@@ -48,8 +52,9 @@ vgg_model.summary()
|
||||
MODEL COMPILE
|
||||
"""
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
vgg_model.compile(
|
||||
with strategy.scope():
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
vgg_model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
||||
metrics=[
|
||||
@@ -58,7 +63,7 @@ vgg_model.compile(
|
||||
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'),
|
||||
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -10,11 +10,14 @@ from data_generator import create_data_generators
|
||||
HYPERPARAMETERS
|
||||
"""
|
||||
|
||||
# Distribute training
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
|
||||
# Input
|
||||
image_size = 224
|
||||
|
||||
# Hyper-parameters
|
||||
batch_size = 128
|
||||
batch_size = 128 * strategy.num_replicas_in_sync
|
||||
num_epochs = 25
|
||||
learning_rate = 0.0001
|
||||
num_classes = 8631
|
||||
@@ -31,14 +34,15 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
|
||||
MODEL
|
||||
"""
|
||||
|
||||
base_model = vit.vit_b32(
|
||||
with strategy.scope():
|
||||
base_model = vit.vit_b32(
|
||||
image_size=image_size,
|
||||
pretrained=True,
|
||||
include_top=False,
|
||||
pretrained_top=False,
|
||||
)
|
||||
y = tf.keras.layers.Dense(num_classes, activation='softmax')(base_model.output)
|
||||
vit_model = tf.keras.models.Model(inputs=base_model.input, outputs=y)
|
||||
)
|
||||
y = tf.keras.layers.Dense(num_classes, activation='softmax')(base_model.output)
|
||||
vit_model = tf.keras.models.Model(inputs=base_model.input, outputs=y)
|
||||
vit_model.summary()
|
||||
|
||||
|
||||
@@ -46,8 +50,9 @@ vit_model.summary()
|
||||
MODEL COMPILE
|
||||
"""
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
vit_model.compile(
|
||||
with strategy.scope():
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
|
||||
vit_model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
||||
metrics=[
|
||||
@@ -56,7 +61,7 @@ vit_model.compile(
|
||||
keras.metrics.SparseTopKCategoricalAccuracy(k=10, name="top-10-accuracy"),
|
||||
keras.metrics.SparseTopKCategoricalAccuracy(k=100, name="top-100-accuracy"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user