mirror of
https://gitcode.com/gh_mirrors/fe/FERPlus.git
synced 2025-12-30 05:22:26 +00:00
Update the model for the update CNTK API.
This commit is contained in:
@@ -7,15 +7,7 @@ import os
|
||||
import sys
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from cntk.blocks import default_options
|
||||
from cntk.layers import Convolution, MaxPooling, AveragePooling, Dropout, BatchNormalization, Dense
|
||||
from cntk.models import Sequential, LayerStack
|
||||
from cntk.utils import *
|
||||
from cntk.initializer import glorot_uniform
|
||||
from cntk import Trainer
|
||||
from cntk.learner import momentum_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
|
||||
from cntk.ops import input_variable, constant, parameter, relu
|
||||
import cntk as ct
|
||||
|
||||
def build_model(num_classes, model_name):
|
||||
'''
|
||||
@@ -56,26 +48,25 @@ class VGG13(object):
|
||||
self._model = self._create_model(num_classes)
|
||||
|
||||
def _create_model(self, num_classes):
|
||||
with default_options(activation=relu, init=glorot_uniform()):
|
||||
model = Sequential([
|
||||
LayerStack(2, lambda i: [
|
||||
Convolution((3,3), [64,128][i], pad=True),
|
||||
Convolution((3,3), [64,128][i], pad=True),
|
||||
MaxPooling((2,2), strides=(2,2)),
|
||||
Dropout(0.25)
|
||||
with ct.default_options(activation=ct.relu, init=ct.glorot_uniform()):
|
||||
model = ct.Sequential([
|
||||
ct.For(range(2), lambda i: [
|
||||
ct.Convolution((3,3), [64,128][i], pad=True),
|
||||
ct.Convolution((3,3), [64,128][i], pad=True),
|
||||
ct.MaxPooling((2,2), strides=(2,2)),
|
||||
ct.Dropout(0.25)
|
||||
]),
|
||||
LayerStack(2, lambda i: [
|
||||
Convolution((3,3), [256,256][i], pad=True),
|
||||
Convolution((3,3), [256,256][i], pad=True),
|
||||
Convolution((3,3), [256,256][i], pad=True),
|
||||
MaxPooling((2,2), strides=(2,2)),
|
||||
Dropout(0.25)
|
||||
ct.For(range(2), lambda i: [
|
||||
ct.Convolution((3,3), [256,256][i], pad=True),
|
||||
ct.Convolution((3,3), [256,256][i], pad=True),
|
||||
ct.Convolution((3,3), [256,256][i], pad=True),
|
||||
ct.MaxPooling((2,2), strides=(2,2)),
|
||||
ct.Dropout(0.25)
|
||||
]),
|
||||
LayerStack(2, lambda : [
|
||||
Dense(1024),
|
||||
Dropout(0.5)
|
||||
ct.For(range(2), lambda : [
|
||||
ct.Dense(1024),
|
||||
ct.Dropout(0.5)
|
||||
]),
|
||||
Dense(num_classes, activation=None)
|
||||
ct.Dense(num_classes, activation=None)
|
||||
])
|
||||
|
||||
return model
|
||||
|
||||
13
src/train.py
13
src/train.py
@@ -18,7 +18,6 @@ from ferplus import *
|
||||
from cntk import Trainer
|
||||
from cntk.learner import sgd, momentum_sgd, learning_rate_schedule, UnitType, momentum_as_time_constant_schedule
|
||||
from cntk.ops import cross_entropy_with_softmax, classification_error
|
||||
from cntk.ops import input_variable, constant, parameter, softmax
|
||||
import cntk as ct
|
||||
|
||||
emotion_table = {'neutral' : 0,
|
||||
@@ -67,8 +66,8 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
|
||||
model = build_model(num_classes, model_name)
|
||||
|
||||
# set the input variables.
|
||||
input_var = input_variable((1, model.input_height, model.input_width), np.float32)
|
||||
label_var = input_variable((num_classes), np.float32)
|
||||
input_var = ct.input_variable((1, model.input_height, model.input_width), np.float32)
|
||||
label_var = ct.input_variable((num_classes), np.float32)
|
||||
|
||||
# read FER+ dataset.
|
||||
logging.info("Loading data...")
|
||||
@@ -84,7 +83,7 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
|
||||
|
||||
# get the probalistic output of the model.
|
||||
z = model.model(input_var)
|
||||
pred = softmax(z)
|
||||
pred = ct.softmax(z)
|
||||
|
||||
epoch_size = train_data_reader.size()
|
||||
minibatch_size = 32
|
||||
@@ -101,7 +100,7 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
|
||||
|
||||
# construct the trainer
|
||||
learner = momentum_sgd(z.parameters, lr_schedule, mm_schedule)
|
||||
trainer = Trainer(z, train_loss, pe, learner)
|
||||
trainer = Trainer(z, (train_loss, pe), learner)
|
||||
|
||||
# Get minibatches of images to train with and perform model training
|
||||
max_val_accuracy = 0.0
|
||||
@@ -127,8 +126,8 @@ def main(base_folder, training_mode='majority', model_name='VGG13', max_epochs =
|
||||
trainer.train_minibatch({input_var : images, label_var : labels})
|
||||
|
||||
# keep track of statistics.
|
||||
training_loss += get_train_loss(trainer) * current_batch_size
|
||||
training_accuracy += get_train_eval_criterion(trainer) * current_batch_size
|
||||
training_loss += trainer.previous_minibatch_loss_average * current_batch_size
|
||||
training_accuracy += trainer.previous_minibatch_evaluation_average * current_batch_size
|
||||
|
||||
training_accuracy /= train_data_reader.size()
|
||||
training_accuracy = 1.0 - training_accuracy
|
||||
|
||||
Reference in New Issue
Block a user