Files
insightface/reconstruction/ostec/external/graphonomy/FaceHairMask/graph.py
2022-05-29 14:26:46 +01:00

840 lines
19 KiB
Python

import torch
import numpy as np
import networkx as nx
import scipy.sparse as sp
pascal_graph = {
0: [0],
1: [1, 2],
2: [1, 2, 3, 5],
3: [2, 3, 4],
4: [3, 4],
5: [2, 5, 6],
6: [5, 6],
}
cihp_graph = {
0: [],
1: [2, 13],
2: [1, 13],
3: [14, 15],
4: [13],
5: [6, 7, 9, 10, 11, 12, 14, 15],
6: [5, 7, 10, 11, 14, 15, 16, 17],
7: [5, 6, 9, 10, 11, 12, 14, 15],
8: [16, 17, 18, 19],
9: [5, 7, 10, 16, 17, 18, 19],
10: [5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17],
11: [5, 6, 7, 10, 13],
12: [5, 7, 10, 16, 17],
13: [1, 2, 4, 10, 11],
14: [3, 5, 6, 7, 10],
15: [3, 5, 6, 7, 10],
16: [6, 8, 9, 10, 12, 18],
17: [6, 8, 9, 10, 12, 19],
18: [8, 9, 16],
19: [8, 9, 17],
}
atr_graph = {
0: [],
1: [2, 11],
2: [1, 11],
3: [11],
4: [5, 6, 7, 11, 14, 15, 17],
5: [4, 6, 7, 8, 12, 13],
6: [4, 5, 7, 8, 9, 10, 12, 13],
7: [4, 11, 12, 13, 14, 15],
8: [5, 6],
9: [6, 12],
10: [6, 13],
11: [1, 2, 3, 4, 7, 14, 15, 17],
12: [5, 6, 7, 9],
13: [5, 6, 7, 10],
14: [4, 7, 11, 16],
15: [4, 7, 11, 16],
16: [14, 15],
17: [4, 11],
}
cihp2pascal_adj = np.array(
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
]
)
cihp2pascal_nlp_adj = np.array(
[
[
1.0,
0.35333052,
0.32727194,
0.17418084,
0.18757584,
0.40608522,
0.37503981,
0.35448462,
0.22598555,
0.23893579,
0.33064262,
0.28923404,
0.27986573,
0.4211553,
0.36915778,
0.41377746,
0.32485771,
0.37248222,
0.36865639,
0.41500332,
],
[
0.39615879,
0.46201529,
0.52321467,
0.30826114,
0.25669527,
0.54747773,
0.3670523,
0.3901983,
0.27519473,
0.3433325,
0.52728509,
0.32771333,
0.34819325,
0.63882953,
0.68042925,
0.69368576,
0.63395791,
0.65344337,
0.59538781,
0.6071375,
],
[
0.16373166,
0.21663339,
0.3053872,
0.28377612,
0.1372435,
0.4448808,
0.29479995,
0.31092595,
0.22703953,
0.33983576,
0.75778818,
0.2619818,
0.37069392,
0.35184867,
0.49877512,
0.49979437,
0.51853277,
0.52517541,
0.32517741,
0.32377309,
],
[
0.32687232,
0.38482461,
0.37693463,
0.41610834,
0.20415749,
0.76749079,
0.35139853,
0.3787411,
0.28411737,
0.35155421,
0.58792618,
0.31141718,
0.40585111,
0.51189218,
0.82042737,
0.8342413,
0.70732188,
0.72752501,
0.60327325,
0.61431337,
],
[
0.34069369,
0.34817292,
0.37525998,
0.36497069,
0.17841617,
0.69746208,
0.31731463,
0.34628951,
0.25167277,
0.32072379,
0.56711286,
0.24894776,
0.37000453,
0.52600859,
0.82483993,
0.84966274,
0.7033991,
0.73449378,
0.56649608,
0.58888791,
],
[
0.28477487,
0.35139564,
0.42742352,
0.41664321,
0.20004676,
0.78566833,
0.42237487,
0.41048549,
0.37933812,
0.46542516,
0.62444759,
0.3274493,
0.49466009,
0.49314658,
0.71244233,
0.71497003,
0.8234787,
0.83566589,
0.62597135,
0.62626812,
],
[
0.3011378,
0.31775977,
0.42922647,
0.36896257,
0.17597556,
0.72214655,
0.39162804,
0.38137872,
0.34980296,
0.43818419,
0.60879174,
0.26762545,
0.46271161,
0.51150476,
0.72318109,
0.73678399,
0.82620388,
0.84942166,
0.5943811,
0.60607602,
],
]
)
pascal2atr_nlp_adj = np.array(
[
[
1.0,
0.35333052,
0.32727194,
0.18757584,
0.40608522,
0.27986573,
0.23893579,
0.27600672,
0.30964391,
0.36865639,
0.41500332,
0.4211553,
0.32485771,
0.37248222,
0.36915778,
0.41377746,
0.32006291,
0.28923404,
],
[
0.39615879,
0.46201529,
0.52321467,
0.25669527,
0.54747773,
0.34819325,
0.3433325,
0.26603942,
0.45162929,
0.59538781,
0.6071375,
0.63882953,
0.63395791,
0.65344337,
0.68042925,
0.69368576,
0.44354613,
0.32771333,
],
[
0.16373166,
0.21663339,
0.3053872,
0.1372435,
0.4448808,
0.37069392,
0.33983576,
0.26563416,
0.35443504,
0.32517741,
0.32377309,
0.35184867,
0.51853277,
0.52517541,
0.49877512,
0.49979437,
0.21750868,
0.2619818,
],
[
0.32687232,
0.38482461,
0.37693463,
0.20415749,
0.76749079,
0.40585111,
0.35155421,
0.28271333,
0.52684576,
0.60327325,
0.61431337,
0.51189218,
0.70732188,
0.72752501,
0.82042737,
0.8342413,
0.40137029,
0.31141718,
],
[
0.34069369,
0.34817292,
0.37525998,
0.17841617,
0.69746208,
0.37000453,
0.32072379,
0.27268885,
0.47426719,
0.56649608,
0.58888791,
0.52600859,
0.7033991,
0.73449378,
0.82483993,
0.84966274,
0.37830796,
0.24894776,
],
[
0.28477487,
0.35139564,
0.42742352,
0.20004676,
0.78566833,
0.49466009,
0.46542516,
0.32662614,
0.55780359,
0.62597135,
0.62626812,
0.49314658,
0.8234787,
0.83566589,
0.71244233,
0.71497003,
0.41223219,
0.3274493,
],
[
0.3011378,
0.31775977,
0.42922647,
0.17597556,
0.72214655,
0.46271161,
0.43818419,
0.3192333,
0.50979216,
0.5943811,
0.60607602,
0.51150476,
0.82620388,
0.84942166,
0.72318109,
0.73678399,
0.39259827,
0.26762545,
],
]
)
cihp2atr_nlp_adj = np.array(
[
[
1.0,
0.35333052,
0.32727194,
0.18757584,
0.40608522,
0.27986573,
0.23893579,
0.27600672,
0.30964391,
0.36865639,
0.41500332,
0.4211553,
0.32485771,
0.37248222,
0.36915778,
0.41377746,
0.32006291,
0.28923404,
],
[
0.35333052,
1.0,
0.39206695,
0.42143438,
0.4736689,
0.47139544,
0.51999208,
0.38354847,
0.45628529,
0.46514124,
0.50083501,
0.4310595,
0.39371443,
0.4319752,
0.42938598,
0.46384034,
0.44833757,
0.6153155,
],
[
0.32727194,
0.39206695,
1.0,
0.32836702,
0.52603065,
0.39543695,
0.3622627,
0.43575346,
0.33866223,
0.45202552,
0.48421,
0.53669903,
0.47266611,
0.50925436,
0.42286557,
0.45403656,
0.37221304,
0.40999322,
],
[
0.17418084,
0.46892601,
0.25774838,
0.31816231,
0.39330317,
0.34218382,
0.48253904,
0.22084125,
0.41335728,
0.52437572,
0.5191713,
0.33576117,
0.44230914,
0.44250678,
0.44330833,
0.43887264,
0.50693611,
0.39278795,
],
[
0.18757584,
0.42143438,
0.32836702,
1.0,
0.35030067,
0.30110947,
0.41055555,
0.34338879,
0.34336307,
0.37704433,
0.38810141,
0.34702081,
0.24171562,
0.25433078,
0.24696241,
0.2570884,
0.4465962,
0.45263213,
],
[
0.40608522,
0.4736689,
0.52603065,
0.35030067,
1.0,
0.54372584,
0.58300258,
0.56674191,
0.555266,
0.66599594,
0.68567555,
0.55716359,
0.62997328,
0.65638548,
0.61219615,
0.63183318,
0.54464151,
0.44293752,
],
[
0.37503981,
0.50675565,
0.4761106,
0.37561813,
0.60419403,
0.77912403,
0.64595517,
0.85939662,
0.46037144,
0.52348817,
0.55875094,
0.37741886,
0.455671,
0.49434392,
0.38479954,
0.41804074,
0.47285709,
0.57236283,
],
[
0.35448462,
0.50576632,
0.51030446,
0.35841033,
0.55106903,
0.50257274,
0.52591451,
0.4283053,
0.39991808,
0.42327211,
0.42853819,
0.42071825,
0.41240559,
0.42259136,
0.38125352,
0.3868255,
0.47604934,
0.51811717,
],
[
0.22598555,
0.5053299,
0.36301185,
0.38002282,
0.49700941,
0.45625243,
0.62876479,
0.4112051,
0.33944371,
0.48322639,
0.50318714,
0.29207815,
0.38801966,
0.41119094,
0.29199072,
0.31021029,
0.41594871,
0.54961962,
],
[
0.23893579,
0.51999208,
0.3622627,
0.41055555,
0.58300258,
0.68874251,
1.0,
0.56977937,
0.49918447,
0.48484363,
0.51615925,
0.41222306,
0.49535971,
0.53134951,
0.3807616,
0.41050298,
0.48675801,
0.51112664,
],
[
0.33064262,
0.306412,
0.60679935,
0.25592294,
0.58738706,
0.40379627,
0.39679161,
0.33618385,
0.39235148,
0.45474013,
0.4648476,
0.59306762,
0.58976007,
0.60778661,
0.55400397,
0.56551297,
0.3698029,
0.33860535,
],
[
0.28923404,
0.6153155,
0.40999322,
0.45263213,
0.44293752,
0.60359359,
0.51112664,
0.46578181,
0.45656936,
0.38142307,
0.38525582,
0.33327223,
0.35360175,
0.36156453,
0.3384992,
0.34261229,
0.49297863,
1.0,
],
[
0.27986573,
0.47139544,
0.39543695,
0.30110947,
0.54372584,
1.0,
0.68874251,
0.67765588,
0.48690078,
0.44010641,
0.44921156,
0.32321099,
0.48311542,
0.4982002,
0.39378102,
0.40297733,
0.45309735,
0.60359359,
],
[
0.4211553,
0.4310595,
0.53669903,
0.34702081,
0.55716359,
0.32321099,
0.41222306,
0.25721705,
0.36633509,
0.5397475,
0.56429928,
1.0,
0.55796926,
0.58842844,
0.57930828,
0.60410597,
0.41615326,
0.33327223,
],
[
0.36915778,
0.42938598,
0.42286557,
0.24696241,
0.61219615,
0.39378102,
0.3807616,
0.28089866,
0.48450394,
0.77400821,
0.68813814,
0.57930828,
0.8856886,
0.81673412,
1.0,
0.92279623,
0.46969152,
0.3384992,
],
[
0.41377746,
0.46384034,
0.45403656,
0.2570884,
0.63183318,
0.40297733,
0.41050298,
0.332879,
0.48799542,
0.69231828,
0.77015091,
0.60410597,
0.79788484,
0.88232104,
0.92279623,
1.0,
0.45685017,
0.34261229,
],
[
0.32485771,
0.39371443,
0.47266611,
0.24171562,
0.62997328,
0.48311542,
0.49535971,
0.32477932,
0.51486622,
0.79353556,
0.69768738,
0.55796926,
1.0,
0.92373745,
0.8856886,
0.79788484,
0.47883134,
0.35360175,
],
[
0.37248222,
0.4319752,
0.50925436,
0.25433078,
0.65638548,
0.4982002,
0.53134951,
0.38057074,
0.52403969,
0.72035243,
0.78711147,
0.58842844,
0.92373745,
1.0,
0.81673412,
0.88232104,
0.47109935,
0.36156453,
],
[
0.36865639,
0.46514124,
0.45202552,
0.37704433,
0.66599594,
0.44010641,
0.48484363,
0.39636574,
0.50175258,
1.0,
0.91320249,
0.5397475,
0.79353556,
0.72035243,
0.77400821,
0.69231828,
0.59087008,
0.38142307,
],
[
0.41500332,
0.50083501,
0.48421,
0.38810141,
0.68567555,
0.44921156,
0.51615925,
0.45156472,
0.50438158,
0.91320249,
1.0,
0.56429928,
0.69768738,
0.78711147,
0.68813814,
0.77015091,
0.57698754,
0.38525582,
],
]
)
def normalize_adj(adj):
"""Symmetrically normalize adjacency matrix."""
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1))
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
def preprocess_adj(adj):
"""Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
adj = nx.adjacency_matrix(
nx.from_dict_of_lists(adj)
) # return a adjacency matrix of adj ( type is numpy)
adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) #
# return sparse_to_tuple(adj_normalized)
return adj_normalized.todense()
def row_norm(inputs):
outputs = []
for x in inputs:
xsum = x.sum()
x = x / xsum
outputs.append(x)
return outputs
def normalize_adj_torch(adj):
# print(adj.size())
if len(adj.size()) == 4:
new_r = torch.zeros(adj.size()).type_as(adj)
for i in range(adj.size(1)):
adj_item = adj[0, i]
rowsum = adj_item.sum(1)
d_inv_sqrt = rowsum.pow_(-0.5)
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt)
new_r[0, i, ...] = r
return new_r
rowsum = adj.sum(1)
d_inv_sqrt = rowsum.pow_(-0.5)
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
return r
# def row_norm(adj):
if __name__ == "__main__":
a = row_norm(cihp2pascal_adj)
print(a)
print(cihp2pascal_adj)
# print(a.shape)