''' @author: insightface ''' from __future__ import absolute_import from __future__ import division from __future__ import print_function import sys import os import json import argparse import numpy as np import mxnet as mx def is_no_bias(attr): ret = False if 'no_bias' in attr and (attr['no_bias'] == True or attr['no_bias'] == 'True'): ret = True return ret def count_fc_flops(input_filter, output_filter, attr): #print(input_filter, output_filter ,attr) ret = 2 * input_filter * output_filter if is_no_bias(attr): ret -= output_filter return int(ret) def count_conv_flops(input_shape, output_shape, attr): kernel = attr['kernel'][1:-1].split(',') kernel = [int(x) for x in kernel] #print('kernel', kernel) if is_no_bias(attr): ret = (2 * input_shape[1] * kernel[0] * kernel[1] - 1) * output_shape[2] * output_shape[3] * output_shape[1] else: ret = 2 * input_shape[1] * kernel[0] * kernel[1] * output_shape[ 2] * output_shape[3] * output_shape[1] num_group = 1 if 'num_group' in attr: num_group = int(attr['num_group']) ret /= num_group return int(ret) def count_flops(sym, **data_shapes): all_layers = sym.get_internals() #print(all_layers) arg_shapes, out_shapes, aux_shapes = all_layers.infer_shape(**data_shapes) out_shape_dict = dict(zip(all_layers.list_outputs(), out_shapes)) nodes = json.loads(sym.tojson())['nodes'] nodeid_shape = {} for nodeid, node in enumerate(nodes): name = node['name'] layer_name = name + "_output" if layer_name in out_shape_dict: nodeid_shape[nodeid] = out_shape_dict[layer_name] #print(nodeid_shape) FLOPs = 0 for nodeid, node in enumerate(nodes): flops = 0 if node['op'] == 'Convolution': output_shape = nodeid_shape[nodeid] name = node['name'] attr = node['attrs'] input_nodeid = node['inputs'][0][0] input_shape = nodeid_shape[input_nodeid] flops = count_conv_flops(input_shape, output_shape, attr) elif node['op'] == 'FullyConnected': attr = node['attrs'] output_shape = nodeid_shape[nodeid] input_nodeid = node['inputs'][0][0] input_shape = nodeid_shape[input_nodeid] output_filter = output_shape[1] input_filter = input_shape[1] * input_shape[2] * input_shape[3] #assert len(input_shape)==4 and input_shape[2]==1 and input_shape[3]==1 flops = count_fc_flops(input_filter, output_filter, attr) #print(node, flops) FLOPs += flops return FLOPs def flops_str(FLOPs): preset = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')] for p in preset: if FLOPs // p[0] > 0: N = FLOPs / p[0] ret = "%.1f%s" % (N, p[1]) return ret ret = "%.1f" % (FLOPs) return ret if __name__ == '__main__': parser = argparse.ArgumentParser(description='flops counter') # general #parser.add_argument('--model', default='../models2/y2-arcface-retinat1/model,1', help='path to load model.') #parser.add_argument('--model', default='../models2/r100fc-arcface-retinaa/model,1', help='path to load model.') parser.add_argument('--model', default='../models2/r50fc-arcface-emore/model,1', help='path to load model.') args = parser.parse_args() _vec = args.model.split(',') assert len(_vec) == 2 prefix = _vec[0] epoch = int(_vec[1]) print('loading', prefix, epoch) sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) all_layers = sym.get_internals() sym = all_layers['fc1_output'] FLOPs = count_flops(sym, data=(1, 3, 112, 112)) print('FLOPs:', FLOPs)