mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
1086 lines
44 KiB
Python
1086 lines
44 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import mxnet as mx
|
|
import numpy as np
|
|
from config import config
|
|
|
|
ACT_BIT = 1
|
|
bn_mom = 0.9
|
|
workspace = 256
|
|
memonger = False
|
|
|
|
|
|
def Conv(**kwargs):
|
|
body = mx.sym.Convolution(**kwargs)
|
|
return body
|
|
|
|
|
|
def Act(data, act_type, name):
|
|
if act_type == 'prelu':
|
|
body = mx.sym.LeakyReLU(data=data, act_type='prelu', name=name)
|
|
else:
|
|
body = mx.symbol.Activation(data=data, act_type=act_type, name=name)
|
|
return body
|
|
|
|
|
|
#def lin(data, num_filter, workspace, name, binarize, dcn):
|
|
# bit = 1
|
|
# if not binarize:
|
|
# if not dcn:
|
|
# conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
|
# no_bias=True, workspace=workspace, name=name + '_conv')
|
|
# bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
|
# act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
|
# return act1
|
|
# else:
|
|
# bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
|
# act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
|
# conv1_offset = mx.symbol.Convolution(name=name+'_conv_offset', data = act1,
|
|
# num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
|
|
# conv1 = mx.contrib.symbol.DeformableConvolution(name=name+"_conv", data=act1, offset=conv1_offset,
|
|
# num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
|
|
# #conv1 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
|
|
# # no_bias=False, workspace=workspace, name=name + '_conv')
|
|
# return conv1
|
|
# else:
|
|
# bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
|
|
# act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
|
# conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
|
|
# no_bias=True, workspace=workspace, name=name + '_conv', act_bit=ACT_BIT, weight_bit=bit)
|
|
# conv1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
|
|
# return conv1
|
|
|
|
|
|
def lin3(data, num_filter, workspace, name, k, g=1, d=1):
|
|
if k != 3:
|
|
conv1 = Conv(data=data,
|
|
num_filter=num_filter,
|
|
kernel=(k, k),
|
|
stride=(1, 1),
|
|
pad=((k - 1) // 2, (k - 1) // 2),
|
|
num_group=g,
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv')
|
|
else:
|
|
conv1 = Conv(data=data,
|
|
num_filter=num_filter,
|
|
kernel=(k, k),
|
|
stride=(1, 1),
|
|
pad=(d, d),
|
|
num_group=g,
|
|
dilate=(d, d),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv')
|
|
bn1 = mx.sym.BatchNorm(data=conv1,
|
|
fix_gamma=False,
|
|
momentum=bn_mom,
|
|
eps=2e-5,
|
|
name=name + '_bn')
|
|
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
|
|
ret = act1
|
|
return ret
|
|
|
|
|
|
def ConvFactory(data,
|
|
num_filter,
|
|
kernel,
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
act_type="relu",
|
|
mirror_attr={},
|
|
with_act=True,
|
|
dcn=False,
|
|
name=''):
|
|
if not dcn:
|
|
conv = mx.symbol.Convolution(data=data,
|
|
num_filter=num_filter,
|
|
kernel=kernel,
|
|
stride=stride,
|
|
pad=pad,
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv')
|
|
else:
|
|
conv_offset = mx.symbol.Convolution(name=name + '_conv_offset',
|
|
data=data,
|
|
num_filter=18,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
stride=(1, 1))
|
|
conv = mx.contrib.symbol.DeformableConvolution(name=name + "_conv",
|
|
data=data,
|
|
offset=conv_offset,
|
|
num_filter=num_filter,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
num_deformable_group=1,
|
|
stride=stride,
|
|
dilate=(1, 1),
|
|
no_bias=False)
|
|
bn = mx.symbol.BatchNorm(data=conv,
|
|
fix_gamma=False,
|
|
momentum=bn_mom,
|
|
eps=2e-5,
|
|
name=name + '_bn')
|
|
if with_act:
|
|
act = Act(bn, act_type, name=name + '_relu')
|
|
#act = mx.symbol.Activation(
|
|
# data=bn, act_type=act_type, attr=mirror_attr, name=name+'_relu')
|
|
return act
|
|
else:
|
|
return bn
|
|
|
|
|
|
class CAB:
|
|
def __init__(self, data, nFilters, nModules, n, workspace, name, dilate,
|
|
group):
|
|
self.data = data
|
|
self.nFilters = nFilters
|
|
self.nModules = nModules
|
|
self.n = n
|
|
self.workspace = workspace
|
|
self.name = name
|
|
self.dilate = dilate
|
|
self.group = group
|
|
self.sym_map = {}
|
|
|
|
def get_output(self, w, h):
|
|
key = (w, h)
|
|
if key in self.sym_map:
|
|
return self.sym_map[key]
|
|
ret = None
|
|
if h == self.n:
|
|
if w == self.n:
|
|
ret = (self.data, self.nFilters)
|
|
else:
|
|
x = self.get_output(w + 1, h)
|
|
f = int(x[1] * 0.5)
|
|
if w != self.n - 1:
|
|
body = lin3(x[0], f, self.workspace,
|
|
"%s_w%d_h%d_1" % (self.name, w, h), 3,
|
|
self.group, 1)
|
|
else:
|
|
body = lin3(x[0], f, self.workspace,
|
|
"%s_w%d_h%d_1" % (self.name, w, h), 3,
|
|
self.group, self.dilate)
|
|
ret = (body, f)
|
|
else:
|
|
x = self.get_output(w + 1, h + 1)
|
|
y = self.get_output(w, h + 1)
|
|
if h % 2 == 1 and h != w:
|
|
xbody = lin3(x[0], x[1], self.workspace,
|
|
"%s_w%d_h%d_2" % (self.name, w, h), 3, x[1])
|
|
#xbody = xbody+x[0]
|
|
else:
|
|
xbody = x[0]
|
|
#xbody = x[0]
|
|
#xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
|
|
if w == 0:
|
|
ybody = lin3(y[0], y[1], self.workspace,
|
|
"%s_w%d_h%d_3" % (self.name, w, h), 3, self.group)
|
|
else:
|
|
ybody = y[0]
|
|
ybody = mx.sym.concat(y[0], ybody, dim=1)
|
|
body = mx.sym.add_n(xbody,
|
|
ybody,
|
|
name="%s_w%d_h%d_add" % (self.name, w, h))
|
|
body = body / 2
|
|
ret = (body, x[1])
|
|
self.sym_map[key] = ret
|
|
return ret
|
|
|
|
def get(self):
|
|
return self.get_output(1, 1)[0]
|
|
|
|
|
|
def conv_resnet(data, num_filter, stride, dim_match, name, binarize, dcn,
|
|
dilate, **kwargs):
|
|
bit = 1
|
|
#print('in unit2')
|
|
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
|
|
bn1 = mx.sym.BatchNorm(data=data,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_bn1')
|
|
if not binarize:
|
|
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
|
|
conv1 = Conv(data=act1,
|
|
num_filter=int(num_filter * 0.5),
|
|
kernel=(1, 1),
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv1')
|
|
else:
|
|
act1 = mx.sym.QActivation(data=bn1,
|
|
act_bit=ACT_BIT,
|
|
name=name + '_relu1',
|
|
backward_only=True)
|
|
conv1 = mx.sym.QConvolution(data=act1,
|
|
num_filter=int(num_filter * 0.5),
|
|
kernel=(1, 1),
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv1',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
bn2 = mx.sym.BatchNorm(data=conv1,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_bn2')
|
|
if not binarize:
|
|
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
|
|
conv2 = Conv(data=act2,
|
|
num_filter=int(num_filter * 0.5),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(1, 1),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv2')
|
|
else:
|
|
act2 = mx.sym.QActivation(data=bn2,
|
|
act_bit=ACT_BIT,
|
|
name=name + '_relu2',
|
|
backward_only=True)
|
|
conv2 = mx.sym.QConvolution(data=act2,
|
|
num_filter=int(num_filter * 0.5),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(1, 1),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv2',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
bn3 = mx.sym.BatchNorm(data=conv2,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_bn3')
|
|
if not binarize:
|
|
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
|
|
conv3 = Conv(data=act3,
|
|
num_filter=num_filter,
|
|
kernel=(1, 1),
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv3')
|
|
else:
|
|
act3 = mx.sym.QActivation(data=bn3,
|
|
act_bit=ACT_BIT,
|
|
name=name + '_relu3',
|
|
backward_only=True)
|
|
conv3 = mx.sym.QConvolution(data=act3,
|
|
num_filter=num_filter,
|
|
kernel=(1, 1),
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv3',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
#if binarize:
|
|
# conv3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
|
|
if dim_match:
|
|
shortcut = data
|
|
else:
|
|
if not binarize:
|
|
shortcut = Conv(data=act1,
|
|
num_filter=num_filter,
|
|
kernel=(1, 1),
|
|
stride=stride,
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_sc')
|
|
else:
|
|
shortcut = mx.sym.QConvolution(data=act1,
|
|
num_filter=num_filter,
|
|
kernel=(1, 1),
|
|
stride=stride,
|
|
pad=(0, 0),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_sc',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
if memonger:
|
|
shortcut._set_attr(mirror_stage='True')
|
|
return conv3 + shortcut
|
|
|
|
|
|
def conv_hpm(data, num_filter, stride, dim_match, name, binarize, dcn,
|
|
dilation, **kwargs):
|
|
bit = 1
|
|
#print('in unit2')
|
|
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
|
|
bn1 = mx.sym.BatchNorm(data=data,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_bn1')
|
|
if not binarize:
|
|
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
|
|
if not dcn:
|
|
conv1 = Conv(data=act1,
|
|
num_filter=int(num_filter * 0.5),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(dilation, dilation),
|
|
dilate=(dilation, dilation),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv1')
|
|
else:
|
|
conv1_offset = mx.symbol.Convolution(name=name + '_conv1_offset',
|
|
data=act1,
|
|
num_filter=18,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
stride=(1, 1))
|
|
conv1 = mx.contrib.symbol.DeformableConvolution(
|
|
name=name + '_conv1',
|
|
data=act1,
|
|
offset=conv1_offset,
|
|
num_filter=int(num_filter * 0.5),
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
num_deformable_group=1,
|
|
stride=(1, 1),
|
|
dilate=(1, 1),
|
|
no_bias=True)
|
|
else:
|
|
act1 = mx.sym.QActivation(data=bn1,
|
|
act_bit=ACT_BIT,
|
|
name=name + '_relu1',
|
|
backward_only=True)
|
|
conv1 = mx.sym.QConvolution_v1(data=act1,
|
|
num_filter=int(num_filter * 0.5),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(1, 1),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv1',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
bn2 = mx.sym.BatchNorm(data=conv1,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_bn2')
|
|
if not binarize:
|
|
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
|
|
if not dcn:
|
|
conv2 = Conv(data=act2,
|
|
num_filter=int(num_filter * 0.25),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(dilation, dilation),
|
|
dilate=(dilation, dilation),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv2')
|
|
else:
|
|
conv2_offset = mx.symbol.Convolution(name=name + '_conv2_offset',
|
|
data=act2,
|
|
num_filter=18,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
stride=(1, 1))
|
|
conv2 = mx.contrib.symbol.DeformableConvolution(
|
|
name=name + '_conv2',
|
|
data=act2,
|
|
offset=conv2_offset,
|
|
num_filter=int(num_filter * 0.25),
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
num_deformable_group=1,
|
|
stride=(1, 1),
|
|
dilate=(1, 1),
|
|
no_bias=True)
|
|
else:
|
|
act2 = mx.sym.QActivation(data=bn2,
|
|
act_bit=ACT_BIT,
|
|
name=name + '_relu2',
|
|
backward_only=True)
|
|
conv2 = mx.sym.QConvolution_v1(data=act2,
|
|
num_filter=int(num_filter * 0.25),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(1, 1),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv2',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
bn3 = mx.sym.BatchNorm(data=conv2,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_bn3')
|
|
if not binarize:
|
|
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
|
|
if not dcn:
|
|
conv3 = Conv(data=act3,
|
|
num_filter=int(num_filter * 0.25),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(dilation, dilation),
|
|
dilate=(dilation, dilation),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv3')
|
|
else:
|
|
conv3_offset = mx.symbol.Convolution(name=name + '_conv3_offset',
|
|
data=act3,
|
|
num_filter=18,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
stride=(1, 1))
|
|
conv3 = mx.contrib.symbol.DeformableConvolution(
|
|
name=name + '_conv3',
|
|
data=act3,
|
|
offset=conv3_offset,
|
|
num_filter=int(num_filter * 0.25),
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
num_deformable_group=1,
|
|
stride=(1, 1),
|
|
dilate=(1, 1),
|
|
no_bias=True)
|
|
else:
|
|
act3 = mx.sym.QActivation(data=bn3,
|
|
act_bit=ACT_BIT,
|
|
name=name + '_relu3',
|
|
backward_only=True)
|
|
conv3 = mx.sym.QConvolution_v1(data=act3,
|
|
num_filter=int(num_filter * 0.25),
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(1, 1),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_conv3',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
conv4 = mx.symbol.Concat(*[conv1, conv2, conv3])
|
|
if binarize:
|
|
conv4 = mx.sym.BatchNorm(data=conv4,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_bn4')
|
|
if dim_match:
|
|
shortcut = data
|
|
else:
|
|
if not binarize:
|
|
shortcut = Conv(data=act1,
|
|
num_filter=num_filter,
|
|
kernel=(1, 1),
|
|
stride=stride,
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_sc')
|
|
else:
|
|
#assert(False)
|
|
shortcut = mx.sym.QConvolution_v1(data=act1,
|
|
num_filter=num_filter,
|
|
kernel=(1, 1),
|
|
stride=stride,
|
|
pad=(0, 0),
|
|
no_bias=True,
|
|
workspace=workspace,
|
|
name=name + '_sc',
|
|
act_bit=ACT_BIT,
|
|
weight_bit=bit)
|
|
shortcut = mx.sym.BatchNorm(data=shortcut,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name=name + '_sc_bn')
|
|
if memonger:
|
|
shortcut._set_attr(mirror_stage='True')
|
|
return conv4 + shortcut
|
|
#return bn4 + shortcut
|
|
#return act4 + shortcut
|
|
|
|
|
|
def block17(net,
|
|
input_num_channels,
|
|
scale=1.0,
|
|
with_act=True,
|
|
act_type='relu',
|
|
mirror_attr={},
|
|
name=''):
|
|
tower_conv = ConvFactory(net, 192, (1, 1), name=name + '_conv')
|
|
tower_conv1_0 = ConvFactory(net, 129, (1, 1), name=name + '_conv1_0')
|
|
tower_conv1_1 = ConvFactory(tower_conv1_0,
|
|
160, (1, 7),
|
|
pad=(1, 2),
|
|
name=name + '_conv1_1')
|
|
tower_conv1_2 = ConvFactory(tower_conv1_1,
|
|
192, (7, 1),
|
|
pad=(2, 1),
|
|
name=name + '_conv1_2')
|
|
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
|
|
tower_out = ConvFactory(tower_mixed,
|
|
input_num_channels, (1, 1),
|
|
with_act=False,
|
|
name=name + '_conv_out')
|
|
net = net + scale * tower_out
|
|
if with_act:
|
|
act = mx.symbol.Activation(data=net,
|
|
act_type=act_type,
|
|
attr=mirror_attr)
|
|
return act
|
|
else:
|
|
return net
|
|
|
|
|
|
def block35(net,
|
|
input_num_channels,
|
|
scale=1.0,
|
|
with_act=True,
|
|
act_type='relu',
|
|
mirror_attr={},
|
|
name=''):
|
|
M = 1.0
|
|
tower_conv = ConvFactory(net,
|
|
int(input_num_channels * 0.25 * M), (1, 1),
|
|
name=name + '_conv')
|
|
tower_conv1_0 = ConvFactory(net,
|
|
int(input_num_channels * 0.25 * M), (1, 1),
|
|
name=name + '_conv1_0')
|
|
tower_conv1_1 = ConvFactory(tower_conv1_0,
|
|
int(input_num_channels * 0.25 * M), (3, 3),
|
|
pad=(1, 1),
|
|
name=name + '_conv1_1')
|
|
tower_conv2_0 = ConvFactory(net,
|
|
int(input_num_channels * 0.25 * M), (1, 1),
|
|
name=name + '_conv2_0')
|
|
tower_conv2_1 = ConvFactory(tower_conv2_0,
|
|
int(input_num_channels * 0.375 * M), (3, 3),
|
|
pad=(1, 1),
|
|
name=name + '_conv2_1')
|
|
tower_conv2_2 = ConvFactory(tower_conv2_1,
|
|
int(input_num_channels * 0.5 * M), (3, 3),
|
|
pad=(1, 1),
|
|
name=name + '_conv2_2')
|
|
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2])
|
|
tower_out = ConvFactory(tower_mixed,
|
|
input_num_channels, (1, 1),
|
|
with_act=False,
|
|
name=name + '_conv_out')
|
|
|
|
net = net + scale * tower_out
|
|
if with_act:
|
|
act = mx.symbol.Activation(data=net,
|
|
act_type=act_type,
|
|
attr=mirror_attr)
|
|
return act
|
|
else:
|
|
return net
|
|
|
|
|
|
def conv_inception(data, num_filter, stride, dim_match, name, binarize, dcn,
|
|
dilate, **kwargs):
|
|
assert not binarize
|
|
if stride[0] > 1 or not dim_match:
|
|
return conv_resnet(data, num_filter, stride, dim_match, name, binarize,
|
|
dcn, dilate, **kwargs)
|
|
conv4 = block35(data, num_filter, name=name + '_block35')
|
|
return conv4
|
|
|
|
|
|
def conv_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate,
|
|
**kwargs):
|
|
if stride[0] > 1 or not dim_match:
|
|
return conv_hpm(data, num_filter, stride, dim_match, name, binarize,
|
|
dcn, dilate, **kwargs)
|
|
cab = CAB(data, num_filter, 1, 4, workspace, name, dilate, 1)
|
|
return cab.get()
|
|
|
|
|
|
def conv_block(data, num_filter, stride, dim_match, name, binarize, dcn,
|
|
dilate):
|
|
if config.net_block == 'resnet':
|
|
return conv_resnet(data, num_filter, stride, dim_match, name, binarize,
|
|
dcn, dilate)
|
|
elif config.net_block == 'inception':
|
|
return conv_inception(data, num_filter, stride, dim_match, name,
|
|
binarize, dcn, dilate)
|
|
elif config.net_block == 'hpm':
|
|
return conv_hpm(data, num_filter, stride, dim_match, name, binarize,
|
|
dcn, dilate)
|
|
elif config.net_block == 'cab':
|
|
return conv_cab(data, num_filter, stride, dim_match, name, binarize,
|
|
dcn, dilate)
|
|
|
|
|
|
def hourglass(data, nFilters, nModules, n, workspace, name, binarize, dcn):
|
|
s = 2
|
|
_dcn = False
|
|
up1 = data
|
|
for i in range(nModules):
|
|
up1 = conv_block(up1, nFilters, (1, 1), True, "%s_up1_%d" % (name, i),
|
|
binarize, _dcn, 1)
|
|
low1 = mx.sym.Pooling(data=data,
|
|
kernel=(s, s),
|
|
stride=(s, s),
|
|
pad=(0, 0),
|
|
pool_type='max')
|
|
for i in range(nModules):
|
|
low1 = conv_block(low1, nFilters, (1, 1), True,
|
|
"%s_low1_%d" % (name, i), binarize, _dcn, 1)
|
|
if n > 1:
|
|
low2 = hourglass(low1, nFilters, nModules, n - 1, workspace,
|
|
"%s_%d" % (name, n - 1), binarize, dcn)
|
|
else:
|
|
low2 = low1
|
|
for i in range(nModules):
|
|
low2 = conv_block(low2, nFilters, (1, 1), True,
|
|
"%s_low2_%d" % (name, i), binarize, _dcn,
|
|
1) #TODO
|
|
low3 = low2
|
|
for i in range(nModules):
|
|
low3 = conv_block(low3, nFilters, (1, 1), True,
|
|
"%s_low3_%d" % (name, i), binarize, _dcn, 1)
|
|
up2 = mx.symbol.UpSampling(low3,
|
|
scale=s,
|
|
sample_type='nearest',
|
|
workspace=512,
|
|
name='%s_upsampling_%s' % (name, n),
|
|
num_args=1)
|
|
return mx.symbol.add_n(up1, up2)
|
|
|
|
|
|
class STA:
|
|
def __init__(self, data, nFilters, nModules, n, workspace, name):
|
|
self.data = data
|
|
self.nFilters = nFilters
|
|
self.nModules = nModules
|
|
self.n = n
|
|
self.workspace = workspace
|
|
self.name = name
|
|
self.sym_map = {}
|
|
|
|
def get_conv(self, data, name, dilate=1, group=1):
|
|
cab = CAB(data, self.nFilters, self.nModules, 4, self.workspace, name,
|
|
dilate, group)
|
|
return cab.get()
|
|
|
|
def get_output(self, w, h):
|
|
#print(w,h)
|
|
assert w >= 1 and w <= config.net_n + 1
|
|
assert h >= 1 and h <= config.net_n + 1
|
|
s = 2
|
|
bn_mom = 0.9
|
|
key = (w, h)
|
|
if key in self.sym_map:
|
|
return self.sym_map[key]
|
|
ret = None
|
|
if h == self.n:
|
|
if w == self.n:
|
|
ret = self.data, 64
|
|
else:
|
|
x = self.get_output(w + 1, h)
|
|
body = self.get_conv(x[0], "%s_w%d_h%d_1" % (self.name, w, h))
|
|
body = mx.sym.Pooling(data=body,
|
|
kernel=(s, s),
|
|
stride=(s, s),
|
|
pad=(0, 0),
|
|
pool_type='max')
|
|
body = self.get_conv(body, "%s_w%d_h%d_2" % (self.name, w, h))
|
|
ret = body, x[1] // 2
|
|
else:
|
|
x = self.get_output(w + 1, h + 1)
|
|
y = self.get_output(w, h + 1)
|
|
|
|
HC = False
|
|
|
|
if h % 2 == 1 and h != w:
|
|
xbody = lin3(x[0], self.nFilters, self.workspace,
|
|
"%s_w%d_h%d_x" % (self.name, w, h), 3,
|
|
self.nFilters, 1)
|
|
HC = True
|
|
#xbody = x[0]
|
|
else:
|
|
xbody = x[0]
|
|
if x[1] // y[1] == 2:
|
|
if w > 1:
|
|
ybody = mx.symbol.Deconvolution(
|
|
data=y[0],
|
|
num_filter=self.nFilters,
|
|
kernel=(s, s),
|
|
stride=(s, s),
|
|
name='%s_upsampling_w%d_h%d' % (self.name, w, h),
|
|
attr={'lr_mult': '1.0'},
|
|
workspace=self.workspace)
|
|
ybody = mx.sym.BatchNorm(data=ybody,
|
|
fix_gamma=False,
|
|
momentum=bn_mom,
|
|
eps=2e-5,
|
|
name="%s_w%d_h%d_y_bn" %
|
|
(self.name, w, h))
|
|
ybody = Act(data=ybody,
|
|
act_type='relu',
|
|
name="%s_w%d_h%d_y_act" % (self.name, w, h))
|
|
else:
|
|
if h >= 1:
|
|
ybody = mx.symbol.UpSampling(
|
|
y[0],
|
|
scale=s,
|
|
sample_type='nearest',
|
|
workspace=512,
|
|
name='%s_upsampling_w%d_h%d' % (self.name, w, h),
|
|
num_args=1)
|
|
ybody = self.get_conv(
|
|
ybody, "%s_w%d_h%d_4" % (self.name, w, h))
|
|
else:
|
|
ybody = mx.symbol.Deconvolution(
|
|
data=y[0],
|
|
num_filter=self.nFilters,
|
|
kernel=(s, s),
|
|
stride=(s, s),
|
|
name='%s_upsampling_w%d_h%d' % (self.name, w, h),
|
|
attr={'lr_mult': '1.0'},
|
|
workspace=self.workspace)
|
|
ybody = mx.sym.BatchNorm(data=ybody,
|
|
fix_gamma=False,
|
|
momentum=bn_mom,
|
|
eps=2e-5,
|
|
name="%s_w%d_h%d_y_bn" %
|
|
(self.name, w, h))
|
|
ybody = Act(data=ybody,
|
|
act_type='relu',
|
|
name="%s_w%d_h%d_y_act" %
|
|
(self.name, w, h))
|
|
ybody = Conv(data=ybody,
|
|
num_filter=self.nFilters,
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(1, 1),
|
|
no_bias=True,
|
|
name="%s_w%d_h%d_y_conv2" %
|
|
(self.name, w, h),
|
|
workspace=self.workspace)
|
|
ybody = mx.sym.BatchNorm(data=ybody,
|
|
fix_gamma=False,
|
|
momentum=bn_mom,
|
|
eps=2e-5,
|
|
name="%s_w%d_h%d_y_bn2" %
|
|
(self.name, w, h))
|
|
ybody = Act(data=ybody,
|
|
act_type='relu',
|
|
name="%s_w%d_h%d_y_act2" %
|
|
(self.name, w, h))
|
|
else:
|
|
ybody = self.get_conv(y[0], "%s_w%d_h%d_5" % (self.name, w, h))
|
|
#if not HC:
|
|
if config.net_sta == 2 and h == 3 and w == 2:
|
|
z = self.get_output(w + 1, h)
|
|
zbody = z[0]
|
|
zbody = mx.sym.Pooling(data=zbody,
|
|
kernel=(z[1], z[1]),
|
|
stride=(z[1], z[1]),
|
|
pad=(0, 0),
|
|
pool_type='avg')
|
|
body = xbody + ybody
|
|
body = body / 2
|
|
body = mx.sym.broadcast_mul(body, zbody)
|
|
else: #sta==1
|
|
body = xbody + ybody
|
|
body = body / 2
|
|
ret = body, x[1]
|
|
|
|
assert ret is not None
|
|
self.sym_map[key] = ret
|
|
return ret
|
|
|
|
def get(self):
|
|
return self.get_output(1, 1)[0]
|
|
|
|
|
|
class SymCoherent:
|
|
def __init__(self, per_batch_size):
|
|
self.per_batch_size = per_batch_size
|
|
self.flip_order = [
|
|
16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 26, 25,
|
|
24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33, 32, 31,
|
|
45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52, 51, 50,
|
|
49, 48, 59, 58, 57, 56, 55, 64, 63, 62, 61, 60, 67, 66, 65
|
|
]
|
|
|
|
def get(self, data):
|
|
#data.shape[0]==per_batch_size
|
|
b = self.per_batch_size // 2
|
|
ux = mx.sym.slice_axis(data, axis=0, begin=0, end=b)
|
|
dx = mx.sym.slice_axis(data, axis=0, begin=b, end=b * 2)
|
|
ux = mx.sym.flip(ux, axis=3)
|
|
#ux = mx.sym.take(ux, indices = self.flip_order, axis=0)
|
|
ux_list = []
|
|
for o in self.flip_order:
|
|
_ux = mx.sym.slice_axis(ux, axis=1, begin=o, end=o + 1)
|
|
ux_list.append(_ux)
|
|
ux = mx.sym.concat(*ux_list, dim=1)
|
|
return ux, dx
|
|
|
|
|
|
def l2_loss(x, y):
|
|
loss = x - y
|
|
loss = mx.symbol.smooth_l1(loss, scalar=1.0)
|
|
#loss = loss*loss
|
|
loss = mx.symbol.mean(loss)
|
|
return loss
|
|
|
|
|
|
def ce_loss(x, y):
|
|
#loss = mx.sym.SoftmaxOutput(data = x, label = y, normalization='valid', multi_output=True)
|
|
x_max = mx.sym.max(x, axis=[2, 3], keepdims=True)
|
|
x = mx.sym.broadcast_minus(x, x_max)
|
|
body = mx.sym.exp(x)
|
|
sums = mx.sym.sum(body, axis=[2, 3], keepdims=True)
|
|
body = mx.sym.broadcast_div(body, sums)
|
|
loss = mx.sym.log(body)
|
|
loss = loss * y * -1.0
|
|
loss = mx.symbol.mean(loss, axis=[1, 2, 3])
|
|
#loss = mx.symbol.mean(loss)
|
|
return loss
|
|
|
|
|
|
def get_symbol(num_classes):
|
|
m = config.multiplier
|
|
sFilters = max(int(64 * m), 32)
|
|
mFilters = max(int(128 * m), 32)
|
|
nFilters = int(256 * m)
|
|
|
|
nModules = 1
|
|
nStacks = config.net_stacks
|
|
binarize = config.net_binarize
|
|
input_size = config.input_img_size
|
|
label_size = config.output_label_size
|
|
use_coherent = config.net_coherent
|
|
use_STA = config.net_sta
|
|
N = config.net_n
|
|
DCN = config.net_dcn
|
|
per_batch_size = config.per_batch_size
|
|
print('binarize', binarize)
|
|
print('use_coherent', use_coherent)
|
|
print('use_STA', use_STA)
|
|
print('use_N', N)
|
|
print('use_DCN', DCN)
|
|
print('per_batch_size', per_batch_size)
|
|
#assert(label_size==64 or label_size==32)
|
|
#assert(input_size==128 or input_size==256)
|
|
coherentor = SymCoherent(per_batch_size)
|
|
D = input_size // label_size
|
|
print(input_size, label_size, D)
|
|
data = mx.sym.Variable(name='data')
|
|
data = data - 127.5
|
|
data = data * 0.0078125
|
|
gt_label = mx.symbol.Variable(name='softmax_label')
|
|
losses = []
|
|
closses = []
|
|
ref_label = gt_label
|
|
if D == 4:
|
|
body = Conv(data=data,
|
|
num_filter=sFilters,
|
|
kernel=(7, 7),
|
|
stride=(2, 2),
|
|
pad=(3, 3),
|
|
no_bias=True,
|
|
name="conv0",
|
|
workspace=workspace)
|
|
else:
|
|
body = Conv(data=data,
|
|
num_filter=sFilters,
|
|
kernel=(3, 3),
|
|
stride=(1, 1),
|
|
pad=(1, 1),
|
|
no_bias=True,
|
|
name="conv0",
|
|
workspace=workspace)
|
|
body = mx.sym.BatchNorm(data=body,
|
|
fix_gamma=False,
|
|
eps=2e-5,
|
|
momentum=bn_mom,
|
|
name='bn0')
|
|
body = Act(data=body, act_type='relu', name='relu0')
|
|
|
|
dcn = False
|
|
body = conv_block(body, mFilters, (1, 1), sFilters == mFilters, 'res0',
|
|
False, dcn, 1)
|
|
|
|
body = mx.sym.Pooling(data=body,
|
|
kernel=(2, 2),
|
|
stride=(2, 2),
|
|
pad=(0, 0),
|
|
pool_type='max')
|
|
|
|
body = conv_block(body, mFilters, (1, 1), True, 'res1', False, dcn,
|
|
1) #TODO
|
|
body = conv_block(body, nFilters, (1, 1), mFilters == nFilters, 'res2',
|
|
binarize, dcn, 1) #binarize=True?
|
|
|
|
heatmap = None
|
|
|
|
for i in range(nStacks):
|
|
shortcut = body
|
|
if config.net_sta > 0:
|
|
sta = STA(body, nFilters, nModules, config.net_n + 1, workspace,
|
|
'sta%d' % (i))
|
|
body = sta.get()
|
|
else:
|
|
body = hourglass(body, nFilters, nModules, config.net_n, workspace,
|
|
'stack%d_hg' % (i), binarize, dcn)
|
|
for j in range(nModules):
|
|
body = conv_block(body, nFilters, (1, 1), True,
|
|
'stack%d_unit%d' % (i, j), binarize, dcn, 1)
|
|
_dcn = True if config.net_dcn >= 2 else False
|
|
ll = ConvFactory(body,
|
|
nFilters, (1, 1),
|
|
dcn=_dcn,
|
|
name='stack%d_ll' % (i))
|
|
_name = "heatmap%d" % (i) if i < nStacks - 1 else "heatmap"
|
|
_dcn = True if config.net_dcn >= 2 else False
|
|
if not _dcn:
|
|
out = Conv(data=ll,
|
|
num_filter=num_classes,
|
|
kernel=(1, 1),
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
name=_name,
|
|
workspace=workspace)
|
|
else:
|
|
out_offset = mx.symbol.Convolution(name=_name + '_offset',
|
|
data=ll,
|
|
num_filter=18,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
stride=(1, 1))
|
|
out = mx.contrib.symbol.DeformableConvolution(
|
|
name=_name,
|
|
data=ll,
|
|
offset=out_offset,
|
|
num_filter=num_classes,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
num_deformable_group=1,
|
|
stride=(1, 1),
|
|
dilate=(1, 1),
|
|
no_bias=False)
|
|
#out = Conv(data=ll, num_filter=num_classes, kernel=(3,3), stride=(1,1), pad=(1,1),
|
|
# name=_name, workspace=workspace)
|
|
if i == nStacks - 1:
|
|
heatmap = out
|
|
loss = ce_loss(out, ref_label)
|
|
#loss = loss/nStacks
|
|
#loss = l2_loss(out, ref_label)
|
|
losses.append(loss)
|
|
if config.net_coherent > 0:
|
|
ux, dx = coherentor.get(out)
|
|
closs = l2_loss(ux, dx)
|
|
closs = closs / nStacks
|
|
closses.append(closs)
|
|
|
|
if i < nStacks - 1:
|
|
ll2 = Conv(data=ll,
|
|
num_filter=nFilters,
|
|
kernel=(1, 1),
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
name="stack%d_ll2" % (i),
|
|
workspace=workspace)
|
|
out2 = Conv(data=out,
|
|
num_filter=nFilters,
|
|
kernel=(1, 1),
|
|
stride=(1, 1),
|
|
pad=(0, 0),
|
|
name="stack%d_out2" % (i),
|
|
workspace=workspace)
|
|
body = mx.symbol.add_n(shortcut, ll2, out2)
|
|
_dcn = True if (config.net_dcn == 1
|
|
or config.net_dcn == 3) else False
|
|
if _dcn:
|
|
_name = "stack%d_out3" % (i)
|
|
out3_offset = mx.symbol.Convolution(name=_name + '_offset',
|
|
data=body,
|
|
num_filter=18,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
stride=(1, 1))
|
|
out3 = mx.contrib.symbol.DeformableConvolution(
|
|
name=_name,
|
|
data=body,
|
|
offset=out3_offset,
|
|
num_filter=nFilters,
|
|
pad=(1, 1),
|
|
kernel=(3, 3),
|
|
num_deformable_group=1,
|
|
stride=(1, 1),
|
|
dilate=(1, 1),
|
|
no_bias=False)
|
|
body = out3
|
|
|
|
pred = mx.symbol.BlockGrad(heatmap)
|
|
#loss = mx.symbol.add_n(*losses)
|
|
#loss = mx.symbol.MakeLoss(loss)
|
|
#syms = [loss]
|
|
syms = []
|
|
for loss in losses:
|
|
loss = mx.symbol.MakeLoss(loss)
|
|
syms.append(loss)
|
|
if len(closses) > 0:
|
|
coherent_weight = 0.0001
|
|
closs = mx.symbol.add_n(*closses)
|
|
closs = mx.symbol.MakeLoss(closs, grad_scale=coherent_weight)
|
|
syms.append(closs)
|
|
syms.append(pred)
|
|
sym = mx.symbol.Group(syms)
|
|
return sym
|
|
|
|
|
|
def init_weights(sym, data_shape_dict):
|
|
#print('in hg')
|
|
arg_name = sym.list_arguments()
|
|
aux_name = sym.list_auxiliary_states()
|
|
arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
|
|
arg_shape_dict = dict(zip(arg_name, arg_shape))
|
|
aux_shape_dict = dict(zip(aux_name, aux_shape))
|
|
#print(aux_shape)
|
|
#print(aux_params)
|
|
#print(arg_shape_dict)
|
|
arg_params = {}
|
|
aux_params = {}
|
|
for k in arg_shape_dict:
|
|
v = arg_shape_dict[k]
|
|
#print(k,v)
|
|
if k.endswith('offset_weight') or k.endswith('offset_bias'):
|
|
print('initializing', k)
|
|
arg_params[k] = mx.nd.zeros(shape=v)
|
|
elif k.startswith('fc6_'):
|
|
if k.endswith('_weight'):
|
|
print('initializing', k)
|
|
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
|
|
elif k.endswith('_bias'):
|
|
print('initializing', k)
|
|
arg_params[k] = mx.nd.zeros(shape=v)
|
|
elif k.find('upsampling') >= 0:
|
|
print('initializing upsampling_weight', k)
|
|
arg_params[k] = mx.nd.zeros(shape=arg_shape_dict[k])
|
|
init = mx.init.Initializer()
|
|
init._init_bilinear(k, arg_params[k])
|
|
return arg_params, aux_params
|