mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-18 14:55:42 +00:00
for test
This commit is contained in:
@@ -291,7 +291,7 @@ def residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck=True
|
||||
return bn3 + shortcut
|
||||
|
||||
def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, use_se=False, bn_mom=0.9, workspace=256, memonger=False):
|
||||
return residual_unit_v3(data, num_filter, stride, dim_match, name, bottle_neck, use_se, bn_mom, workspace, memonger)
|
||||
return residual_unit_v1(data, num_filter, stride, dim_match, name, bottle_neck, use_se, bn_mom, workspace, memonger)
|
||||
|
||||
def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se=False, bn_mom=0.9, workspace=256, memonger=False):
|
||||
"""Return ResNet symbol of
|
||||
@@ -310,6 +310,8 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se
|
||||
workspace : int
|
||||
Workspace used in convolution operator
|
||||
"""
|
||||
L_type = False
|
||||
fc_type = 'B'#'A'-'E'
|
||||
num_unit = len(units)
|
||||
assert(num_unit == num_stages)
|
||||
data = mx.sym.Variable(name='data')
|
||||
@@ -317,41 +319,48 @@ def resnet(units, num_stages, filter_list, num_classes, bottle_neck=True, use_se
|
||||
data = data-127.5
|
||||
data = data*0.0078125
|
||||
#data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
|
||||
#body = Conv(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
|
||||
# no_bias=True, name="conv0", workspace=workspace)
|
||||
body = Conv(data=data, num_filter=filter_list[0], kernel=(3,3), stride=(1,1), pad=(1, 1),
|
||||
no_bias=True, name="conv0", workspace=workspace)
|
||||
if not L_type:
|
||||
body = Conv(data=data, num_filter=filter_list[0], kernel=(7, 7), stride=(2,2), pad=(3, 3),
|
||||
no_bias=True, name="conv0", workspace=workspace)
|
||||
else:
|
||||
body = Conv(data=data, num_filter=filter_list[0], kernel=(3,3), stride=(1,1), pad=(1, 1),
|
||||
no_bias=True, name="conv0", workspace=workspace)
|
||||
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
|
||||
body = Act(data=body, act_type='relu', name='relu0')
|
||||
#body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
|
||||
body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
|
||||
|
||||
for i in range(num_stages):
|
||||
#body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
|
||||
# name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
|
||||
# memonger=memonger)
|
||||
body = residual_unit(body, filter_list[i+1], (2, 2), False,
|
||||
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
|
||||
memonger=memonger)
|
||||
body = residual_unit(body, filter_list[i+1], (1 if i==0 else 2, 1 if i==0 else 2), False,
|
||||
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
|
||||
memonger=memonger)
|
||||
#body = residual_unit(body, filter_list[i+1], (2, 2), False,
|
||||
# name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, use_se=use_se,workspace=workspace,
|
||||
# memonger=memonger)
|
||||
for j in range(units[i]-1):
|
||||
body = residual_unit(body, filter_list[i+1], (1,1), True, name='stage%d_unit%d' % (i+1, j+2),
|
||||
bottle_neck=bottle_neck, use_se=use_se, workspace=workspace, memonger=memonger)
|
||||
#body = residual_unit(body, filter_list[i+1], (2,2), False, name='stage%d_unit%d' % (i+1, units[i]),
|
||||
# bottle_neck=bottle_neck, use_se=use_se, workspace=workspace, memonger=memonger)
|
||||
fc_type = 1#0 or 1
|
||||
|
||||
if fc_type==0:
|
||||
|
||||
if fc_type=='E':
|
||||
body = mx.symbol.Dropout(data=body, p=0.4)
|
||||
fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1')
|
||||
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
|
||||
else:
|
||||
bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
|
||||
relu1 = Act(data=bn1, act_type='relu', name='relu1')
|
||||
# Although kernel is not used here when global_pool=True, we should put one
|
||||
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
|
||||
flat = mx.sym.Flatten(data=pool1)
|
||||
#fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='fc1')
|
||||
fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1')
|
||||
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
|
||||
else:
|
||||
body = mx.symbol.Dropout(data=body, p=0.4)
|
||||
fc1 = mx.sym.FullyConnected(data=body, num_hidden=num_classes, name='pre_fc1')
|
||||
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
|
||||
if fc_type=='A':
|
||||
fc1 = flat
|
||||
else:
|
||||
#B
|
||||
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_classes, name='pre_fc1')
|
||||
if fc_type=='C':
|
||||
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
|
||||
elif fc_type=='D':
|
||||
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
|
||||
fc1 = Act(data=fc1, act_type='relu', name='fc1_relu')
|
||||
return fc1
|
||||
|
||||
def get_symbol(num_classes, num_layers, conv_workspace=256, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user