mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-19 07:27:52 +00:00
add get_head
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
# MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU
|
||||
os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2'
|
||||
#os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2'
|
||||
import mxnet as mx
|
||||
|
||||
|
||||
|
||||
@@ -177,16 +177,15 @@ def get_symbol(num_classes = 1000, num_layers=92, **kwargs):
|
||||
|
||||
## define Dual Path Network
|
||||
data = mx.symbol.Variable(name="data")
|
||||
data = data-127.5
|
||||
data = data*0.0078125
|
||||
|
||||
# conv1
|
||||
if version_input==0:
|
||||
conv1_x_1 = Conv(data=data, num_filter=128, kernel=(7, 7), name='conv1_x_1', pad=(3,3), stride=(2,2))
|
||||
else:
|
||||
conv1_x_1 = Conv(data=data, num_filter=128, kernel=(3, 3), name='conv1_x_1', pad=(3,3), stride=(1,1))
|
||||
conv1_x_1 = BN_AC(conv1_x_1, name='conv1_x_1__relu-sp')
|
||||
conv1_x_x = mx.symbol.Pooling(data=conv1_x_1, pool_type="max", kernel=(3, 3), pad=(1,1), stride=(2,2), name="pool1")
|
||||
#data = data-127.5
|
||||
#data = data*0.0078125
|
||||
#if version_input==0:
|
||||
# conv1_x_1 = Conv(data=data, num_filter=128, kernel=(7, 7), name='conv1_x_1', pad=(3,3), stride=(2,2))
|
||||
#else:
|
||||
# conv1_x_1 = Conv(data=data, num_filter=128, kernel=(3, 3), name='conv1_x_1', pad=(3,3), stride=(1,1))
|
||||
#conv1_x_1 = BN_AC(conv1_x_1, name='conv1_x_1__relu-sp')
|
||||
#conv1_x_x = mx.symbol.Pooling(data=conv1_x_1, pool_type="max", kernel=(3, 3), pad=(1,1), stride=(2,2), name="pool1")
|
||||
conv1_x_x = symbol_utils.get_head(data, version_input, 128)
|
||||
|
||||
# conv2
|
||||
bw = 256
|
||||
|
||||
@@ -44,3 +44,68 @@ def get_fc1(last_conv, num_classes, fc_type):
|
||||
fc1 = Act(data=fc1, act_type='relu', name='fc1_relu')
|
||||
return fc1
|
||||
|
||||
def residual_unit_v3(data, num_filter, stride, dim_match, name, **kwargs):
|
||||
|
||||
"""Return ResNet Unit symbol for building ResNet
|
||||
Parameters
|
||||
----------
|
||||
data : str
|
||||
Input data
|
||||
num_filter : int
|
||||
Number of output channels
|
||||
bnf : int
|
||||
Bottle neck channels factor with regard to num_filter
|
||||
stride : tuple
|
||||
Stride used in convolution
|
||||
dim_match : Boolean
|
||||
True means channel number between input and output is the same, otherwise means differ
|
||||
name : str
|
||||
Base name of the operators
|
||||
workspace : int
|
||||
Workspace used in convolution operator
|
||||
"""
|
||||
bn_mom = kwargs.get('bn_mom', 0.9)
|
||||
workspace = kwargs.get('workspace', 256)
|
||||
memonger = kwargs.get('memonger', False)
|
||||
#print('in unit3')
|
||||
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
|
||||
conv1 = Conv(data=bn1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv1')
|
||||
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
|
||||
act1 = Act(data=bn2, act_type='relu', name=name + '_relu1')
|
||||
conv2 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
|
||||
no_bias=True, workspace=workspace, name=name + '_conv2')
|
||||
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
|
||||
|
||||
if dim_match:
|
||||
shortcut = data
|
||||
else:
|
||||
conv1sc = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
|
||||
workspace=workspace, name=name+'_conv1sc')
|
||||
shortcut = mx.sym.BatchNorm(data=conv1sc, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_sc')
|
||||
if memonger:
|
||||
shortcut._set_attr(mirror_stage='True')
|
||||
return bn3 + shortcut
|
||||
|
||||
|
||||
def get_head(data, version_input, num_filter):
|
||||
data = data-127.5
|
||||
data = data*0.0078125
|
||||
#data = mx.sym.BatchNorm(data=data, fix_gamma=True, eps=2e-5, momentum=bn_mom, name='bn_data')
|
||||
if version_input==0:
|
||||
body = Conv(data=data, num_filter=num_filter, kernel=(7, 7), stride=(2,2), pad=(3, 3),
|
||||
no_bias=True, name="conv0", workspace=workspace)
|
||||
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
|
||||
body = Act(data=body, act_type='relu', name='relu0')
|
||||
body = mx.sym.Pooling(data=body, kernel=(3, 3), stride=(2,2), pad=(1,1), pool_type='max')
|
||||
else:
|
||||
body = data
|
||||
body = Conv(data=body, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1, 1),
|
||||
no_bias=True, name="conv0", workspace=workspace)
|
||||
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
|
||||
body = Act(data=body, act_type='relu', name='relu0')
|
||||
body = residual_unit_v3(body, num_filter, (2, 2), False,
|
||||
name='head', **kwargs)
|
||||
return body
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user