1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-19 18:26:46 +01:00
immich/machine-learning/app/sessions/ann.py

59 lines
1.8 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
from numpy.typing import NDArray
from ann.ann import Ann
2024-06-25 18:00:24 +02:00
from app.schemas import SessionNode
from ..config import log, settings
class AnnSession:
"""
Wrapper for ANN to be drop-in replacement for ONNX session.
"""
2024-06-25 18:00:24 +02:00
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
self.model_path = model_path
self.cache_dir = cache_dir
self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
2024-06-25 18:00:24 +02:00
log.info("Loading ANN model %s ...", model_path)
self.model = self.ann.load(
model_path.as_posix(),
2024-06-25 18:00:24 +02:00
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
fp16=settings.ann_fp16_turbo,
)
log.info("Loaded ANN model with ID %d", self.model)
def __del__(self) -> None:
self.ann.unload(self.model)
log.info("Unloaded ANN model %d", self.model)
self.ann.destroy()
2024-06-25 18:00:24 +02:00
def get_inputs(self) -> list[SessionNode]:
shapes = self.ann.input_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
2024-06-25 18:00:24 +02:00
def get_outputs(self) -> list[SessionNode]:
shapes = self.ann.output_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
return self.ann.execute(self.model, inputs)
class AnnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]