mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 13:46:15 +00:00
754 lines
27 KiB
Python
754 lines
27 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import Parameter
|
|
import torch.nn.functional as F
|
|
from collections import OrderedDict
|
|
from . import deeplab_xception, gcn
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
#######################
|
|
# base model
|
|
#######################
|
|
|
|
|
|
class deeplab_xception_transfer_basemodel(deeplab_xception.DeepLabv3_plus):
|
|
def __init__(
|
|
self,
|
|
nInputChannels=3,
|
|
n_classes=7,
|
|
os=16,
|
|
input_channels=256,
|
|
hidden_layers=128,
|
|
out_channels=256,
|
|
):
|
|
super(deeplab_xception_transfer_basemodel, self).__init__(
|
|
nInputChannels=nInputChannels,
|
|
n_classes=n_classes,
|
|
os=os,
|
|
)
|
|
### source graph
|
|
# self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(input_channels=input_channels, hidden_layers=hidden_layers,
|
|
# nodes=n_classes)
|
|
# self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
# self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
# self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
#
|
|
# self.source_graph_2_fea = gcn.Graph_to_Featuremaps(input_channels=input_channels, output_channels=out_channels,
|
|
# hidden_layers=hidden_layers, nodes=n_classes
|
|
# )
|
|
# self.source_skip_conv = nn.Sequential(*[nn.Conv2d(input_channels, input_channels, kernel_size=1),
|
|
# nn.ReLU(True)])
|
|
|
|
### target graph
|
|
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(
|
|
input_channels=input_channels, hidden_layers=hidden_layers, nodes=n_classes
|
|
)
|
|
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
|
|
self.target_graph_2_fea = gcn.Graph_to_Featuremaps(
|
|
input_channels=input_channels,
|
|
output_channels=out_channels,
|
|
hidden_layers=hidden_layers,
|
|
nodes=n_classes,
|
|
)
|
|
self.target_skip_conv = nn.Sequential(
|
|
*[nn.Conv2d(input_channels, input_channels, kernel_size=1), nn.ReLU(True)]
|
|
)
|
|
|
|
def load_source_model(self, state_dict):
|
|
own_state = self.state_dict()
|
|
# for name inshop_cos own_state:
|
|
# print name
|
|
new_state_dict = OrderedDict()
|
|
for name, param in state_dict.items():
|
|
name = name.replace("module.", "")
|
|
if (
|
|
"graph" in name
|
|
and "source" not in name
|
|
and "target" not in name
|
|
and "fc_graph" not in name
|
|
and "transpose_graph" not in name
|
|
):
|
|
if "featuremap_2_graph" in name:
|
|
name = name.replace(
|
|
"featuremap_2_graph", "source_featuremap_2_graph"
|
|
)
|
|
else:
|
|
name = name.replace("graph", "source_graph")
|
|
new_state_dict[name] = 0
|
|
if name not in own_state:
|
|
if "num_batch" in name:
|
|
continue
|
|
print('unexpected key "{}" in state_dict'.format(name))
|
|
continue
|
|
# if isinstance(param, own_state):
|
|
if isinstance(param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
param = param.data
|
|
try:
|
|
own_state[name].copy_(param)
|
|
except:
|
|
print(
|
|
"While copying the parameter named {}, whose dimensions in the model are"
|
|
" {} and whose dimensions in the checkpoint are {}, ...".format(
|
|
name, own_state[name].size(), param.size()
|
|
)
|
|
)
|
|
continue # i add inshop_cos 2018/02/01
|
|
own_state[name].copy_(param)
|
|
# print 'copying %s' %name
|
|
|
|
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
|
if len(missing) > 0:
|
|
print('missing keys in state_dict: "{}"'.format(missing))
|
|
|
|
def get_target_parameter(self):
|
|
l = []
|
|
other = []
|
|
for name, k in self.named_parameters():
|
|
if "target" in name or "semantic" in name:
|
|
l.append(k)
|
|
else:
|
|
other.append(k)
|
|
return l, other
|
|
|
|
def get_semantic_parameter(self):
|
|
l = []
|
|
for name, k in self.named_parameters():
|
|
if "semantic" in name:
|
|
l.append(k)
|
|
return l
|
|
|
|
def get_source_parameter(self):
|
|
l = []
|
|
for name, k in self.named_parameters():
|
|
if "source" in name:
|
|
l.append(k)
|
|
return l
|
|
|
|
def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
|
|
x, low_level_features = self.xception_features(input)
|
|
# print(x.size())
|
|
x1 = self.aspp1(x)
|
|
x2 = self.aspp2(x)
|
|
x3 = self.aspp3(x)
|
|
x4 = self.aspp4(x)
|
|
x5 = self.global_avg_pool(x)
|
|
x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
|
|
|
x = self.concat_projection_conv1(x)
|
|
x = self.concat_projection_bn1(x)
|
|
x = self.relu(x)
|
|
# print(x.size())
|
|
x = F.upsample(
|
|
x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
|
|
)
|
|
|
|
low_level_features = self.feature_projection_conv1(low_level_features)
|
|
low_level_features = self.feature_projection_bn1(low_level_features)
|
|
low_level_features = self.relu(low_level_features)
|
|
# print(low_level_features.size())
|
|
# print(x.size())
|
|
x = torch.cat((x, low_level_features), dim=1)
|
|
x = self.decoder(x)
|
|
|
|
### add graph
|
|
|
|
# target graph
|
|
# print('x size',x.size(),adj1.size())
|
|
graph = self.target_featuremap_2_graph(x)
|
|
|
|
# graph combine
|
|
# print(graph.size(),source_2_target_graph.size())
|
|
# graph = self.fc_graph.forward(graph,relu=True)
|
|
# print(graph.size())
|
|
|
|
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
|
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
|
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
|
# print(graph.size(),x.size())
|
|
# graph = self.gcn_encode.forward(graph,relu=True)
|
|
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
|
# graph = self.gcn_decode.forward(graph,relu=True)
|
|
graph = self.target_graph_2_fea.forward(graph, x)
|
|
x = self.target_skip_conv(x)
|
|
x = x + graph
|
|
|
|
###
|
|
x = self.semantic(x)
|
|
x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
return x
|
|
|
|
|
|
class deeplab_xception_transfer_basemodel_savememory(deeplab_xception.DeepLabv3_plus):
|
|
def __init__(
|
|
self,
|
|
nInputChannels=3,
|
|
n_classes=7,
|
|
os=16,
|
|
input_channels=256,
|
|
hidden_layers=128,
|
|
out_channels=256,
|
|
):
|
|
super(deeplab_xception_transfer_basemodel_savememory, self).__init__(
|
|
nInputChannels=nInputChannels,
|
|
n_classes=n_classes,
|
|
os=os,
|
|
)
|
|
### source graph
|
|
|
|
### target graph
|
|
self.target_featuremap_2_graph = gcn.Featuremaps_to_Graph(
|
|
input_channels=input_channels, hidden_layers=hidden_layers, nodes=n_classes
|
|
)
|
|
self.target_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.target_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.target_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
|
|
self.target_graph_2_fea = gcn.Graph_to_Featuremaps_savemem(
|
|
input_channels=input_channels,
|
|
output_channels=out_channels,
|
|
hidden_layers=hidden_layers,
|
|
nodes=n_classes,
|
|
)
|
|
self.target_skip_conv = nn.Sequential(
|
|
*[nn.Conv2d(input_channels, input_channels, kernel_size=1), nn.ReLU(True)]
|
|
)
|
|
|
|
def load_source_model(self, state_dict):
|
|
own_state = self.state_dict()
|
|
# for name inshop_cos own_state:
|
|
# print name
|
|
new_state_dict = OrderedDict()
|
|
for name, param in state_dict.items():
|
|
name = name.replace("module.", "")
|
|
if (
|
|
"graph" in name
|
|
and "source" not in name
|
|
and "target" not in name
|
|
and "fc_graph" not in name
|
|
and "transpose_graph" not in name
|
|
):
|
|
if "featuremap_2_graph" in name:
|
|
name = name.replace(
|
|
"featuremap_2_graph", "source_featuremap_2_graph"
|
|
)
|
|
else:
|
|
name = name.replace("graph", "source_graph")
|
|
new_state_dict[name] = 0
|
|
if name not in own_state:
|
|
if "num_batch" in name:
|
|
continue
|
|
print('unexpected key "{}" in state_dict'.format(name))
|
|
continue
|
|
# if isinstance(param, own_state):
|
|
if isinstance(param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
param = param.data
|
|
try:
|
|
own_state[name].copy_(param)
|
|
except:
|
|
print(
|
|
"While copying the parameter named {}, whose dimensions in the model are"
|
|
" {} and whose dimensions in the checkpoint are {}, ...".format(
|
|
name, own_state[name].size(), param.size()
|
|
)
|
|
)
|
|
continue # i add inshop_cos 2018/02/01
|
|
own_state[name].copy_(param)
|
|
# print 'copying %s' %name
|
|
|
|
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
|
if len(missing) > 0:
|
|
print('missing keys in state_dict: "{}"'.format(missing))
|
|
|
|
def get_target_parameter(self):
|
|
l = []
|
|
other = []
|
|
for name, k in self.named_parameters():
|
|
if "target" in name or "semantic" in name:
|
|
l.append(k)
|
|
else:
|
|
other.append(k)
|
|
return l, other
|
|
|
|
def get_semantic_parameter(self):
|
|
l = []
|
|
for name, k in self.named_parameters():
|
|
if "semantic" in name:
|
|
l.append(k)
|
|
return l
|
|
|
|
def get_source_parameter(self):
|
|
l = []
|
|
for name, k in self.named_parameters():
|
|
if "source" in name:
|
|
l.append(k)
|
|
return l
|
|
|
|
def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
|
|
x, low_level_features = self.xception_features(input)
|
|
# print(x.size())
|
|
x1 = self.aspp1(x)
|
|
x2 = self.aspp2(x)
|
|
x3 = self.aspp3(x)
|
|
x4 = self.aspp4(x)
|
|
x5 = self.global_avg_pool(x)
|
|
x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
|
|
|
x = self.concat_projection_conv1(x)
|
|
x = self.concat_projection_bn1(x)
|
|
x = self.relu(x)
|
|
# print(x.size())
|
|
x = F.upsample(
|
|
x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
|
|
)
|
|
|
|
low_level_features = self.feature_projection_conv1(low_level_features)
|
|
low_level_features = self.feature_projection_bn1(low_level_features)
|
|
low_level_features = self.relu(low_level_features)
|
|
# print(low_level_features.size())
|
|
# print(x.size())
|
|
x = torch.cat((x, low_level_features), dim=1)
|
|
x = self.decoder(x)
|
|
|
|
### add graph
|
|
|
|
# target graph
|
|
# print('x size',x.size(),adj1.size())
|
|
graph = self.target_featuremap_2_graph(x)
|
|
|
|
# graph combine
|
|
# print(graph.size(),source_2_target_graph.size())
|
|
# graph = self.fc_graph.forward(graph,relu=True)
|
|
# print(graph.size())
|
|
|
|
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
|
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
|
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
|
# print(graph.size(),x.size())
|
|
# graph = self.gcn_encode.forward(graph,relu=True)
|
|
# graph = self.graph_conv2.forward(graph,adj=adj2,relu=True)
|
|
# graph = self.gcn_decode.forward(graph,relu=True)
|
|
graph = self.target_graph_2_fea.forward(graph, x)
|
|
x = self.target_skip_conv(x)
|
|
x = x + graph
|
|
|
|
###
|
|
x = self.semantic(x)
|
|
x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
return x
|
|
|
|
|
|
#######################
|
|
# transfer model
|
|
#######################
|
|
|
|
|
|
class deeplab_xception_transfer_projection(deeplab_xception_transfer_basemodel):
|
|
def __init__(
|
|
self,
|
|
nInputChannels=3,
|
|
n_classes=7,
|
|
os=16,
|
|
input_channels=256,
|
|
hidden_layers=128,
|
|
out_channels=256,
|
|
transfer_graph=None,
|
|
source_classes=20,
|
|
):
|
|
super(deeplab_xception_transfer_projection, self).__init__(
|
|
nInputChannels=nInputChannels,
|
|
n_classes=n_classes,
|
|
os=os,
|
|
input_channels=input_channels,
|
|
hidden_layers=hidden_layers,
|
|
out_channels=out_channels,
|
|
)
|
|
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(
|
|
input_channels=input_channels,
|
|
hidden_layers=hidden_layers,
|
|
nodes=source_classes,
|
|
)
|
|
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.transpose_graph = gcn.Graph_trans(
|
|
in_features=hidden_layers,
|
|
out_features=hidden_layers,
|
|
adj=transfer_graph,
|
|
begin_nodes=source_classes,
|
|
end_nodes=n_classes,
|
|
)
|
|
self.fc_graph = gcn.GraphConvolution(hidden_layers * 3, hidden_layers)
|
|
|
|
def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
|
|
x, low_level_features = self.xception_features(input)
|
|
# print(x.size())
|
|
x1 = self.aspp1(x)
|
|
x2 = self.aspp2(x)
|
|
x3 = self.aspp3(x)
|
|
x4 = self.aspp4(x)
|
|
x5 = self.global_avg_pool(x)
|
|
x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
|
|
|
x = self.concat_projection_conv1(x)
|
|
x = self.concat_projection_bn1(x)
|
|
x = self.relu(x)
|
|
# print(x.size())
|
|
x = F.upsample(
|
|
x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
|
|
)
|
|
|
|
low_level_features = self.feature_projection_conv1(low_level_features)
|
|
low_level_features = self.feature_projection_bn1(low_level_features)
|
|
low_level_features = self.relu(low_level_features)
|
|
# print(low_level_features.size())
|
|
# print(x.size())
|
|
x = torch.cat((x, low_level_features), dim=1)
|
|
x = self.decoder(x)
|
|
|
|
### add graph
|
|
# source graph
|
|
source_graph = self.source_featuremap_2_graph(x)
|
|
source_graph1 = self.source_graph_conv1.forward(
|
|
source_graph, adj=adj2_source, relu=True
|
|
)
|
|
source_graph2 = self.source_graph_conv2.forward(
|
|
source_graph1, adj=adj2_source, relu=True
|
|
)
|
|
source_graph3 = self.source_graph_conv2.forward(
|
|
source_graph2, adj=adj2_source, relu=True
|
|
)
|
|
|
|
source_2_target_graph1_v5 = self.transpose_graph.forward(
|
|
source_graph1, adj=adj3_transfer, relu=True
|
|
)
|
|
source_2_target_graph2_v5 = self.transpose_graph.forward(
|
|
source_graph2, adj=adj3_transfer, relu=True
|
|
)
|
|
source_2_target_graph3_v5 = self.transpose_graph.forward(
|
|
source_graph3, adj=adj3_transfer, relu=True
|
|
)
|
|
|
|
# target graph
|
|
# print('x size',x.size(),adj1.size())
|
|
graph = self.target_featuremap_2_graph(x)
|
|
|
|
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
|
|
# graph combine 1
|
|
# print(graph.size())
|
|
# print(source_2_target_graph1.size())
|
|
# print(source_2_target_graph1_v5.size())
|
|
graph = torch.cat(
|
|
(
|
|
graph,
|
|
source_2_target_graph1.squeeze(0),
|
|
source_2_target_graph1_v5.squeeze(0),
|
|
),
|
|
dim=-1,
|
|
)
|
|
graph = self.fc_graph.forward(graph, relu=True)
|
|
|
|
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
|
|
|
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
|
|
# graph combine 2
|
|
graph = torch.cat(
|
|
(graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1
|
|
)
|
|
graph = self.fc_graph.forward(graph, relu=True)
|
|
|
|
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
|
|
|
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
|
|
# graph combine 3
|
|
graph = torch.cat(
|
|
(graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1
|
|
)
|
|
graph = self.fc_graph.forward(graph, relu=True)
|
|
|
|
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
|
|
|
# print(graph.size(),x.size())
|
|
|
|
graph = self.target_graph_2_fea.forward(graph, x)
|
|
x = self.target_skip_conv(x)
|
|
x = x + graph
|
|
|
|
###
|
|
x = self.semantic(x)
|
|
x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
return x
|
|
|
|
def similarity_trans(self, source, target):
|
|
sim = torch.matmul(
|
|
F.normalize(target, p=2, dim=-1),
|
|
F.normalize(source, p=2, dim=-1).transpose(-1, -2),
|
|
)
|
|
sim = F.softmax(sim, dim=-1)
|
|
return torch.matmul(sim, source)
|
|
|
|
def load_source_model(self, state_dict):
|
|
own_state = self.state_dict()
|
|
# for name inshop_cos own_state:
|
|
# print name
|
|
new_state_dict = OrderedDict()
|
|
for name, param in state_dict.items():
|
|
name = name.replace("module.", "")
|
|
|
|
if (
|
|
"graph" in name
|
|
and "source" not in name
|
|
and "target" not in name
|
|
and "fc_" not in name
|
|
and "transpose_graph" not in name
|
|
):
|
|
if "featuremap_2_graph" in name:
|
|
name = name.replace(
|
|
"featuremap_2_graph", "source_featuremap_2_graph"
|
|
)
|
|
else:
|
|
name = name.replace("graph", "source_graph")
|
|
new_state_dict[name] = 0
|
|
if name not in own_state:
|
|
if "num_batch" in name:
|
|
continue
|
|
print('unexpected key "{}" in state_dict'.format(name))
|
|
continue
|
|
# if isinstance(param, own_state):
|
|
if isinstance(param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
param = param.data
|
|
try:
|
|
own_state[name].copy_(param)
|
|
except:
|
|
print(
|
|
"While copying the parameter named {}, whose dimensions in the model are"
|
|
" {} and whose dimensions in the checkpoint are {}, ...".format(
|
|
name, own_state[name].size(), param.size()
|
|
)
|
|
)
|
|
continue # i add inshop_cos 2018/02/01
|
|
own_state[name].copy_(param)
|
|
# print 'copying %s' %name
|
|
|
|
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
|
if len(missing) > 0:
|
|
print('missing keys in state_dict: "{}"'.format(missing))
|
|
|
|
|
|
class deeplab_xception_transfer_projection_savemem(
|
|
deeplab_xception_transfer_basemodel_savememory
|
|
):
|
|
def __init__(
|
|
self,
|
|
nInputChannels=3,
|
|
n_classes=7,
|
|
os=16,
|
|
input_channels=256,
|
|
hidden_layers=128,
|
|
out_channels=256,
|
|
transfer_graph=None,
|
|
source_classes=20,
|
|
):
|
|
super(deeplab_xception_transfer_projection_savemem, self).__init__(
|
|
nInputChannels=nInputChannels,
|
|
n_classes=n_classes,
|
|
os=os,
|
|
input_channels=input_channels,
|
|
hidden_layers=hidden_layers,
|
|
out_channels=out_channels,
|
|
)
|
|
self.source_featuremap_2_graph = gcn.Featuremaps_to_Graph(
|
|
input_channels=input_channels,
|
|
hidden_layers=hidden_layers,
|
|
nodes=source_classes,
|
|
)
|
|
self.source_graph_conv1 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.source_graph_conv2 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.source_graph_conv3 = gcn.GraphConvolution(hidden_layers, hidden_layers)
|
|
self.transpose_graph = gcn.Graph_trans(
|
|
in_features=hidden_layers,
|
|
out_features=hidden_layers,
|
|
adj=transfer_graph,
|
|
begin_nodes=source_classes,
|
|
end_nodes=n_classes,
|
|
)
|
|
self.fc_graph = gcn.GraphConvolution(hidden_layers * 3, hidden_layers)
|
|
|
|
def forward(self, input, adj1_target=None, adj2_source=None, adj3_transfer=None):
|
|
x, low_level_features = self.xception_features(input)
|
|
# print(x.size())
|
|
x1 = self.aspp1(x)
|
|
x2 = self.aspp2(x)
|
|
x3 = self.aspp3(x)
|
|
x4 = self.aspp4(x)
|
|
x5 = self.global_avg_pool(x)
|
|
x5 = F.upsample(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
|
|
|
|
x = self.concat_projection_conv1(x)
|
|
x = self.concat_projection_bn1(x)
|
|
x = self.relu(x)
|
|
# print(x.size())
|
|
x = F.upsample(
|
|
x, size=low_level_features.size()[2:], mode="bilinear", align_corners=True
|
|
)
|
|
|
|
low_level_features = self.feature_projection_conv1(low_level_features)
|
|
low_level_features = self.feature_projection_bn1(low_level_features)
|
|
low_level_features = self.relu(low_level_features)
|
|
# print(low_level_features.size())
|
|
# print(x.size())
|
|
x = torch.cat((x, low_level_features), dim=1)
|
|
x = self.decoder(x)
|
|
|
|
### add graph
|
|
# source graph
|
|
source_graph = self.source_featuremap_2_graph(x)
|
|
source_graph1 = self.source_graph_conv1.forward(
|
|
source_graph, adj=adj2_source, relu=True
|
|
)
|
|
source_graph2 = self.source_graph_conv2.forward(
|
|
source_graph1, adj=adj2_source, relu=True
|
|
)
|
|
source_graph3 = self.source_graph_conv2.forward(
|
|
source_graph2, adj=adj2_source, relu=True
|
|
)
|
|
|
|
source_2_target_graph1_v5 = self.transpose_graph.forward(
|
|
source_graph1, adj=adj3_transfer, relu=True
|
|
)
|
|
source_2_target_graph2_v5 = self.transpose_graph.forward(
|
|
source_graph2, adj=adj3_transfer, relu=True
|
|
)
|
|
source_2_target_graph3_v5 = self.transpose_graph.forward(
|
|
source_graph3, adj=adj3_transfer, relu=True
|
|
)
|
|
|
|
# target graph
|
|
# print('x size',x.size(),adj1.size())
|
|
graph = self.target_featuremap_2_graph(x)
|
|
|
|
source_2_target_graph1 = self.similarity_trans(source_graph1, graph)
|
|
# graph combine 1
|
|
graph = torch.cat(
|
|
(
|
|
graph,
|
|
source_2_target_graph1.squeeze(0),
|
|
source_2_target_graph1_v5.squeeze(0),
|
|
),
|
|
dim=-1,
|
|
)
|
|
graph = self.fc_graph.forward(graph, relu=True)
|
|
|
|
graph = self.target_graph_conv1.forward(graph, adj=adj1_target, relu=True)
|
|
|
|
source_2_target_graph2 = self.similarity_trans(source_graph2, graph)
|
|
# graph combine 2
|
|
graph = torch.cat(
|
|
(graph, source_2_target_graph2, source_2_target_graph2_v5), dim=-1
|
|
)
|
|
graph = self.fc_graph.forward(graph, relu=True)
|
|
|
|
graph = self.target_graph_conv2.forward(graph, adj=adj1_target, relu=True)
|
|
|
|
source_2_target_graph3 = self.similarity_trans(source_graph3, graph)
|
|
# graph combine 3
|
|
graph = torch.cat(
|
|
(graph, source_2_target_graph3, source_2_target_graph3_v5), dim=-1
|
|
)
|
|
graph = self.fc_graph.forward(graph, relu=True)
|
|
|
|
graph = self.target_graph_conv3.forward(graph, adj=adj1_target, relu=True)
|
|
|
|
# print(graph.size(),x.size())
|
|
|
|
graph = self.target_graph_2_fea.forward(graph, x)
|
|
x = self.target_skip_conv(x)
|
|
x = x + graph
|
|
|
|
###
|
|
x = self.semantic(x)
|
|
x = F.upsample(x, size=input.size()[2:], mode="bilinear", align_corners=True)
|
|
|
|
return x
|
|
|
|
def similarity_trans(self, source, target):
|
|
sim = torch.matmul(
|
|
F.normalize(target, p=2, dim=-1),
|
|
F.normalize(source, p=2, dim=-1).transpose(-1, -2),
|
|
)
|
|
sim = F.softmax(sim, dim=-1)
|
|
return torch.matmul(sim, source)
|
|
|
|
def load_source_model(self, state_dict):
|
|
own_state = self.state_dict()
|
|
# for name inshop_cos own_state:
|
|
# print name
|
|
new_state_dict = OrderedDict()
|
|
for name, param in state_dict.items():
|
|
name = name.replace("module.", "")
|
|
|
|
if (
|
|
"graph" in name
|
|
and "source" not in name
|
|
and "target" not in name
|
|
and "fc_" not in name
|
|
and "transpose_graph" not in name
|
|
):
|
|
if "featuremap_2_graph" in name:
|
|
name = name.replace(
|
|
"featuremap_2_graph", "source_featuremap_2_graph"
|
|
)
|
|
else:
|
|
name = name.replace("graph", "source_graph")
|
|
new_state_dict[name] = 0
|
|
if name not in own_state:
|
|
if "num_batch" in name:
|
|
continue
|
|
print('unexpected key "{}" in state_dict'.format(name))
|
|
continue
|
|
# if isinstance(param, own_state):
|
|
if isinstance(param, Parameter):
|
|
# backwards compatibility for serialized parameters
|
|
param = param.data
|
|
try:
|
|
own_state[name].copy_(param)
|
|
except:
|
|
print(
|
|
"While copying the parameter named {}, whose dimensions in the model are"
|
|
" {} and whose dimensions in the checkpoint are {}, ...".format(
|
|
name, own_state[name].size(), param.size()
|
|
)
|
|
)
|
|
continue # i add inshop_cos 2018/02/01
|
|
own_state[name].copy_(param)
|
|
# print 'copying %s' %name
|
|
|
|
missing = set(own_state.keys()) - set(new_state_dict.keys())
|
|
if len(missing) > 0:
|
|
print('missing keys in state_dict: "{}"'.format(missing))
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# net = deeplab_xception_transfer_projection_v3v5_more_savemem()
|
|
# img = torch.rand((2,3,128,128))
|
|
# net.eval()
|
|
# a = torch.rand((1,1,7,7))
|
|
# net.forward(img, adj1_target=a)
|