import json
from abc import abstractmethod
from functools import cached_property
from io import BytesIO
from pathlib import Path
from typing import Any, Literal

import numpy as np
import onnxruntime as ort
from PIL import Image
from transformers import AutoTokenizer

from app.config import clean_name, log
from app.models.transforms import crop, get_pil_resampling, normalize, resize, to_numpy
from app.schemas import ModelType, ndarray_f32, ndarray_i32, ndarray_i64

from .base import InferenceModel


class BaseCLIPEncoder(InferenceModel):
    _model_type = ModelType.CLIP

    def __init__(
        self,
        model_name: str,
        cache_dir: str | None = None,
        mode: Literal["text", "vision"] | None = None,
        **model_kwargs: Any,
    ) -> None:
        self.mode = mode
        super().__init__(model_name, cache_dir, **model_kwargs)

    def _load(self) -> None:
        if self.mode == "text" or self.mode is None:
            log.debug(f"Loading clip text model '{self.model_name}'")

            self.text_model = ort.InferenceSession(
                self.textual_path.as_posix(),
                sess_options=self.sess_options,
                providers=self.providers,
                provider_options=self.provider_options,
            )

        if self.mode == "vision" or self.mode is None:
            log.debug(f"Loading clip vision model '{self.model_name}'")

            self.vision_model = ort.InferenceSession(
                self.visual_path.as_posix(),
                sess_options=self.sess_options,
                providers=self.providers,
                provider_options=self.provider_options,
            )

    def _predict(self, image_or_text: Image.Image | str) -> list[float]:
        if isinstance(image_or_text, bytes):
            image_or_text = Image.open(BytesIO(image_or_text))

        match image_or_text:
            case Image.Image():
                if self.mode == "text":
                    raise TypeError("Cannot encode image as text-only model")

                outputs = self.vision_model.run(None, self.transform(image_or_text))
            case str():
                if self.mode == "vision":
                    raise TypeError("Cannot encode text as vision-only model")

                outputs = self.text_model.run(None, self.tokenize(image_or_text))
            case _:
                raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")

        return outputs[0][0].tolist()

    @abstractmethod
    def tokenize(self, text: str) -> dict[str, ndarray_i32]:
        pass

    @abstractmethod
    def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
        pass

    @property
    def textual_dir(self) -> Path:
        return self.cache_dir / "textual"

    @property
    def visual_dir(self) -> Path:
        return self.cache_dir / "visual"

    @property
    def model_cfg_path(self) -> Path:
        return self.cache_dir / "config.json"

    @property
    def textual_path(self) -> Path:
        return self.textual_dir / "model.onnx"

    @property
    def visual_path(self) -> Path:
        return self.visual_dir / "model.onnx"

    @property
    def preprocess_cfg_path(self) -> Path:
        return self.visual_dir / "preprocess_cfg.json"

    @property
    def cached(self) -> bool:
        return self.textual_path.is_file() and self.visual_path.is_file()


class OpenCLIPEncoder(BaseCLIPEncoder):
    def __init__(
        self,
        model_name: str,
        cache_dir: str | None = None,
        mode: Literal["text", "vision"] | None = None,
        **model_kwargs: Any,
    ) -> None:
        super().__init__(clean_name(model_name), cache_dir, mode, **model_kwargs)

    def _load(self) -> None:
        super()._load()

        self.tokenizer = AutoTokenizer.from_pretrained(self.textual_dir)
        self.sequence_length = self.model_cfg["text_cfg"]["context_length"]

        self.size = (
            self.preprocess_cfg["size"][0] if type(self.preprocess_cfg["size"]) == list else self.preprocess_cfg["size"]
        )
        self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
        self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
        self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)

    def tokenize(self, text: str) -> dict[str, ndarray_i32]:
        input_ids: ndarray_i64 = self.tokenizer(
            text,
            max_length=self.sequence_length,
            return_tensors="np",
            return_attention_mask=False,
            padding="max_length",
            truncation=True,
        ).input_ids
        return {"text": input_ids.astype(np.int32)}

    def transform(self, image: Image.Image) -> dict[str, ndarray_f32]:
        image = resize(image, self.size)
        image = crop(image, self.size)
        image_np = to_numpy(image)
        image_np = normalize(image_np, self.mean, self.std)
        return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}

    @cached_property
    def model_cfg(self) -> dict[str, Any]:
        return json.load(self.model_cfg_path.open())

    @cached_property
    def preprocess_cfg(self) -> dict[str, Any]:
        return json.load(self.preprocess_cfg_path.open())


class MCLIPEncoder(OpenCLIPEncoder):
    def tokenize(self, text: str) -> dict[str, ndarray_i32]:
        tokens: dict[str, ndarray_i64] = self.tokenizer(text, return_tensors="np")
        return {k: v.astype(np.int32) for k, v in tokens.items()}