1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-06 03:46:47 +01:00
immich/machine-learning/app/models/image_classification.py

43 lines
1.1 KiB
Python
Raw Normal View History

from pathlib import Path
from typing import Any
from PIL.Image import Image
from transformers.pipelines import pipeline
from ..config import settings
from ..schemas import ModelType
from .base import InferenceModel
class ImageClassifier(InferenceModel):
_model_type = ModelType.IMAGE_CLASSIFICATION
def __init__(
self,
model_name: str,
min_score: float = settings.min_tag_score,
cache_dir: Path | None = None,
**model_kwargs,
) -> None:
self.min_score = min_score
super().__init__(model_name, cache_dir, **model_kwargs)
def load(self, **model_kwargs: Any) -> None:
self.model = pipeline(
self.model_type.value,
self.model_name,
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
)
def predict(self, image: Image) -> list[str]:
predictions = self.model(image)
tags = list(
{
tag
for pred in predictions
for tag in pred["label"].split(", ")
if pred["score"] >= self.min_score
}
)
return tags