Update the model for the update CNTK API.

This commit is contained in:
Emad Barsoum
2017-02-17 22:53:11 -08:00
parent 12632bdc6e
commit b3290259a9
2 changed files with 24 additions and 34 deletions

View File

@@ -7,15 +7,7 @@ import os
import sys
import math
import numpy as np
from cntk.blocks import default_options
from cntk.layers import Convolution, MaxPooling, AveragePooling, Dropout, BatchNormalization, Dense
from cntk.models import Sequential, LayerStack
from cntk.utils import *
from cntk.initializer import glorot_uniform
from cntk import Trainer
from cntk.learner import momentum_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
from cntk.ops import input_variable, constant, parameter, relu
import cntk as ct
def build_model(num_classes, model_name):
'''
@@ -56,26 +48,25 @@ class VGG13(object):
self._model = self._create_model(num_classes)
def _create_model(self, num_classes):
with default_options(activation=relu, init=glorot_uniform()):
model = Sequential([
LayerStack(2, lambda i: [
Convolution((3,3), [64,128][i], pad=True),
Convolution((3,3), [64,128][i], pad=True),
MaxPooling((2,2), strides=(2,2)),
Dropout(0.25)
with ct.default_options(activation=ct.relu, init=ct.glorot_uniform()):
model = ct.Sequential([
ct.For(range(2), lambda i: [
ct.Convolution((3,3), [64,128][i], pad=True),
ct.Convolution((3,3), [64,128][i], pad=True),
ct.MaxPooling((2,2), strides=(2,2)),
ct.Dropout(0.25)
]),
LayerStack(2, lambda i: [
Convolution((3,3), [256,256][i], pad=True),
Convolution((3,3), [256,256][i], pad=True),
Convolution((3,3), [256,256][i], pad=True),
MaxPooling((2,2), strides=(2,2)),
Dropout(0.25)
ct.For(range(2), lambda i: [
ct.Convolution((3,3), [256,256][i], pad=True),
ct.Convolution((3,3), [256,256][i], pad=True),
ct.Convolution((3,3), [256,256][i], pad=True),
ct.MaxPooling((2,2), strides=(2,2)),
ct.Dropout(0.25)
]),
LayerStack(2, lambda : [
Dense(1024),
Dropout(0.5)
ct.For(range(2), lambda : [
ct.Dense(1024),
ct.Dropout(0.5)
]),
Dense(num_classes, activation=None)
ct.Dense(num_classes, activation=None)
])
return model

View File

@@ -18,7 +18,6 @@ from ferplus import *
from cntk import Trainer
from cntk.learner import sgd, momentum_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
from cntk.ops import cross_entropy_with_softmax, classification_error
from cntk.ops import input_variable, constant, parameter, softmax
import cntk as ct
emotion_table = {'neutral' : 0,
@@ -67,8 +66,8 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
model = build_model(num_classes, model_name)
# set the input variables.
input_var = input_variable((1, model.input_height, model.input_width), np.float32)
label_var = input_variable((num_classes), np.float32)
input_var = ct.input_variable((1, model.input_height, model.input_width), np.float32)
label_var = ct.input_variable((num_classes), np.float32)
# read FER+ dataset.
logging.info("Loading data...")
@@ -84,7 +83,7 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
# get the probalistic output of the model.
z = model.model(input_var)
pred = softmax(z)
pred = ct.softmax(z)
epoch_size = train_data_reader.size()
minibatch_size = 32
@@ -101,7 +100,7 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
# construct the trainer
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule)
trainer = Trainer(z, train_loss, pe, learner)
trainer = Trainer(z, (train_loss, pe), learner)
# Get minibatches of images to train with and perform model training
max_val_accuracy = 0.0
@@ -127,8 +126,8 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
trainer.train_minibatch({input_var : images, label_var : labels})
# keep track of statistics.
training_loss += get_train_loss(trainer) * current_batch_size
training_accuracy += get_train_eval_criterion(trainer) * current_batch_size
training_loss += trainer.previous_minibatch_loss_average * current_batch_size
training_accuracy += trainer.previous_minibatch_evaluation_average * current_batch_size
training_accuracy /= train_data_reader.size()
training_accuracy = 1.0 - training_accuracy