1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-07 20:36:48 +01:00
immich/machine-learning/app/models/image_classification.py

43 lines
1.3 KiB
Python
Raw Normal View History

from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from PIL.Image import Image
from transformers.pipelines import pipeline
from ..config import settings
from ..schemas import ModelType
from .base import InferenceModel
class ImageClassifier(InferenceModel):
_model_type = ModelType.IMAGE_CLASSIFICATION
def __init__(
self,
model_name: str,
min_score: float = settings.min_tag_score,
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
self.min_score = min_score
super().__init__(model_name, cache_dir, **model_kwargs)
def _download(self, **model_kwargs: Any) -> None:
snapshot_download(
cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"]
)
def _load(self, **model_kwargs: Any) -> None:
self.model = pipeline(
self.model_type.value,
self.model_name,
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
)
def _predict(self, image: Image) -> list[str]:
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
return tags