1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-17 01:06:46 +01:00

chore(ml): improved logging (#3918)

* fixed `minScore` not being set correctly

* apply to init

* don't send `enabled`

* fix eslint warning

* added logger

* added logging

* refinements

* enable access log for info level

* formatting

* merged strings

---------

Co-authored-by: Alex <alex.tran1502@gmail.com>
This commit is contained in:
Mert 2023-08-30 04:22:01 -04:00 committed by GitHub
parent df26e12db6
commit 54b2779b79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 92 additions and 11 deletions

View file

@ -1,7 +1,11 @@
import logging
import os import os
from pathlib import Path from pathlib import Path
import starlette
from pydantic import BaseSettings from pydantic import BaseSettings
from rich.console import Console
from rich.logging import RichHandler
from .schemas import ModelType from .schemas import ModelType
@ -23,6 +27,14 @@ class Settings(BaseSettings):
case_sensitive = False case_sensitive = False
class LogSettings(BaseSettings):
log_level: str = "info"
no_color: bool = False
class Config:
case_sensitive = False
_clean_name = str.maketrans(":\\/", "___", ".") _clean_name = str.maketrans(":\\/", "___", ".")
@ -30,4 +42,26 @@ def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name) return Path(settings.cache_folder) / model_type.value / model_name.translate(_clean_name)
LOG_LEVELS: dict[str, int] = {
"critical": logging.ERROR,
"error": logging.ERROR,
"warning": logging.WARNING,
"warn": logging.WARNING,
"info": logging.INFO,
"log": logging.INFO,
"debug": logging.DEBUG,
"verbose": logging.DEBUG,
}
settings = Settings() settings = Settings()
log_settings = LogSettings()
console = Console(color_system="standard", no_color=log_settings.no_color)
logging.basicConfig(
format="%(message)s",
handlers=[
RichHandler(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
],
)
log = logging.getLogger("uvicorn")
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
@ -11,7 +12,7 @@ from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel from app.models.base import InferenceModel
from .config import settings from .config import log, settings
from .models.cache import ModelCache from .models.cache import ModelCache
from .schemas import ( from .schemas import (
MessageResponse, MessageResponse,
@ -20,14 +21,20 @@ from .schemas import (
) )
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
app = FastAPI() app = FastAPI()
def init_state() -> None: def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
log.info(
(
"Created in-memory cache with unloading "
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
)
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@app.on_event("startup") @app.on_event("startup")
@ -77,4 +84,6 @@ if __name__ == "__main__":
port=settings.port, port=settings.port,
reload=is_dev, reload=is_dev,
workers=settings.workers, workers=settings.workers,
log_config=None,
access_log=log.isEnabledFor(logging.INFO),
) )

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import os
import pickle import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
@ -11,7 +10,7 @@ from zipfile import BadZipFile
import onnxruntime as ort import onnxruntime as ort
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf # type: ignore
from ..config import get_cache_dir, settings from ..config import get_cache_dir, log, settings
from ..schemas import ModelType from ..schemas import ModelType
@ -37,22 +36,41 @@ class InferenceModel(ABC):
self.provider_options = model_kwargs.pop( self.provider_options = model_kwargs.pop(
"provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers) "provider_options", [{"arena_extend_strategy": "kSameAsRequested"}] * len(self.providers)
) )
log.debug(
(
f"Setting '{self.model_name}' execution providers to {self.providers}"
"in descending order of preference"
),
)
log.debug(f"Setting execution provider options to {self.provider_options}")
self.sess_options = PicklableSessionOptions() self.sess_options = PicklableSessionOptions()
# avoid thread contention between models # avoid thread contention between models
if inter_op_num_threads > 1: if inter_op_num_threads > 1:
self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL self.sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
log.debug(f"Setting execution_mode to {self.sess_options.execution_mode.name}")
log.debug(f"Setting inter_op_num_threads to {inter_op_num_threads}")
log.debug(f"Setting intra_op_num_threads to {intra_op_num_threads}")
self.sess_options.inter_op_num_threads = inter_op_num_threads self.sess_options.inter_op_num_threads = inter_op_num_threads
self.sess_options.intra_op_num_threads = intra_op_num_threads self.sess_options.intra_op_num_threads = intra_op_num_threads
try: try:
loader(**model_kwargs) loader(**model_kwargs)
except (OSError, InvalidProtobuf, BadZipFile): except (OSError, InvalidProtobuf, BadZipFile):
log.warn(
(
f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
"Clearing cache and retrying."
)
)
self.clear_cache() self.clear_cache()
loader(**model_kwargs) loader(**model_kwargs)
def download(self, **model_kwargs: Any) -> None: def download(self, **model_kwargs: Any) -> None:
if not self.cached: if not self.cached:
print(f"Downloading {self.model_type.value.replace('_', ' ')} model. This may take a while...") log.info(
(f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
)
self._download(**model_kwargs) self._download(**model_kwargs)
def load(self, **model_kwargs: Any) -> None: def load(self, **model_kwargs: Any) -> None:
@ -62,7 +80,7 @@ class InferenceModel(ABC):
def predict(self, inputs: Any, **model_kwargs: Any) -> Any: def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
if not self._loaded: if not self._loaded:
print(f"Loading {self.model_type.value.replace('_', ' ')} model...") log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
self.load() self.load()
if model_kwargs: if model_kwargs:
self.configure(**model_kwargs) self.configure(**model_kwargs)
@ -109,13 +127,23 @@ class InferenceModel(ABC):
def clear_cache(self) -> None: def clear_cache(self) -> None:
if not self.cache_dir.exists(): if not self.cache_dir.exists():
log.warn(
f"Attempted to clear cache for model '{self.model_name}' but cache directory does not exist.",
)
return return
if not rmtree.avoids_symlink_attacks: if not rmtree.avoids_symlink_attacks:
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.") raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform.")
if self.cache_dir.is_dir(): if self.cache_dir.is_dir():
log.info(f"Cleared cache directory for model '{self.model_name}'.")
rmtree(self.cache_dir) rmtree(self.cache_dir)
else: else:
log.warn(
(
f"Encountered file instead of directory at cache path "
f"for '{self.model_name}'. Removing file and replacing with a directory."
),
)
self.cache_dir.unlink() self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True) self.cache_dir.mkdir(parents=True, exist_ok=True)

