''' @author: insightface ''' from __future__ import absolute_import from __future__ import division from __future__ import print_function import sys import os import json import argparse import numpy as np import mxnet as mx def is_no_bias(attr): ret = False if 'no_bias' in attr and (attr['no_bias']==True or attr['no_bias']=='True'): ret = True return ret def count_fc_flops(input_filter, output_filter, attr): #print(input_filter, output_filter ,attr) ret = 2*input_filter*output_filter if is_no_bias(attr): ret -= output_filter return int(ret) def count_conv_flops(input_shape, output_shape, attr): kernel = attr['kernel'][1:-1].split(',') kernel = [int(x) for x in kernel] #print('kernel', kernel) if is_no_bias(attr): ret = (2*input_shape[1]*kernel[0]*kernel[1]-1)*output_shape[2]*output_shape[3]*output_shape[1] else: ret = 2*input_shape[1]*kernel[0]*kernel[1]*output_shape[2]*output_shape[3]*output_shape[1] num_group = 1 if 'num_group' in attr: num_group = int(attr['num_group']) ret /= num_group return int(ret) def count_flops(sym, **data_shapes): all_layers = sym.get_internals() #print(all_layers) arg_shapes, out_shapes, aux_shapes = all_layers.infer_shape(**data_shapes) out_shape_dict = dict(zip(all_layers.list_outputs(), out_shapes)) nodes = json.loads(sym.tojson())['nodes'] nodeid_shape = {} for nodeid, node in enumerate(nodes): name = node['name'] layer_name = name+"_output" if layer_name in out_shape_dict: nodeid_shape[nodeid] = out_shape_dict[layer_name] #print(nodeid_shape) FLOPs = 0 for nodeid, node in enumerate(nodes): flops = 0 if node['op']=='Convolution': output_shape = nodeid_shape[nodeid] name = node['name'] attr = node['attrs'] input_nodeid = node['inputs'][0][0] input_shape = nodeid_shape[input_nodeid] flops = count_conv_flops(input_shape, output_shape, attr) elif node['op']=='FullyConnected': attr = node['attrs'] output_shape = nodeid_shape[nodeid] input_nodeid = node['inputs'][0][0] input_shape = nodeid_shape[input_nodeid] output_filter = output_shape[1] input_filter = input_shape[1]*input_shape[2]*input_shape[3] #assert len(input_shape)==4 and input_shape[2]==1 and input_shape[3]==1 flops = count_fc_flops(input_filter, output_filter, attr) #print(node, flops) FLOPs += flops return FLOPs def flops_str(FLOPs): preset = [ (1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K') ] for p in preset: if FLOPs//p[0]>0: N = FLOPs/p[0] ret = "%.1f%s"%(N, p[1]) return ret ret = "%.1f"%(FLOPs) return ret if __name__ == '__main__': parser = argparse.ArgumentParser(description='flops counter') # general #parser.add_argument('--model', default='../models2/y2-arcface-retinat1/model,1', help='path to load model.') #parser.add_argument('--model', default='../models2/r100fc-arcface-retinaa/model,1', help='path to load model.') parser.add_argument('--model', default='../models2/r50fc-arcface-emore/model,1', help='path to load model.') args = parser.parse_args() _vec = args.model.split(',') assert len(_vec)==2 prefix = _vec[0] epoch = int(_vec[1]) print('loading',prefix, epoch) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) all_layers = sym.get_internals() sym = all_layers['fc1_output'] FLOPs = count_flops(sym, data=(1,3,112,112)) print('FLOPs:', FLOPs)