mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
363 lines
17 KiB
Python
363 lines
17 KiB
Python
import mxnet as mx
|
|
import mxnet.ndarray as nd
|
|
import mxnet.gluon as gluon
|
|
import mxnet.gluon.nn as nn
|
|
import mxnet.autograd as ag
|
|
import numpy as np
|
|
from rcnn.config import config
|
|
from rcnn.PY_OP import rpn_fpn_ohem, rpn_fpn_ohem2, rpn_fpn_ohem3
|
|
|
|
USE_DCN = False
|
|
MM = 1.0
|
|
|
|
def ConvBlock(channels, kernel_size, strides, **kwargs):
|
|
out = nn.HybridSequential(**kwargs)
|
|
with out.name_scope():
|
|
out.add(
|
|
nn.Conv2D(channels, kernel_size, strides=strides, padding=1, use_bias=False),
|
|
nn.BatchNorm(scale=True),
|
|
nn.Activation('relu')
|
|
)
|
|
return out
|
|
|
|
def Conv1x1(channels, is_linear=False, **kwargs):
|
|
out = nn.HybridSequential(**kwargs)
|
|
with out.name_scope():
|
|
out.add(
|
|
nn.Conv2D(channels, 1, padding=0, use_bias=False),
|
|
nn.BatchNorm(scale=True)
|
|
)
|
|
if not is_linear:
|
|
out.add(nn.Activation('relu'))
|
|
return out
|
|
|
|
def DWise(channels, strides, kernel_size=3, **kwargs):
|
|
out = nn.HybridSequential(**kwargs)
|
|
with out.name_scope():
|
|
out.add(
|
|
nn.Conv2D(channels, kernel_size, strides=strides, padding=kernel_size // 2, groups=channels, use_bias=False),
|
|
nn.BatchNorm(scale=True),
|
|
nn.Activation('relu')
|
|
)
|
|
return out
|
|
|
|
class SepCONV(nn.HybridBlock):
|
|
def __init__(self, inp, output, kernel_size, depth_multiplier=1, with_bn=True, **kwargs):
|
|
super(SepCONV, self).__init__(**kwargs)
|
|
with self.name_scope():
|
|
self.net = nn.HybridSequential()
|
|
cn = int(inp*depth_multiplier)
|
|
|
|
if output is None:
|
|
self.net.add(
|
|
nn.Conv2D(in_channels=inp, channels=cn, groups=inp, kernel_size=kernel_size, strides=(1,1), padding=kernel_size // 2
|
|
, use_bias=not with_bn)
|
|
)
|
|
else:
|
|
self.net.add(
|
|
nn.Conv2D(in_channels=inp, channels=cn, groups=inp, kernel_size=kernel_size, strides=(1,1), padding=kernel_size // 2
|
|
, use_bias=False),
|
|
nn.BatchNorm(),
|
|
nn.Activation('relu'),
|
|
nn.Conv2D(in_channels=cn, channels=output, kernel_size=(1,1), strides=(1,1)
|
|
, use_bias=not with_bn)
|
|
)
|
|
|
|
self.with_bn = with_bn
|
|
self.act = nn.Activation('relu')
|
|
if with_bn:
|
|
self.bn = nn.BatchNorm()
|
|
def hybrid_forward(self, F ,x):
|
|
x = self.net(x)
|
|
if self.with_bn:
|
|
x = self.bn(x)
|
|
if self.act is not None:
|
|
x = self.act(x)
|
|
return x
|
|
|
|
class ExpandedConv(nn.HybridBlock):
|
|
def __init__(self, inp, oup, t, strides, kernel=3, same_shape=True, **kwargs):
|
|
super(ExpandedConv, self).__init__(**kwargs)
|
|
|
|
self.same_shape = same_shape
|
|
self.strides = strides
|
|
with self.name_scope():
|
|
self.bottleneck = nn.HybridSequential()
|
|
self.bottleneck.add(
|
|
Conv1x1(inp*t, prefix="expand_"),
|
|
DWise(inp*t, self.strides, kernel, prefix="dwise_"),
|
|
Conv1x1(oup, is_linear=True, prefix="linear_")
|
|
)
|
|
def hybrid_forward(self, F, x):
|
|
out = self.bottleneck(x)
|
|
if self.strides == 1 and self.same_shape:
|
|
out = F.elemwise_add(out, x)
|
|
return out
|
|
|
|
def ExpandedConvSequence(t, k, inp, oup, repeats, first_strides, **kwargs):
|
|
seq = nn.HybridSequential(**kwargs)
|
|
with seq.name_scope():
|
|
seq.add(ExpandedConv(inp, oup, t, first_strides, k, same_shape=False))
|
|
curr_inp = oup
|
|
for i in range(1, repeats):
|
|
seq.add(ExpandedConv(curr_inp, oup, t, 1))
|
|
curr_inp = oup
|
|
return seq
|
|
|
|
class Mnasnet(nn.HybridBlock):
|
|
def __init__(self, multiplier=1.0, **kwargs):
|
|
super(Mnasnet, self).__init__(**kwargs)
|
|
mm = multiplier
|
|
|
|
self.first_oup = 32
|
|
self.interverted_residual_setting = [
|
|
# t, c, n, s, k
|
|
[3, int(24*mm), 3, 2, 3, "stage2_"], # -> 56x56
|
|
[3, int(40*mm), 3, 2, 5, "stage3_"], # -> 28x28
|
|
[6, int(80*mm), 3, 2, 5, "stage4_1_"], # -> 14x14
|
|
[6, int(96*mm), 2, 1, 3, "stage4_2_"], # -> 14x14
|
|
[6, int(192*mm), 4, 2, 5, "stage5_1_"], # -> 7x7
|
|
[6, int(320*mm), 1, 1, 3, "stage5_2_"], # -> 7x7
|
|
]
|
|
self.last_channels = 1280
|
|
|
|
with self.name_scope():
|
|
self.features = nn.HybridSequential()
|
|
self.features.add(ConvBlock(self.first_oup, 3, 2, prefix="stage1_conv0_"))
|
|
self.features.add(SepCONV(self.first_oup, 16, 3, prefix="stage1_sepconv0_"))
|
|
inp = 16
|
|
for i, (t, c, n, s, k, prefix) in enumerate(self.interverted_residual_setting):
|
|
oup = c
|
|
self.features.add(ExpandedConvSequence(t, k, inp, oup, n, s, prefix=prefix))
|
|
inp = oup
|
|
|
|
self.features.add(Conv1x1(self.last_channels, prefix="stage5_3_"))
|
|
def hybrid_forward(self, F, x):
|
|
x = self.features(x)
|
|
return x
|
|
|
|
def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \
|
|
stride=(1,1), act_type="relu", bias_wd_mult=0.0, dcn=False):
|
|
|
|
weight = mx.symbol.Variable(name="{}_weight".format(name),
|
|
init=mx.init.Normal(0.01), attr={'__lr_mult__': '1.0'})
|
|
bias = mx.symbol.Variable(name="{}_bias".format(name),
|
|
init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0', '__wd_mult__': str(bias_wd_mult)})
|
|
if not dcn:
|
|
conv = mx.symbol.Convolution(data=from_layer, kernel=kernel, pad=pad, \
|
|
stride=stride, num_filter=num_filter, name="{}".format(name), weight = weight, bias=bias)
|
|
else:
|
|
assert kernel[0]==3 and kernel[1]==3
|
|
num_group = 1
|
|
f = num_group*18
|
|
offset_weight = mx.symbol.Variable(name="{}_offset_weight".format(name),
|
|
init=mx.init.Constant(0.0), attr={'__lr_mult__': '1.0'})
|
|
offset_bias = mx.symbol.Variable(name="{}_offset_bias".format(name),
|
|
init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0', '__wd_mult__': str(bias_wd_mult)})
|
|
conv_offset = mx.symbol.Convolution(name=name+'_offset', data = from_layer, weight=offset_weight, bias=offset_bias,
|
|
num_filter=f, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
|
conv = mx.contrib.symbol.DeformableConvolution(name=name, data=from_layer, offset=conv_offset, weight=weight, bias=bias,
|
|
num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=num_group, stride=(1, 1), no_bias=False)
|
|
if len(act_type)>0:
|
|
relu = mx.symbol.Activation(data=conv, act_type=act_type, \
|
|
name="{}_{}".format(name, act_type))
|
|
else:
|
|
relu = conv
|
|
return relu
|
|
|
|
def ssh_context_module(body, num_filters, name):
|
|
conv_dimred = conv_act_layer(body, name+'_conv1',
|
|
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='relu', dcn=False)
|
|
conv5x5 = conv_act_layer(conv_dimred, name+'_conv2',
|
|
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='', dcn=USE_DCN)
|
|
conv7x7_1 = conv_act_layer(conv_dimred, name+'_conv3_1',
|
|
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='relu', dcn=False)
|
|
conv7x7 = conv_act_layer(conv7x7_1, name+'_conv3_2',
|
|
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='', dcn=USE_DCN)
|
|
return (conv5x5, conv7x7)
|
|
|
|
def ssh_detection_module(body, num_filters, name):
|
|
conv3x3 = conv_act_layer(body, name+'_conv1',
|
|
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='', dcn=USE_DCN)
|
|
conv5x5, conv7x7 = ssh_context_module(body, num_filters//2, name+'_context')
|
|
ret = mx.sym.concat(*[conv3x3, conv5x5, conv7x7], dim=1, name = name+'_concat')
|
|
ret = mx.symbol.Activation(data=ret, act_type='relu', name=name+'_concat_relu')
|
|
return ret
|
|
|
|
def conv_bn(input, filter, ksize, stride, padding, act_type='relu', name=''):
|
|
conv = mx.symbol.Convolution(data=input, kernel=(ksize,ksize), pad=(padding,padding), \
|
|
stride=(stride,stride), num_filter=filter, name=name+"_conv")
|
|
ret = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name=name + '_bn')
|
|
if act_type is not None:
|
|
ret = mx.symbol.Activation(data=ret, act_type=act_type, \
|
|
name="{}_{}".format(name, act_type))
|
|
return ret
|
|
|
|
def cpm(input, name):
|
|
# residual
|
|
branch1 = conv_bn(input, 1024, 1, 1, 0, act_type=None, name=name+"_branch1")
|
|
branch2a = conv_bn(input, 256, 1, 1, 0, act_type='relu', name=name+"_branch2a")
|
|
branch2b = conv_bn(branch2a, 256, 3, 1, 1, act_type='relu', name=name+"_branch2b")
|
|
branch2c = conv_bn(branch2b, 1024, 1, 1, 0, act_type=None, name=name+"_branch2c")
|
|
sum = branch1 + branch2c
|
|
rescomb = mx.symbol.Activation(data=sum, act_type='relu', name="%s_relu2"%(name))
|
|
|
|
ssh_out = ssh_detection_module(rescomb, 256, name=name+"_ssh")
|
|
return ssh_out
|
|
|
|
def get_mnet_conv(data):
|
|
mm = MM
|
|
net = Mnasnet(mm, prefix="")
|
|
body = net(data)
|
|
|
|
all_layers = body.get_internals()
|
|
#print(all_layers)
|
|
c1 = all_layers['stage3_expandedconv2_elemwise_add0_output']
|
|
c2 = all_layers['stage4_2_expandedconv1_elemwise_add0_output']
|
|
#c3 = all_layers['stage5_3_relu0_fwd_output']
|
|
c3 = all_layers['stage5_2_expandedconv0_linear_batchnorm0_fwd_output']
|
|
|
|
F1 = int(256*mm)
|
|
F2 = int(128*mm)
|
|
_bwm = 1.0
|
|
conv4_128 = conv_act_layer(c1, 'ssh_m1_red_conv',
|
|
F2, kernel=(1, 1), pad=(0, 0), stride=(1, 1), act_type='relu', bias_wd_mult=_bwm)
|
|
conv5_128 = conv_act_layer(c2, 'ssh_m2_red_conv',
|
|
F2, kernel=(1, 1), pad=(0, 0), stride=(1, 1), act_type='relu', bias_wd_mult=_bwm)
|
|
conv5_128_up = mx.symbol.Deconvolution(data=conv5_128, num_filter=F2, kernel=(4,4), stride=(2, 2), pad=(1,1),
|
|
num_group = F2, no_bias = True, attr={'__lr_mult__': '0.0', '__wd_mult__': '0.0'},
|
|
name='ssh_m2_red_upsampling')
|
|
#conv5_128_up = mx.symbol.UpSampling(conv5_128, scale=2, sample_type='nearest', workspace=512, name='ssh_m2_red_up', num_args=1)
|
|
conv4_128 = mx.symbol.Crop(*[conv4_128, conv5_128_up])
|
|
#conv5_128_up = mx.symbol.Crop(*[conv5_128_up, conv4_128])
|
|
|
|
conv_sum = conv4_128+conv5_128_up
|
|
#conv_sum = conv_1x1
|
|
|
|
m1_conv = conv_act_layer(conv_sum, 'ssh_m1_conv',
|
|
F2, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='relu', bias_wd_mult=_bwm)
|
|
m1 = ssh_detection_module(m1_conv, F2, 'ssh_m1_det')
|
|
m2 = ssh_detection_module(c2, F1, 'ssh_m2_det')
|
|
m3 = ssh_detection_module(c3, F1, 'ssh_m3_det')
|
|
return {8: m1, 16:m2, 32: m3}
|
|
|
|
def get_out(conv_fpn_feat, prefix, stride, landmark=False, lr_mult=1.0):
|
|
A = config.NUM_ANCHORS
|
|
ret_group = []
|
|
num_anchors = config.RPN_ANCHOR_CFG[str(stride)]['NUM_ANCHORS']
|
|
label = mx.symbol.Variable(name='%s_label_stride%d'%(prefix,stride))
|
|
bbox_target = mx.symbol.Variable(name='%s_bbox_target_stride%d'%(prefix,stride))
|
|
bbox_weight = mx.symbol.Variable(name='%s_bbox_weight_stride%d'%(prefix,stride))
|
|
if landmark:
|
|
landmark_target = mx.symbol.Variable(name='%s_landmark_target_stride%d'%(prefix,stride))
|
|
landmark_weight = mx.symbol.Variable(name='%s_landmark_weight_stride%d'%(prefix,stride))
|
|
rpn_relu = conv_fpn_feat[stride]
|
|
maxout_stat = 0
|
|
if config.USE_MAXOUT>=1 and stride==config.RPN_FEAT_STRIDE[-1]:
|
|
maxout_stat = 1
|
|
if config.USE_MAXOUT>=2 and stride!=config.RPN_FEAT_STRIDE[-1]:
|
|
maxout_stat = 2
|
|
|
|
if maxout_stat==0:
|
|
rpn_cls_score = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d'%(prefix, stride), 2*num_anchors,
|
|
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
|
|
elif maxout_stat==1:
|
|
cls_list = []
|
|
for a in range(num_anchors):
|
|
rpn_cls_score_bg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_bg'%(prefix,stride,a), 3,
|
|
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
|
|
rpn_cls_score_bg = mx.sym.max(rpn_cls_score_bg, axis=1, keepdims=True)
|
|
cls_list.append(rpn_cls_score_bg)
|
|
rpn_cls_score_fg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_fg'%(prefix,stride,a), 1,
|
|
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
|
|
cls_list.append(rpn_cls_score_fg)
|
|
rpn_cls_score = mx.sym.concat(*cls_list, dim=1, name='%s_rpn_cls_score_stride%d'%(prefix,stride))
|
|
else:
|
|
cls_list = []
|
|
for a in range(num_anchors):
|
|
rpn_cls_score_bg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_bg'%(prefix,stride,a), 1,
|
|
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
|
|
cls_list.append(rpn_cls_score_bg)
|
|
rpn_cls_score_fg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_fg'%(prefix,stride,a), 3,
|
|
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
|
|
rpn_cls_score_fg = mx.sym.max(rpn_cls_score_fg, axis=1, keepdims=True)
|
|
cls_list.append(rpn_cls_score_fg)
|
|
rpn_cls_score = mx.sym.concat(*cls_list, dim=1, name='%s_rpn_cls_score_stride%d'%(prefix,stride))
|
|
|
|
rpn_bbox_pred = conv_act_layer(rpn_relu, '%s_rpn_bbox_pred_stride%d'%(prefix,stride), 4*num_anchors,
|
|
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
|
|
|
|
# prepare rpn data
|
|
rpn_cls_score_reshape = mx.symbol.Reshape(data=rpn_cls_score,
|
|
shape=(0, 2, -1),
|
|
name="%s_rpn_cls_score_reshape_stride%s" % (prefix,stride))
|
|
rpn_bbox_pred_reshape = mx.symbol.Reshape(data=rpn_bbox_pred,
|
|
shape=(0, 0, -1),
|
|
name="%s_rpn_bbox_pred_reshape_stride%s" % (prefix,stride))
|
|
if landmark:
|
|
rpn_landmark_pred = conv_act_layer(rpn_relu, '%s_rpn_landmark_pred_stride%d'%(prefix,stride), 10*num_anchors,
|
|
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
|
|
rpn_landmark_pred_reshape = mx.symbol.Reshape(data=rpn_landmark_pred,
|
|
shape=(0, 0, -1),
|
|
name="%s_rpn_landmark_pred_reshape_stride%s" % (prefix,stride))
|
|
|
|
if config.TRAIN.RPN_ENABLE_OHEM>=2:
|
|
label, anchor_weight = mx.sym.Custom(op_type='rpn_fpn_ohem3', stride=int(stride), network=config.network, dataset=config.dataset, prefix=prefix, cls_score=rpn_cls_score_reshape, labels = label)
|
|
|
|
_bbox_weight = mx.sym.tile(anchor_weight, (1,1,4))
|
|
_bbox_weight = _bbox_weight.reshape((0, -1, A * 4)).transpose((0,2,1))
|
|
bbox_weight = mx.sym.elemwise_mul(bbox_weight, _bbox_weight, name='%s_bbox_weight_mul_stride%s'%(prefix,stride))
|
|
|
|
if landmark:
|
|
_landmark_weight = mx.sym.tile(anchor_weight, (1,1,10))
|
|
_landmark_weight = _landmark_weight.reshape((0, -1, A * 10)).transpose((0,2,1))
|
|
landmark_weight = mx.sym.elemwise_mul(landmark_weight, _landmark_weight, name='%s_landmark_weight_mul_stride%s'%(prefix,stride))
|
|
#if not config.FACE_LANDMARK:
|
|
# label, bbox_weight = mx.sym.Custom(op_type='rpn_fpn_ohem', stride=int(stride), cls_score=rpn_cls_score_reshape, bbox_weight = bbox_weight , labels = label)
|
|
#else:
|
|
# label, bbox_weight, landmark_weight = mx.sym.Custom(op_type='rpn_fpn_ohem2', stride=int(stride), cls_score=rpn_cls_score_reshape, bbox_weight = bbox_weight, landmark_weight=landmark_weight, labels = label)
|
|
#cls loss
|
|
rpn_cls_prob = mx.symbol.SoftmaxOutput(data=rpn_cls_score_reshape,
|
|
label=label,
|
|
multi_output=True,
|
|
normalization='valid', use_ignore=True, ignore_label=-1,
|
|
grad_scale = lr_mult,
|
|
name='%s_rpn_cls_prob_stride%d'%(prefix,stride))
|
|
ret_group.append(rpn_cls_prob)
|
|
ret_group.append(mx.sym.BlockGrad(label))
|
|
|
|
#bbox loss
|
|
bbox_diff = rpn_bbox_pred_reshape-bbox_target
|
|
bbox_diff = bbox_diff * bbox_weight
|
|
rpn_bbox_loss_ = mx.symbol.smooth_l1(name='%s_rpn_bbox_loss_stride%d_'%(prefix,stride), scalar=3.0, data=bbox_diff)
|
|
rpn_bbox_loss = mx.sym.MakeLoss(name='%s_rpn_bbox_loss_stride%d'%(prefix,stride), data=rpn_bbox_loss_, grad_scale=1.0*lr_mult / (config.TRAIN.RPN_BATCH_SIZE))
|
|
ret_group.append(rpn_bbox_loss)
|
|
ret_group.append(mx.sym.BlockGrad(bbox_weight))
|
|
|
|
#landmark loss
|
|
if landmark:
|
|
landmark_diff = rpn_landmark_pred_reshape-landmark_target
|
|
landmark_diff = landmark_diff * landmark_weight
|
|
rpn_landmark_loss_ = mx.symbol.smooth_l1(name='%s_rpn_landmark_loss_stride%d_'%(prefix,stride), scalar=3.0, data=landmark_diff)
|
|
rpn_landmark_loss = mx.sym.MakeLoss(name='%s_rpn_landmark_loss_stride%d'%(prefix,stride), data=rpn_landmark_loss_, grad_scale=0.5*lr_mult / (config.TRAIN.RPN_BATCH_SIZE))
|
|
ret_group.append(rpn_landmark_loss)
|
|
ret_group.append(mx.sym.BlockGrad(landmark_weight))
|
|
return ret_group
|
|
|
|
def get_mnet_train():
|
|
data = mx.symbol.Variable(name="data")
|
|
|
|
# shared convolutional layers
|
|
conv_fpn_feat = get_mnet_conv(data)
|
|
ret_group = []
|
|
for stride in config.RPN_FEAT_STRIDE:
|
|
ret = get_out(conv_fpn_feat, 'face', stride, config.FACE_LANDMARK, lr_mult=1.0)
|
|
ret_group += ret
|
|
if config.HEAD_BOX:
|
|
ret = get_out(conv_fpn_feat, 'head', stride, False, lr_mult=1.0)
|
|
ret_group += ret
|
|
|
|
return mx.sym.Group(ret_group)
|
|
|
|
|