Fix FERPlus to match latest CNTK.

This commit is contained in:
Emad Barsoum
2017-02-08 14:29:17 -08:00
parent 45edecd6e9
commit 12632bdc6e

View File

@@ -90,17 +90,17 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
minibatch_size = 32
# Training config
lr_schedule = [model.learning_rate]*20 + [model.learning_rate / 2.0]*20 + [model.learning_rate / 10.0]
lr_per_minibatch = learning_rate_schedule(lr_schedule, epoch_size, UnitType.minibatch)
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9), epoch_size)
lr_per_minibatch = [model.learning_rate]*20 + [model.learning_rate / 2.0]*20 + [model.learning_rate / 10.0]
mm_time_constant = -minibatch_size/np.log(0.9)
lr_schedule = learning_rate_schedule(lr_per_minibatch, unit=UnitType.minibatch, epoch_size=epoch_size)
mm_schedule = momentum_as_time_constant_schedule(mm_time_constant)
# loss and error cost
train_loss = cost_func(training_mode, pred, label_var)
pe = classification_error(z, label_var)
# construct the trainer
learner = momentum_sgd(z.parameters, lr = lr_per_minibatch, momentum = momentum_time_constant)
#learner = sgd(pred.parameters, lr = lr_per_minibatch)
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule)
trainer = Trainer(z, train_loss, pe, learner)
# Get minibatches of images to train with and perform model training