View file

@ -12,6 +12,7 @@ from clip_server.model.tokenization import Tokenizer
from PIL import Image from PIL import Image
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from ..config import log
from ..schemas import ModelType from ..schemas import ModelType
from .base import InferenceModel from .base import InferenceModel
@ -105,9 +106,11 @@ class CLIPEncoder(InferenceModel):
if model_name in _MODELS: if model_name in _MODELS:
return model_name return model_name
elif model_name in _ST_TO_JINA_MODEL_NAME: elif model_name in _ST_TO_JINA_MODEL_NAME:
print( log.warn(
(f"Warning: Sentence-Transformer model names such as '{model_name}' are no longer supported."), (
(f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."), f"Sentence-Transformer models like '{model_name}' are not supported."
f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
),
) )
return _ST_TO_JINA_MODEL_NAME[model_name] return _ST_TO_JINA_MODEL_NAME[model_name]
else: else:

View file

@ -8,6 +8,7 @@ from optimum.pipelines import pipeline
from PIL import Image from PIL import Image
from transformers import AutoImageProcessor from transformers import AutoImageProcessor
from ..config import log
from ..schemas import ModelType from ..schemas import ModelType
from .base import InferenceModel from .base import InferenceModel
@ -35,19 +36,25 @@ class ImageClassifier(InferenceModel):
) )
def _load(self, **model_kwargs: Any) -> None: def _load(self, **model_kwargs: Any) -> None:
processor = AutoImageProcessor.from_pretrained(self.cache_dir) processor = AutoImageProcessor.from_pretrained(self.cache_dir, cache_dir=self.cache_dir)
model_path = self.cache_dir / "model.onnx"
model_kwargs |= { model_kwargs |= {
"cache_dir": self.cache_dir, "cache_dir": self.cache_dir,
"provider": self.providers[0], "provider": self.providers[0],
"provider_options": self.provider_options[0], "provider_options": self.provider_options[0],
"session_options": self.sess_options, "session_options": self.sess_options,
} }
model_path = self.cache_dir / "model.onnx"
if model_path.exists(): if model_path.exists():
model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs) model = ORTModelForImageClassification.from_pretrained(self.cache_dir, **model_kwargs)
self.model = pipeline(self.model_type.value, model, feature_extractor=processor) self.model = pipeline(self.model_type.value, model, feature_extractor=processor)
else: else:
log.info(
(
f"ONNX model not found in cache directory for '{self.model_name}'."
"Exporting optimized model for future use."
),
)
self.sess_options.optimized_model_filepath = model_path.as_posix() self.sess_options.optimized_model_filepath = model_path.as_posix()
self.model = pipeline( self.model = pipeline(
self.model_type.value, self.model_type.value,