mirror of
https://github.com/immich-app/immich.git
synced 2025-01-07 20:36:48 +01:00
2b1b43a7e4
* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
import json
|
|
from abc import abstractmethod
|
|
from functools import cached_property
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
from PIL import Image
|
|
|
|
from app.config import log
|
|
from app.models.base import InferenceModel
|
|
from app.models.transforms import crop_pil, decode_pil, get_pil_resampling, normalize, resize_pil, to_numpy
|
|
from app.schemas import ModelSession, ModelTask, ModelType
|
|
|
|
|
|
class BaseCLIPVisualEncoder(InferenceModel):
|
|
depends = []
|
|
identity = (ModelType.VISUAL, ModelTask.SEARCH)
|
|
|
|
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> NDArray[np.float32]:
|
|
image = decode_pil(inputs)
|
|
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
|
|
return res
|
|
|
|
@abstractmethod
|
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
|
pass
|
|
|
|
@property
|
|
def model_cfg_path(self) -> Path:
|
|
return self.cache_dir / "config.json"
|
|
|
|
@property
|
|
def preprocess_cfg_path(self) -> Path:
|
|
return self.model_dir / "preprocess_cfg.json"
|
|
|
|
@cached_property
|
|
def model_cfg(self) -> dict[str, Any]:
|
|
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
|
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
|
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
|
return model_cfg
|
|
|
|
@cached_property
|
|
def preprocess_cfg(self) -> dict[str, Any]:
|
|
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
|
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
|
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
|
return preprocess_cfg
|
|
|
|
|
|
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
|
|
def _load(self) -> ModelSession:
|
|
size: list[int] | int = self.preprocess_cfg["size"]
|
|
self.size = size[0] if isinstance(size, list) else size
|
|
|
|
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
|
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
|
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
|
|
|
return super()._load()
|
|
|
|
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
|
image = resize_pil(image, self.size)
|
|
image = crop_pil(image, self.size)
|
|
image_np = to_numpy(image)
|
|
image_np = normalize(image_np, self.mean, self.std)
|
|
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|