1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2024-12-29 15:11:58 +00:00
immich/machine-learning/app/config.py
Mert e7397f35c9
chore(ml): update pydantic (#13230)
* update pydantic

* fix typing

* remove unused import

* remove unused schema
2024-10-13 18:00:21 -04:00

125 lines
3.5 KiB
Python

import concurrent.futures
import logging
import os
import sys
from pathlib import Path
from socket import socket
from gunicorn.arbiter import Arbiter
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
from uvicorn.workers import UvicornWorker
class PreloadModelData(BaseModel):
clip: str | None = None
facial_recognition: str | None = None
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="MACHINE_LEARNING_",
case_sensitive=False,
env_nested_delimiter="__",
protected_namespaces=("settings_",),
)
cache_folder: Path = Path("/cache")
model_ttl: int = 300
model_ttl_poll_s: int = 10
host: str = "0.0.0.0"
port: int = 3003
workers: int = 1
test_full: bool = False
request_threads: int = os.cpu_count() or 4
model_inter_op_threads: int = 0
model_intra_op_threads: int = 0
ann: bool = True
ann_fp16_turbo: bool = False
ann_tuning_level: int = 2
preload: PreloadModelData | None = None
@property
def device_id(self) -> str:
return os.environ.get("MACHINE_LEARNING_DEVICE_ID", "0")
class LogSettings(BaseSettings):
model_config = SettingsConfigDict(case_sensitive=False)
immich_log_level: str = "info"
no_color: bool = False
_clean_name = str.maketrans(":\\/", "___", ".")
def clean_name(model_name: str) -> str:
return model_name.split("/")[-1].translate(_clean_name)
LOG_LEVELS: dict[str, int] = {
"critical": logging.ERROR,
"error": logging.ERROR,
"warning": logging.WARNING,
"warn": logging.WARNING,
"info": logging.INFO,
"log": logging.INFO,
"debug": logging.DEBUG,
"verbose": logging.DEBUG,
}
settings = Settings()
log_settings = LogSettings()
LOG_LEVEL = LOG_LEVELS.get(log_settings.immich_log_level.lower(), logging.INFO)
class CustomRichHandler(RichHandler):
def __init__(self) -> None:
console = Console(color_system="standard", no_color=log_settings.no_color)
self.excluded = ["uvicorn", "starlette", "fastapi"]
super().__init__(
show_path=False,
omit_repeated_times=False,
console=console,
rich_tracebacks=True,
tracebacks_suppress=[*self.excluded, concurrent.futures],
tracebacks_show_locals=LOG_LEVEL == logging.DEBUG,
)
# hack to exclude certain modules from rich tracebacks
def emit(self, record: logging.LogRecord) -> None:
if record.exc_info is not None:
tb = record.exc_info[2]
while tb is not None:
if any(excluded in tb.tb_frame.f_code.co_filename for excluded in self.excluded):
tb.tb_frame.f_locals["_rich_traceback_omit"] = True
tb = tb.tb_next
return super().emit(record)
log = logging.getLogger("ml.log")
log.setLevel(LOG_LEVEL)
# patches this issue https://github.com/encode/uvicorn/discussions/1803
class CustomUvicornServer(Server):
async def shutdown(self, sockets: list[socket] | None = None) -> None:
for sock in sockets or []:
sock.close()
await super().shutdown()
class CustomUvicornWorker(UvicornWorker):
async def _serve(self) -> None:
self.config.app = self.wsgi
server = CustomUvicornServer(config=self.config)
self._install_sigquit_handler()
await server.serve(sockets=self.sockets)
if not server.started:
sys.exit(Arbiter.WORKER_BOOT_ERROR)