1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-01 08:31:59 +00:00

chore(ml): improve shutdown (#5689)

This commit is contained in:
Mert 2023-12-14 14:51:24 -05:00 committed by GitHub
parent 9768931275
commit d729c863c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 48 deletions

View file

@ -1,12 +1,16 @@
import logging import logging
import os import os
import sys
from pathlib import Path from pathlib import Path
from socket import socket
import gunicorn
import starlette import starlette
from gunicorn.arbiter import Arbiter
from pydantic import BaseSettings from pydantic import BaseSettings
from rich.console import Console from rich.console import Console
from rich.logging import RichHandler from rich.logging import RichHandler
from uvicorn import Server
from uvicorn.workers import UvicornWorker
from .schemas import ModelType from .schemas import ModelType
@ -69,10 +73,26 @@ log_settings = LogSettings()
class CustomRichHandler(RichHandler): class CustomRichHandler(RichHandler):
def __init__(self) -> None: def __init__(self) -> None:
console = Console(color_system="standard", no_color=log_settings.no_color) console = Console(color_system="standard", no_color=log_settings.no_color)
super().__init__( super().__init__(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette])
show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette]
)
log = logging.getLogger("gunicorn.access") log = logging.getLogger("gunicorn.access")
log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO)) log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO))
# patches this issue https://github.com/encode/uvicorn/discussions/1803
class CustomUvicornServer(Server):
async def shutdown(self, sockets: list[socket] | None = None) -> None:
for sock in sockets or []:
sock.close()
await super().shutdown()
class CustomUvicornWorker(UvicornWorker):
async def _serve(self) -> None:
self.config.app = self.wsgi
server = CustomUvicornServer(config=self.config)
self._install_sigquit_handler()
await server.serve(sockets=self.sockets)
if not server.started:
sys.exit(Arbiter.WORKER_BOOT_ERROR)

View file

@ -1,5 +1,4 @@
import json import json
from pathlib import Path
from typing import Any, Iterator from typing import Any, Iterator
from unittest import mock from unittest import mock
@ -8,7 +7,7 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from PIL import Image from PIL import Image
from .main import app, init_state from .main import app
from .schemas import ndarray_f32 from .schemas import ndarray_f32
@ -29,9 +28,9 @@ def mock_get_model() -> Iterator[mock.Mock]:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def deployed_app() -> TestClient: def deployed_app() -> Iterator[TestClient]:
init_state() with TestClient(app) as client:
return TestClient(app) yield client
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View file

