mirror of
https://github.com/immich-app/immich.git
synced 2024-12-29 15:11:58 +00:00
squeeze output dims
This commit is contained in:
parent
4d862525bc
commit
7e587c2703
1 changed files with 12 additions and 0 deletions
|
@ -2,6 +2,7 @@ from pathlib import Path
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from insightface.model_zoo import RetinaFace
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
@ -26,6 +27,7 @@ class FaceDetector(InferenceModel):
|
|||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
self._squeeze_outputs(session)
|
||||
self.model = RetinaFace(session=session)
|
||||
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
||||
|
||||
|
@ -44,5 +46,15 @@ class FaceDetector(InferenceModel):
|
|||
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
||||
return self.model.detect(inputs) # type: ignore
|
||||
|
||||
def _squeeze_outputs(self, session: ort.InferenceSession) -> None:
|
||||
original_run = session.run
|
||||
|
||||
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
|
||||
out = [o.squeeze(axis=0) for o in out]
|
||||
return out
|
||||
|
||||
session.run = run
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
||||
|
|
Loading…
Reference in a new issue