from __future__ import annotations

from pathlib import Path
from typing import Any, NamedTuple

import numpy as np
from numpy.typing import NDArray

from ann.ann import Ann
from app.schemas import SessionNode

from ..config import log, settings


class AnnSession:
    """
    Wrapper for ANN to be drop-in replacement for ONNX session.
    """

    def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
        self.model_path = model_path
        self.cache_dir = cache_dir
        self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())

        log.info("Loading ANN model %s ...", model_path)
        self.model = self.ann.load(
            model_path.as_posix(),
            cached_network_path=model_path.with_suffix(".anncache").as_posix(),
            fp16=settings.ann_fp16_turbo,
        )
        log.info("Loaded ANN model with ID %d", self.model)

    def __del__(self) -> None:
        self.ann.unload(self.model)
        log.info("Unloaded ANN model %d", self.model)
        self.ann.destroy()

    def get_inputs(self) -> list[SessionNode]:
        shapes = self.ann.input_shapes[self.model]
        return [AnnNode(None, s) for s in shapes]

    def get_outputs(self) -> list[SessionNode]:
        shapes = self.ann.output_shapes[self.model]
        return [AnnNode(None, s) for s in shapes]

    def run(
        self,
        output_names: list[str] | None,
        input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
        run_options: Any = None,
    ) -> list[NDArray[np.float32]]:
        inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
        return self.ann.execute(self.model, inputs)


class AnnNode(NamedTuple):
    name: str | None
    shape: tuple[int, ...]