mirror of
https://gitcode.com/gh_mirrors/fe/FERPlus.git
synced 2025-12-30 05:22:26 +00:00
Fix FERPlus to match latest CNTK.
This commit is contained in:
10
src/train.py
10
src/train.py
@@ -90,17 +90,17 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
|
||||
minibatch_size = 32
|
||||
|
||||
# Training config
|
||||
lr_schedule = [model.learning_rate]*20 + [model.learning_rate / 2.0]*20 + [model.learning_rate / 10.0]
|
||||
lr_per_minibatch = learning_rate_schedule(lr_schedule, epoch_size, UnitType.minibatch)
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9), epoch_size)
|
||||
lr_per_minibatch = [model.learning_rate]*20 + [model.learning_rate / 2.0]*20 + [model.learning_rate / 10.0]
|
||||
mm_time_constant = -minibatch_size/np.log(0.9)
|
||||
lr_schedule = learning_rate_schedule(lr_per_minibatch, unit=UnitType.minibatch, epoch_size=epoch_size)
|
||||
mm_schedule = momentum_as_time_constant_schedule(mm_time_constant)
|
||||
|
||||
# loss and error cost
|
||||
train_loss = cost_func(training_mode, pred, label_var)
|
||||
pe = classification_error(z, label_var)
|
||||
|
||||
# construct the trainer
|
||||
learner = momentum_sgd(z.parameters, lr = lr_per_minibatch, momentum = momentum_time_constant)
|
||||
#learner = sgd(pred.parameters, lr = lr_per_minibatch)
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule)
|
||||
trainer = Trainer(z, train_loss, pe, learner)
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
|
||||
Reference in New Issue
Block a user