From d729c863c8566f39655e96f4030d0946ac4ba513 Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Thu, 14 Dec 2023 14:51:24 -0500 Subject: [PATCH] chore(ml): improve shutdown (#5689) --- machine-learning/app/config.py | 28 +++++++++-- machine-learning/app/conftest.py | 9 ++-- machine-learning/app/main.py | 85 ++++++++++++++++++-------------- machine-learning/start.sh | 6 ++- 4 files changed, 80 insertions(+), 48 deletions(-) diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py index fa4fefeb37..a0bc01d9a6 100644 --- a/machine-learning/app/config.py +++ b/machine-learning/app/config.py @@ -1,12 +1,16 @@ import logging import os +import sys from pathlib import Path +from socket import socket -import gunicorn import starlette +from gunicorn.arbiter import Arbiter from pydantic import BaseSettings from rich.console import Console from rich.logging import RichHandler +from uvicorn import Server +from uvicorn.workers import UvicornWorker from .schemas import ModelType @@ -69,10 +73,26 @@ log_settings = LogSettings() class CustomRichHandler(RichHandler): def __init__(self) -> None: console = Console(color_system="standard", no_color=log_settings.no_color) - super().__init__( - show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[gunicorn, starlette] - ) + super().__init__(show_path=False, omit_repeated_times=False, console=console, tracebacks_suppress=[starlette]) log = logging.getLogger("gunicorn.access") log.setLevel(LOG_LEVELS.get(log_settings.log_level.lower(), logging.INFO)) + + +# patches this issue https://github.com/encode/uvicorn/discussions/1803 +class CustomUvicornServer(Server): + async def shutdown(self, sockets: list[socket] | None = None) -> None: + for sock in sockets or []: + sock.close() + await super().shutdown() + + +class CustomUvicornWorker(UvicornWorker): + async def _serve(self) -> None: + self.config.app = self.wsgi + server = CustomUvicornServer(config=self.config) + self._install_sigquit_handler() + await server.serve(sockets=self.sockets) + if not server.started: + sys.exit(Arbiter.WORKER_BOOT_ERROR) diff --git a/machine-learning/app/conftest.py b/machine-learning/app/conftest.py index 5e2dc1e847..5ef628f56a 100644 --- a/machine-learning/app/conftest.py +++ b/machine-learning/app/conftest.py @@ -1,5 +1,4 @@ import json -from pathlib import Path from typing import Any, Iterator from unittest import mock @@ -8,7 +7,7 @@ import pytest from fastapi.testclient import TestClient from PIL import Image -from .main import app, init_state +from .main import app from .schemas import ndarray_f32 @@ -29,9 +28,9 @@ def mock_get_model() -> Iterator[mock.Mock]: @pytest.fixture(scope="session") -def deployed_app() -> TestClient: - init_state() - return TestClient(app) +def deployed_app() -> Iterator[TestClient]: + with TestClient(app) as client: + yield client @pytest.fixture(scope="session") diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index bf232071b9..7631fe080b 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -1,15 +1,16 @@ import asyncio import gc import os +import signal import sys import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, Iterator from zipfile import BadZipFile import orjson -from fastapi import FastAPI, Form, HTTPException, UploadFile +from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile from fastapi.responses import ORJSONResponse from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile from starlette.formparsers import MultiPartParser @@ -27,9 +28,16 @@ from .schemas import ( MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger app = FastAPI() +model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) +thread_pool: ThreadPoolExecutor | None = None +lock = threading.Lock() +active_requests = 0 +last_called: float | None = None -def init_state() -> None: - app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0) + +@app.on_event("startup") +def startup() -> None: + global thread_pool log.info( ( "Created in-memory cache with unloading " @@ -37,17 +45,30 @@ def init_state() -> None: ) ) # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code - app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None - app.state.lock = threading.Lock() - app.state.last_called = None + thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0: asyncio.ensure_future(idle_shutdown_task()) log.info(f"Initialized request thread pool with {settings.request_threads} threads.") -@app.on_event("startup") -async def startup_event() -> None: - init_state() +@app.on_event("shutdown") +def shutdown() -> None: + log.handlers.clear() + for model in model_cache.cache._cache.values(): + del model + if thread_pool is not None: + thread_pool.shutdown() + gc.collect() + + +def update_state() -> Iterator[None]: + global active_requests, last_called + active_requests += 1 + last_called = time.time() + try: + yield + finally: + active_requests -= 1 @app.get("/", response_model=MessageResponse) @@ -60,7 +81,7 @@ def ping() -> str: return "pong" -@app.post("/predict") +@app.post("/predict", dependencies=[Depends(update_state)]) async def predict( model_name: str = Form(alias="modelName"), model_type: ModelType = Form(alias="modelType"), @@ -79,17 +100,16 @@ async def predict( except orjson.JSONDecodeError: raise HTTPException(400, f"Invalid options JSON: {options}") - model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs)) + model = await load(await model_cache.get(model_name, model_type, **kwargs)) model.configure(**kwargs) outputs = await run(model, inputs) return ORJSONResponse(outputs) async def run(model: InferenceModel, inputs: Any) -> Any: - app.state.last_called = time.time() - if app.state.thread_pool is None: + if thread_pool is None: return model.predict(inputs) - return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) + return await asyncio.get_running_loop().run_in_executor(thread_pool, model.predict, inputs) async def load(model: InferenceModel) -> InferenceModel: @@ -97,15 +117,15 @@ async def load(model: InferenceModel) -> InferenceModel: return model def _load() -> None: - with app.state.lock: + with lock: model.load() loop = asyncio.get_running_loop() try: - if app.state.thread_pool is None: + if thread_pool is None: model.load() else: - await loop.run_in_executor(app.state.thread_pool, _load) + await loop.run_in_executor(thread_pool, _load) return model except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile): log.warn( @@ -115,32 +135,23 @@ async def load(model: InferenceModel) -> InferenceModel: ) ) model.clear_cache() - if app.state.thread_pool is None: + if thread_pool is None: model.load() else: - await loop.run_in_executor(app.state.thread_pool, _load) + await loop.run_in_executor(thread_pool, _load) return model async def idle_shutdown_task() -> None: while True: log.debug("Checking for inactivity...") - if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl: + if ( + last_called is not None + and not active_requests + and not lock.locked() + and time.time() - last_called > settings.model_ttl + ): log.info("Shutting down due to inactivity.") - loop = asyncio.get_running_loop() - for task in asyncio.all_tasks(loop): - if task is not asyncio.current_task(): - try: - task.cancel() - except asyncio.CancelledError: - pass - sys.stderr.close() - sys.stdout.close() - sys.stdout = sys.stderr = open(os.devnull, "w") - try: - await app.state.model_cache.cache.clear() - gc.collect() - loop.stop() - except asyncio.CancelledError: - pass + os.kill(os.getpid(), signal.SIGINT) + break await asyncio.sleep(settings.model_ttl_poll_s) diff --git a/machine-learning/start.sh b/machine-learning/start.sh index 0836213e6a..d522f11435 100755 --- a/machine-learning/start.sh +++ b/machine-learning/start.sh @@ -1,6 +1,7 @@ #!/usr/bin/env sh export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2" +export LD_BIND_NOW=1 : "${MACHINE_LEARNING_HOST:=0.0.0.0}" : "${MACHINE_LEARNING_PORT:=3003}" @@ -8,8 +9,9 @@ export LD_PRELOAD="/usr/lib/$(arch)-linux-gnu/libmimalloc.so.2" : "${MACHINE_LEARNING_WORKER_TIMEOUT:=120}" gunicorn app.main:app \ - -k uvicorn.workers.UvicornWorker \ + -k app.config.CustomUvicornWorker \ -w $MACHINE_LEARNING_WORKERS \ -b $MACHINE_LEARNING_HOST:$MACHINE_LEARNING_PORT \ -t $MACHINE_LEARNING_WORKER_TIMEOUT \ - --log-config-json log_conf.json + --log-config-json log_conf.json \ + --graceful-timeout 0