update arcface_onnx

This commit is contained in:
Jia Guo
2021-06-21 19:43:26 +08:00
parent 4ea2cceaf8
commit 72847c059d

View File

@@ -56,20 +56,12 @@ class ArcFaceONNX:
self.input_name = input_name
self.output_names = output_names
assert len(self.output_names)==1
self.output_shape = outputs[0].shape
def prepare(self, ctx_id, **kwargs):
if ctx_id<0:
self.session.set_providers(['CPUExecutionProvider'])
def get_feat(self, img):
assert img.shape[2] == 3
input_size = tuple(img.shape[0:2][::-1])
assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
net_outs = self.session.run(self.output_names, {self.input_name : blob})
feat = net_outs[0]
return feat
def get(self, img, face):
aimg = face_align.norm_crop(img, landmark=face.kps)
face.embedding = self.get_feat(aimg).flatten()
@@ -82,7 +74,7 @@ class ArcFaceONNX:
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
return sim
def forward(self, imgs):
def get_feat(self, imgs):
if not isinstance(imgs, list):
imgs = [imgs]
input_size = self.input_size
@@ -92,4 +84,9 @@ class ArcFaceONNX:
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out
def forward(self, batch_data):
blob = (batch_data - self.input_mean) / self.input_std
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
return net_out