import json
from abc import abstractmethod
from functools import cached_property
from pathlib import Path
from typing import Any

import numpy as np
from numpy.typing import NDArray
from PIL import Image

from app.config import log
from app.models.base import InferenceModel
from app.models.transforms import crop_pil, decode_pil, get_pil_resampling, normalize, resize_pil, to_numpy
from app.schemas import ModelSession, ModelTask, ModelType


class BaseCLIPVisualEncoder(InferenceModel):
    depends = []
    identity = (ModelType.VISUAL, ModelTask.SEARCH)

    def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> NDArray[np.float32]:
        image = decode_pil(inputs)
        res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
        return res

    @abstractmethod
    def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
        pass

    @property
    def model_cfg_path(self) -> Path:
        return self.cache_dir / "config.json"

    @property
    def preprocess_cfg_path(self) -> Path:
        return self.model_dir / "preprocess_cfg.json"

    @cached_property
    def model_cfg(self) -> dict[str, Any]:
        log.debug(f"Loading model config for CLIP model '{self.model_name}'")
        model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
        log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
        return model_cfg

    @cached_property
    def preprocess_cfg(self) -> dict[str, Any]:
        log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
        preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
        log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
        return preprocess_cfg


class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
    def _load(self) -> ModelSession:
        size: list[int] | int = self.preprocess_cfg["size"]
        self.size = size[0] if isinstance(size, list) else size

        self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
        self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
        self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)

        return super()._load()

    def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
        image = resize_pil(image, self.size)
        image = crop_pil(image, self.size)
        image_np = to_numpy(image)
        image_np = normalize(image_np, self.mean, self.std)
        return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}