mirror of
https://github.com/deepinsight/insightface.git
synced 2025-12-30 08:02:27 +00:00
update arcface_onnx
This commit is contained in:
@@ -56,20 +56,12 @@ class ArcFaceONNX:
|
||||
self.input_name = input_name
|
||||
self.output_names = output_names
|
||||
assert len(self.output_names)==1
|
||||
self.output_shape = outputs[0].shape
|
||||
|
||||
def prepare(self, ctx_id, **kwargs):
|
||||
if ctx_id<0:
|
||||
self.session.set_providers(['CPUExecutionProvider'])
|
||||
|
||||
def get_feat(self, img):
|
||||
assert img.shape[2] == 3
|
||||
input_size = tuple(img.shape[0:2][::-1])
|
||||
assert input_size==self.input_size
|
||||
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
||||
net_outs = self.session.run(self.output_names, {self.input_name : blob})
|
||||
feat = net_outs[0]
|
||||
return feat
|
||||
|
||||
def get(self, img, face):
|
||||
aimg = face_align.norm_crop(img, landmark=face.kps)
|
||||
face.embedding = self.get_feat(aimg).flatten()
|
||||
@@ -82,7 +74,7 @@ class ArcFaceONNX:
|
||||
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
|
||||
return sim
|
||||
|
||||
def forward(self, imgs):
|
||||
def get_feat(self, imgs):
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
input_size = self.input_size
|
||||
@@ -92,4 +84,9 @@ class ArcFaceONNX:
|
||||
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
|
||||
return net_out
|
||||
|
||||
def forward(self, batch_data):
|
||||
blob = (batch_data - self.input_mean) / self.input_std
|
||||
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
|
||||
return net_out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user