mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-16 05:27:56 +00:00
183 lines
6.0 KiB
Python
183 lines
6.0 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
"""
|
|
@Author : Qingping Zheng
|
|
@Contact : qingpingzheng2014@gmail.com
|
|
@File : ddgcn.py
|
|
@Time : 10/01/21 00:00 PM
|
|
@Desc :
|
|
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
|
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
from inplace_abn import InPlaceABNSync
|
|
|
|
|
|
class SpatialGCN(nn.Module):
|
|
def __init__(self, plane, abn=InPlaceABNSync):
|
|
super(SpatialGCN, self).__init__()
|
|
inter_plane = plane // 2
|
|
self.node_k = nn.Conv2d(plane, inter_plane, kernel_size=1)
|
|
self.node_v = nn.Conv2d(plane, inter_plane, kernel_size=1)
|
|
self.node_q = nn.Conv2d(plane, inter_plane, kernel_size=1)
|
|
|
|
self.conv_wg = nn.Conv1d(inter_plane, inter_plane, kernel_size=1, bias=False)
|
|
self.bn_wg = nn.BatchNorm1d(inter_plane)
|
|
self.softmax = nn.Softmax(dim=2)
|
|
|
|
self.out = nn.Sequential(nn.Conv2d(inter_plane, plane, kernel_size=1),
|
|
abn(plane))
|
|
|
|
self.gamma = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, x):
|
|
# b, c, h, w = x.size()
|
|
node_k = self.node_k(x)
|
|
node_v = self.node_v(x)
|
|
node_q = self.node_q(x)
|
|
b,c,h,w = node_k.size()
|
|
node_k = node_k.view(b, c, -1).permute(0, 2, 1)
|
|
node_q = node_q.view(b, c, -1)
|
|
node_v = node_v.view(b, c, -1).permute(0, 2, 1)
|
|
# A = k * q
|
|
# AV = k * q * v
|
|
# AVW = k *(q *v) * w
|
|
AV = torch.bmm(node_q,node_v)
|
|
AV = self.softmax(AV)
|
|
AV = torch.bmm(node_k, AV)
|
|
AV = AV.transpose(1, 2).contiguous()
|
|
AVW = self.conv_wg(AV)
|
|
AVW = self.bn_wg(AVW)
|
|
AVW = AVW.view(b, c, h, -1)
|
|
# out = F.relu_(self.out(AVW) + x)
|
|
out = self.gamma * self.out(AVW) + x
|
|
return out
|
|
|
|
|
|
class DDualGCN(nn.Module):
|
|
"""
|
|
Feature GCN with coordinate GCN
|
|
"""
|
|
def __init__(self, planes, abn=InPlaceABNSync, ratio=4):
|
|
super(DDualGCN, self).__init__()
|
|
|
|
self.phi = nn.Conv2d(planes, planes // ratio * 2, kernel_size=1, bias=False)
|
|
self.bn_phi = abn(planes // ratio * 2)
|
|
self.theta = nn.Conv2d(planes, planes // ratio, kernel_size=1, bias=False)
|
|
self.bn_theta = abn(planes // ratio)
|
|
|
|
# Interaction Space
|
|
# Adjacency Matrix: (-)A_g
|
|
self.conv_adj = nn.Conv1d(planes // ratio, planes // ratio, kernel_size=1, bias=False)
|
|
self.bn_adj = nn.BatchNorm1d(planes // ratio)
|
|
|
|
# State Update Function: W_g
|
|
self.conv_wg = nn.Conv1d(planes // ratio * 2, planes // ratio * 2, kernel_size=1, bias=False)
|
|
self.bn_wg = nn.BatchNorm1d(planes // ratio * 2)
|
|
|
|
# last fc
|
|
self.conv3 = nn.Conv2d(planes // ratio * 2, planes, kernel_size=1, bias=False)
|
|
self.bn3 = abn(planes)
|
|
|
|
self.local = nn.Sequential(
|
|
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
|
|
abn(planes),
|
|
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
|
|
abn(planes),
|
|
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
|
|
abn(planes))
|
|
self.gcn_local_attention = SpatialGCN(planes, abn)
|
|
|
|
self.final = nn.Sequential(nn.Conv2d(planes * 2, planes, kernel_size=1, bias=False),
|
|
abn(planes))
|
|
|
|
self.gamma1 = nn.Parameter(torch.zeros(1))
|
|
|
|
def to_matrix(self, x):
|
|
n, c, h, w = x.size()
|
|
x = x.view(n, c, -1)
|
|
return x
|
|
|
|
def forward(self, feat):
|
|
# # # # Local # # # #
|
|
x = feat
|
|
local = self.local(feat)
|
|
local = self.gcn_local_attention(local)
|
|
local = F.interpolate(local, size=x.size()[2:], mode='bilinear', align_corners=True)
|
|
spatial_local_feat = x * local + x
|
|
|
|
# # # # Projection Space # # # #
|
|
x_sqz, b = x, x
|
|
|
|
x_sqz = self.phi(x_sqz)
|
|
x_sqz = self.bn_phi(x_sqz)
|
|
x_sqz = self.to_matrix(x_sqz)
|
|
|
|
b = self.theta(b)
|
|
b = self.bn_theta(b)
|
|
b = self.to_matrix(b)
|
|
|
|
# Project
|
|
z_idt = torch.matmul(x_sqz, b.transpose(1, 2)) # channel
|
|
|
|
# # # # Interaction Space # # # #
|
|
z = z_idt.transpose(1, 2).contiguous()
|
|
|
|
z = self.conv_adj(z)
|
|
z = self.bn_adj(z)
|
|
|
|
z = z.transpose(1, 2).contiguous()
|
|
# Laplacian smoothing: (I - A_g)Z => Z - A_gZ
|
|
z += z_idt
|
|
|
|
z = self.conv_wg(z)
|
|
z = self.bn_wg(z)
|
|
|
|
# # # # Re-projection Space # # # #
|
|
# Re-project
|
|
y = torch.matmul(z, b)
|
|
|
|
n, _, h, w = x.size()
|
|
y = y.view(n, -1, h, w)
|
|
|
|
y = self.conv3(y)
|
|
y = self.bn3(y)
|
|
|
|
# g_out = x + y
|
|
# g_out = F.relu_(x+y)
|
|
g_out = self.gamma1*y + x
|
|
|
|
# cat or sum, nearly the same results
|
|
out = self.final(torch.cat((spatial_local_feat, g_out), 1))
|
|
|
|
return out
|
|
|
|
|
|
class DDualGCNHead(nn.Module):
|
|
def __init__(self, inplanes, interplanes, abn=InPlaceABNSync):
|
|
super(DDualGCNHead, self).__init__()
|
|
self.conva = nn.Sequential(nn.Conv2d(inplanes, interplanes, 3, padding=1, bias=False),
|
|
abn(interplanes))
|
|
self.dualgcn = DDualGCN(interplanes, abn)
|
|
self.convb = nn.Sequential(nn.Conv2d(interplanes, interplanes, 3, padding=1, bias=False),
|
|
abn(interplanes))
|
|
|
|
self.bottleneck = nn.Sequential(
|
|
nn.Conv2d(inplanes + interplanes, interplanes, kernel_size=3, padding=1, dilation=1, bias=False),
|
|
abn(interplanes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
output = self.conva(x)
|
|
output = self.dualgcn(output)
|
|
output = self.convb(output)
|
|
output = self.bottleneck(torch.cat([x, output], 1))
|
|
return output
|