Added multi-gpu capabilities

This commit is contained in:
mrt
2023-05-31 11:40:35 +02:00
parent 9230013fb2
commit feca6ce642
3 changed files with 77 additions and 62 deletions

View File

@@ -11,11 +11,14 @@ from data_generator import create_data_generators
HYPERPARAMETERS HYPERPARAMETERS
""" """
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input # Input
image_size = 224 image_size = 224
# Hyper-parameters # Hyper-parameters
batch_size = 128 batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25 num_epochs = 25
learning_rate = 0.0001 learning_rate = 0.0001
num_classes = 8631 num_classes = 8631
@@ -32,15 +35,16 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
MODEL MODEL
""" """
resnet_model = tf.keras.applications.ResNet50( with strategy.scope():
resnet_model = tf.keras.applications.ResNet50(
include_top=False, include_top=False,
weights="imagenet", weights="imagenet",
input_shape=(image_size, image_size, 3), input_shape=(image_size, image_size, 3),
pooling=None, pooling=None,
) )
Y = GlobalAvgPool2D()(resnet_model.output) Y = GlobalAvgPool2D()(resnet_model.output)
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform())(Y) Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform())(Y)
resnet_model = Model(inputs=resnet_model.input, outputs=Y, name='ResNet50') resnet_model = Model(inputs=resnet_model.input, outputs=Y, name='ResNet50')
resnet_model.summary() resnet_model.summary()
@@ -48,8 +52,9 @@ resnet_model.summary()
MODEL COMPILE MODEL COMPILE
""" """
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) with strategy.scope():
resnet_model.compile( optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
resnet_model.compile(
optimizer=optimizer, optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[ metrics=[
@@ -58,7 +63,7 @@ resnet_model.compile(
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'), tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'), tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'),
] ]
) )
""" """

View File

@@ -11,11 +11,14 @@ from data_generator import create_data_generators
HYPERPARAMETERS HYPERPARAMETERS
""" """
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input # Input
image_size = 224 image_size = 224
# Hyper-parameters # Hyper-parameters
batch_size = 128 batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25 num_epochs = 25
learning_rate = 0.0001 learning_rate = 0.0001
num_classes = 8631 num_classes = 8631
@@ -32,15 +35,16 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
MODEL MODEL
""" """
vgg_model = tf.keras.applications.VGG16( with strategy.scope():
vgg_model = tf.keras.applications.VGG16(
include_top=True, include_top=True,
weights="imagenet", weights="imagenet",
input_shape=(image_size, image_size, 3), input_shape=(image_size, image_size, 3),
pooling=None, pooling=None,
) )
Y = vgg_model.layers[-2].output Y = vgg_model.layers[-2].output
Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform)(Y) Y = Dense(units=num_classes, activation='softmax', kernel_initializer=GlorotUniform)(Y)
vgg_model = Model(inputs=vgg_model.input, outputs=Y, name='VGG16') vgg_model = Model(inputs=vgg_model.input, outputs=Y, name='VGG16')
vgg_model.summary() vgg_model.summary()
@@ -48,8 +52,9 @@ vgg_model.summary()
MODEL COMPILE MODEL COMPILE
""" """
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) with strategy.scope():
vgg_model.compile( optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
vgg_model.compile(
optimizer=optimizer, optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[ metrics=[
@@ -58,7 +63,7 @@ vgg_model.compile(
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'), tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name='top-10-accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'), tf.keras.metrics.SparseTopKCategoricalAccuracy(k=100, name='top-100-accuracy'),
] ]
) )
""" """

View File

@@ -10,11 +10,14 @@ from data_generator import create_data_generators
HYPERPARAMETERS HYPERPARAMETERS
""" """
# Distribute training
strategy = tf.distribute.MirroredStrategy()
# Input # Input
image_size = 224 image_size = 224
# Hyper-parameters # Hyper-parameters
batch_size = 128 batch_size = 128 * strategy.num_replicas_in_sync
num_epochs = 25 num_epochs = 25
learning_rate = 0.0001 learning_rate = 0.0001
num_classes = 8631 num_classes = 8631
@@ -31,14 +34,15 @@ train_gen, val_gen, test_gen = create_data_generators(target_size=image_size, ba
MODEL MODEL
""" """
base_model = vit.vit_b32( with strategy.scope():
base_model = vit.vit_b32(
image_size=image_size, image_size=image_size,
pretrained=True, pretrained=True,
include_top=False, include_top=False,
pretrained_top=False, pretrained_top=False,
) )
y = tf.keras.layers.Dense(num_classes, activation='softmax')(base_model.output) y = tf.keras.layers.Dense(num_classes, activation='softmax')(base_model.output)
vit_model = tf.keras.models.Model(inputs=base_model.input, outputs=y) vit_model = tf.keras.models.Model(inputs=base_model.input, outputs=y)
vit_model.summary() vit_model.summary()
@@ -46,8 +50,9 @@ vit_model.summary()
MODEL COMPILE MODEL COMPILE
""" """
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) with strategy.scope():
vit_model.compile( optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
vit_model.compile(
optimizer=optimizer, optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[ metrics=[
@@ -56,7 +61,7 @@ vit_model.compile(
keras.metrics.SparseTopKCategoricalAccuracy(k=10, name="top-10-accuracy"), keras.metrics.SparseTopKCategoricalAccuracy(k=10, name="top-10-accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(k=100, name="top-100-accuracy"), keras.metrics.SparseTopKCategoricalAccuracy(k=100, name="top-100-accuracy"),
] ]
) )
""" """