mirror of
https://github.com/immich-app/immich.git
synced 2025-01-04 10:56:47 +01:00
squeeze output dims
This commit is contained in:
parent
4d862525bc
commit
7e587c2703
1 changed files with 12 additions and 0 deletions
|
@ -2,6 +2,7 @@ from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
from insightface.model_zoo import RetinaFace
|
from insightface.model_zoo import RetinaFace
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
@ -26,6 +27,7 @@ class FaceDetector(InferenceModel):
|
||||||
|
|
||||||
def _load(self) -> ModelSession:
|
def _load(self) -> ModelSession:
|
||||||
session = self._make_session(self.model_path)
|
session = self._make_session(self.model_path)
|
||||||
|
self._squeeze_outputs(session)
|
||||||
self.model = RetinaFace(session=session)
|
self.model = RetinaFace(session=session)
|
||||||
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
||||||
|
|
||||||
|
@ -44,5 +46,15 @@ class FaceDetector(InferenceModel):
|
||||||
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
||||||
return self.model.detect(inputs) # type: ignore
|
return self.model.detect(inputs) # type: ignore
|
||||||
|
|
||||||
|
def _squeeze_outputs(self, session: ort.InferenceSession) -> None:
|
||||||
|
original_run = session.run
|
||||||
|
|
||||||
|
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||||
|
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
|
||||||
|
out = [o.squeeze(axis=0) for o in out]
|
||||||
|
return out
|
||||||
|
|
||||||
|
session.run = run
|
||||||
|
|
||||||
def configure(self, **kwargs: Any) -> None:
|
def configure(self, **kwargs: Any) -> None:
|
||||||
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
||||||
|
|
Loading…
Reference in a new issue