1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-01 16:41:59 +00:00
This commit is contained in:
mertalev 2024-06-07 00:01:31 -04:00
parent 7e587c2703
commit 259386cf13
No known key found for this signature in database
GPG key ID: 9181CD92C0A1C5E3
3 changed files with 41 additions and 28 deletions

View file

@ -7,6 +7,7 @@ from insightface.model_zoo import RetinaFace
from numpy.typing import NDArray from numpy.typing import NDArray
from app.models.base import InferenceModel from app.models.base import InferenceModel
from app.models.session import ort_has_batch_dim, ort_squeeze_outputs
from app.models.transforms import decode_cv2 from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
@ -27,7 +28,8 @@ class FaceDetector(InferenceModel):
def _load(self) -> ModelSession: def _load(self) -> ModelSession:
session = self._make_session(self.model_path) session = self._make_session(self.model_path)
self._squeeze_outputs(session) if isinstance(session, ort.InferenceSession) and ort_has_batch_dim(session):
ort_squeeze_outputs(session)
self.model = RetinaFace(session=session) self.model = RetinaFace(session=session)
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640)) self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
@ -46,15 +48,5 @@ class FaceDetector(InferenceModel):
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]: def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
return self.model.detect(inputs) # type: ignore return self.model.detect(inputs) # type: ignore
def _squeeze_outputs(self, session: ort.InferenceSession) -> None:
original_run = session.run
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
out = [o.squeeze(axis=0) for o in out]
return out
session.run = run
def configure(self, **kwargs: Any) -> None: def configure(self, **kwargs: Any) -> None:
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh) self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)

View file

@ -2,16 +2,15 @@ from pathlib import Path
from typing import Any from typing import Any
import numpy as np import numpy as np
import onnx
import onnxruntime as ort import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX from insightface.model_zoo import ArcFaceONNX
from insightface.utils.face_align import norm_crop from insightface.utils.face_align import norm_crop
from numpy.typing import NDArray from numpy.typing import NDArray
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from PIL import Image from PIL import Image
from app.config import clean_name, log from app.config import clean_name, log
from app.models.base import InferenceModel from app.models.base import InferenceModel
from app.models.session import ort_add_batch_dim, ort_has_batch_dim
from app.models.transforms import decode_cv2 from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelSession, ModelTask, ModelType from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelSession, ModelTask, ModelType
@ -32,8 +31,9 @@ class FaceRecognizer(InferenceModel):
def _load(self) -> ModelSession: def _load(self) -> ModelSession:
session = self._make_session(self.model_path) session = self._make_session(self.model_path)
if not self._has_batch_dim(session): if isinstance(session, ort.InferenceSession) and not ort_has_batch_dim(session):
self._add_batch_dim(self.model_path) log.debug(f"Adding batch dimension to model {self.model_path}")
ort_add_batch_dim(self.model_path, self.model_path)
session = self._make_session(self.model_path) session = self._make_session(self.model_path)
self.model = ArcFaceONNX( self.model = ArcFaceONNX(
self.model_path.with_suffix(".onnx").as_posix(), self.model_path.with_suffix(".onnx").as_posix(),
@ -62,16 +62,3 @@ class FaceRecognizer(InferenceModel):
def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]: def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]:
return [norm_crop(image, landmark) for landmark in faces["landmarks"]] return [norm_crop(image, landmark) for landmark in faces["landmarks"]]
def _has_batch_dim(self, session: ort.InferenceSession) -> bool:
return not isinstance(session, ort.InferenceSession) or session.get_inputs()[0].shape[0] == "batch"
def _add_batch_dim(self, model_path: Path) -> None:
log.debug(f"Adding batch dimension to model {model_path}")
proto = onnx.load(model_path)
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
onnx.save(updated_proto, model_path)

View file

@ -0,0 +1,34 @@
from pathlib import Path
import numpy as np
import onnx
import onnxruntime as ort
from numpy.typing import NDArray
from onnx.shape_inference import infer_shapes
from onnx.tools.update_model_dims import update_inputs_outputs_dims
def ort_has_batch_dim(session: ort.InferenceSession) -> bool:
return session.get_inputs()[0].shape[0] == "batch"
def ort_squeeze_outputs(session: ort.InferenceSession) -> None:
original_run = session.run
def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]:
out: list[NDArray[np.float32]] = original_run(output_names, input_feed)
out = [o.squeeze(axis=0) for o in out]
return out
session.run = run
def ort_add_batch_dim(input_path: Path, output_path: Path) -> None:
proto = onnx.load(input_path)
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
inferred = infer_shapes(updated_proto)
onnx.save(inferred, output_path)