Files
insightface/alignment/symbol/sym_heatmap.py
2019-06-07 12:44:48 +08:00

619 lines
30 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import mxnet as mx
import numpy as np
from config import config
ACT_BIT = 1
bn_mom = 0.9
workspace = 256
memonger = False
def Conv(**kwargs):
body = mx.sym.Convolution(**kwargs)
return body
def Act(data, act_type, name):
if act_type=='prelu':
body = mx.sym.LeakyReLU(data = data, act_type='prelu', name = name)
else:
body = mx.symbol.Activation(data=data, act_type=act_type, name=name)
return body
#def lin(data, num_filter, workspace, name, binarize, dcn):
# bit = 1
# if not binarize:
# if not dcn:
# conv1 = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
# no_bias=True, workspace=workspace, name=name + '_conv')
# bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
# act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
# return act1
# else:
# bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
# act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
# conv1_offset = mx.symbol.Convolution(name=name+'_conv_offset', data = act1,
# num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
# conv1 = mx.contrib.symbol.DeformableConvolution(name=name+"_conv", data=act1, offset=conv1_offset,
# num_filter=num_filter, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
# #conv1 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
# # no_bias=False, workspace=workspace, name=name + '_conv')
# return conv1
# else:
# bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
# act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
# conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
# no_bias=True, workspace=workspace, name=name + '_conv', act_bit=ACT_BIT, weight_bit=bit)
# conv1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
# return conv1
def lin3(data, num_filter, workspace, name, k, g=1, d=1):
if k!=3:
conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=((k-1)//2,(k-1)//2), num_group=g,
no_bias=True, workspace=workspace, name=name + '_conv')
else:
conv1 = Conv(data=data, num_filter=num_filter, kernel=(k,k), stride=(1,1), pad=(d,d), num_group=g, dilate=(d, d),
no_bias=True, workspace=workspace, name=name + '_conv')
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn')
act1 = Act(data=bn1, act_type='relu', name=name + '_relu')
ret = act1
return ret
def ConvFactory(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), act_type="relu", mirror_attr={}, with_act=True, dcn=False, name=''):
if not dcn:
conv = mx.symbol.Convolution(
data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, workspace=workspace, name=name+'_conv')
else:
conv_offset = mx.symbol.Convolution(name=name+'_conv_offset', data = data,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv = mx.contrib.symbol.DeformableConvolution(name=name+"_conv", data=data, offset=conv_offset,
num_filter=num_filter, pad=(1,1), kernel=(3,3), num_deformable_group=1, stride=stride, dilate=(1, 1), no_bias=False)
bn = mx.symbol.BatchNorm(data=conv, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name+'_bn')
if with_act:
act = Act(bn, act_type, name=name+'_relu')
#act = mx.symbol.Activation(
# data=bn, act_type=act_type, attr=mirror_attr, name=name+'_relu')
return act
else:
return bn
class CAB:
def __init__(self, data, nFilters, nModules, n, workspace, name, dilate, group):
self.data = data
self.nFilters = nFilters
self.nModules = nModules
self.n = n
self.workspace = workspace
self.name = name
self.dilate = dilate
self.group = group
self.sym_map = {}
def get_output(self, w, h):
key = (w, h)
if key in self.sym_map:
return self.sym_map[key]
ret = None
if h==self.n:
if w==self.n:
ret = (self.data, self.nFilters)
else:
x = self.get_output(w+1, h)
f = int(x[1]*0.5)
if w!=self.n-1:
body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, 1)
else:
body = lin3(x[0], f, self.workspace, "%s_w%d_h%d_1"%(self.name, w, h), 3, self.group, self.dilate)
ret = (body,f)
else:
x = self.get_output(w+1, h+1)
y = self.get_output(w, h+1)
if h%2==1 and h!=w:
xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
#xbody = xbody+x[0]
else:
xbody = x[0]
#xbody = x[0]
#xbody = lin3(x[0], x[1], self.workspace, "%s_w%d_h%d_2"%(self.name, w, h), 3, x[1])
if w==0:
ybody = lin3(y[0], y[1], self.workspace, "%s_w%d_h%d_3"%(self.name, w, h), 3, self.group)
else:
ybody = y[0]
ybody = mx.sym.concat(y[0], ybody, dim=1)
body = mx.sym.add_n(xbody,ybody, name="%s_w%d_h%d_add"%(self.name, w, h))
body = body/2
ret = (body, x[1])
self.sym_map[key] = ret
return ret
def get(self):
return self.get_output(1, 1)[0]
def conv_resnet(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
bit = 1
#print('in unit2')
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
if not binarize:
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1')
else:
act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True)
conv1 = mx.sym.QConvolution(data=act1, num_filter=int(num_filter*0.5), kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit)
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
if not binarize:
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = Conv(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
else:
act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True)
conv2 = mx.sym.QConvolution(data=act2, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit)
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
if not binarize:
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
conv3 = Conv(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0), no_bias=True,
workspace=workspace, name=name + '_conv3')
else:
act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True)
conv3 = mx.sym.QConvolution(data=act3, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit)
#if binarize:
# conv3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
if dim_match:
shortcut = data
else:
if not binarize:
shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
else:
shortcut = mx.sym.QConvolution(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit)
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv3 + shortcut
def conv_hpm(data, num_filter, stride, dim_match, name, binarize, dcn, dilation, **kwargs):
bit = 1
#print('in unit2')
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
if not binarize:
act1 = Act(data=bn1, act_type='relu', name=name + '_relu1')
if not dcn:
conv1 = Conv(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
no_bias=True, workspace=workspace, name=name + '_conv1')
else:
conv1_offset = mx.symbol.Convolution(name=name+'_conv1_offset', data = act1,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv1 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv1', data=act1, offset=conv1_offset,
num_filter=int(num_filter*0.5), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
else:
act1 = mx.sym.QActivation(data=bn1, act_bit=ACT_BIT, name=name + '_relu1', backward_only=True)
conv1 = mx.sym.QConvolution_v1(data=act1, num_filter=int(num_filter*0.5), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv1', act_bit=ACT_BIT, weight_bit=bit)
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
if not binarize:
act2 = Act(data=bn2, act_type='relu', name=name + '_relu2')
if not dcn:
conv2 = Conv(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
no_bias=True, workspace=workspace, name=name + '_conv2')
else:
conv2_offset = mx.symbol.Convolution(name=name+'_conv2_offset', data = act2,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv2 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv2', data=act2, offset=conv2_offset,
num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
else:
act2 = mx.sym.QActivation(data=bn2, act_bit=ACT_BIT, name=name + '_relu2', backward_only=True)
conv2 = mx.sym.QConvolution_v1(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2', act_bit=ACT_BIT, weight_bit=bit)
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
if not binarize:
act3 = Act(data=bn3, act_type='relu', name=name + '_relu3')
if not dcn:
conv3 = Conv(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(dilation,dilation), dilate=(dilation,dilation),
no_bias=True, workspace=workspace, name=name + '_conv3')
else:
conv3_offset = mx.symbol.Convolution(name=name+'_conv3_offset', data = act3,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
conv3 = mx.contrib.symbol.DeformableConvolution(name=name+'_conv3', data=act3, offset=conv3_offset,
num_filter=int(num_filter*0.25), pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=True)
else:
act3 = mx.sym.QActivation(data=bn3, act_bit=ACT_BIT, name=name + '_relu3', backward_only=True)
conv3 = mx.sym.QConvolution_v1(data=act3, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv3', act_bit=ACT_BIT, weight_bit=bit)
conv4 = mx.symbol.Concat(*[conv1, conv2, conv3])
if binarize:
conv4 = mx.sym.BatchNorm(data=conv4, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn4')
if dim_match:
shortcut = data
else:
if not binarize:
shortcut = Conv(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
else:
#assert(False)
shortcut = mx.sym.QConvolution_v1(data=act1, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_sc', act_bit=ACT_BIT, weight_bit=bit)
shortcut = mx.sym.BatchNorm(data=shortcut, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_sc_bn')
if memonger:
shortcut._set_attr(mirror_stage='True')
return conv4 + shortcut
#return bn4 + shortcut
#return act4 + shortcut
def block17(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}, name=''):
tower_conv = ConvFactory(net, 192, (1, 1), name=name+'_conv')
tower_conv1_0 = ConvFactory(net, 129, (1, 1), name=name+'_conv1_0')
tower_conv1_1 = ConvFactory(tower_conv1_0, 160, (1, 7), pad=(1, 2), name=name+'_conv1_1')
tower_conv1_2 = ConvFactory(tower_conv1_1, 192, (7, 1), pad=(2, 1), name=name+'_conv1_2')
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_2])
tower_out = ConvFactory(
tower_mixed, input_num_channels, (1, 1), with_act=False, name=name+'_conv_out')
net = net+scale * tower_out
if with_act:
act = mx.symbol.Activation(
data=net, act_type=act_type, attr=mirror_attr)
return act
else:
return net
def block35(net, input_num_channels, scale=1.0, with_act=True, act_type='relu', mirror_attr={}, name=''):
M = 1.0
tower_conv = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1), name=name+'_conv')
tower_conv1_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1), name=name+'_conv1_0')
tower_conv1_1 = ConvFactory(tower_conv1_0, int(input_num_channels*0.25*M), (3, 3), pad=(1, 1), name=name+'_conv1_1')
tower_conv2_0 = ConvFactory(net, int(input_num_channels*0.25*M), (1, 1), name=name+'_conv2_0')
tower_conv2_1 = ConvFactory(tower_conv2_0, int(input_num_channels*0.375*M), (3, 3), pad=(1, 1), name=name+'_conv2_1')
tower_conv2_2 = ConvFactory(tower_conv2_1, int(input_num_channels*0.5*M), (3, 3), pad=(1, 1), name=name+'_conv2_2')
tower_mixed = mx.symbol.Concat(*[tower_conv, tower_conv1_1, tower_conv2_2])
tower_out = ConvFactory(
tower_mixed, input_num_channels, (1, 1), with_act=False, name=name+'_conv_out')
net = net+scale * tower_out
if with_act:
act = mx.symbol.Activation(
data=net, act_type=act_type, attr=mirror_attr)
return act
else:
return net
def conv_inception(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
assert not binarize
if stride[0]>1 or not dim_match:
return conv_resnet(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
conv4 = block35(data, num_filter, name=name+'_block35')
return conv4
def conv_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs):
if stride[0]>1 or not dim_match:
return conv_hpm(data, num_filter, stride, dim_match, name, binarize, dcn, dilate, **kwargs)
cab = CAB(data, num_filter, 1, 4, workspace, name, dilate, 1)
return cab.get()
def conv_block(data, num_filter, stride, dim_match, name, binarize, dcn, dilate):
if config.net_block=='resnet':
return conv_resnet(data, num_filter, stride, dim_match, name, binarize, dcn, dilate)
elif config.net_block=='inception':
return conv_inception(data, num_filter, stride, dim_match, name, binarize, dcn, dilate)
elif config.net_block=='hpm':
return conv_hpm(data, num_filter, stride, dim_match, name, binarize, dcn, dilate)
elif config.net_block=='cab':
return conv_cab(data, num_filter, stride, dim_match, name, binarize, dcn, dilate)
def hourglass(data, nFilters, nModules, n, workspace, name, binarize, dcn):
s = 2
_dcn = False
up1 = data
for i in range(nModules):
up1 = conv_block(up1, nFilters, (1,1), True, "%s_up1_%d"%(name,i), binarize, _dcn, 1)
low1 = mx.sym.Pooling(data=data, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
for i in range(nModules):
low1 = conv_block(low1, nFilters, (1,1), True, "%s_low1_%d"%(name,i), binarize, _dcn, 1)
if n>1:
low2 = hourglass(low1, nFilters, nModules, n-1, workspace, "%s_%d"%(name, n-1), binarize, dcn)
else:
low2 = low1
for i in range(nModules):
low2 = conv_block(low2, nFilters, (1,1), True, "%s_low2_%d"%(name,i), binarize, _dcn, 1) #TODO
low3 = low2
for i in range(nModules):
low3 = conv_block(low3, nFilters, (1,1), True, "%s_low3_%d"%(name,i), binarize, _dcn, 1)
up2 = mx.symbol.UpSampling(low3, scale=s, sample_type='nearest', workspace=512, name='%s_upsampling_%s'%(name,n), num_args=1)
return mx.symbol.add_n(up1, up2)
class STA:
def __init__(self, data, nFilters, nModules, n, workspace, name):
self.data = data
self.nFilters = nFilters
self.nModules = nModules
self.n = n
self.workspace = workspace
self.name = name
self.sym_map = {}
def get_conv(self, data, name, dilate=1, group=1):
cab = CAB(data, self.nFilters, self.nModules, 4, self.workspace, name, dilate, group)
return cab.get()
def get_output(self, w, h):
#print(w,h)
assert w>=1 and w<=config.net_n+1
assert h>=1 and h<=config.net_n+1
s = 2
bn_mom = 0.9
key = (w,h)
if key in self.sym_map:
return self.sym_map[key]
ret = None
if h==self.n:
if w==self.n:
ret = self.data,64
else:
x = self.get_output(w+1, h)
body = self.get_conv(x[0], "%s_w%d_h%d_1"%(self.name, w, h))
body = mx.sym.Pooling(data=body, kernel=(s, s), stride=(s,s), pad=(0,0), pool_type='max')
body = self.get_conv(body, "%s_w%d_h%d_2"%(self.name, w, h))
ret = body, x[1]//2
else:
x = self.get_output(w+1, h+1)
y = self.get_output(w, h+1)
HC = False
if h%2==1 and h!=w:
xbody = lin3(x[0], self.nFilters, self.workspace, "%s_w%d_h%d_x"%(self.name, w, h), 3, self.nFilters, 1)
HC = True
#xbody = x[0]
else:
xbody = x[0]
if x[1]//y[1]==2:
if w>1:
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
stride=(s, s),
name='%s_upsampling_w%d_h%d'%(self.name,w, h),
attr={'lr_mult': '1.0'}, workspace=self.workspace)
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
else:
if h>=1:
ybody = mx.symbol.UpSampling(y[0], scale=s, sample_type='nearest', workspace=512, name='%s_upsampling_w%d_h%d'%(self.name,w, h), num_args=1)
ybody = self.get_conv(ybody, "%s_w%d_h%d_4"%(self.name, w, h))
else:
ybody = mx.symbol.Deconvolution(data=y[0], num_filter=self.nFilters, kernel=(s,s),
stride=(s, s),
name='%s_upsampling_w%d_h%d'%(self.name,w, h),
attr={'lr_mult': '1.0'}, workspace=self.workspace)
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn"%(self.name, w, h))
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act"%(self.name, w, h))
ybody = Conv(data=ybody, num_filter=self.nFilters, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, name="%s_w%d_h%d_y_conv2"%(self.name, w, h), workspace=self.workspace)
ybody = mx.sym.BatchNorm(data=ybody, fix_gamma=False, momentum=bn_mom, eps=2e-5, name="%s_w%d_h%d_y_bn2"%(self.name, w, h))
ybody = Act(data=ybody, act_type='relu', name="%s_w%d_h%d_y_act2"%(self.name, w, h))
else:
ybody = self.get_conv(y[0], "%s_w%d_h%d_5"%(self.name, w, h))
#if not HC:
if config.net_sta==2 and h==3 and w==2:
z = self.get_output(w+1, h)
zbody = z[0]
zbody = mx.sym.Pooling(data=zbody, kernel=(z[1], z[1]), stride=(z[1],z[1]), pad=(0,0), pool_type='avg')
body = xbody+ybody
body = body/2
body = mx.sym.broadcast_mul(body, zbody)
else: #sta==1
body = xbody+ybody
body = body/2
ret = body, x[1]
assert ret is not None
self.sym_map[key] = ret
return ret
def get(self):
return self.get_output(1, 1)[0]
class SymCoherent:
def __init__(self, per_batch_size):
self.per_batch_size = per_batch_size
self.flip_order = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 34, 33, 32, 31,
45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 40, 54, 53, 52, 51, 50, 49, 48,
59, 58, 57, 56, 55, 64, 63, 62, 61, 60, 67, 66, 65]
def get(self, data):
#data.shape[0]==per_batch_size
b = self.per_batch_size//2
ux = mx.sym.slice_axis(data, axis=0, begin=0, end=b)
dx = mx.sym.slice_axis(data, axis=0, begin=b, end=b*2)
ux = mx.sym.flip(ux, axis=3)
#ux = mx.sym.take(ux, indices = self.flip_order, axis=0)
ux_list = []
for o in self.flip_order:
_ux = mx.sym.slice_axis(ux, axis=1, begin=o, end=o+1)
ux_list.append(_ux)
ux = mx.sym.concat(*ux_list, dim=1)
return ux, dx
def l2_loss(x, y):
loss = x-y
loss = mx.symbol.smooth_l1(loss, scalar=1.0)
#loss = loss*loss
loss = mx.symbol.mean(loss)
return loss
def ce_loss(x, y):
#loss = mx.sym.SoftmaxOutput(data = x, label = y, normalization='valid', multi_output=True)
x_max = mx.sym.max(x, axis=[2,3], keepdims=True)
x = mx.sym.broadcast_minus(x, x_max)
body = mx.sym.exp(x)
sums = mx.sym.sum(body, axis=[2,3], keepdims=True)
body = mx.sym.broadcast_div(body, sums)
loss = mx.sym.log(body)
loss = loss*y*-1.0
#loss = mx.symbol.mean(loss, axis=[1,2,3])
loss = mx.symbol.mean(loss)
return loss
def get_symbol(num_classes):
m = config.multiplier
sFilters = max(int(64*m), 32)
mFilters = max(int(128*m), 32)
nFilters = int(256*m)
nModules = 1
nStacks = config.net_stacks
binarize = config.net_binarize
input_size = config.input_img_size
label_size = config.output_label_size
use_coherent = config.net_coherent
use_STA = config.net_sta
N = config.net_n
DCN = config.net_dcn
per_batch_size = config.per_batch_size
print('binarize', binarize)
print('use_coherent', use_coherent)
print('use_STA', use_STA)
print('use_N', N)
print('use_DCN', DCN)
print('per_batch_size', per_batch_size)
#assert(label_size==64 or label_size==32)
#assert(input_size==128 or input_size==256)
coherentor = SymCoherent(per_batch_size)
D = input_size // label_size
print(input_size, label_size, D)
data = mx.sym.Variable(name='data')
data = data-127.5
data = data*0.0078125
gt_label = mx.symbol.Variable(name='softmax_label')
losses = []
closses = []
ref_label = gt_label
if D==4:
body = Conv(data=data, num_filter=sFilters, kernel=(7, 7), stride=(2,2), pad=(3, 3),
no_bias=True, name="conv0", workspace=workspace)
else:
body = Conv(data=data, num_filter=sFilters, kernel=(3, 3), stride=(1,1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = Act(data=body, act_type='relu', name='relu0')
dcn = False
body = conv_block(body, mFilters, (1,1), sFilters==mFilters, 'res0', False, dcn, 1)
body = mx.sym.Pooling(data=body, kernel=(2, 2), stride=(2,2), pad=(0,0), pool_type='max')
body = conv_block(body, mFilters, (1,1), True, 'res1', False, dcn, 1) #TODO
body = conv_block(body, nFilters, (1,1), mFilters==nFilters, 'res2', binarize, dcn, 1) #binarize=True?
heatmap = None
for i in range(nStacks):
shortcut = body
if config.net_sta>0:
sta = STA(body, nFilters, nModules, config.net_n+1, workspace, 'sta%d'%(i))
body = sta.get()
else:
body = hourglass(body, nFilters, nModules, config.net_n, workspace, 'stack%d_hg'%(i), binarize, dcn)
for j in range(nModules):
body = conv_block(body, nFilters, (1,1), True, 'stack%d_unit%d'%(i,j), binarize, dcn, 1)
_dcn = True if config.net_dcn>=2 else False
ll = ConvFactory(body, nFilters, (1,1), dcn = _dcn, name='stack%d_ll'%(i))
_name = "heatmap%d"%(i) if i<nStacks-1 else "heatmap"
_dcn = True if config.net_dcn>=2 else False
if not _dcn:
out = Conv(data=ll, num_filter=num_classes, kernel=(1, 1), stride=(1,1), pad=(0,0),
name=_name, workspace=workspace)
else:
out_offset = mx.symbol.Convolution(name=_name+'_offset', data = ll,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
out = mx.contrib.symbol.DeformableConvolution(name=_name, data=ll, offset=out_offset,
num_filter=num_classes, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
#out = Conv(data=ll, num_filter=num_classes, kernel=(3,3), stride=(1,1), pad=(1,1),
# name=_name, workspace=workspace)
if i==nStacks-1:
heatmap = out
loss = ce_loss(out, ref_label)
#loss = loss/nStacks
#loss = l2_loss(out, ref_label)
losses.append(loss)
if config.net_coherent>0:
ux, dx = coherentor.get(out)
closs = l2_loss(ux, dx)
closs = closs/nStacks
closses.append(closs)
if i<nStacks-1:
ll2 = Conv(data=ll, num_filter=nFilters, kernel=(1, 1), stride=(1,1), pad=(0,0),
name="stack%d_ll2"%(i), workspace=workspace)
out2 = Conv(data=out, num_filter=nFilters, kernel=(1, 1), stride=(1,1), pad=(0,0),
name="stack%d_out2"%(i), workspace=workspace)
body = mx.symbol.add_n(shortcut, ll2, out2)
_dcn = True if (config.net_dcn==1 or config.net_dcn==3) else False
if _dcn:
_name = "stack%d_out3" % (i)
out3_offset = mx.symbol.Convolution(name=_name+'_offset', data = body,
num_filter=18, pad=(1, 1), kernel=(3, 3), stride=(1, 1))
out3 = mx.contrib.symbol.DeformableConvolution(name=_name, data=body, offset=out3_offset,
num_filter=nFilters, pad=(1,1), kernel=(3, 3), num_deformable_group=1, stride=(1, 1), dilate=(1, 1), no_bias=False)
body = out3
pred = mx.symbol.BlockGrad(heatmap)
#loss = mx.symbol.add_n(*losses)
#loss = mx.symbol.MakeLoss(loss)
#syms = [loss]
syms = []
for loss in losses:
loss = mx.symbol.MakeLoss(loss)
syms.append(loss)
if len(closses)>0:
coherent_weight = 0.0001
closs = mx.symbol.add_n(*closses)
closs = mx.symbol.MakeLoss(closs, grad_scale = coherent_weight)
syms.append(closs)
syms.append(pred)
sym = mx.symbol.Group( syms )
return sym
def init_weights(sym, data_shape_dict):
#print('in hg')
arg_name = sym.list_arguments()
aux_name = sym.list_auxiliary_states()
arg_shape, _, aux_shape = sym.infer_shape(**data_shape_dict)
arg_shape_dict = dict(zip(arg_name, arg_shape))
aux_shape_dict = dict(zip(aux_name, aux_shape))
#print(aux_shape)
#print(aux_params)
#print(arg_shape_dict)
arg_params = {}
aux_params = {}
for k in arg_shape_dict:
v = arg_shape_dict[k]
#print(k,v)
if k.endswith('offset_weight') or k.endswith('offset_bias'):
print('initializing',k)
arg_params[k] = mx.nd.zeros(shape = v)
elif k.startswith('fc6_'):
if k.endswith('_weight'):
print('initializing',k)
arg_params[k] = mx.random.normal(0, 0.01, shape=v)
elif k.endswith('_bias'):
print('initializing',k)
arg_params[k] = mx.nd.zeros(shape=v)
elif k.find('upsampling')>=0:
print('initializing upsampling_weight', k)
arg_params[k] = mx.nd.zeros(shape=arg_shape_dict[k])
init = mx.init.Initializer()
init._init_bilinear(k, arg_params[k])
return arg_params, aux_params