mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-19 15:41:33 +00:00
67 lines
2.6 KiB
Python
Executable File
67 lines
2.6 KiB
Python
Executable File
from ..builder import DETECTORS
|
|
from .faster_rcnn import FasterRCNN
|
|
|
|
|
|
@DETECTORS.register_module()
|
|
class TridentFasterRCNN(FasterRCNN):
|
|
"""Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
rpn_head,
|
|
roi_head,
|
|
train_cfg,
|
|
test_cfg,
|
|
neck=None,
|
|
pretrained=None):
|
|
|
|
super(TridentFasterRCNN, self).__init__(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
rpn_head=rpn_head,
|
|
roi_head=roi_head,
|
|
train_cfg=train_cfg,
|
|
test_cfg=test_cfg,
|
|
pretrained=pretrained)
|
|
assert self.backbone.num_branch == self.roi_head.num_branch
|
|
assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
|
|
self.num_branch = self.backbone.num_branch
|
|
self.test_branch_idx = self.backbone.test_branch_idx
|
|
|
|
def simple_test(self, img, img_metas, proposals=None, rescale=False):
|
|
"""Test without augmentation."""
|
|
assert self.with_bbox, 'Bbox head must be implemented.'
|
|
x = self.extract_feat(img)
|
|
if proposals is None:
|
|
num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
|
|
trident_img_metas = img_metas * num_branch
|
|
proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas)
|
|
else:
|
|
proposal_list = proposals
|
|
|
|
return self.roi_head.simple_test(
|
|
x, proposal_list, trident_img_metas, rescale=rescale)
|
|
|
|
def aug_test(self, imgs, img_metas, rescale=False):
|
|
"""Test with augmentations.
|
|
|
|
If rescale is False, then returned bboxes and masks will fit the scale
|
|
of imgs[0].
|
|
"""
|
|
x = self.extract_feats(imgs)
|
|
num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
|
|
trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
|
|
proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
|
|
return self.roi_head.aug_test(
|
|
x, proposal_list, img_metas, rescale=rescale)
|
|
|
|
def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
|
|
"""make copies of img and gts to fit multi-branch."""
|
|
trident_gt_bboxes = tuple(gt_bboxes * self.num_branch)
|
|
trident_gt_labels = tuple(gt_labels * self.num_branch)
|
|
trident_img_metas = tuple(img_metas * self.num_branch)
|
|
|
|
return super(TridentFasterRCNN,
|
|
self).forward_train(img, trident_img_metas,
|
|
trident_gt_bboxes, trident_gt_labels)
|