2023-06-25 05:18:09 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-08-25 06:28:51 +02:00
|
|
|
import pickle
|
2023-06-27 19:21:50 +02:00
|
|
|
from abc import ABC, abstractmethod
|
2023-06-25 05:18:09 +02:00
|
|
|
from pathlib import Path
|
2023-06-27 23:01:24 +02:00
|
|
|
from shutil import rmtree
|
2023-06-25 05:18:09 +02:00
|
|
|
from typing import Any
|
|
|
|
|
2023-08-25 06:28:51 +02:00
|
|
|
import onnxruntime as ort
|
2023-11-12 02:04:49 +01:00
|
|
|
from huggingface_hub import snapshot_download
|
2023-11-13 17:18:46 +01:00
|
|
|
from typing_extensions import Buffer
|
2023-06-27 23:01:24 +02:00
|
|
|
|
2024-01-11 18:26:46 +01:00
|
|
|
import ann.ann
|
2024-01-22 00:22:39 +01:00
|
|
|
from app.models.constants import SUPPORTED_PROVIDERS
|
2024-01-11 18:26:46 +01:00
|
|
|
|
2023-11-12 02:04:49 +01:00
|
|
|
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
2023-06-25 05:18:09 +02:00
|
|
|
from ..schemas import ModelType
|
2024-01-11 18:26:46 +01:00
|
|
|
from .ann import AnnSession
|
2023-06-25 05:18:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
class InferenceModel(ABC):
|
|
|
|
_model_type: ModelType
|
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
def __init__(
|
2023-08-25 06:28:51 +02:00
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
cache_dir: Path | str | None = None,
|
2024-01-22 00:22:39 +01:00
|
|
|
providers: list[str] | None = None,
|
|
|
|
provider_options: list[dict[str, Any]] | None = None,
|
|
|
|
sess_options: ort.SessionOptions | None = None,
|
2023-08-25 06:28:51 +02:00
|
|
|
**model_kwargs: Any,
|
2023-08-06 04:45:13 +02:00
|
|
|
) -> None:
|
2023-09-09 11:02:44 +02:00
|
|
|
self.loaded = False
|
2024-01-22 00:22:39 +01:00
|
|
|
self.model_name = model_name
|
|
|
|
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
|
|
|
|
self.providers = providers if providers is not None else self.providers_default
|
|
|
|
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
|
|
|
|
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
|
2023-08-25 06:28:51 +02:00
|
|
|
|
2023-09-09 11:02:44 +02:00
|
|
|
def download(self) -> None:
|
2023-08-06 04:45:13 +02:00
|
|
|
if not self.cached:
|
2023-08-30 10:22:01 +02:00
|
|
|
log.info(
|
2023-12-21 02:47:56 +01:00
|
|
|
f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'. This may take a while."
|
2023-08-30 10:22:01 +02:00
|
|
|
)
|
2023-09-09 11:02:44 +02:00
|
|
|
self._download()
|
2023-06-27 23:01:24 +02:00
|
|
|
|
2023-09-09 11:02:44 +02:00
|
|
|
def load(self) -> None:
|
|
|
|
if self.loaded:
|
|
|
|
return
|
|
|
|
self.download()
|
2023-12-21 02:47:56 +01:00
|
|
|
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
2023-09-09 11:02:44 +02:00
|
|
|
self._load()
|
|
|
|
self.loaded = True
|
2023-08-06 04:45:13 +02:00
|
|
|
|
2023-08-29 15:58:00 +02:00
|
|
|
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
2023-09-09 11:02:44 +02:00
|
|
|
self.load()
|
2023-08-29 15:58:00 +02:00
|
|
|
if model_kwargs:
|
|
|
|
self.configure(**model_kwargs)
|
2023-08-06 04:45:13 +02:00
|
|
|
return self._predict(inputs)
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def _predict(self, inputs: Any) -> Any:
|
2023-06-27 23:01:24 +02:00
|
|
|
...
|
|
|
|
|
2023-08-29 15:58:00 +02:00
|
|
|
def configure(self, **model_kwargs: Any) -> None:
|
|
|
|
pass
|
|
|
|
|
2023-09-09 11:02:44 +02:00
|
|
|
def _download(self) -> None:
|
2023-11-12 02:04:49 +01:00
|
|
|
snapshot_download(
|
|
|
|
get_hf_model_name(self.model_name),
|
|
|
|
cache_dir=self.cache_dir,
|
|
|
|
local_dir=self.cache_dir,
|
|
|
|
local_dir_use_symlinks=False,
|
|
|
|
)
|
2023-08-06 04:45:13 +02:00
|
|
|
|
|
|
|
@abstractmethod
|
2023-09-09 11:02:44 +02:00
|
|
|
def _load(self) -> None:
|
2023-06-25 05:18:09 +02:00
|
|
|
...
|
|
|
|
|
2023-06-27 23:01:24 +02:00
|
|
|
def clear_cache(self) -> None:
|
|
|
|
if not self.cache_dir.exists():
|
2024-01-22 00:22:39 +01:00
|
|
|
log.warning(
|
2023-12-21 02:47:56 +01:00
|
|
|
f"Attempted to clear cache for model '{self.model_name}', but cache directory does not exist",
|
2023-08-30 10:22:01 +02:00
|
|
|
)
|
2023-06-27 23:01:24 +02:00
|
|
|
return
|
2023-08-06 04:45:13 +02:00
|
|
|
if not rmtree.avoids_symlink_attacks:
|
2023-12-21 02:47:56 +01:00
|
|
|
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform")
|
2023-06-27 23:01:24 +02:00
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
if self.cache_dir.is_dir():
|
2023-08-30 10:22:01 +02:00
|
|
|
log.info(f"Cleared cache directory for model '{self.model_name}'.")
|
2023-08-06 04:45:13 +02:00
|
|
|
rmtree(self.cache_dir)
|
|
|
|
else:
|
2024-01-22 00:22:39 +01:00
|
|
|
log.warning(
|
2023-08-30 10:22:01 +02:00
|
|
|
(
|
|
|
|
f"Encountered file instead of directory at cache path "
|
|
|
|
f"for '{self.model_name}'. Removing file and replacing with a directory."
|
|
|
|
),
|
|
|
|
)
|
2023-08-06 04:45:13 +02:00
|
|
|
self.cache_dir.unlink()
|
|
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
2023-08-25 06:28:51 +02:00
|
|
|
|
2024-01-11 18:26:46 +01:00
|
|
|
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
|
|
|
|
armnn_path = model_path.with_suffix(".armnn")
|
|
|
|
if settings.ann and ann.ann.is_available and armnn_path.is_file():
|
|
|
|
session = AnnSession(armnn_path)
|
|
|
|
elif model_path.is_file():
|
|
|
|
session = ort.InferenceSession(
|
|
|
|
model_path.as_posix(),
|
|
|
|
sess_options=self.sess_options,
|
|
|
|
providers=self.providers,
|
|
|
|
provider_options=self.provider_options,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"the file model_path='{model_path}' does not exist")
|
|
|
|
return session
|
|
|
|
|
2024-01-22 00:22:39 +01:00
|
|
|
@property
|
|
|
|
def model_type(self) -> ModelType:
|
|
|
|
return self._model_type
|
|
|
|
|
|
|
|
@property
|
|
|
|
def cache_dir(self) -> Path:
|
|
|
|
return self._cache_dir
|
|
|
|
|
|
|
|
@cache_dir.setter
|
|
|
|
def cache_dir(self, cache_dir: Path) -> None:
|
|
|
|
self._cache_dir = cache_dir
|
|
|
|
|
|
|
|
@property
|
|
|
|
def cache_dir_default(self) -> Path:
|
|
|
|
return get_cache_dir(self.model_name, self.model_type)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def cached(self) -> bool:
|
|
|
|
return self.cache_dir.exists() and any(self.cache_dir.iterdir())
|
|
|
|
|
|
|
|
@property
|
|
|
|
def providers(self) -> list[str]:
|
|
|
|
return self._providers
|
|
|
|
|
|
|
|
@providers.setter
|
|
|
|
def providers(self, providers: list[str]) -> None:
|
|
|
|
log.debug(
|
|
|
|
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
|
|
|
|
)
|
|
|
|
self._providers = providers
|
|
|
|
|
|
|
|
@property
|
|
|
|
def providers_default(self) -> list[str]:
|
|
|
|
available_providers = set(ort.get_available_providers())
|
|
|
|
log.debug(f"Available ORT providers: {available_providers}")
|
|
|
|
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def provider_options(self) -> list[dict[str, Any]]:
|
|
|
|
return self._provider_options
|
|
|
|
|
|
|
|
@provider_options.setter
|
|
|
|
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
|
|
|
log.debug(f"Setting execution provider options to {provider_options}")
|
|
|
|
self._provider_options = provider_options
|
|
|
|
|
|
|
|
@property
|
|
|
|
def provider_options_default(self) -> list[dict[str, Any]]:
|
|
|
|
options = []
|
|
|
|
for provider in self.providers:
|
|
|
|
match provider:
|
|
|
|
case "CPUExecutionProvider" | "CUDAExecutionProvider":
|
|
|
|
option = {"arena_extend_strategy": "kSameAsRequested"}
|
|
|
|
case "OpenVINOExecutionProvider":
|
|
|
|
try:
|
|
|
|
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
|
|
|
|
log.debug(f"Available OpenVINO devices: {device_ids}")
|
|
|
|
gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
|
|
|
|
option = {"device_id": gpu_devices[0]} if gpu_devices else {}
|
|
|
|
except AttributeError as e:
|
|
|
|
log.warning("Failed to get OpenVINO device IDs. Using default options.")
|
|
|
|
log.error(e)
|
|
|
|
option = {}
|
|
|
|
case _:
|
|
|
|
option = {}
|
|
|
|
options.append(option)
|
|
|
|
return options
|
|
|
|
|
|
|
|
@property
|
|
|
|
def sess_options(self) -> ort.SessionOptions:
|
|
|
|
return self._sess_options
|
|
|
|
|
|
|
|
@sess_options.setter
|
|
|
|
def sess_options(self, sess_options: ort.SessionOptions) -> None:
|
|
|
|
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
|
|
|
|
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
|
|
|
|
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
|
|
|
|
self._sess_options = sess_options
|
|
|
|
|
|
|
|
@property
|
|
|
|
def sess_options_default(self) -> ort.SessionOptions:
|
|
|
|
sess_options = PicklableSessionOptions()
|
|
|
|
sess_options.enable_cpu_mem_arena = False
|
|
|
|
|
|
|
|
# avoid thread contention between models
|
|
|
|
if settings.model_inter_op_threads > 0:
|
|
|
|
sess_options.inter_op_num_threads = settings.model_inter_op_threads
|
|
|
|
# these defaults work well for CPU, but bottleneck GPU
|
|
|
|
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
|
|
|
|
sess_options.inter_op_num_threads = 1
|
|
|
|
|
|
|
|
if settings.model_intra_op_threads > 0:
|
|
|
|
sess_options.intra_op_num_threads = settings.model_intra_op_threads
|
|
|
|
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
|
|
|
|
sess_options.intra_op_num_threads = 2
|
|
|
|
|
|
|
|
if sess_options.inter_op_num_threads > 1:
|
|
|
|
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
|
|
|
|
|
|
|
return sess_options
|
|
|
|
|
2023-08-25 06:28:51 +02:00
|
|
|
|
|
|
|
# HF deep copies configs, so we need to make session options picklable
|
2023-11-13 17:18:46 +01:00
|
|
|
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
2023-08-25 06:28:51 +02:00
|
|
|
def __getstate__(self) -> bytes:
|
|
|
|
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
|
|
|
|
2023-11-13 17:18:46 +01:00
|
|
|
def __setstate__(self, state: Buffer) -> None:
|
|
|
|
self.__init__() # type: ignore[misc]
|
|
|
|
attrs: list[tuple[str, Any]] = pickle.loads(state)
|
|
|
|
for attr, val in attrs:
|
2023-08-25 06:28:51 +02:00
|
|
|
setattr(self, attr, val)
|