diff --git a/machine-learning/app/models/facial_recognition/detection.py b/machine-learning/app/models/facial_recognition/detection.py index 8d6a752aaa..18dd12214c 100644 --- a/machine-learning/app/models/facial_recognition/detection.py +++ b/machine-learning/app/models/facial_recognition/detection.py @@ -7,6 +7,7 @@ from insightface.model_zoo import RetinaFace from numpy.typing import NDArray from app.models.base import InferenceModel +from app.models.session import ort_has_batch_dim, ort_squeeze_outputs from app.models.transforms import decode_cv2 from app.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType @@ -27,7 +28,8 @@ class FaceDetector(InferenceModel): def _load(self) -> ModelSession: session = self._make_session(self.model_path) - self._squeeze_outputs(session) + if isinstance(session, ort.InferenceSession) and ort_has_batch_dim(session): + ort_squeeze_outputs(session) self.model = RetinaFace(session=session) self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640)) @@ -46,15 +48,5 @@ class FaceDetector(InferenceModel): def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]: return self.model.detect(inputs) # type: ignore - def _squeeze_outputs(self, session: ort.InferenceSession) -> None: - original_run = session.run - - def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]: - out: list[NDArray[np.float32]] = original_run(output_names, input_feed) - out = [o.squeeze(axis=0) for o in out] - return out - - session.run = run - def configure(self, **kwargs: Any) -> None: self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh) diff --git a/machine-learning/app/models/facial_recognition/recognition.py b/machine-learning/app/models/facial_recognition/recognition.py index cb8093dd95..bc362fb156 100644 --- a/machine-learning/app/models/facial_recognition/recognition.py +++ b/machine-learning/app/models/facial_recognition/recognition.py @@ -2,16 +2,15 @@ from pathlib import Path from typing import Any import numpy as np -import onnx import onnxruntime as ort from insightface.model_zoo import ArcFaceONNX from insightface.utils.face_align import norm_crop from numpy.typing import NDArray -from onnx.tools.update_model_dims import update_inputs_outputs_dims from PIL import Image from app.config import clean_name, log from app.models.base import InferenceModel +from app.models.session import ort_add_batch_dim, ort_has_batch_dim from app.models.transforms import decode_cv2 from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelSession, ModelTask, ModelType @@ -32,8 +31,9 @@ class FaceRecognizer(InferenceModel): def _load(self) -> ModelSession: session = self._make_session(self.model_path) - if not self._has_batch_dim(session): - self._add_batch_dim(self.model_path) + if isinstance(session, ort.InferenceSession) and not ort_has_batch_dim(session): + log.debug(f"Adding batch dimension to model {self.model_path}") + ort_add_batch_dim(self.model_path, self.model_path) session = self._make_session(self.model_path) self.model = ArcFaceONNX( self.model_path.with_suffix(".onnx").as_posix(), @@ -62,16 +62,3 @@ class FaceRecognizer(InferenceModel): def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]: return [norm_crop(image, landmark) for landmark in faces["landmarks"]] - - def _has_batch_dim(self, session: ort.InferenceSession) -> bool: - return not isinstance(session, ort.InferenceSession) or session.get_inputs()[0].shape[0] == "batch" - - def _add_batch_dim(self, model_path: Path) -> None: - log.debug(f"Adding batch dimension to model {model_path}") - proto = onnx.load(model_path) - static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]] - static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]] - input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims} - output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims} - updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims) - onnx.save(updated_proto, model_path) diff --git a/machine-learning/app/models/session.py b/machine-learning/app/models/session.py index e69de29bb2..ddd87a6a27 100644 --- a/machine-learning/app/models/session.py +++ b/machine-learning/app/models/session.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import numpy as np +import onnx +import onnxruntime as ort +from numpy.typing import NDArray +from onnx.shape_inference import infer_shapes +from onnx.tools.update_model_dims import update_inputs_outputs_dims + + +def ort_has_batch_dim(session: ort.InferenceSession) -> bool: + return session.get_inputs()[0].shape[0] == "batch" + + +def ort_squeeze_outputs(session: ort.InferenceSession) -> None: + original_run = session.run + + def run(output_names: list[str], input_feed: dict[str, NDArray[np.float32]]) -> list[NDArray[np.float32]]: + out: list[NDArray[np.float32]] = original_run(output_names, input_feed) + out = [o.squeeze(axis=0) for o in out] + return out + + session.run = run + + +def ort_add_batch_dim(input_path: Path, output_path: Path) -> None: + proto = onnx.load(input_path) + static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]] + static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]] + input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims} + output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims} + updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims) + inferred = infer_shapes(updated_proto) + onnx.save(inferred, output_path)