mirror of
https://github.com/deepinsight/insightface.git
synced 2026-05-12 02:32:41 +00:00
121 lines
3.8 KiB
Python
121 lines
3.8 KiB
Python
'''
|
|
@author: insightface
|
|
'''
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
import argparse
|
|
import numpy as np
|
|
import mxnet as mx
|
|
|
|
|
|
def is_no_bias(attr):
|
|
ret = False
|
|
if 'no_bias' in attr and (attr['no_bias'] == True
|
|
or attr['no_bias'] == 'True'):
|
|
ret = True
|
|
return ret
|
|
|
|
|
|
def count_fc_flops(input_filter, output_filter, attr):
|
|
#print(input_filter, output_filter ,attr)
|
|
ret = 2 * input_filter * output_filter
|
|
if is_no_bias(attr):
|
|
ret -= output_filter
|
|
return int(ret)
|
|
|
|
|
|
def count_conv_flops(input_shape, output_shape, attr):
|
|
kernel = attr['kernel'][1:-1].split(',')
|
|
kernel = [int(x) for x in kernel]
|
|
|
|
#print('kernel', kernel)
|
|
if is_no_bias(attr):
|
|
ret = (2 * input_shape[1] * kernel[0] * kernel[1] -
|
|
1) * output_shape[2] * output_shape[3] * output_shape[1]
|
|
else:
|
|
ret = 2 * input_shape[1] * kernel[0] * kernel[1] * output_shape[
|
|
2] * output_shape[3] * output_shape[1]
|
|
num_group = 1
|
|
if 'num_group' in attr:
|
|
num_group = int(attr['num_group'])
|
|
ret /= num_group
|
|
return int(ret)
|
|
|
|
|
|
def count_flops(sym, **data_shapes):
|
|
all_layers = sym.get_internals()
|
|
#print(all_layers)
|
|
arg_shapes, out_shapes, aux_shapes = all_layers.infer_shape(**data_shapes)
|
|
out_shape_dict = dict(zip(all_layers.list_outputs(), out_shapes))
|
|
|
|
nodes = json.loads(sym.tojson())['nodes']
|
|
nodeid_shape = {}
|
|
for nodeid, node in enumerate(nodes):
|
|
name = node['name']
|
|
layer_name = name + "_output"
|
|
if layer_name in out_shape_dict:
|
|
nodeid_shape[nodeid] = out_shape_dict[layer_name]
|
|
#print(nodeid_shape)
|
|
FLOPs = 0
|
|
for nodeid, node in enumerate(nodes):
|
|
flops = 0
|
|
if node['op'] == 'Convolution':
|
|
output_shape = nodeid_shape[nodeid]
|
|
name = node['name']
|
|
attr = node['attrs']
|
|
input_nodeid = node['inputs'][0][0]
|
|
input_shape = nodeid_shape[input_nodeid]
|
|
flops = count_conv_flops(input_shape, output_shape, attr)
|
|
elif node['op'] == 'FullyConnected':
|
|
attr = node['attrs']
|
|
output_shape = nodeid_shape[nodeid]
|
|
input_nodeid = node['inputs'][0][0]
|
|
input_shape = nodeid_shape[input_nodeid]
|
|
output_filter = output_shape[1]
|
|
input_filter = input_shape[1] * input_shape[2] * input_shape[3]
|
|
#assert len(input_shape)==4 and input_shape[2]==1 and input_shape[3]==1
|
|
flops = count_fc_flops(input_filter, output_filter, attr)
|
|
#print(node, flops)
|
|
FLOPs += flops
|
|
|
|
return FLOPs
|
|
|
|
|
|
def flops_str(FLOPs):
|
|
preset = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')]
|
|
|
|
for p in preset:
|
|
if FLOPs // p[0] > 0:
|
|
N = FLOPs / p[0]
|
|
ret = "%.1f%s" % (N, p[1])
|
|
return ret
|
|
ret = "%.1f" % (FLOPs)
|
|
return ret
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='flops counter')
|
|
# general
|
|
#parser.add_argument('--model', default='../models2/y2-arcface-retinat1/model,1', help='path to load model.')
|
|
#parser.add_argument('--model', default='../models2/r100fc-arcface-retinaa/model,1', help='path to load model.')
|
|
parser.add_argument('--model',
|
|
default='../models2/r50fc-arcface-emore/model,1',
|
|
help='path to load model.')
|
|
args = parser.parse_args()
|
|
_vec = args.model.split(',')
|
|
assert len(_vec) == 2
|
|
prefix = _vec[0]
|
|
epoch = int(_vec[1])
|
|
print('loading', prefix, epoch)
|
|
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
|
|
all_layers = sym.get_internals()
|
|
sym = all_layers['fc1_output']
|
|
FLOPs = count_flops(sym, data=(1, 3, 112, 112))
|
|
print('FLOPs:', FLOPs)
|