@ -1,15 +1,16 @@
import asyncio import asyncio
import gc import gc
import os import os
import signal
import sys import sys
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any, Iterator
from zipfile import BadZipFile from zipfile import BadZipFile
import orjson import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from starlette.formparsers import MultiPartParser from starlette.formparsers import MultiPartParser
@ -27,9 +28,16 @@ from .schemas import (
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
app = FastAPI() app = FastAPI()
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
last_called: float | None = None
def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) @app.on_event("startup")
def startup() -> None:
global thread_pool
log.info( log.info(
( (
"Created in-memory cache with unloading " "Created in-memory cache with unloading "
@ -37,17 +45,30 @@ def init_state() -> None:
) )
) )
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
app.state.lock = threading.Lock()
app.state.last_called = None
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0: if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task()) asyncio.ensure_future(idle_shutdown_task())
log.info(f"Initialized request thread pool with {settings.request_threads} threads.") log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@app.on_event("startup") @app.on_event("shutdown")
async def startup_event() -> None: def shutdown() -> None:
init_state() log.handlers.clear()
for model in model_cache.cache._cache.values():
del model
if thread_pool is not None:
thread_pool.shutdown()
gc.collect()
def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
last_called = time.time()
try:
yield
finally:
active_requests -= 1
@app.get("/", response_model=MessageResponse) @app.get("/", response_model=MessageResponse)
@ -60,7 +81,7 @@ def ping() -> str:
return "pong" return "pong"
@app.post("/predict") @app.post("/predict", dependencies=[Depends(update_state)])
async def predict( async def predict(
model_name: str = Form(alias="modelName"), model_name: str = Form(alias="modelName"),
model_type: ModelType = Form(alias="modelType"), model_type: ModelType = Form(alias="modelType"),
@ -79,17 +100,16 @@ async def predict(
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}") raise HTTPException(400, f"Invalid options JSON: {options}")
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs)) model = await load(await model_cache.get(model_name, model_type, **kwargs))
model.configure(**kwargs) model.configure(**kwargs)
outputs = await run(model, inputs) outputs = await run(model, inputs)
return ORJSONResponse(outputs) return ORJSONResponse(outputs)
async def run(model: InferenceModel, inputs: Any) -> Any: async def run(model: InferenceModel, inputs: Any) -> Any:
app.state.last_called = time.time() if thread_pool is None:
if app.state.thread_pool is None:
return model.predict(inputs) return model.predict(inputs)
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) return await asyncio.get_running_loop().run_in_executor(thread_pool, model.predict, inputs)
async def load(model: InferenceModel) -> InferenceModel: async def load(model: InferenceModel) -> InferenceModel:
@ -97,15 +117,15 @@ async def load(model: InferenceModel) -> InferenceModel:
return model return model
def _load() -> None: def _load() -> None:
with app.state.lock: with lock:
model.load() model.load()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
if app.state.thread_pool is None: if thread_pool is None:
model.load() model.load()
else: else:
await loop.run_in_executor(app.state.thread_pool, _load) await loop.run_in_executor(thread_pool, _load)
return model return model
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile): except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warn( log.warn(
@ -115,32 +135,23 @@ async def load(model: InferenceModel) -> InferenceModel:
) )
) )
model.clear_cache() model.clear_cache()
if app.state.thread_pool is None: if thread_pool is None:
model.load() model.load()
else: else:
await loop.run_in_executor(app.state.thread_pool, _load) await loop.run_in_executor(thread_pool, _load)
return model return model
async def idle_shutdown_task() -> None: async def idle_shutdown_task() -> None:
while True: while True:
log.debug("Checking for inactivity...") log.debug("Checking for inactivity...")
if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl: if (
last_called is not None
and not active_requests
and not lock.locked()
and time.time() - last_called > settings.model_ttl
):
log.info("Shutting down due to inactivity.") log.info("Shutting down due to inactivity.")
loop = asyncio.get_running_loop() os.kill(os.getpid(), signal.SIGINT)
for task in asyncio.all_tasks(loop): break
if task is not asyncio.current_task():
try:
task.cancel()
except asyncio.CancelledError:
pass
sys.stderr.close()
sys.stdout.close()
sys.stdout = sys.stderr = open(os.devnull, "w")
try:
await app.state.model_cache.cache.clear()
gc.collect()
loop.stop()
except asyncio.CancelledError:
pass
await asyncio.sleep(settings.model_ttl_poll_s) await asyncio.sleep(settings.model_ttl_poll_s)

View file

@ -1,6 +1,7 @@
#!/usr/bin/env sh #!/usr/bin/env sh
export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2" export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2"
export LD_BIND_NOW=1
: "${MACHINE_LEARNING_HOST:=0.0.0.0}" : "${MACHINE_LEARNING_HOST:=0.0.0.0}"
: "${MACHINE_LEARNING_PORT:=3003}" : "${MACHINE_LEARNING_PORT:=3003}"
@ -8,8 +9,9 @@ export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2"
: "${MACHINE_LEARNING_WORKER_TIMEOUT:=120}" : "${MACHINE_LEARNING_WORKER_TIMEOUT:=120}"
gunicorn app.main:app \ gunicorn app.main:app \
-k uvicorn.workers.UvicornWorker \ -k app.config.CustomUvicornWorker \
-w $MACHINE_LEARNING_WORKERS \ -w $MACHINE_LEARNING_WORKERS \
-b $MACHINE_LEARNING_HOST:$MACHINE_LEARNING_PORT \ -b $MACHINE_LEARNING_HOST:$MACHINE_LEARNING_PORT \
-t $MACHINE_LEARNING_WORKER_TIMEOUT \ -t $MACHINE_LEARNING_WORKER_TIMEOUT \
--log-config-json log_conf.json --log-config-json log_conf.json \
--graceful-timeout 0