mirror of
https://github.com/immich-app/immich.git
synced 2025-01-16 16:56:46 +01:00
chore(ml): improved logging (#3918)
* fixed `minScore` not being set correctly * apply to init * don't send `enabled` * fix eslint warning * added logger * added logging * refinements * enable access log for info level * formatting * merged strings --------- Co-authored-by: Alex <alex.tran1502@gmail.com>
This commit is contained in:
parent
df26e12db6
commit
54b2779b79
5 changed files with 92 additions and 11 deletions
|
@ -1,7 +1,11 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import starlette
|
||||
from pydantic import BaseSettings
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from .schemas import ModelType
|
||||
|
||||
|
@ -23,6 +27,14 @@ class Settings(BaseSettings):
|
|||
case_sensitive = False
|
||||
|
||||
|
||||
class LogSettings(BaseSettings):
|
||||
log_level: str = "info"
|
||||
no_color: bool = False
|
||||
|
||||
class Config:
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
_clean_name = str.maketrans(":\\/", "___", ".")
|
||||
|
||||
|
||||
|
@ -30,4 +42,26 @@ def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
|
|||
return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
|
||||
|
||||
|
||||
LOG_LEVELS: dict[str, int] = {
|
||||
"critical": logging.ERROR,
|
||||
"error": logging.ERROR,
|
||||
"warning": logging.WARNING,
|
||||
"warn": logging.WARNING,
|
||||
"info": logging.INFO,
|
||||
"log": logging.INFO,
|
||||
"debug": logging.DEBUG,
|
||||
"verbose": logging.DEBUG,
|
||||
}
|
||||
|
||||
settings = Settings()
|
||||
log_settings = LogSettings()
|
||||
|
||||
console = Console(color_system="standard", no_color=log_settings.no_color)
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
handlers=[
|
||||
RichHandler(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
|
||||
],
|
||||
)
|
||||
log = logging.getLogger("uvicorn")
|
||||
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
@ -11,7 +12,7 @@ from starlette.formparsers import MultiPartParser
|
|||
|
||||
from app.models.base import InferenceModel
|
||||
|
||||
from .config import settings
|
||||
from .config import log, settings
|
||||
from .models.cache import ModelCache
|
||||
from .schemas import (
|
||||
MessageResponse,
|
||||
|
@ -20,14 +21,20 @@ from .schemas import (
|
|||
)
|
||||
|
||||
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def init_state() -> None:
|
||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||
log.info(
|
||||
(
|
||||
"Created in-memory cache with unloading "
|
||||
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
|
||||
)
|
||||
)
|
||||
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
|
||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
|
@ -77,4 +84,6 @@ if __name__ == "__main__":
|
|||
port=settings.port,
|
||||
reload=is_dev,
|
||||
workers=settings.workers,
|
||||
log_config=None,
|
||||
access_log=log.isEnabledFor(logging.INFO),
|
||||
)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
@ -11,7 +10,7 @@ from zipfile import BadZipFile
|
|||
import onnxruntime as ort
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
|
||||
|
||||
from ..config import get_cache_dir, settings
|
||||
from ..config import get_cache_dir, log, settings
|
||||
from ..schemas import ModelType
|
||||
|
||||
|
||||
|
@ -37,22 +36,41 @@ class InferenceModel(ABC):
|
|||
self.provider_options = model_kwargs.pop(
|
||||
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
|
||||
)
|
||||
log.debug(
|
||||
(
|
||||
f"Setting '{self.model_name}' execution providers to {self.providers}"
|
||||
"in descending order of preference"
|
||||
),
|
||||
)
|
||||
log.debug(f"Setting execution provider options to {self.provider_options}")
|
||||
self.sess_options = PicklableSessionOptions()
|
||||
# avoid thread contention between models
|
||||
if inter_op_num_threads > 1:
|
||||
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
||||
|
||||
log.debug(f"Setting execution_mode to {self.sess_options.execution_mode.name}")
|
||||
log.debug(f"Setting inter_op_num_threads to {inter_op_num_threads}")
|
||||
log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
|
||||
self.sess_options.inter_op_num_threads = inter_op_num_threads
|
||||
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
||||
|
||||
try:
|
||||
loader(**model_kwargs)
|
||||
except (OSError, InvalidProtobuf, BadZipFile):
|
||||
log.warn(
|
||||
(
|
||||
f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
|
||||
"Clearing cache and retrying."
|
||||
)
|
||||
)
|
||||
self.clear_cache()
|
||||
loader(**model_kwargs)
|
||||
|
||||
def download(self, **model_kwargs: Any) -> None:
|
||||
if not self.cached:
|
||||
print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
|
||||
log.info(
|
||||
(f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
|
||||
)
|
||||
self._download(**model_kwargs)
|
||||
|
||||
def load(self, **model_kwargs: Any) -> None:
|
||||
|
@ -62,7 +80,7 @@ class InferenceModel(ABC):
|
|||
|
||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
||||
if not self._loaded:
|
||||
print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
|
||||
log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
|
||||
self.load()
|
||||
if model_kwargs:
|
||||
self.configure(**model_kwargs)
|
||||
|
@ -109,13 +127,23 @@ class InferenceModel(ABC):
|
|||
|
||||
def clear_cache(self) -> None:
|
||||
if not self.cache_dir.exists():
|
||||
log.warn(
|
||||
f"Attempted to clear cache for model '{self.model_name}' but cache directory does not exist.",
|
||||
)
|
||||
return
|
||||
if not rmtree.avoids_symlink_attacks:
|
||||
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
|
||||
|
||||
if self.cache_dir.is_dir():
|
||||
log.info(f"Cleared cache directory for model '{self.model_name}'.")
|
||||
rmtree(self.cache_dir)
|
||||
else:
|
||||
log.warn(
|
||||
(
|
||||
f"Encountered file instead of directory at cache path "
|
||||
f"for '{self.model_name}'. Removing file and replacing with a directory."
|
||||
),
|
||||
)
|
||||
self.cache_dir.unlink()
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from clip_server.model.tokenization import Tokenizer
|
|||
from PIL import Image
|
||||
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
||||
|
||||
from ..config import log
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
|
@ -105,9 +106,11 @@ class CLIPEncoder(InferenceModel):
|
|||
if model_name in _MODELS:
|
||||
return model_name
|
||||
elif model_name in _ST_TO_JINA_MODEL_NAME:
|
||||
print(
|
||||
(f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
|
||||
(f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
|
||||
log.warn(
|
||||
(
|
||||
f"Sentence-Transformer models like '{model_name}' are not supported."
|
||||
f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
|
||||
),
|
||||
)
|
||||
return _ST_TO_JINA_MODEL_NAME[model_name]
|
||||
else:
|
||||
|
|
|
@ -8,6 +8,7 @@ from optimum.pipelines import pipeline
|
|||
from PIL import Image
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
from ..config import log
|
||||
from ..schemas import ModelType
|
||||
from .base import InferenceModel
|
||||
|
||||
|
@ -35,19 +36,25 @@ class ImageClassifier(InferenceModel):
|
|||
)
|
||||
|
||||
def _load(self, **model_kwargs: Any) -> None:
|
||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir)
|
||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
|
||||
model_path = self.cache_dir / "model.onnx"
|
||||
model_kwargs |= {
|
||||
"cache_dir": self.cache_dir,
|
||||
"provider": self.providers[0],
|
||||
"provider_options": self.provider_options[0],
|
||||
"session_options": self.sess_options,
|
||||
}
|
||||
model_path = self.cache_dir / "model.onnx"
|
||||
|
||||
if model_path.exists():
|
||||
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
|
||||
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
|
||||
else:
|
||||
log.info(
|
||||
(
|
||||
f"ONNX model not found in cache directory for '{self.model_name}'."
|
||||
"Exporting optimized model for future use."
|
||||
),
|
||||
)
|
||||
self.sess_options.optimized_model_filepath = model_path.as_posix()
|
||||
self.model = pipeline(
|
||||
self.model_type.value,
|
||||
|
|
Loading…
Reference in a new issue