1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-06 03:46:47 +01:00
immich/machine-learning/app/models/image_classification.py
Mert 258b98c262
fix(ml): load models in separate threads (#4034)
* load models in thread

* set clip mode logs to debug level

* updated tests

* made fixtures slightly less ugly

* moved responses to json file

* formatting
2023-09-09 16:02:44 +07:00

75 lines
2.6 KiB
Python

from io import BytesIO
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from optimum.onnxruntime import ORTModelForImageClassification
from optimum.pipelines import pipeline
from PIL import Image
from transformers import AutoImageProcessor
from ..config import log
from ..schemas import ModelType
from .base import InferenceModel
class ImageClassifier(InferenceModel):
_model_type = ModelType.IMAGE_CLASSIFICATION
def __init__(
self,
model_name: str,
min_score: float = 0.9,
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self) -> None:
snapshot_download(
cache_dir=self.cache_dir,
repo_id=self.model_name,
allow_patterns=["*.bin", "*.json", "*.txt"],
local_dir=self.cache_dir,
local_dir_use_symlinks=True,
)
def _load(self) -> None:
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
model_path = self.cache_dir / "model.onnx"
model_kwargs = {
"cache_dir": self.cache_dir,
"provider": self.providers[0],
"provider_options": self.provider_options[0],
"session_options": self.sess_options,
}
if model_path.exists():
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
else:
log.info(
(
f"ONNX model not found in cache directory for '{self.model_name}'."
"Exporting optimized model for future use."
),
)
self.sess_options.optimized_model_filepath = model_path.as_posix()
self.model = pipeline(
self.model_type.value,
self.model_name,
model_kwargs=model_kwargs,
feature_extractor=processor,
)
def _predict(self, image: Image.Image | bytes) -> list[str]:
if isinstance(image, bytes):
image = Image.open(BytesIO(image))
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
return tags
def configure(self, **model_kwargs: Any) -> None:
self.min_score = model_kwargs.pop("minScore", self.min_score)