mirror of
https://gitcode.com/gh_mirrors/fe/FERPlus.git
synced 2025-12-30 05:22:26 +00:00
Add name for each layer.
This commit is contained in:
@@ -51,22 +51,23 @@ class VGG13(object):
|
||||
with ct.default_options(activation=ct.relu, init=ct.glorot_uniform()):
|
||||
model = ct.Sequential([
|
||||
ct.For(range(2), lambda i: [
|
||||
ct.Convolution((3,3), [64,128][i], pad=True),
|
||||
ct.Convolution((3,3), [64,128][i], pad=True),
|
||||
ct.MaxPooling((2,2), strides=(2,2)),
|
||||
ct.Dropout(0.25)
|
||||
ct.Convolution((3,3), [64,128][i], pad=True, name='conv{}-1'.format(i+1)),
|
||||
ct.Convolution((3,3), [64,128][i], pad=True, name='conv{}-2'.format(i+1)),
|
||||
ct.MaxPooling((2,2), strides=(2,2), name='pool{}-1'.format(i+1)),
|
||||
ct.Dropout(0.25, name='drop{}-1'.format(i+1))
|
||||
]),
|
||||
ct.For(range(2), lambda i: [
|
||||
ct.Convolution((3,3), [256,256][i], pad=True),
|
||||
ct.Convolution((3,3), [256,256][i], pad=True),
|
||||
ct.Convolution((3,3), [256,256][i], pad=True),
|
||||
ct.MaxPooling((2,2), strides=(2,2)),
|
||||
ct.Dropout(0.25)
|
||||
ct.Convolution((3,3), [256,256][i], pad=True, name='conv{}-1'.format(i+3)),
|
||||
ct.Convolution((3,3), [256,256][i], pad=True, name='conv{}-2'.format(i+3)),
|
||||
ct.Convolution((3,3), [256,256][i], pad=True, name='conv{}-3'.format(i+3)),
|
||||
ct.MaxPooling((2,2), strides=(2,2), name='pool{}-1'.format(i+3)),
|
||||
ct.Dropout(0.25, name='drop{}-1'.format(i+3))
|
||||
]),
|
||||
ct.For(range(2), lambda : [
|
||||
ct.Dense(1024),
|
||||
ct.Dropout(0.5)
|
||||
ct.For(range(2), lambda i: [
|
||||
ct.Dense(1024, activation=None, name='fc{}'.format(i+5)),
|
||||
ct.Activation(activation=ct.relu, name='relu{}'.format(i+5)),
|
||||
ct.Dropout(0.5, name='drop{}'.format(i+5))
|
||||
]),
|
||||
ct.Dense(num_classes, activation=None)
|
||||
ct.Dense(num_classes, activation=None, name='output')
|
||||
])
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user