import os import json import os.path as osp import io import torch import numpy as np from mmcv import Config from mmdet.models import build_detector from mmcv.cnn import get_model_complexity_info def get_flops(cfg, input_shape): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) if torch.cuda.is_available(): model.cuda() model.eval() if hasattr(model, 'forward_dummy'): model.forward = model.forward_dummy else: raise NotImplementedError( 'FLOPs counter is currently not currently supported with {}'. format(model.__class__.__name__)) buf = io.StringIO() all_flops, params = get_model_complexity_info(model, input_shape, print_per_layer_stat=True, as_strings=False, ost=buf) buf = buf.getvalue() lines = buf.split("\n") names = ['(stem)', '(layer1)', '(layer2)', '(layer3)', '(layer4)', '(neck)', '(bbox_head)'] name_ptr = 0 line_num = 0 _flops = [] while name_ptr