Added multi-gpu capabilities

This commit is contained in:
mrt
2023-05-31 11:40:35 +02:00
parent 9230013fb2
commit feca6ce642
3 changed files with 77 additions and 62 deletions

View File

@@ -11,11 +11,14 @@ from data_generator import create_data_generators
HYPERPARAMETERS HYPERPARAMETERS
""" """
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input # Input
image_size = 224 image_size = 224
# Hyper-parameters # Hyper-parameters
batch_size = 128 batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25 num_epochs = 25
learning_rate = 0.0001 learning_rate = 0.0001
num_classes = 8631 num_classes = 8631
@@ -32,6 +35,7 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
MODEL MODEL
""" """
with strategy.scope():
resnet_model = tf.keras.applications.ResNet50( resnet_model = tf.keras.applications.ResNet50(
include_top=False, include_top=False,
weights="imagenet", weights="imagenet",
@@ -48,6 +52,7 @@ resnet_model.summary()
MODEL COMPILE MODEL COMPILE
""" """
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
resnet_model.compile( resnet_model.compile(
optimizer=optimizer, optimizer=optimizer,

View File

@@ -11,11 +11,14 @@ from data_generator import create_data_generators
HYPERPARAMETERS HYPERPARAMETERS
""" """
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input # Input
image_size = 224 image_size = 224
# Hyper-parameters # Hyper-parameters
batch_size = 128 batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25 num_epochs = 25
learning_rate = 0.0001 learning_rate = 0.0001
num_classes = 8631 num_classes = 8631
@@ -32,6 +35,7 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
MODEL MODEL
""" """
with strategy.scope():
vgg_model = tf.keras.applications.VGG16( vgg_model = tf.keras.applications.VGG16(
include_top=True, include_top=True,
weights="imagenet", weights="imagenet",
@@ -48,6 +52,7 @@ vgg_model.summary()
MODEL COMPILE MODEL COMPILE
""" """
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
vgg_model.compile( vgg_model.compile(
optimizer=optimizer, optimizer=optimizer,

View File

@@ -10,11 +10,14 @@ from data_generator import create_data_generators
HYPERPARAMETERS HYPERPARAMETERS
""" """
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input # Input
image_size = 224 image_size = 224
# Hyper-parameters # Hyper-parameters
batch_size = 128 batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25 num_epochs = 25
learning_rate = 0.0001 learning_rate = 0.0001
num_classes = 8631 num_classes = 8631
@@ -31,6 +34,7 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
MODEL MODEL
""" """
with strategy.scope():
base_model = vit.vit_b32( base_model = vit.vit_b32(
image_size=image_size, image_size=image_size,
pretrained=True, pretrained=True,
@@ -46,6 +50,7 @@ vit_model.summary()
MODEL COMPILE MODEL COMPILE
""" """
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
vit_model.compile( vit_model.compile(
optimizer=optimizer, optimizer=optimizer,