mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-20 00:10:28 +00:00
75 lines
2.7 KiB
Python
Executable File
75 lines
2.7 KiB
Python
Executable File
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, xavier_init
|
|
|
|
from ..builder import NECKS
|
|
|
|
|
|
@NECKS.register_module()
|
|
class ChannelMapper(nn.Module):
|
|
r"""Channel Mapper to reduce/increase channels of backbone features.
|
|
|
|
This is used to reduce/increase channels of backbone features.
|
|
|
|
Args:
|
|
in_channels (List[int]): Number of input channels per scale.
|
|
out_channels (int): Number of output channels (used at each scale).
|
|
kernel_size (int, optional): kernel_size for reducing channels (used
|
|
at each scale). Default: 3.
|
|
conv_cfg (dict, optional): Config dict for convolution layer.
|
|
Default: None.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: None.
|
|
act_cfg (dict, optional): Config dict for activation layer in
|
|
ConvModule. Default: dict(type='ReLU').
|
|
|
|
Example:
|
|
>>> import torch
|
|
>>> in_channels = [2, 3, 5, 7]
|
|
>>> scales = [340, 170, 84, 43]
|
|
>>> inputs = [torch.rand(1, c, s, s)
|
|
... for c, s in zip(in_channels, scales)]
|
|
>>> self = ChannelMapper(in_channels, 11, 3).eval()
|
|
>>> outputs = self.forward(inputs)
|
|
>>> for i in range(len(outputs)):
|
|
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
|
outputs[0].shape = torch.Size([1, 11, 340, 340])
|
|
outputs[1].shape = torch.Size([1, 11, 170, 170])
|
|
outputs[2].shape = torch.Size([1, 11, 84, 84])
|
|
outputs[3].shape = torch.Size([1, 11, 43, 43])
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=dict(type='ReLU')):
|
|
super(ChannelMapper, self).__init__()
|
|
assert isinstance(in_channels, list)
|
|
|
|
self.convs = nn.ModuleList()
|
|
for in_channel in in_channels:
|
|
self.convs.append(
|
|
ConvModule(
|
|
in_channel,
|
|
out_channels,
|
|
kernel_size,
|
|
padding=(kernel_size - 1) // 2,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
|
|
# default init_weights for conv(msra) and norm in ConvModule
|
|
def init_weights(self):
|
|
"""Initialize the weights of ChannelMapper module."""
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
xavier_init(m, distribution='uniform')
|
|
|
|
def forward(self, inputs):
|
|
"""Forward function."""
|
|
assert len(inputs) == len(self.convs)
|
|
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
|
|
return tuple(outs)
|