mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
126 lines
4.7 KiB
Python
126 lines
4.7 KiB
Python
from collections import OrderedDict, Counter
|
|
|
|
from caffe.proto import caffe_pb2
|
|
from google import protobuf
|
|
import six
|
|
|
|
def param_name_dict():
|
|
"""Find out the correspondence between layer names and parameter names."""
|
|
|
|
layer = caffe_pb2.LayerParameter()
|
|
# get all parameter names (typically underscore case) and corresponding
|
|
# type names (typically camel case), which contain the layer names
|
|
# (note that not all parameters correspond to layers, but we'll ignore that)
|
|
param_names = [f.name for f in layer.DESCRIPTOR.fields if f.name.endswith('_param')]
|
|
param_type_names = [type(getattr(layer, s)).__name__ for s in param_names]
|
|
# strip the final '_param' or 'Parameter'
|
|
param_names = [s[:-len('_param')] for s in param_names]
|
|
param_type_names = [s[:-len('Parameter')] for s in param_type_names]
|
|
return dict(zip(param_type_names, param_names))
|
|
|
|
def assign_proto(proto, name, val):
|
|
"""Assign a Python object to a protobuf message, based on the Python
|
|
type (in recursive fashion). Lists become repeated fields/messages, dicts
|
|
become messages, and other types are assigned directly. For convenience,
|
|
repeated fields whose values are not lists are converted to single-element
|
|
lists; e.g., `my_repeated_int_field=3` is converted to
|
|
`my_repeated_int_field=[3]`."""
|
|
|
|
is_repeated_field = hasattr(getattr(proto, name), 'extend')
|
|
if is_repeated_field and not isinstance(val, list):
|
|
val = [val]
|
|
if isinstance(val, list):
|
|
if isinstance(val[0], dict):
|
|
for item in val:
|
|
proto_item = getattr(proto, name).add()
|
|
for k, v in six.iteritems(item):
|
|
assign_proto(proto_item, k, v)
|
|
else:
|
|
getattr(proto, name).extend(val)
|
|
elif isinstance(val, dict):
|
|
for k, v in six.iteritems(val):
|
|
assign_proto(getattr(proto, name), k, v)
|
|
else:
|
|
setattr(proto, name, val)
|
|
|
|
class Function(object):
|
|
"""A Function specifies a layer, its parameters, and its inputs (which
|
|
are Tops from other layers)."""
|
|
|
|
def __init__(self, type_name, layer_name, inputs,outputs, **params):
|
|
self.type_name = type_name
|
|
self.inputs = inputs
|
|
self.outputs = outputs
|
|
self.params = params
|
|
self.layer_name = layer_name
|
|
self.ntop = self.params.get('ntop', 1)
|
|
# use del to make sure kwargs are not double-processed as layer params
|
|
if 'ntop' in self.params:
|
|
del self.params['ntop']
|
|
self.in_place = self.params.get('in_place', False)
|
|
if 'in_place' in self.params:
|
|
del self.params['in_place']
|
|
# self.tops = tuple(Top(self, n) for n in range(self.ntop))l
|
|
|
|
def _get_name(self, names, autonames):
|
|
if self not in names and self.ntop > 0:
|
|
names[self] = self._get_top_name(self.tops[0], names, autonames)
|
|
elif self not in names:
|
|
autonames[self.type_name] += 1
|
|
names[self] = self.type_name + str(autonames[self.type_name])
|
|
return names[self]
|
|
|
|
def _get_top_name(self, top, names, autonames):
|
|
if top not in names:
|
|
autonames[top.fn.type_name] += 1
|
|
names[top] = top.fn.type_name + str(autonames[top.fn.type_name])
|
|
return names[top]
|
|
|
|
def _to_proto(self):
|
|
bottom_names = []
|
|
for inp in self.inputs:
|
|
# inp._to_proto(layers, names, autonames)
|
|
bottom_names.append(inp)
|
|
layer = caffe_pb2.LayerParameter()
|
|
layer.type = self.type_name
|
|
layer.bottom.extend(bottom_names)
|
|
|
|
if self.in_place:
|
|
layer.top.extend(layer.bottom)
|
|
else:
|
|
for top in self.outputs:
|
|
layer.top.append(top)
|
|
layer.name = self.layer_name
|
|
# print(self.type_name + "...")
|
|
for k, v in six.iteritems(self.params):
|
|
# special case to handle generic *params
|
|
# print("generating "+k+"...")
|
|
|
|
if k.endswith('param'):
|
|
assign_proto(layer, k, v)
|
|
else:
|
|
try:
|
|
assign_proto(getattr(layer,
|
|
_param_names[self.type_name] + '_param'), k, v)
|
|
except (AttributeError, KeyError):
|
|
assign_proto(layer, k, v)
|
|
|
|
return layer
|
|
|
|
class Layers(object):
|
|
"""A Layers object is a pseudo-module which generates functions that specify
|
|
layers; e.g., Layers().Convolution(bottom, kernel_size=3) will produce a Top
|
|
specifying a 3x3 convolution applied to bottom."""
|
|
|
|
def __getattr__(self, name):
|
|
def layer_fn(*args, **kwargs):
|
|
fn = Function(name, args, kwargs)
|
|
return fn
|
|
return layer_fn
|
|
|
|
|
|
|
|
|
|
_param_names = param_name_dict()
|
|
|