2021-03-19 12:36:15 +08:00
|
|
|
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
|
2021-07-08 18:07:32 +08:00
|
|
|
from .mobilefacenet import get_mbf
|
2021-05-09 13:28:42 +08:00
|
|
|
|
2021-05-12 15:02:40 +08:00
|
|
|
|
2021-05-09 13:28:42 +08:00
|
|
|
def get_model(name, **kwargs):
|
2021-07-08 18:07:32 +08:00
|
|
|
# resnet
|
2021-05-09 13:28:42 +08:00
|
|
|
if name == "r18":
|
|
|
|
|
return iresnet18(False, **kwargs)
|
|
|
|
|
elif name == "r34":
|
|
|
|
|
return iresnet34(False, **kwargs)
|
|
|
|
|
elif name == "r50":
|
|
|
|
|
return iresnet50(False, **kwargs)
|
|
|
|
|
elif name == "r100":
|
|
|
|
|
return iresnet100(False, **kwargs)
|
2021-05-12 15:02:40 +08:00
|
|
|
elif name == "r200":
|
|
|
|
|
return iresnet200(False, **kwargs)
|
2021-06-23 13:33:14 +08:00
|
|
|
elif name == "r2060":
|
|
|
|
|
from .iresnet2060 import iresnet2060
|
|
|
|
|
return iresnet2060(False, **kwargs)
|
2021-07-08 18:07:32 +08:00
|
|
|
elif name == "mbf":
|
|
|
|
|
fp16 = kwargs.get("fp16", False)
|
|
|
|
|
num_features = kwargs.get("num_features", 512)
|
|
|
|
|
return get_mbf(fp16=fp16, num_features=num_features)
|
2021-05-09 13:28:42 +08:00
|
|
|
else:
|
2021-07-14 12:01:48 +08:00
|
|
|
raise ValueError()
|