From 7e587c2703fe64d2a616197eb651448cf60a4d86 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 6 Jun 2024 23:07:30 -0400 Subject: [PATCH] squeeze output dims --- .../app/models/facial_recognition/detection.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/machine-learning/app/models/facial_recognition/detection.py b/machine-learning/app/models/facial_recognition/detection.py index 2efacd447e..8d6a752aaa 100644 --- a/machine-learning/app/models/facial_recognition/detection.py +++ b/machine-learning/app/models/facial_recognition/detection.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any import numpy as np +import onnxruntime as ort from insightface.model_zoo import RetinaFace from numpy.typing import NDArray @@ -26,6 +27,7 @@ class FaceDetector(InferenceModel): def _load(self) -> ModelSession: session = self._make_session(self.model_path) + self._squeeze_outputs(session) self.model = RetinaFace(session=session) self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640)) @@ -44,5 +46,15 @@ class FaceDetector(InferenceModel): def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]: return self.model.detect(inputs) # type: ignore + def _squeeze_outputs(self, session: ort.InferenceSession) -> None: + original_run = session.run + + def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]: + out: list[NDArray[np.float32]] = original_run(output_names, input_feed) + out = [o.squeeze(axis=0) for o in out] + return out + + session.run = run + def configure(self, **kwargs: Any) -> None: self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)