mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-17 06:12:02 +00:00
294 lines
13 KiB
Python
294 lines
13 KiB
Python
import mxnet as mx
|
|
import numpy as np
|
|
|
|
|
|
|
|
def _Conv(data, num_filter, kernel, stride, pad, name, no_bias=False, workspace=256):
|
|
_weight = mx.symbol.Variable(name+'_weight')
|
|
_bias = mx.symbol.Variable(name+'_bias', lr_mult=2.0, wd_mult=0.0)
|
|
body = mx.sym.Convolution(data=data, weight = _weight, bias = _bias, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias = no_bias, workspace = workspace, name = name)
|
|
return body
|
|
|
|
def Conv(**kwargs):
|
|
name = kwargs.get('name')
|
|
_weight = mx.symbol.Variable(name+'_weight')
|
|
_bias = mx.symbol.Variable(name+'_bias', lr_mult=2.0, wd_mult=0.0)
|
|
body = mx.sym.Convolution(weight = _weight, bias = _bias, **kwargs)
|
|
return body
|
|
|
|
|
|
def Act(data, name):
|
|
body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name)
|
|
return body
|
|
|
|
def resnet_unit0(data, num_filter, name, workspace = 256):
|
|
bn_mom = 0.9
|
|
body = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn')
|
|
body = Act(data=body, name=name+'_relu')
|
|
return body
|
|
|
|
def resnet_unit1(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
shortcut = data
|
|
body = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv1", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv2", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
if dim_match:
|
|
body = body+shortcut
|
|
body = Act(data=body, name=name+'_relu2')
|
|
return body
|
|
|
|
def resnet_unit2(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
shortcut = data
|
|
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv1", workspace=workspace)
|
|
#body = mx.symbol.Dropout(data=body, p=0.2)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
body = Act(data=body, name=name+'_relu2')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv2", workspace=workspace)
|
|
if dim_match:
|
|
body = body+shortcut
|
|
return body
|
|
|
|
def resnet_unit3(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
shortcut = data
|
|
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv1", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
body = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv2", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
|
|
if dim_match:
|
|
body = body+shortcut
|
|
return body
|
|
|
|
def resnet_unit4(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
shortcut = data
|
|
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=data, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
|
|
name=name+"_conv1", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
body = Act(data=body, name=name+'_relu2')
|
|
body = Conv(data=body, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv2", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
|
|
body = Act(data=body, name=name+'_relu3')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0, 0),
|
|
name=name+"_conv3", workspace=workspace)
|
|
if dim_match:
|
|
body = body+shortcut
|
|
return body
|
|
|
|
def resnet_unit5(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
shortcut = data
|
|
body = Conv(data=data, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
|
|
name=name+"_conv1", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=body, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
|
|
name=name+"_conv2", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
body = Act(data=body, name=name+'_relu2')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
|
name=name+"_conv3", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
|
|
if dim_match:
|
|
body = body+shortcut
|
|
body = Act(data=body, name=name+'_relu3')
|
|
return body
|
|
|
|
def resnet_unit6(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
shortcut = data
|
|
body = Conv(data=data, num_filter=num_filter*4, kernel=(1,1), stride=(1,1), pad=(0,0),
|
|
name=name+"_conv1", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=body, num_filter=num_filter*4, kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
|
|
name=name+"_conv2", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
body = Act(data=body, name=name+'_relu2')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
|
name=name+"_conv3", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
|
|
if dim_match:
|
|
body = body+shortcut
|
|
body = Act(data=body, name=name+'_relu3')
|
|
return body
|
|
|
|
def resnet_unit7(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
shortcut = data
|
|
body = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
|
|
name=name+"_conv1", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
|
|
name=name+"_conv2", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
if dim_match:
|
|
body = body+shortcut
|
|
body = Act(data=body, name=name+'_relu2')
|
|
return body
|
|
|
|
def resnet_unit100(data, num_filter, name, dim_match=True, workspace = 256):
|
|
bn_mom = 0.9
|
|
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
|
|
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv1", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
|
|
act = Act(data=body, name=name+'_relu1')
|
|
body = Conv(data=act, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
|
name=name+"_conv2", workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
|
|
if not dim_match:
|
|
shortcut = Conv(data=act, num_filter=num_filter, kernel=(1,1), pad=(0,0), name=name+"_shortcut", workspace=workspace)
|
|
else:
|
|
shortcut = data
|
|
body = body+shortcut
|
|
return body
|
|
|
|
def resnet_unit(rtype, data, num_filter, name, dim_match=True, workspace = 256):
|
|
if rtype==1:
|
|
return resnet_unit1(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
elif rtype==2:
|
|
return resnet_unit2(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
elif rtype==3:
|
|
return resnet_unit3(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
elif rtype==4:
|
|
return resnet_unit4(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
elif rtype==5:
|
|
return resnet_unit5(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
elif rtype==6:
|
|
return resnet_unit6(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
elif rtype==7:
|
|
return resnet_unit7(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
elif rtype==100:
|
|
return resnet_unit100(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
|
|
else:
|
|
assert(False)
|
|
|
|
def resnet(data, units, filters, rtype, workspace):
|
|
body = resnet_unit0(data=data, num_filter=32, name="stage%d_unit%d"%(0, 0))
|
|
for i in xrange(len(units)):
|
|
f = filters[i]
|
|
dim_match = False
|
|
if i==0:
|
|
dim_match = True
|
|
if rtype>=100:
|
|
body = resnet_unit(rtype=rtype, data=body, num_filter=f, name="stage%d_unit%d"%(i+1, 0), dim_match=dim_match) # do not connect to last layer, dim not match
|
|
else:
|
|
body = resnet_unit0(data=body, num_filter=f, name="stage%d_unit%d"%(i+1, 0)) # do not connect to last layer, dim not match
|
|
body = mx.sym.Pooling(data=body, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max', name="stage%d_pool"%(i+1))
|
|
for j in xrange(units[i]):
|
|
body = resnet_unit(rtype=rtype, data=body, num_filter=f, name="stage%d_unit%d"%(i+1, j+1), dim_match=True)
|
|
|
|
return body
|
|
|
|
def get_symbol(num_classes, num_layers, conv_workspace=256):
|
|
data = mx.symbol.Variable('data')
|
|
bn_mom = 0.9
|
|
if num_layers<29:
|
|
data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
|
|
else:
|
|
data = data-127.5
|
|
data = data*0.0078125
|
|
units = [1,2,5,3] # all number of layers = sum(units)*2+len(units)+1
|
|
filter_list = [64, 128, 256, 512]
|
|
rtype = 1
|
|
ftype = 1
|
|
if num_layers==27:
|
|
rtype = 1
|
|
elif num_layers==28:
|
|
rtype = 2
|
|
elif num_layers==29:
|
|
rtype = 3
|
|
#use_last_bn = False
|
|
#use_dropout = False
|
|
elif num_layers==30:
|
|
filter_list = [64, 256, 512, 1024]
|
|
rtype = 3
|
|
elif num_layers==31:
|
|
rtype = 100
|
|
elif num_layers==51:
|
|
units = [2,3,15,3]
|
|
rtype = 3
|
|
elif num_layers==52:
|
|
filter_list = [64, 256, 512, 1024]
|
|
units = [2,3,15,3]
|
|
rtype = 7
|
|
elif num_layers==74:
|
|
units = [2,3,15,3]
|
|
rtype = 4
|
|
elif num_layers==75:
|
|
units = [2,3,15,3]
|
|
rtype = 5
|
|
elif num_layers==76:
|
|
filter_list = [16, 32, 64, 128]
|
|
units = [2,3,15,3]
|
|
rtype = 6
|
|
else:
|
|
assert(False)
|
|
|
|
body = resnet(data = data, units = units, filters = filter_list, rtype=rtype, workspace = conv_workspace)
|
|
_weight = mx.symbol.Variable("fc1_weight")
|
|
_bias = mx.symbol.Variable("fc1_bias", lr_mult=2.0, wd_mult=0.0)
|
|
if ftype==0:
|
|
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1')
|
|
elif ftype==1:
|
|
body = mx.symbol.Dropout(data=body, p=0.4)
|
|
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='pre_fc1')
|
|
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
|
|
else:
|
|
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
|
|
body = mx.sym.Activation(data=body, act_type='relu', name='relu1')
|
|
body = mx.sym.Pooling(data=body, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
|
|
body = mx.sym.Flatten(data=body)
|
|
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1')
|
|
|
|
return fc1
|
|
|
|
def init_weights(sym, data_shape_dict, num_layers):
|
|
arg_name = sym.list_arguments()
|
|
aux_name = sym.list_auxiliary_states()
|
|
arg_shape, aaa, aux_shape = sym.infer_shape(**data_shape_dict)
|
|
#print(data_shape_dict)
|
|
#print(arg_name)
|
|
#print(arg_shape)
|
|
arg_params = {}
|
|
aux_params = None
|
|
#print(aaa)
|
|
#print(aux_shape)
|
|
arg_shape_dict = dict(zip(arg_name, arg_shape))
|
|
aux_shape_dict = dict(zip(aux_name, aux_shape))
|
|
#print(aux_shape)
|
|
#print(aux_params)
|
|
#print(arg_shape_dict)
|
|
for k,v in arg_shape_dict.iteritems():
|
|
#print('find', k)
|
|
if k.endswith('_weight') and k.find('_conv')>=0:
|
|
if not k.find('_unit0_')>=0:
|
|
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
|
|
print('init', k)
|
|
if k.endswith('_bias'):
|
|
arg_params[k] = mx.nd.zeros(shape=v)
|
|
print('init', k)
|
|
return arg_params, aux_params
|
|
|