Files
insightface/detection/retinaface/rcnn/symbol/symbol_mnet.py.bak
2021-06-19 23:57:15 +08:00

363 lines
17 KiB
Python

import mxnet as mx
import mxnet.ndarray as nd
import mxnet.gluon as gluon
import mxnet.gluon.nn as nn
import mxnet.autograd as ag
import numpy as np
from rcnn.config import config
from rcnn.PY_OP import rpn_fpn_ohem, rpn_fpn_ohem2, rpn_fpn_ohem3
USE_DCN = False
MM = 1.0
def ConvBlock(channels, kernel_size, strides, **kwargs):
out = nn.HybridSequential(**kwargs)
with out.name_scope():
out.add(
nn.Conv2D(channels, kernel_size, strides=strides, padding=1, use_bias=False),
nn.BatchNorm(scale=True),
nn.Activation('relu')
)
return out
def Conv1x1(channels, is_linear=False, **kwargs):
out = nn.HybridSequential(**kwargs)
with out.name_scope():
out.add(
nn.Conv2D(channels, 1, padding=0, use_bias=False),
nn.BatchNorm(scale=True)
)
if not is_linear:
out.add(nn.Activation('relu'))
return out
def DWise(channels, strides, kernel_size=3, **kwargs):
out = nn.HybridSequential(**kwargs)
with out.name_scope():
out.add(
nn.Conv2D(channels, kernel_size, strides=strides, padding=kernel_size // 2, groups=channels, use_bias=False),
nn.BatchNorm(scale=True),
nn.Activation('relu')
)
return out
class SepCONV(nn.HybridBlock):
def __init__(self, inp, output, kernel_size, depth_multiplier=1, with_bn=True, **kwargs):
super(SepCONV, self).__init__(**kwargs)
with self.name_scope():
self.net = nn.HybridSequential()
cn = int(inp*depth_multiplier)
if output is None:
self.net.add(
nn.Conv2D(in_channels=inp, channels=cn, groups=inp, kernel_size=kernel_size, strides=(1,1), padding=kernel_size // 2
, use_bias=not with_bn)
)
else:
self.net.add(
nn.Conv2D(in_channels=inp, channels=cn, groups=inp, kernel_size=kernel_size, strides=(1,1), padding=kernel_size // 2
, use_bias=False),
nn.BatchNorm(),
nn.Activation('relu'),
nn.Conv2D(in_channels=cn, channels=output, kernel_size=(1,1), strides=(1,1)
, use_bias=not with_bn)
)
self.with_bn = with_bn
self.act = nn.Activation('relu')
if with_bn:
self.bn = nn.BatchNorm()
def hybrid_forward(self, F ,x):
x = self.net(x)
if self.with_bn:
x = self.bn(x)
if self.act is not None:
x = self.act(x)
return x
class ExpandedConv(nn.HybridBlock):
def __init__(self, inp, oup, t, strides, kernel=3, same_shape=True, **kwargs):
super(ExpandedConv, self).__init__(**kwargs)
self.same_shape = same_shape
self.strides = strides
with self.name_scope():
self.bottleneck = nn.HybridSequential()
self.bottleneck.add(
Conv1x1(inp*t, prefix="expand_"),
DWise(inp*t, self.strides, kernel, prefix="dwise_"),
Conv1x1(oup, is_linear=True, prefix="linear_")
)
def hybrid_forward(self, F, x):
out = self.bottleneck(x)
if self.strides == 1 and self.same_shape:
out = F.elemwise_add(out, x)
return out
def ExpandedConvSequence(t, k, inp, oup, repeats, first_strides, **kwargs):
seq = nn.HybridSequential(**kwargs)
with seq.name_scope():
seq.add(ExpandedConv(inp, oup, t, first_strides, k, same_shape=False))
curr_inp = oup
for i in range(1, repeats):
seq.add(ExpandedConv(curr_inp, oup, t, 1))
curr_inp = oup
return seq
class Mnasnet(nn.HybridBlock):
def __init__(self, multiplier=1.0, **kwargs):
super(Mnasnet, self).__init__(**kwargs)
mm = multiplier
self.first_oup = 32
self.interverted_residual_setting = [
# t, c, n, s, k
[3, int(24*mm), 3, 2, 3, "stage2_"], # -> 56x56
[3, int(40*mm), 3, 2, 5, "stage3_"], # -> 28x28
[6, int(80*mm), 3, 2, 5, "stage4_1_"], # -> 14x14
[6, int(96*mm), 2, 1, 3, "stage4_2_"], # -> 14x14
[6, int(192*mm), 4, 2, 5, "stage5_1_"], # -> 7x7
[6, int(320*mm), 1, 1, 3, "stage5_2_"], # -> 7x7
]
self.last_channels = 1280
with self.name_scope():
self.features = nn.HybridSequential()
self.features.add(ConvBlock(self.first_oup, 3, 2, prefix="stage1_conv0_"))
self.features.add(SepCONV(self.first_oup, 16, 3, prefix="stage1_sepconv0_"))
inp = 16
for i, (t, c, n, s, k, prefix) in enumerate(self.interverted_residual_setting):
oup = c
self.features.add(ExpandedConvSequence(t, k, inp, oup, n, s, prefix=prefix))
inp = oup
self.features.add(Conv1x1(self.last_channels, prefix="stage5_3_"))
def hybrid_forward(self, F, x):
x = self.features(x)
return x
def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \
stride=(1,1), act_type="relu", bias_wd_mult=0.0, dcn=False):
weight = mx.symbol.Variable(name="{}_weight".format(name),
init=mx.init.Normal(0.01), attr={'__lr_mult__': '1.0'})
bias = mx.symbol.Variable(name="{}_bias".format(name),
init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0', '__wd_mult__': str(bias_wd_mult)})
if not dcn:
conv = mx.symbol.Convolution(data=from_layer, kernel=kernel, pad=pad, \
stride=stride, num_filter=num_filter, name="{}".format(name), weight = weight, bias=bias)
else:
assert kernel[0]==3 and kernel[1]==3
num_group = 1
f = num_group*18
offset_weight = mx.symbol.Variable(name="{}_offset_weight".format(name),
init=mx.init.Constant(0.0), attr={'__lr_mult__': '1.0'})
offset_bias = mx.symbol.Variable(name="{}_offset_bias".format(name),
init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0', '__wd_mult__': str(bias_wd_mult)})
conv_offset = mx.symbol.Convolution(name=name+'_offset', data = from_layer, weight=offset_weight, bias=offset_bias,
num_filter=f, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv = mx.contrib.symbol.DeformableConvolution(name=name, data=from_layer, offset=conv_offset, weight=weight, bias=bias,
num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=num_group, stride=(1, 1), no_bias=False)
if len(act_type)>0:
relu = mx.symbol.Activation(data=conv, act_type=act_type, \
name="{}_{}".format(name, act_type))
else:
relu = conv
return relu
def ssh_context_module(body, num_filters, name):
conv_dimred = conv_act_layer(body, name+'_conv1',
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='relu', dcn=False)
conv5x5 = conv_act_layer(conv_dimred, name+'_conv2',
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='', dcn=USE_DCN)
conv7x7_1 = conv_act_layer(conv_dimred, name+'_conv3_1',
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='relu', dcn=False)
conv7x7 = conv_act_layer(conv7x7_1, name+'_conv3_2',
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='', dcn=USE_DCN)
return (conv5x5, conv7x7)
def ssh_detection_module(body, num_filters, name):
conv3x3 = conv_act_layer(body, name+'_conv1',
num_filters, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='', dcn=USE_DCN)
conv5x5, conv7x7 = ssh_context_module(body, num_filters//2, name+'_context')
ret = mx.sym.concat(*[conv3x3, conv5x5, conv7x7], dim=1, name = name+'_concat')
ret = mx.symbol.Activation(data=ret, act_type='relu', name=name+'_concat_relu')
return ret
def conv_bn(input, filter, ksize, stride, padding, act_type='relu', name=''):
conv = mx.symbol.Convolution(data=input, kernel=(ksize,ksize), pad=(padding,padding), \
stride=(stride,stride), num_filter=filter, name=name+"_conv")
ret = mx.sym.BatchNorm(data=conv, fix_gamma=False, eps=2e-5, momentum=0.9, name=name + '_bn')
if act_type is not None:
ret = mx.symbol.Activation(data=ret, act_type=act_type, \
name="{}_{}".format(name, act_type))
return ret
def cpm(input, name):
# residual
branch1 = conv_bn(input, 1024, 1, 1, 0, act_type=None, name=name+"_branch1")
branch2a = conv_bn(input, 256, 1, 1, 0, act_type='relu', name=name+"_branch2a")
branch2b = conv_bn(branch2a, 256, 3, 1, 1, act_type='relu', name=name+"_branch2b")
branch2c = conv_bn(branch2b, 1024, 1, 1, 0, act_type=None, name=name+"_branch2c")
sum = branch1 + branch2c
rescomb = mx.symbol.Activation(data=sum, act_type='relu', name="%s_relu2"%(name))
ssh_out = ssh_detection_module(rescomb, 256, name=name+"_ssh")
return ssh_out
def get_mnet_conv(data):
mm = MM
net = Mnasnet(mm, prefix="")
body = net(data)
all_layers = body.get_internals()
#print(all_layers)
c1 = all_layers['stage3_expandedconv2_elemwise_add0_output']
c2 = all_layers['stage4_2_expandedconv1_elemwise_add0_output']
#c3 = all_layers['stage5_3_relu0_fwd_output']
c3 = all_layers['stage5_2_expandedconv0_linear_batchnorm0_fwd_output']
F1 = int(256*mm)
F2 = int(128*mm)
_bwm = 1.0
conv4_128 = conv_act_layer(c1, 'ssh_m1_red_conv',
F2, kernel=(1, 1), pad=(0, 0), stride=(1, 1), act_type='relu', bias_wd_mult=_bwm)
conv5_128 = conv_act_layer(c2, 'ssh_m2_red_conv',
F2, kernel=(1, 1), pad=(0, 0), stride=(1, 1), act_type='relu', bias_wd_mult=_bwm)
conv5_128_up = mx.symbol.Deconvolution(data=conv5_128, num_filter=F2, kernel=(4,4), stride=(2, 2), pad=(1,1),
num_group = F2, no_bias = True, attr={'__lr_mult__': '0.0', '__wd_mult__': '0.0'},
name='ssh_m2_red_upsampling')
#conv5_128_up = mx.symbol.UpSampling(conv5_128, scale=2, sample_type='nearest', workspace=512, name='ssh_m2_red_up', num_args=1)
conv4_128 = mx.symbol.Crop(*[conv4_128, conv5_128_up])
#conv5_128_up = mx.symbol.Crop(*[conv5_128_up, conv4_128])
conv_sum = conv4_128+conv5_128_up
#conv_sum = conv_1x1
m1_conv = conv_act_layer(conv_sum, 'ssh_m1_conv',
F2, kernel=(3, 3), pad=(1, 1), stride=(1, 1), act_type='relu', bias_wd_mult=_bwm)
m1 = ssh_detection_module(m1_conv, F2, 'ssh_m1_det')
m2 = ssh_detection_module(c2, F1, 'ssh_m2_det')
m3 = ssh_detection_module(c3, F1, 'ssh_m3_det')
return {8: m1, 16:m2, 32: m3}
def get_out(conv_fpn_feat, prefix, stride, landmark=False, lr_mult=1.0):
A = config.NUM_ANCHORS
ret_group = []
num_anchors = config.RPN_ANCHOR_CFG[str(stride)]['NUM_ANCHORS']
label = mx.symbol.Variable(name='%s_label_stride%d'%(prefix,stride))
bbox_target = mx.symbol.Variable(name='%s_bbox_target_stride%d'%(prefix,stride))
bbox_weight = mx.symbol.Variable(name='%s_bbox_weight_stride%d'%(prefix,stride))
if landmark:
landmark_target = mx.symbol.Variable(name='%s_landmark_target_stride%d'%(prefix,stride))
landmark_weight = mx.symbol.Variable(name='%s_landmark_weight_stride%d'%(prefix,stride))
rpn_relu = conv_fpn_feat[stride]
maxout_stat = 0
if config.USE_MAXOUT>=1 and stride==config.RPN_FEAT_STRIDE[-1]:
maxout_stat = 1
if config.USE_MAXOUT>=2 and stride!=config.RPN_FEAT_STRIDE[-1]:
maxout_stat = 2
if maxout_stat==0:
rpn_cls_score = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d'%(prefix, stride), 2*num_anchors,
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
elif maxout_stat==1:
cls_list = []
for a in range(num_anchors):
rpn_cls_score_bg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_bg'%(prefix,stride,a), 3,
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
rpn_cls_score_bg = mx.sym.max(rpn_cls_score_bg, axis=1, keepdims=True)
cls_list.append(rpn_cls_score_bg)
rpn_cls_score_fg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_fg'%(prefix,stride,a), 1,
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
cls_list.append(rpn_cls_score_fg)
rpn_cls_score = mx.sym.concat(*cls_list, dim=1, name='%s_rpn_cls_score_stride%d'%(prefix,stride))
else:
cls_list = []
for a in range(num_anchors):
rpn_cls_score_bg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_bg'%(prefix,stride,a), 1,
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
cls_list.append(rpn_cls_score_bg)
rpn_cls_score_fg = conv_act_layer(rpn_relu, '%s_rpn_cls_score_stride%d_anchor%d_fg'%(prefix,stride,a), 3,
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
rpn_cls_score_fg = mx.sym.max(rpn_cls_score_fg, axis=1, keepdims=True)
cls_list.append(rpn_cls_score_fg)
rpn_cls_score = mx.sym.concat(*cls_list, dim=1, name='%s_rpn_cls_score_stride%d'%(prefix,stride))
rpn_bbox_pred = conv_act_layer(rpn_relu, '%s_rpn_bbox_pred_stride%d'%(prefix,stride), 4*num_anchors,
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
# prepare rpn data
rpn_cls_score_reshape = mx.symbol.Reshape(data=rpn_cls_score,
shape=(0, 2, -1),
name="%s_rpn_cls_score_reshape_stride%s" % (prefix,stride))
rpn_bbox_pred_reshape = mx.symbol.Reshape(data=rpn_bbox_pred,
shape=(0, 0, -1),
name="%s_rpn_bbox_pred_reshape_stride%s" % (prefix,stride))
if landmark:
rpn_landmark_pred = conv_act_layer(rpn_relu, '%s_rpn_landmark_pred_stride%d'%(prefix,stride), 10*num_anchors,
kernel=(1,1), pad=(0,0), stride=(1, 1), act_type='')
rpn_landmark_pred_reshape = mx.symbol.Reshape(data=rpn_landmark_pred,
shape=(0, 0, -1),
name="%s_rpn_landmark_pred_reshape_stride%s" % (prefix,stride))
if config.TRAIN.RPN_ENABLE_OHEM>=2:
label, anchor_weight = mx.sym.Custom(op_type='rpn_fpn_ohem3', stride=int(stride), network=config.network, dataset=config.dataset, prefix=prefix, cls_score=rpn_cls_score_reshape, labels = label)
_bbox_weight = mx.sym.tile(anchor_weight, (1,1,4))
_bbox_weight = _bbox_weight.reshape((0, -1, A * 4)).transpose((0,2,1))
bbox_weight = mx.sym.elemwise_mul(bbox_weight, _bbox_weight, name='%s_bbox_weight_mul_stride%s'%(prefix,stride))
if landmark:
_landmark_weight = mx.sym.tile(anchor_weight, (1,1,10))
_landmark_weight = _landmark_weight.reshape((0, -1, A * 10)).transpose((0,2,1))
landmark_weight = mx.sym.elemwise_mul(landmark_weight, _landmark_weight, name='%s_landmark_weight_mul_stride%s'%(prefix,stride))
#if not config.FACE_LANDMARK:
# label, bbox_weight = mx.sym.Custom(op_type='rpn_fpn_ohem', stride=int(stride), cls_score=rpn_cls_score_reshape, bbox_weight = bbox_weight , labels = label)
#else:
# label, bbox_weight, landmark_weight = mx.sym.Custom(op_type='rpn_fpn_ohem2', stride=int(stride), cls_score=rpn_cls_score_reshape, bbox_weight = bbox_weight, landmark_weight=landmark_weight, labels = label)
#cls loss
rpn_cls_prob = mx.symbol.SoftmaxOutput(data=rpn_cls_score_reshape,
label=label,
multi_output=True,
normalization='valid', use_ignore=True, ignore_label=-1,
grad_scale = lr_mult,
name='%s_rpn_cls_prob_stride%d'%(prefix,stride))
ret_group.append(rpn_cls_prob)
ret_group.append(mx.sym.BlockGrad(label))
#bbox loss
bbox_diff = rpn_bbox_pred_reshape-bbox_target
bbox_diff = bbox_diff * bbox_weight
rpn_bbox_loss_ = mx.symbol.smooth_l1(name='%s_rpn_bbox_loss_stride%d_'%(prefix,stride), scalar=3.0, data=bbox_diff)
rpn_bbox_loss = mx.sym.MakeLoss(name='%s_rpn_bbox_loss_stride%d'%(prefix,stride), data=rpn_bbox_loss_, grad_scale=1.0*lr_mult / (config.TRAIN.RPN_BATCH_SIZE))
ret_group.append(rpn_bbox_loss)
ret_group.append(mx.sym.BlockGrad(bbox_weight))
#landmark loss
if landmark:
landmark_diff = rpn_landmark_pred_reshape-landmark_target
landmark_diff = landmark_diff * landmark_weight
rpn_landmark_loss_ = mx.symbol.smooth_l1(name='%s_rpn_landmark_loss_stride%d_'%(prefix,stride), scalar=3.0, data=landmark_diff)
rpn_landmark_loss = mx.sym.MakeLoss(name='%s_rpn_landmark_loss_stride%d'%(prefix,stride), data=rpn_landmark_loss_, grad_scale=0.5*lr_mult / (config.TRAIN.RPN_BATCH_SIZE))
ret_group.append(rpn_landmark_loss)
ret_group.append(mx.sym.BlockGrad(landmark_weight))
return ret_group
def get_mnet_train():
data = mx.symbol.Variable(name="data")
# shared convolutional layers
conv_fpn_feat = get_mnet_conv(data)
ret_group = []
for stride in config.RPN_FEAT_STRIDE:
ret = get_out(conv_fpn_feat, 'face', stride, config.FACE_LANDMARK, lr_mult=1.0)
ret_group += ret
if config.HEAD_BOX:
ret = get_out(conv_fpn_feat, 'head', stride, False, lr_mult=1.0)
ret_group += ret
return mx.sym.Group(ret_group)