from pathlib import Path from typing import Any from huggingface_hub import snapshot_download from optimum.onnxruntime import ORTModelForImageClassification from optimum.pipelines import pipeline from PIL.Image import Image from transformers import AutoImageProcessor from ..config import settings from ..schemas import ModelType from .base import InferenceModel class ImageClassifier(InferenceModel): _model_type = ModelType.IMAGE_CLASSIFICATION def __init__( self, model_name: str, min_score: float = settings.min_tag_score, cache_dir: Path | str | None = None, **model_kwargs: Any, ) -> None: self.min_score = min_score super().__init__(model_name, cache_dir, **model_kwargs) def _download(self, **model_kwargs: Any) -> None: snapshot_download( cache_dir=self.cache_dir, repo_id=self.model_name, allow_patterns=["*.bin", "*.json", "*.txt"], local_dir=self.cache_dir, local_dir_use_symlinks=True, ) def _load(self, **model_kwargs: Any) -> None: processor = AutoImageProcessor.from_pretrained(self.cache_dir) model_kwargs |= { "cache_dir": self.cache_dir, "provider": self.providers[0], "provider_options": self.provider_options[0], "session_options": self.sess_options, } model_path = self.cache_dir / "model.onnx" if model_path.exists(): model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs) self.model = pipeline(self.model_type.value, model, feature_extractor=processor) else: self.sess_options.optimized_model_filepath = model_path.as_posix() self.model = pipeline( self.model_type.value, self.model_name, model_kwargs=model_kwargs, feature_extractor=processor, ) def _predict(self, image: Image) -> list[str]: predictions: list[dict[str, Any]] = self.model(image) # type: ignore tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score] return tags