Files
insightface/src/marginalnet.py
nttstar 7b42654dc1 tiny
2017-11-22 00:06:40 +08:00

294 lines
13 KiB
Python

import mxnet as mx
import numpy as np
def _Conv(data, num_filter, kernel, stride, pad, name, no_bias=False, workspace=256):
_weight = mx.symbol.Variable(name+'_weight')
_bias = mx.symbol.Variable(name+'_bias', lr_mult=2.0, wd_mult=0.0)
body = mx.sym.Convolution(data=data, weight = _weight, bias = _bias, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias = no_bias, workspace = workspace, name = name)
return body
def Conv(**kwargs):
name = kwargs.get('name')
_weight = mx.symbol.Variable(name+'_weight')
_bias = mx.symbol.Variable(name+'_bias', lr_mult=2.0, wd_mult=0.0)
body = mx.sym.Convolution(weight = _weight, bias = _bias, **kwargs)
return body
def Act(data, name):
body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name)
return body
def resnet_unit0(data, num_filter, name, workspace = 256):
bn_mom = 0.9
body = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn')
body = Act(data=body, name=name+'_relu')
return body
def resnet_unit1(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
shortcut = data
body = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv1", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Act(data=body, name=name+'_relu1')
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv2", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
if dim_match:
body = body+shortcut
body = Act(data=body, name=name+'_relu2')
return body
def resnet_unit2(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
shortcut = data
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Act(data=body, name=name+'_relu1')
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv1", workspace=workspace)
#body = mx.symbol.Dropout(data=body, p=0.2)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
body = Act(data=body, name=name+'_relu2')
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv2", workspace=workspace)
if dim_match:
body = body+shortcut
return body
def resnet_unit3(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
shortcut = data
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv1", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
body = Act(data=body, name=name+'_relu1')
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv2", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
if dim_match:
body = body+shortcut
return body
def resnet_unit4(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
shortcut = data
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Act(data=body, name=name+'_relu1')
body = Conv(data=data, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_conv1", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
body = Act(data=body, name=name+'_relu2')
body = Conv(data=body, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv2", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
body = Act(data=body, name=name+'_relu3')
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0, 0),
name=name+"_conv3", workspace=workspace)
if dim_match:
body = body+shortcut
return body
def resnet_unit5(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
shortcut = data
body = Conv(data=data, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_conv1", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Act(data=body, name=name+'_relu1')
body = Conv(data=body, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
name=name+"_conv2", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
body = Act(data=body, name=name+'_relu2')
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_conv3", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
if dim_match:
body = body+shortcut
body = Act(data=body, name=name+'_relu3')
return body
def resnet_unit6(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
shortcut = data
body = Conv(data=data, num_filter=num_filter*4, kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_conv1", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Act(data=body, name=name+'_relu1')
body = Conv(data=body, num_filter=num_filter*4, kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
name=name+"_conv2", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
body = Act(data=body, name=name+'_relu2')
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_conv3", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
if dim_match:
body = body+shortcut
body = Act(data=body, name=name+'_relu3')
return body
def resnet_unit7(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
shortcut = data
body = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
name=name+"_conv1", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Act(data=body, name=name+'_relu1')
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1), num_group=32,
name=name+"_conv2", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
if dim_match:
body = body+shortcut
body = Act(data=body, name=name+'_relu2')
return body
def resnet_unit100(data, num_filter, name, dim_match=True, workspace = 256):
bn_mom = 0.9
body = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn1')
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv1", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn2')
act = Act(data=body, name=name+'_relu1')
body = Conv(data=act, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
name=name+"_conv2", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name+'_bn3')
if not dim_match:
shortcut = Conv(data=act, num_filter=num_filter, kernel=(1,1), pad=(0,0), name=name+"_shortcut", workspace=workspace)
else:
shortcut = data
body = body+shortcut
return body
def resnet_unit(rtype, data, num_filter, name, dim_match=True, workspace = 256):
if rtype==1:
return resnet_unit1(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
elif rtype==2:
return resnet_unit2(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
elif rtype==3:
return resnet_unit3(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
elif rtype==4:
return resnet_unit4(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
elif rtype==5:
return resnet_unit5(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
elif rtype==6:
return resnet_unit6(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
elif rtype==7:
return resnet_unit7(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
elif rtype==100:
return resnet_unit100(data=data, num_filter=num_filter, name=name, dim_match=dim_match, workspace=workspace)
else:
assert(False)
def resnet(data, units, filters, rtype, workspace):
body = resnet_unit0(data=data, num_filter=32, name="stage%d_unit%d"%(0, 0))
for i in xrange(len(units)):
f = filters[i]
dim_match = False
if i==0:
dim_match = True
if rtype>=100:
body = resnet_unit(rtype=rtype, data=body, num_filter=f, name="stage%d_unit%d"%(i+1, 0), dim_match=dim_match) # do not connect to last layer, dim not match
else:
body = resnet_unit0(data=body, num_filter=f, name="stage%d_unit%d"%(i+1, 0)) # do not connect to last layer, dim not match
body = mx.sym.Pooling(data=body, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max', name="stage%d_pool"%(i+1))
for j in xrange(units[i]):
body = resnet_unit(rtype=rtype, data=body, num_filter=f, name="stage%d_unit%d"%(i+1, j+1), dim_match=True)
return body
def get_symbol(num_classes, num_layers, conv_workspace=256):
data = mx.symbol.Variable('data')
bn_mom = 0.9
if num_layers<29:
data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
else:
data = data-127.5
data = data*0.0078125
units = [1,2,5,3] # all number of layers = sum(units)*2+len(units)+1
filter_list = [64, 128, 256, 512]
rtype = 1
ftype = 1
if num_layers==27:
rtype = 1
elif num_layers==28:
rtype = 2
elif num_layers==29:
rtype = 3
#use_last_bn = False
#use_dropout = False
elif num_layers==30:
filter_list = [64, 256, 512, 1024]
rtype = 3
elif num_layers==31:
rtype = 100
elif num_layers==51:
units = [2,3,15,3]
rtype = 3
elif num_layers==52:
filter_list = [64, 256, 512, 1024]
units = [2,3,15,3]
rtype = 7
elif num_layers==74:
units = [2,3,15,3]
rtype = 4
elif num_layers==75:
units = [2,3,15,3]
rtype = 5
elif num_layers==76:
filter_list = [16, 32, 64, 128]
units = [2,3,15,3]
rtype = 6
else:
assert(False)
body = resnet(data = data, units = units, filters = filter_list, rtype=rtype, workspace = conv_workspace)
_weight = mx.symbol.Variable("fc1_weight")
_bias = mx.symbol.Variable("fc1_bias", lr_mult=2.0, wd_mult=0.0)
if ftype==0:
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1')
elif ftype==1:
body = mx.symbol.Dropout(data=body, p=0.4)
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='pre_fc1')
fc1 = mx.sym.BatchNorm(data=fc1, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='fc1')
else:
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn1')
body = mx.sym.Activation(data=body, act_type='relu', name='relu1')
body = mx.sym.Pooling(data=body, global_pool=True, kernel=(7, 7), pool_type='avg', name='pool1')
body = mx.sym.Flatten(data=body)
fc1 = mx.sym.FullyConnected(data=body, weight=_weight, bias=_bias, num_hidden=num_classes, name='fc1')
return fc1
def init_weights(sym, data_shape_dict, num_layers):
arg_name = sym.list_arguments()
aux_name = sym.list_auxiliary_states()
arg_shape, aaa, aux_shape = sym.infer_shape(**data_shape_dict)
#print(data_shape_dict)
#print(arg_name)
#print(arg_shape)
arg_params = {}
aux_params = None
#print(aaa)
#print(aux_shape)
arg_shape_dict = dict(zip(arg_name, arg_shape))
aux_shape_dict = dict(zip(aux_name, aux_shape))
#print(aux_shape)
#print(aux_params)
#print(arg_shape_dict)
for k,v in arg_shape_dict.iteritems():
#print('find', k)
if k.endswith('_weight') and k.find('_conv')>=0:
if not k.find('_unit0_')>=0:
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
print('init', k)
if k.endswith('_bias'):
arg_params[k] = mx.nd.zeros(shape=v)
print('init', k)
return arg_params, aux_params