diff --git a/machine-learning/app/models/facial_recognition/detection.py b/machine-learning/app/models/facial_recognition/detection.py index 2efacd447e..8d6a752aaa 100644 --- a/machine-learning/app/models/facial_recognition/detection.py +++ b/machine-learning/app/models/facial_recognition/detection.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any import numpy as np +import onnxruntime as ort from insightface.model_zoo import RetinaFace from numpy.typing import NDArray @@ -26,6 +27,7 @@ class FaceDetector(InferenceModel): def _load(self) -> ModelSession: session = self._make_session(self.model_path) + self._squeeze_outputs(session) self.model = RetinaFace(session=session) self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640)) @@ -44,5 +46,15 @@ class FaceDetector(InferenceModel): def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]: return self.model.detect(inputs) # type: ignore + def _squeeze_outputs(self, session: ort.InferenceSession) -> None: + original_run = session.run + + def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]: + out: list[NDArray[np.float32]] = original_run(output_names, input_feed) + out = [o.squeeze(axis=0) for o in out] + return out + + session.run = run + def configure(self, **kwargs: Any) -> None: self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)