mirror of
https://github.com/immich-app/immich.git
synced 2025-01-27 22:22:45 +01:00
feat(ml): add face models (#4952)
added models to config dropdown fixed downloading updated tests use hf for face models formatting
This commit is contained in:
parent
7fca0d8da5
commit
328a58ac0d
8 changed files with 101 additions and 94 deletions
machine-learning/app
web/src/lib/components/admin-page/settings/machine-learning-settings
|
@ -38,8 +38,16 @@ class LogSettings(BaseSettings):
|
|||
_clean_name = str.maketrans(":\\/", "___", ".")
|
||||
|
||||
|
||||
def clean_name(model_name: str) -> str:
|
||||
return model_name.split("/")[-1].translate(_clean_name)
|
||||
|
||||
|
||||
def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
|
||||
return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
|
||||
return Path(settings.cache_folder) / model_type.value / clean_name(model_name)
|
||||
|
||||
|
||||
def get_hf_model_name(model_name: str) -> str:
|
||||
return f"immich-app/{clean_name(model_name)}"
|
||||
|
||||
|
||||
LOG_LEVELS: dict[str, int] = {
|
||||
|
|
|
@ -3,7 +3,8 @@ from typing import Any
|
|||
from app.schemas import ModelType
|
||||
|
||||
from .base import InferenceModel
|
||||
from .clip import MCLIPEncoder, OpenCLIPEncoder, is_mclip, is_openclip
|
||||
from .clip import MCLIPEncoder, OpenCLIPEncoder
|
||||
from .constants import is_insightface, is_mclip, is_openclip
|
||||
from .facial_recognition import FaceRecognizer
|
||||
from .image_classification import ImageClassifier
|
||||
|
||||
|
@ -15,11 +16,12 @@ def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any)
|
|||
return OpenCLIPEncoder(model_name, **model_kwargs)
|
||||
elif is_mclip(model_name):
|
||||
return MCLIPEncoder(model_name, **model_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown CLIP model {model_name}")
|
||||
case ModelType.FACIAL_RECOGNITION:
|
||||
return FaceRecognizer(model_name, **model_kwargs)
|
||||
if is_insightface(model_name):
|
||||
return FaceRecognizer(model_name, **model_kwargs)
|
||||
case ModelType.IMAGE_CLASSIFICATION:
|
||||
return ImageClassifier(model_name, **model_kwargs)
|
||||
case _:
|
||||
raise ValueError(f"Unknown model type {model_type}")
|
||||
|
||||
raise ValueError(f"Unknown ${model_type} model {model_name}")
|
||||
|
|
|
@ -7,8 +7,9 @@ from shutil import rmtree
|
|||
from typing import Any
|
||||
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from ..config import get_cache_dir, log, settings
|
||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||
from ..schemas import ModelType
|
||||
|
||||
|
||||
|
@ -78,9 +79,13 @@ class InferenceModel(ABC):
|
|||
def configure(self, **model_kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _download(self) -> None:
|
||||
...
|
||||
snapshot_download(
|
||||
get_hf_model_name(self.model_name),
|
||||
cache_dir=self.cache_dir,
|
||||
local_dir=self.cache_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _load(self) -> None:
|
||||
|
|
|
@ -7,11 +7,10 @@ from typing import Any, Literal
|
|||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from app.config import log
|
||||
from app.config import clean_name, log
|
||||
from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
|
||||
from app.schemas import ModelType, ndarray_f32, ndarray_i32, ndarray_i64
|
||||
|
||||
|
@ -117,15 +116,7 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
|||
mode: Literal["text", "vision"] | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(_clean_model_name(model_name), cache_dir, mode, **model_kwargs)
|
||||
|
||||
def _download(self) -> None:
|
||||
snapshot_download(
|
||||
f"immich-app/{self.model_name}",
|
||||
cache_dir=self.cache_dir,
|
||||
local_dir=self.cache_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
super().__init__(clean_name(model_name), cache_dir, mode, **model_kwargs)
|
||||
|
||||
def _load(self) -> None:
|
||||
super()._load()
|
||||
|
@ -171,52 +162,3 @@ class MCLIPEncoder(OpenCLIPEncoder):
|
|||
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
|
||||
tokens: dict[str, ndarray_i64] = self.tokenizer(text, return_tensors="np")
|
||||
return {k: v.astype(np.int32) for k, v in tokens.items()}
|
||||
|
||||
|
||||
_OPENCLIP_MODELS = {
|
||||
"RN50__openai",
|
||||
"RN50__yfcc15m",
|
||||
"RN50__cc12m",
|
||||
"RN101__openai",
|
||||
"RN101__yfcc15m",
|
||||
"RN50x4__openai",
|
||||
"RN50x16__openai",
|
||||
"RN50x64__openai",
|
||||
"ViT-B-32__openai",
|
||||
"ViT-B-32__laion2b_e16",
|
||||
"ViT-B-32__laion400m_e31",
|
||||
"ViT-B-32__laion400m_e32",
|
||||
"ViT-B-32__laion2b-s34b-b79k",
|
||||
"ViT-B-16__openai",
|
||||
"ViT-B-16__laion400m_e31",
|
||||
"ViT-B-16__laion400m_e32",
|
||||
"ViT-B-16-plus-240__laion400m_e31",
|
||||
"ViT-B-16-plus-240__laion400m_e32",
|
||||
"ViT-L-14__openai",
|
||||
"ViT-L-14__laion400m_e31",
|
||||
"ViT-L-14__laion400m_e32",
|
||||
"ViT-L-14__laion2b-s32b-b82k",
|
||||
"ViT-L-14-336__openai",
|
||||
"ViT-H-14__laion2b-s32b-b79k",
|
||||
"ViT-g-14__laion2b-s12b-b42k",
|
||||
}
|
||||
|
||||
|
||||
_MCLIP_MODELS = {
|
||||
"LABSE-Vit-L-14",
|
||||
"XLM-Roberta-Large-Vit-B-32",
|
||||
"XLM-Roberta-Large-Vit-B-16Plus",
|
||||
"XLM-Roberta-Large-Vit-L-14",
|
||||
}
|
||||
|
||||
|
||||
def _clean_model_name(model_name: str) -> str:
|
||||
return model_name.split("/")[-1].replace("::", "__")
|
||||
|
||||
|
||||
def is_openclip(model_name: str) -> bool:
|
||||
return _clean_model_name(model_name) in _OPENCLIP_MODELS
|
||||
|
||||
|
||||
def is_mclip(model_name: str) -> bool:
|
||||
return _clean_model_name(model_name) in _MCLIP_MODELS
|
||||
|
|
57
machine-learning/app/models/constants.py
Normal file
57
machine-learning/app/models/constants.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from app.config import clean_name
|
||||
|
||||
_OPENCLIP_MODELS = {
|
||||
"RN50__openai",
|
||||
"RN50__yfcc15m",
|
||||
"RN50__cc12m",
|
||||
"RN101__openai",
|
||||
"RN101__yfcc15m",
|
||||
"RN50x4__openai",
|
||||
"RN50x16__openai",
|
||||
"RN50x64__openai",
|
||||
"ViT-B-32__openai",
|
||||
"ViT-B-32__laion2b_e16",
|
||||
"ViT-B-32__laion400m_e31",
|
||||
"ViT-B-32__laion400m_e32",
|
||||
"ViT-B-32__laion2b-s34b-b79k",
|
||||
"ViT-B-16__openai",
|
||||
"ViT-B-16__laion400m_e31",
|
||||
"ViT-B-16__laion400m_e32",
|
||||
"ViT-B-16-plus-240__laion400m_e31",
|
||||
"ViT-B-16-plus-240__laion400m_e32",
|
||||
"ViT-L-14__openai",
|
||||
"ViT-L-14__laion400m_e31",
|
||||
"ViT-L-14__laion400m_e32",
|
||||
"ViT-L-14__laion2b-s32b-b82k",
|
||||
"ViT-L-14-336__openai",
|
||||
"ViT-H-14__laion2b-s32b-b79k",
|
||||
"ViT-g-14__laion2b-s12b-b42k",
|
||||
}
|
||||
|
||||
|
||||
_MCLIP_MODELS = {
|
||||
"LABSE-Vit-L-14",
|
||||
"XLM-Roberta-Large-Vit-B-32",
|
||||
"XLM-Roberta-Large-Vit-B-16Plus",
|
||||
"XLM-Roberta-Large-Vit-L-14",
|
||||
}
|
||||
|
||||
|
||||
_INSIGHTFACE_MODELS = {
|
||||
"antelopev2",
|
||||
"buffalo_l",
|
||||
"buffalo_m",
|
||||
"buffalo_s",
|
||||
}
|
||||
|
||||
|
||||
def is_openclip(model_name: str) -> bool:
|
||||
return clean_name(model_name) in _OPENCLIP_MODELS
|
||||
|
||||
|
||||
def is_mclip(model_name: str) -> bool:
|
||||
return clean_name(model_name) in _MCLIP_MODELS
|
||||
|
||||
|
||||
def is_insightface(model_name: str) -> bool:
|
||||
return clean_name(model_name) in _INSIGHTFACE_MODELS
|
|
@ -1,4 +1,3 @@
|
|||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
@ -7,8 +6,8 @@ import numpy as np
|
|||
import onnxruntime as ort
|
||||
from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
||||
from insightface.utils.face_align import norm_crop
|
||||
from insightface.utils.storage import BASE_REPO_URL, download_file
|
||||
|
||||
from app.config import clean_name
|
||||
from app.schemas import ModelType, ndarray_f32
|
||||
|
||||
from .base import InferenceModel
|
||||
|
@ -25,37 +24,21 @@ class FaceRecognizer(InferenceModel):
|
|||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||
super().__init__(model_name, cache_dir, **model_kwargs)
|
||||
|
||||
def _download(self) -> None:
|
||||
zip_file = self.cache_dir / f"{self.model_name}.zip"
|
||||
download_file(f"{BASE_REPO_URL}/{self.model_name}.zip", zip_file)
|
||||
with zipfile.ZipFile(zip_file, "r") as zip:
|
||||
members = zip.namelist()
|
||||
det_file = next(model for model in members if model.startswith("det_"))
|
||||
rec_file = next(model for model in members if model.startswith("w600k_"))
|
||||
zip.extractall(self.cache_dir, members=[det_file, rec_file])
|
||||
zip_file.unlink()
|
||||
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
det_file = next(self.cache_dir.glob("det_*.onnx"))
|
||||
rec_file = next(self.cache_dir.glob("w600k_*.onnx"))
|
||||
except StopIteration:
|
||||
raise FileNotFoundError("Facial recognition models not found in cache directory")
|
||||
|
||||
self.det_model = RetinaFace(
|
||||
session=ort.InferenceSession(
|
||||
det_file.as_posix(),
|
||||
self.det_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
),
|
||||
)
|
||||
self.rec_model = ArcFaceONNX(
|
||||
rec_file.as_posix(),
|
||||
self.rec_file.as_posix(),
|
||||
session=ort.InferenceSession(
|
||||
rec_file.as_posix(),
|
||||
self.rec_file.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
|
@ -103,7 +86,15 @@ class FaceRecognizer(InferenceModel):
|
|||
|
||||
@property
|
||||
def cached(self) -> bool:
|
||||
return self.cache_dir.is_dir() and any(self.cache_dir.glob("*.onnx"))
|
||||
return self.det_file.is_file() and self.rec_file.is_file()
|
||||
|
||||
@property
|
||||
def det_file(self) -> Path:
|
||||
return self.cache_dir / "detection" / "model.onnx"
|
||||
|
||||
@property
|
||||
def rec_file(self) -> Path:
|
||||
return self.cache_dir / "recognition" / "model.onnx"
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
self.det_model.det_thresh = model_kwargs.pop("minScore", self.det_model.det_thresh)
|
||||
|
|
|
@ -106,13 +106,13 @@ class TestCLIP:
|
|||
class TestFaceRecognition:
|
||||
def test_set_min_score(self, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(FaceRecognizer, "load")
|
||||
face_recognizer = FaceRecognizer("test_model_name", cache_dir="test_cache", min_score=0.5)
|
||||
face_recognizer = FaceRecognizer("buffalo_s", cache_dir="test_cache", min_score=0.5)
|
||||
|
||||
assert face_recognizer.min_score == 0.5
|
||||
|
||||
def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
||||
mocker.patch.object(FaceRecognizer, "load")
|
||||
face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
|
||||
face_recognizer = FaceRecognizer("buffalo_s", min_score=0.0, cache_dir="test_cache")
|
||||
|
||||
det_model = mock.Mock()
|
||||
num_faces = 2
|
||||
|
|
|
@ -160,11 +160,13 @@
|
|||
|
||||
<SettingSelect
|
||||
label="FACIAL RECOGNITION MODEL"
|
||||
desc="Smaller models are faster and use less memory, but perform worse. Note that you must re-run the Recognize Faces job for all images upon changing a model."
|
||||
desc="Models are listed in descending order of size. Larger models are slower and use more memory, but produce better results. Note that you must re-run the Recognize Faces job for all images upon changing a model."
|
||||
name="facial-recognition-model"
|
||||
bind:value={machineLearningConfig.facialRecognition.modelName}
|
||||
options={[
|
||||
{ value: 'antelopev2', text: 'antelopev2' },
|
||||
{ value: 'buffalo_l', text: 'buffalo_l' },
|
||||
{ value: 'buffalo_m', text: 'buffalo_m' },
|
||||
{ value: 'buffalo_s', text: 'buffalo_s' },
|
||||
]}
|
||||
disabled={disabled || !machineLearningConfig.enabled || !machineLearningConfig.facialRecognition.enabled}
|
||||
|
|
Loading…
Reference in a new issue