1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-04 10:56:47 +01:00

squeeze output dims

This commit is contained in:
mertalev 2024-06-06 23:07:30 -04:00
parent 4d862525bc
commit 7e587c2703
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3

View file

@ -2,6 +2,7 @@ from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
import onnxruntime as ort
from insightface.model_zoo import RetinaFace from insightface.model_zoo import RetinaFace
from numpy.typing import NDArray from numpy.typing import NDArray
@ -26,6 +27,7 @@ class FaceDetector(InferenceModel):
def _load(self) -> ModelSession: def _load(self) -> ModelSession:
session = self._make_session(self.model_path) session = self._make_session(self.model_path)
self._squeeze_outputs(session)
self.model = RetinaFace(session=session) self.model = RetinaFace(session=session)
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640)) self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
@ -44,5 +46,15 @@ class FaceDetector(InferenceModel):
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]: def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
return self.model.detect(inputs) # type: ignore return self.model.detect(inputs) # type: ignore
def _squeeze_outputs(self, session: ort.InferenceSession) -> None:
original_run = session.run
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
out = [o.squeeze(axis=0) for o in out]
return out
session.run = run
def configure(self, **kwargs: Any) -> None: def configure(self, **kwargs: Any) -> None:
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh) self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)