mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 13:46:15 +00:00
840 lines
19 KiB
Python
840 lines
19 KiB
Python
import torch
|
|
import numpy as np
|
|
import networkx as nx
|
|
import scipy.sparse as sp
|
|
|
|
pascal_graph = {
|
|
0: [0],
|
|
1: [1, 2],
|
|
2: [1, 2, 3, 5],
|
|
3: [2, 3, 4],
|
|
4: [3, 4],
|
|
5: [2, 5, 6],
|
|
6: [5, 6],
|
|
}
|
|
|
|
cihp_graph = {
|
|
0: [],
|
|
1: [2, 13],
|
|
2: [1, 13],
|
|
3: [14, 15],
|
|
4: [13],
|
|
5: [6, 7, 9, 10, 11, 12, 14, 15],
|
|
6: [5, 7, 10, 11, 14, 15, 16, 17],
|
|
7: [5, 6, 9, 10, 11, 12, 14, 15],
|
|
8: [16, 17, 18, 19],
|
|
9: [5, 7, 10, 16, 17, 18, 19],
|
|
10: [5, 6, 7, 9, 11, 12, 13, 14, 15, 16, 17],
|
|
11: [5, 6, 7, 10, 13],
|
|
12: [5, 7, 10, 16, 17],
|
|
13: [1, 2, 4, 10, 11],
|
|
14: [3, 5, 6, 7, 10],
|
|
15: [3, 5, 6, 7, 10],
|
|
16: [6, 8, 9, 10, 12, 18],
|
|
17: [6, 8, 9, 10, 12, 19],
|
|
18: [8, 9, 16],
|
|
19: [8, 9, 17],
|
|
}
|
|
|
|
atr_graph = {
|
|
0: [],
|
|
1: [2, 11],
|
|
2: [1, 11],
|
|
3: [11],
|
|
4: [5, 6, 7, 11, 14, 15, 17],
|
|
5: [4, 6, 7, 8, 12, 13],
|
|
6: [4, 5, 7, 8, 9, 10, 12, 13],
|
|
7: [4, 11, 12, 13, 14, 15],
|
|
8: [5, 6],
|
|
9: [6, 12],
|
|
10: [6, 13],
|
|
11: [1, 2, 3, 4, 7, 14, 15, 17],
|
|
12: [5, 6, 7, 9],
|
|
13: [5, 6, 7, 10],
|
|
14: [4, 7, 11, 16],
|
|
15: [4, 7, 11, 16],
|
|
16: [14, 15],
|
|
17: [4, 11],
|
|
}
|
|
|
|
cihp2pascal_adj = np.array(
|
|
[
|
|
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
|
|
]
|
|
)
|
|
|
|
cihp2pascal_nlp_adj = np.array(
|
|
[
|
|
[
|
|
1.0,
|
|
0.35333052,
|
|
0.32727194,
|
|
0.17418084,
|
|
0.18757584,
|
|
0.40608522,
|
|
0.37503981,
|
|
0.35448462,
|
|
0.22598555,
|
|
0.23893579,
|
|
0.33064262,
|
|
0.28923404,
|
|
0.27986573,
|
|
0.4211553,
|
|
0.36915778,
|
|
0.41377746,
|
|
0.32485771,
|
|
0.37248222,
|
|
0.36865639,
|
|
0.41500332,
|
|
],
|
|
[
|
|
0.39615879,
|
|
0.46201529,
|
|
0.52321467,
|
|
0.30826114,
|
|
0.25669527,
|
|
0.54747773,
|
|
0.3670523,
|
|
0.3901983,
|
|
0.27519473,
|
|
0.3433325,
|
|
0.52728509,
|
|
0.32771333,
|
|
0.34819325,
|
|
0.63882953,
|
|
0.68042925,
|
|
0.69368576,
|
|
0.63395791,
|
|
0.65344337,
|
|
0.59538781,
|
|
0.6071375,
|
|
],
|
|
[
|
|
0.16373166,
|
|
0.21663339,
|
|
0.3053872,
|
|
0.28377612,
|
|
0.1372435,
|
|
0.4448808,
|
|
0.29479995,
|
|
0.31092595,
|
|
0.22703953,
|
|
0.33983576,
|
|
0.75778818,
|
|
0.2619818,
|
|
0.37069392,
|
|
0.35184867,
|
|
0.49877512,
|
|
0.49979437,
|
|
0.51853277,
|
|
0.52517541,
|
|
0.32517741,
|
|
0.32377309,
|
|
],
|
|
[
|
|
0.32687232,
|
|
0.38482461,
|
|
0.37693463,
|
|
0.41610834,
|
|
0.20415749,
|
|
0.76749079,
|
|
0.35139853,
|
|
0.3787411,
|
|
0.28411737,
|
|
0.35155421,
|
|
0.58792618,
|
|
0.31141718,
|
|
0.40585111,
|
|
0.51189218,
|
|
0.82042737,
|
|
0.8342413,
|
|
0.70732188,
|
|
0.72752501,
|
|
0.60327325,
|
|
0.61431337,
|
|
],
|
|
[
|
|
0.34069369,
|
|
0.34817292,
|
|
0.37525998,
|
|
0.36497069,
|
|
0.17841617,
|
|
0.69746208,
|
|
0.31731463,
|
|
0.34628951,
|
|
0.25167277,
|
|
0.32072379,
|
|
0.56711286,
|
|
0.24894776,
|
|
0.37000453,
|
|
0.52600859,
|
|
0.82483993,
|
|
0.84966274,
|
|
0.7033991,
|
|
0.73449378,
|
|
0.56649608,
|
|
0.58888791,
|
|
],
|
|
[
|
|
0.28477487,
|
|
0.35139564,
|
|
0.42742352,
|
|
0.41664321,
|
|
0.20004676,
|
|
0.78566833,
|
|
0.42237487,
|
|
0.41048549,
|
|
0.37933812,
|
|
0.46542516,
|
|
0.62444759,
|
|
0.3274493,
|
|
0.49466009,
|
|
0.49314658,
|
|
0.71244233,
|
|
0.71497003,
|
|
0.8234787,
|
|
0.83566589,
|
|
0.62597135,
|
|
0.62626812,
|
|
],
|
|
[
|
|
0.3011378,
|
|
0.31775977,
|
|
0.42922647,
|
|
0.36896257,
|
|
0.17597556,
|
|
0.72214655,
|
|
0.39162804,
|
|
0.38137872,
|
|
0.34980296,
|
|
0.43818419,
|
|
0.60879174,
|
|
0.26762545,
|
|
0.46271161,
|
|
0.51150476,
|
|
0.72318109,
|
|
0.73678399,
|
|
0.82620388,
|
|
0.84942166,
|
|
0.5943811,
|
|
0.60607602,
|
|
],
|
|
]
|
|
)
|
|
|
|
pascal2atr_nlp_adj = np.array(
|
|
[
|
|
[
|
|
1.0,
|
|
0.35333052,
|
|
0.32727194,
|
|
0.18757584,
|
|
0.40608522,
|
|
0.27986573,
|
|
0.23893579,
|
|
0.27600672,
|
|
0.30964391,
|
|
0.36865639,
|
|
0.41500332,
|
|
0.4211553,
|
|
0.32485771,
|
|
0.37248222,
|
|
0.36915778,
|
|
0.41377746,
|
|
0.32006291,
|
|
0.28923404,
|
|
],
|
|
[
|
|
0.39615879,
|
|
0.46201529,
|
|
0.52321467,
|
|
0.25669527,
|
|
0.54747773,
|
|
0.34819325,
|
|
0.3433325,
|
|
0.26603942,
|
|
0.45162929,
|
|
0.59538781,
|
|
0.6071375,
|
|
0.63882953,
|
|
0.63395791,
|
|
0.65344337,
|
|
0.68042925,
|
|
0.69368576,
|
|
0.44354613,
|
|
0.32771333,
|
|
],
|
|
[
|
|
0.16373166,
|
|
0.21663339,
|
|
0.3053872,
|
|
0.1372435,
|
|
0.4448808,
|
|
0.37069392,
|
|
0.33983576,
|
|
0.26563416,
|
|
0.35443504,
|
|
0.32517741,
|
|
0.32377309,
|
|
0.35184867,
|
|
0.51853277,
|
|
0.52517541,
|
|
0.49877512,
|
|
0.49979437,
|
|
0.21750868,
|
|
0.2619818,
|
|
],
|
|
[
|
|
0.32687232,
|
|
0.38482461,
|
|
0.37693463,
|
|
0.20415749,
|
|
0.76749079,
|
|
0.40585111,
|
|
0.35155421,
|
|
0.28271333,
|
|
0.52684576,
|
|
0.60327325,
|
|
0.61431337,
|
|
0.51189218,
|
|
0.70732188,
|
|
0.72752501,
|
|
0.82042737,
|
|
0.8342413,
|
|
0.40137029,
|
|
0.31141718,
|
|
],
|
|
[
|
|
0.34069369,
|
|
0.34817292,
|
|
0.37525998,
|
|
0.17841617,
|
|
0.69746208,
|
|
0.37000453,
|
|
0.32072379,
|
|
0.27268885,
|
|
0.47426719,
|
|
0.56649608,
|
|
0.58888791,
|
|
0.52600859,
|
|
0.7033991,
|
|
0.73449378,
|
|
0.82483993,
|
|
0.84966274,
|
|
0.37830796,
|
|
0.24894776,
|
|
],
|
|
[
|
|
0.28477487,
|
|
0.35139564,
|
|
0.42742352,
|
|
0.20004676,
|
|
0.78566833,
|
|
0.49466009,
|
|
0.46542516,
|
|
0.32662614,
|
|
0.55780359,
|
|
0.62597135,
|
|
0.62626812,
|
|
0.49314658,
|
|
0.8234787,
|
|
0.83566589,
|
|
0.71244233,
|
|
0.71497003,
|
|
0.41223219,
|
|
0.3274493,
|
|
],
|
|
[
|
|
0.3011378,
|
|
0.31775977,
|
|
0.42922647,
|
|
0.17597556,
|
|
0.72214655,
|
|
0.46271161,
|
|
0.43818419,
|
|
0.3192333,
|
|
0.50979216,
|
|
0.5943811,
|
|
0.60607602,
|
|
0.51150476,
|
|
0.82620388,
|
|
0.84942166,
|
|
0.72318109,
|
|
0.73678399,
|
|
0.39259827,
|
|
0.26762545,
|
|
],
|
|
]
|
|
)
|
|
|
|
cihp2atr_nlp_adj = np.array(
|
|
[
|
|
[
|
|
1.0,
|
|
0.35333052,
|
|
0.32727194,
|
|
0.18757584,
|
|
0.40608522,
|
|
0.27986573,
|
|
0.23893579,
|
|
0.27600672,
|
|
0.30964391,
|
|
0.36865639,
|
|
0.41500332,
|
|
0.4211553,
|
|
0.32485771,
|
|
0.37248222,
|
|
0.36915778,
|
|
0.41377746,
|
|
0.32006291,
|
|
0.28923404,
|
|
],
|
|
[
|
|
0.35333052,
|
|
1.0,
|
|
0.39206695,
|
|
0.42143438,
|
|
0.4736689,
|
|
0.47139544,
|
|
0.51999208,
|
|
0.38354847,
|
|
0.45628529,
|
|
0.46514124,
|
|
0.50083501,
|
|
0.4310595,
|
|
0.39371443,
|
|
0.4319752,
|
|
0.42938598,
|
|
0.46384034,
|
|
0.44833757,
|
|
0.6153155,
|
|
],
|
|
[
|
|
0.32727194,
|
|
0.39206695,
|
|
1.0,
|
|
0.32836702,
|
|
0.52603065,
|
|
0.39543695,
|
|
0.3622627,
|
|
0.43575346,
|
|
0.33866223,
|
|
0.45202552,
|
|
0.48421,
|
|
0.53669903,
|
|
0.47266611,
|
|
0.50925436,
|
|
0.42286557,
|
|
0.45403656,
|
|
0.37221304,
|
|
0.40999322,
|
|
],
|
|
[
|
|
0.17418084,
|
|
0.46892601,
|
|
0.25774838,
|
|
0.31816231,
|
|
0.39330317,
|
|
0.34218382,
|
|
0.48253904,
|
|
0.22084125,
|
|
0.41335728,
|
|
0.52437572,
|
|
0.5191713,
|
|
0.33576117,
|
|
0.44230914,
|
|
0.44250678,
|
|
0.44330833,
|
|
0.43887264,
|
|
0.50693611,
|
|
0.39278795,
|
|
],
|
|
[
|
|
0.18757584,
|
|
0.42143438,
|
|
0.32836702,
|
|
1.0,
|
|
0.35030067,
|
|
0.30110947,
|
|
0.41055555,
|
|
0.34338879,
|
|
0.34336307,
|
|
0.37704433,
|
|
0.38810141,
|
|
0.34702081,
|
|
0.24171562,
|
|
0.25433078,
|
|
0.24696241,
|
|
0.2570884,
|
|
0.4465962,
|
|
0.45263213,
|
|
],
|
|
[
|
|
0.40608522,
|
|
0.4736689,
|
|
0.52603065,
|
|
0.35030067,
|
|
1.0,
|
|
0.54372584,
|
|
0.58300258,
|
|
0.56674191,
|
|
0.555266,
|
|
0.66599594,
|
|
0.68567555,
|
|
0.55716359,
|
|
0.62997328,
|
|
0.65638548,
|
|
0.61219615,
|
|
0.63183318,
|
|
0.54464151,
|
|
0.44293752,
|
|
],
|
|
[
|
|
0.37503981,
|
|
0.50675565,
|
|
0.4761106,
|
|
0.37561813,
|
|
0.60419403,
|
|
0.77912403,
|
|
0.64595517,
|
|
0.85939662,
|
|
0.46037144,
|
|
0.52348817,
|
|
0.55875094,
|
|
0.37741886,
|
|
0.455671,
|
|
0.49434392,
|
|
0.38479954,
|
|
0.41804074,
|
|
0.47285709,
|
|
0.57236283,
|
|
],
|
|
[
|
|
0.35448462,
|
|
0.50576632,
|
|
0.51030446,
|
|
0.35841033,
|
|
0.55106903,
|
|
0.50257274,
|
|
0.52591451,
|
|
0.4283053,
|
|
0.39991808,
|
|
0.42327211,
|
|
0.42853819,
|
|
0.42071825,
|
|
0.41240559,
|
|
0.42259136,
|
|
0.38125352,
|
|
0.3868255,
|
|
0.47604934,
|
|
0.51811717,
|
|
],
|
|
[
|
|
0.22598555,
|
|
0.5053299,
|
|
0.36301185,
|
|
0.38002282,
|
|
0.49700941,
|
|
0.45625243,
|
|
0.62876479,
|
|
0.4112051,
|
|
0.33944371,
|
|
0.48322639,
|
|
0.50318714,
|
|
0.29207815,
|
|
0.38801966,
|
|
0.41119094,
|
|
0.29199072,
|
|
0.31021029,
|
|
0.41594871,
|
|
0.54961962,
|
|
],
|
|
[
|
|
0.23893579,
|
|
0.51999208,
|
|
0.3622627,
|
|
0.41055555,
|
|
0.58300258,
|
|
0.68874251,
|
|
1.0,
|
|
0.56977937,
|
|
0.49918447,
|
|
0.48484363,
|
|
0.51615925,
|
|
0.41222306,
|
|
0.49535971,
|
|
0.53134951,
|
|
0.3807616,
|
|
0.41050298,
|
|
0.48675801,
|
|
0.51112664,
|
|
],
|
|
[
|
|
0.33064262,
|
|
0.306412,
|
|
0.60679935,
|
|
0.25592294,
|
|
0.58738706,
|
|
0.40379627,
|
|
0.39679161,
|
|
0.33618385,
|
|
0.39235148,
|
|
0.45474013,
|
|
0.4648476,
|
|
0.59306762,
|
|
0.58976007,
|
|
0.60778661,
|
|
0.55400397,
|
|
0.56551297,
|
|
0.3698029,
|
|
0.33860535,
|
|
],
|
|
[
|
|
0.28923404,
|
|
0.6153155,
|
|
0.40999322,
|
|
0.45263213,
|
|
0.44293752,
|
|
0.60359359,
|
|
0.51112664,
|
|
0.46578181,
|
|
0.45656936,
|
|
0.38142307,
|
|
0.38525582,
|
|
0.33327223,
|
|
0.35360175,
|
|
0.36156453,
|
|
0.3384992,
|
|
0.34261229,
|
|
0.49297863,
|
|
1.0,
|
|
],
|
|
[
|
|
0.27986573,
|
|
0.47139544,
|
|
0.39543695,
|
|
0.30110947,
|
|
0.54372584,
|
|
1.0,
|
|
0.68874251,
|
|
0.67765588,
|
|
0.48690078,
|
|
0.44010641,
|
|
0.44921156,
|
|
0.32321099,
|
|
0.48311542,
|
|
0.4982002,
|
|
0.39378102,
|
|
0.40297733,
|
|
0.45309735,
|
|
0.60359359,
|
|
],
|
|
[
|
|
0.4211553,
|
|
0.4310595,
|
|
0.53669903,
|
|
0.34702081,
|
|
0.55716359,
|
|
0.32321099,
|
|
0.41222306,
|
|
0.25721705,
|
|
0.36633509,
|
|
0.5397475,
|
|
0.56429928,
|
|
1.0,
|
|
0.55796926,
|
|
0.58842844,
|
|
0.57930828,
|
|
0.60410597,
|
|
0.41615326,
|
|
0.33327223,
|
|
],
|
|
[
|
|
0.36915778,
|
|
0.42938598,
|
|
0.42286557,
|
|
0.24696241,
|
|
0.61219615,
|
|
0.39378102,
|
|
0.3807616,
|
|
0.28089866,
|
|
0.48450394,
|
|
0.77400821,
|
|
0.68813814,
|
|
0.57930828,
|
|
0.8856886,
|
|
0.81673412,
|
|
1.0,
|
|
0.92279623,
|
|
0.46969152,
|
|
0.3384992,
|
|
],
|
|
[
|
|
0.41377746,
|
|
0.46384034,
|
|
0.45403656,
|
|
0.2570884,
|
|
0.63183318,
|
|
0.40297733,
|
|
0.41050298,
|
|
0.332879,
|
|
0.48799542,
|
|
0.69231828,
|
|
0.77015091,
|
|
0.60410597,
|
|
0.79788484,
|
|
0.88232104,
|
|
0.92279623,
|
|
1.0,
|
|
0.45685017,
|
|
0.34261229,
|
|
],
|
|
[
|
|
0.32485771,
|
|
0.39371443,
|
|
0.47266611,
|
|
0.24171562,
|
|
0.62997328,
|
|
0.48311542,
|
|
0.49535971,
|
|
0.32477932,
|
|
0.51486622,
|
|
0.79353556,
|
|
0.69768738,
|
|
0.55796926,
|
|
1.0,
|
|
0.92373745,
|
|
0.8856886,
|
|
0.79788484,
|
|
0.47883134,
|
|
0.35360175,
|
|
],
|
|
[
|
|
0.37248222,
|
|
0.4319752,
|
|
0.50925436,
|
|
0.25433078,
|
|
0.65638548,
|
|
0.4982002,
|
|
0.53134951,
|
|
0.38057074,
|
|
0.52403969,
|
|
0.72035243,
|
|
0.78711147,
|
|
0.58842844,
|
|
0.92373745,
|
|
1.0,
|
|
0.81673412,
|
|
0.88232104,
|
|
0.47109935,
|
|
0.36156453,
|
|
],
|
|
[
|
|
0.36865639,
|
|
0.46514124,
|
|
0.45202552,
|
|
0.37704433,
|
|
0.66599594,
|
|
0.44010641,
|
|
0.48484363,
|
|
0.39636574,
|
|
0.50175258,
|
|
1.0,
|
|
0.91320249,
|
|
0.5397475,
|
|
0.79353556,
|
|
0.72035243,
|
|
0.77400821,
|
|
0.69231828,
|
|
0.59087008,
|
|
0.38142307,
|
|
],
|
|
[
|
|
0.41500332,
|
|
0.50083501,
|
|
0.48421,
|
|
0.38810141,
|
|
0.68567555,
|
|
0.44921156,
|
|
0.51615925,
|
|
0.45156472,
|
|
0.50438158,
|
|
0.91320249,
|
|
1.0,
|
|
0.56429928,
|
|
0.69768738,
|
|
0.78711147,
|
|
0.68813814,
|
|
0.77015091,
|
|
0.57698754,
|
|
0.38525582,
|
|
],
|
|
]
|
|
)
|
|
|
|
|
|
def normalize_adj(adj):
|
|
"""Symmetrically normalize adjacency matrix."""
|
|
adj = sp.coo_matrix(adj)
|
|
rowsum = np.array(adj.sum(1))
|
|
d_inv_sqrt = np.power(rowsum, -0.5).flatten()
|
|
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
|
|
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
|
|
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
|
|
|
|
|
|
def preprocess_adj(adj):
|
|
"""Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
|
|
adj = nx.adjacency_matrix(
|
|
nx.from_dict_of_lists(adj)
|
|
) # return a adjacency matrix of adj ( type is numpy)
|
|
adj_normalized = normalize_adj(adj + sp.eye(adj.shape[0])) #
|
|
# return sparse_to_tuple(adj_normalized)
|
|
return adj_normalized.todense()
|
|
|
|
|
|
def row_norm(inputs):
|
|
outputs = []
|
|
for x in inputs:
|
|
xsum = x.sum()
|
|
x = x / xsum
|
|
outputs.append(x)
|
|
return outputs
|
|
|
|
|
|
def normalize_adj_torch(adj):
|
|
# print(adj.size())
|
|
if len(adj.size()) == 4:
|
|
new_r = torch.zeros(adj.size()).type_as(adj)
|
|
for i in range(adj.size(1)):
|
|
adj_item = adj[0, i]
|
|
rowsum = adj_item.sum(1)
|
|
d_inv_sqrt = rowsum.pow_(-0.5)
|
|
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
|
|
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
|
|
r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj_item), d_mat_inv_sqrt)
|
|
new_r[0, i, ...] = r
|
|
return new_r
|
|
rowsum = adj.sum(1)
|
|
d_inv_sqrt = rowsum.pow_(-0.5)
|
|
d_inv_sqrt[torch.isnan(d_inv_sqrt)] = 0
|
|
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
|
|
r = torch.matmul(torch.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
|
|
return r
|
|
|
|
|
|
# def row_norm(adj):
|
|
|
|
|
|
if __name__ == "__main__":
|
|
a = row_norm(cihp2pascal_adj)
|
|
print(a)
|
|
print(cihp2pascal_adj)
|
|
# print(a.shape)
|