1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-16 16:56:46 +01:00
This commit is contained in:
mertalev 2024-10-13 17:16:30 -04:00
parent d554630534
commit 662b0f844a
No known key found for this signature in database
GPG key ID: 46904880C3E8B346
2 changed files with 9 additions and 7 deletions

View file

@ -19,9 +19,8 @@ class FaceRecognizer(InferenceModel):
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)] depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION) identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None: def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__(model_name, **model_kwargs) super().__init__(model_name, **model_kwargs)
self.min_score = model_kwargs.pop("minScore", min_score)
self.batch = self.model_format == ModelFormat.ONNX self.batch = self.model_format == ModelFormat.ONNX
def _load(self) -> ModelSession: def _load(self) -> ModelSession:

View file

@ -323,7 +323,7 @@ class TestAnnSession:
session.run(None, input_feed) session.run(None, input_feed)
ann_session.return_value.execute.assert_called_once_with(123, [input1, input2]) ann_session.return_value.execute.assert_called_once_with(123, [input1, input2])
np_spy.call_count == 2 assert np_spy.call_count == 2
np_spy.assert_has_calls([mock.call(input1), mock.call(input2)]) np_spy.assert_has_calls([mock.call(input1), mock.call(input2)])
@ -456,11 +456,14 @@ class TestCLIP:
class TestFaceRecognition: class TestFaceRecognition:
def test_set_min_score(self, mocker: MockerFixture) -> None: def test_set_min_score(self, snapshot_download: mock.Mock, ort_session: mock.Mock, path: mock.Mock) -> None:
mocker.patch.object(FaceRecognizer, "load") path.return_value.__truediv__.return_value.__truediv__.return_value.suffix = ".onnx"
face_recognizer = FaceRecognizer("buffalo_s", cache_dir="test_cache", min_score=0.5)
assert face_recognizer.min_score == 0.5 face_detector = FaceDetector("buffalo_s", min_score=0.5, cache_dir="test_cache")
face_detector.load()
assert face_detector.min_score == 0.5
assert face_detector.model.det_thresh == 0.5
def test_detection(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None: def test_detection(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
mocker.patch.object(FaceDetector, "load") mocker.patch.object(FaceDetector, "load")