1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-04 02:46:47 +01:00

fix(ml): clear model cache on load error (#2951)

* clear model cache on load error

* updated caught exceptions
This commit is contained in:
Mert 2023-06-27 17:01:24 -04:00 committed by GitHub
parent 39a885a37c
commit 47982641b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 19 deletions

View file

@ -2,8 +2,11 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from shutil import rmtree
from typing import Any from typing import Any
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf
from ..config import get_cache_dir from ..config import get_cache_dir
from ..schemas import ModelType from ..schemas import ModelType
@ -12,10 +15,8 @@ class InferenceModel(ABC):
_model_type: ModelType _model_type: ModelType
def __init__( def __init__(
self, self, model_name: str, cache_dir: Path | None = None, **model_kwargs
model_name: str, ) -> None:
cache_dir: Path | None = None,
):
self.model_name = model_name self.model_name = model_name
self._cache_dir = ( self._cache_dir = (
cache_dir cache_dir
@ -23,6 +24,16 @@ class InferenceModel(ABC):
else get_cache_dir(model_name, self.model_type) else get_cache_dir(model_name, self.model_type)
) )
try:
self.load(**model_kwargs)
except (OSError, InvalidProtobuf):
self.clear_cache()
self.load(**model_kwargs)
@abstractmethod
def load(self, **model_kwargs: Any) -> None:
...
@abstractmethod @abstractmethod
def predict(self, inputs: Any) -> Any: def predict(self, inputs: Any) -> Any:
... ...
@ -36,7 +47,7 @@ class InferenceModel(ABC):
return self._cache_dir return self._cache_dir
@cache_dir.setter @cache_dir.setter
def cache_dir(self, cache_dir: Path): def cache_dir(self, cache_dir: Path) -> None:
self._cache_dir = cache_dir self._cache_dir = cache_dir
@classmethod @classmethod
@ -50,3 +61,13 @@ class InferenceModel(ABC):
raise ValueError(f"Unsupported model type: {model_type}") raise ValueError(f"Unsupported model type: {model_type}")
return subclasses[model_type](model_name, **model_kwargs) return subclasses[model_type](model_name, **model_kwargs)
def clear_cache(self) -> None:
if not self.cache_dir.exists():
return
elif not rmtree.avoids_symlink_attacks:
raise RuntimeError(
"Attempted to clear cache, but rmtree is not safe on this platform."
)
rmtree(self.cache_dir)

View file

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any
from PIL.Image import Image from PIL.Image import Image
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
@ -10,13 +11,7 @@ from .base import InferenceModel
class CLIPSTEncoder(InferenceModel): class CLIPSTEncoder(InferenceModel):
_model_type = ModelType.CLIP _model_type = ModelType.CLIP
def __init__( def load(self, **model_kwargs: Any) -> None:
self,
model_name: str,
cache_dir: Path | None = None,
**model_kwargs,
):
super().__init__(model_name, cache_dir)
self.model = SentenceTransformer( self.model = SentenceTransformer(
self.model_name, self.model_name,
cache_folder=self.cache_dir.as_posix(), cache_folder=self.cache_dir.as_posix(),

View file

@ -18,21 +18,22 @@ class FaceRecognizer(InferenceModel):
min_score: float = settings.min_face_score, min_score: float = settings.min_face_score,
cache_dir: Path | None = None, cache_dir: Path | None = None,
**model_kwargs, **model_kwargs,
): ) -> None:
super().__init__(model_name, cache_dir)
self.min_score = min_score self.min_score = min_score
model = FaceAnalysis( super().__init__(model_name, cache_dir, **model_kwargs)
def load(self, **model_kwargs: Any) -> None:
self.model = FaceAnalysis(
name=self.model_name, name=self.model_name,
root=self.cache_dir.as_posix(), root=self.cache_dir.as_posix(),
allowed_modules=["detection", "recognition"], allowed_modules=["detection", "recognition"],
**model_kwargs, **model_kwargs,
) )
model.prepare( self.model.prepare(
ctx_id=0, ctx_id=0,
det_thresh=self.min_score, det_thresh=self.min_score,
det_size=(640, 640), det_size=(640, 640),
) )
self.model = model
def predict(self, image: cv2.Mat) -> list[dict[str, Any]]: def predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
height, width, _ = image.shape height, width, _ = image.shape

View file

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Any
from PIL.Image import Image from PIL.Image import Image
from transformers.pipelines import pipeline from transformers.pipelines import pipeline
@ -17,10 +18,11 @@ class ImageClassifier(InferenceModel):
min_score: float = settings.min_tag_score, min_score: float = settings.min_tag_score,
cache_dir: Path | None = None, cache_dir: Path | None = None,
**model_kwargs, **model_kwargs,
): ) -> None:
super().__init__(model_name, cache_dir)
self.min_score = min_score self.min_score = min_score
super().__init__(model_name, cache_dir, **model_kwargs)
def load(self, **model_kwargs: Any) -> None:
self.model = pipeline( self.model = pipeline(
self.model_type.value, self.model_type.value,
self.model_name, self.model_name,