mirror of
https://github.com/immich-app/immich.git
synced 2025-01-01 08:31:59 +00:00
chore(ml): improved logging (#3918)
* fixed `minScore` not being set correctly * apply to init * don't send `enabled` * fix eslint warning * added logger * added logging * refinements * enable access log for info level * formatting * merged strings --------- Co-authored-by: Alex <alex.tran1502@gmail.com>
This commit is contained in:
parent
df26e12db6
commit
54b2779b79
5 changed files with 92 additions and 11 deletions
|
@ -1,7 +1,11 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import starlette
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
from .schemas import ModelType
|
from .schemas import ModelType
|
||||||
|
|
||||||
|
@ -23,6 +27,14 @@ class Settings(BaseSettings):
|
||||||
case_sensitive = False
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
|
class LogSettings(BaseSettings):
|
||||||
|
log_level: str = "info"
|
||||||
|
no_color: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
_clean_name = str.maketrans(":\\/", "___", ".")
|
_clean_name = str.maketrans(":\\/", "___", ".")
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,4 +42,26 @@ def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
|
||||||
return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
|
return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
|
||||||
|
|
||||||
|
|
||||||
|
LOG_LEVELS: dict[str, int] = {
|
||||||
|
"critical": logging.ERROR,
|
||||||
|
"error": logging.ERROR,
|
||||||
|
"warning": logging.WARNING,
|
||||||
|
"warn": logging.WARNING,
|
||||||
|
"info": logging.INFO,
|
||||||
|
"log": logging.INFO,
|
||||||
|
"debug": logging.DEBUG,
|
||||||
|
"verbose": logging.DEBUG,
|
||||||
|
}
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
log_settings = LogSettings()
|
||||||
|
|
||||||
|
console = Console(color_system="standard", no_color=log_settings.no_color)
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(message)s",
|
||||||
|
handlers=[
|
||||||
|
RichHandler(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
log = logging.getLogger("uvicorn")
|
||||||
|
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -11,7 +12,7 @@ from starlette.formparsers import MultiPartParser
|
||||||
|
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
|
|
||||||
from .config import settings
|
from .config import log, settings
|
||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
|
@ -20,14 +21,20 @@ from .schemas import (
|
||||||
)
|
)
|
||||||
|
|
||||||
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
|
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
def init_state() -> None:
|
def init_state() -> None:
|
||||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||||
|
log.info(
|
||||||
|
(
|
||||||
|
"Created in-memory cache with unloading "
|
||||||
|
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
|
||||||
|
)
|
||||||
|
)
|
||||||
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
||||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
|
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
|
||||||
|
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
|
@ -77,4 +84,6 @@ if __name__ == "__main__":
|
||||||
port=settings.port,
|
port=settings.port,
|
||||||
reload=is_dev,
|
reload=is_dev,
|
||||||
workers=settings.workers,
|
workers=settings.workers,
|
||||||
|
log_config=None,
|
||||||
|
access_log=log.isEnabledFor(logging.INFO),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -11,7 +10,7 @@ from zipfile import BadZipFile
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
|
||||||
|
|
||||||
from ..config import get_cache_dir, settings
|
from ..config import get_cache_dir, log, settings
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,22 +36,41 @@ class InferenceModel(ABC):
|
||||||
self.provider_options = model_kwargs.pop(
|
self.provider_options = model_kwargs.pop(
|
||||||
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
|
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
|
||||||
)
|
)
|
||||||
|
log.debug(
|
||||||
|
(
|
||||||
|
f"Setting '{self.model_name}' execution providers to {self.providers}"
|
||||||
|
"in descending order of preference"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
log.debug(f"Setting execution provider options to {self.provider_options}")
|
||||||
self.sess_options = PicklableSessionOptions()
|
self.sess_options = PicklableSessionOptions()
|
||||||
# avoid thread contention between models
|
# avoid thread contention between models
|
||||||
if inter_op_num_threads > 1:
|
if inter_op_num_threads > 1:
|
||||||
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
||||||
|
|
||||||
|
log.debug(f"Setting execution_mode to {self.sess_options.execution_mode.name}")
|
||||||
|
log.debug(f"Setting inter_op_num_threads to {inter_op_num_threads}")
|
||||||
|
log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
|
||||||
self.sess_options.inter_op_num_threads = inter_op_num_threads
|
self.sess_options.inter_op_num_threads = inter_op_num_threads
|
||||||
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loader(**model_kwargs)
|
loader(**model_kwargs)
|
||||||
except (OSError, InvalidProtobuf, BadZipFile):
|
except (OSError, InvalidProtobuf, BadZipFile):
|
||||||
|
log.warn(
|
||||||
|
(
|
||||||
|
f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
|
||||||
|
"Clearing cache and retrying."
|
||||||
|
)
|
||||||
|
)
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
loader(**model_kwargs)
|
loader(**model_kwargs)
|
||||||
|
|
||||||
def download(self, **model_kwargs: Any) -> None:
|
def download(self, **model_kwargs: Any) -> None:
|
||||||
if not self.cached:
|
if not self.cached:
|
||||||
print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...")
|
log.info(
|
||||||
|
(f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
|
||||||
|
)
|
||||||
self._download(**model_kwargs)
|
self._download(**model_kwargs)
|
||||||
|
|
||||||
def load(self, **model_kwargs: Any) -> None:
|
def load(self, **model_kwargs: Any) -> None:
|
||||||
|
@ -62,7 +80,7 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
print(f"Loading {self.model_type.value.replace('_', ' ')} model...")
|
log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
|
||||||
self.load()
|
self.load()
|
||||||
if model_kwargs:
|
if model_kwargs:
|
||||||
self.configure(**model_kwargs)
|
self.configure(**model_kwargs)
|
||||||
|
@ -109,13 +127,23 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
if not self.cache_dir.exists():
|
if not self.cache_dir.exists():
|
||||||
|
log.warn(
|
||||||
|
f"Attempted to clear cache for model '{self.model_name}' but cache directory does not exist.",
|
||||||
|
)
|
||||||
return
|
return
|
||||||
if not rmtree.avoids_symlink_attacks:
|
if not rmtree.avoids_symlink_attacks:
|
||||||
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
|
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
|
||||||
|
|
||||||
if self.cache_dir.is_dir():
|
if self.cache_dir.is_dir():
|
||||||
|
log.info(f"Cleared cache directory for model '{self.model_name}'.")
|
||||||
rmtree(self.cache_dir)
|
rmtree(self.cache_dir)
|
||||||
else:
|
else:
|
||||||
|
log.warn(
|
||||||
|
(
|
||||||
|
f"Encountered file instead of directory at cache path "
|
||||||
|
f"for '{self.model_name}'. Removing file and replacing with a directory."
|
||||||
|
),
|
||||||
|
)
|
||||||
self.cache_dir.unlink()
|
self.cache_dir.unlink()
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from clip_server.model.tokenization import Tokenizer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
|
||||||
|
|
||||||
|
from ..config import log
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
@ -105,9 +106,11 @@ class CLIPEncoder(InferenceModel):
|
||||||
if model_name in _MODELS:
|
if model_name in _MODELS:
|
||||||
return model_name
|
return model_name
|
||||||
elif model_name in _ST_TO_JINA_MODEL_NAME:
|
elif model_name in _ST_TO_JINA_MODEL_NAME:
|
||||||
print(
|
log.warn(
|
||||||
(f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."),
|
(
|
||||||
(f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."),
|
f"Sentence-Transformer models like '{model_name}' are not supported."
|
||||||
|
f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return _ST_TO_JINA_MODEL_NAME[model_name]
|
return _ST_TO_JINA_MODEL_NAME[model_name]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from optimum.pipelines import pipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoImageProcessor
|
from transformers import AutoImageProcessor
|
||||||
|
|
||||||
|
from ..config import log
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
@ -35,19 +36,25 @@ class ImageClassifier(InferenceModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load(self, **model_kwargs: Any) -> None:
|
def _load(self, **model_kwargs: Any) -> None:
|
||||||
processor = AutoImageProcessor.from_pretrained(self.cache_dir)
|
processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
|
||||||
|
model_path = self.cache_dir / "model.onnx"
|
||||||
model_kwargs |= {
|
model_kwargs |= {
|
||||||
"cache_dir": self.cache_dir,
|
"cache_dir": self.cache_dir,
|
||||||
"provider": self.providers[0],
|
"provider": self.providers[0],
|
||||||
"provider_options": self.provider_options[0],
|
"provider_options": self.provider_options[0],
|
||||||
"session_options": self.sess_options,
|
"session_options": self.sess_options,
|
||||||
}
|
}
|
||||||
model_path = self.cache_dir / "model.onnx"
|
|
||||||
|
|
||||||
if model_path.exists():
|
if model_path.exists():
|
||||||
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
|
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
|
||||||
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
|
self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
|
||||||
else:
|
else:
|
||||||
|
log.info(
|
||||||
|
(
|
||||||
|
f"ONNX model not found in cache directory for '{self.model_name}'."
|
||||||
|
"Exporting optimized model for future use."
|
||||||
|
),
|
||||||
|
)
|
||||||
self.sess_options.optimized_model_filepath = model_path.as_posix()
|
self.sess_options.optimized_model_filepath = model_path.as_posix()
|
||||||
self.model = pipeline(
|
self.model = pipeline(
|
||||||
self.model_type.value,
|
self.model_type.value,
|
||||||
|
|
Loading…
Reference in a new issue