1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-04 02:46:47 +01:00
immich/machine-learning/app/main.py

77 lines
2.3 KiB
Python
Raw Normal View History

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel
from .config import log, settings
from .models.cache import ModelCache
from .schemas import (
MessageResponse,
ModelType,
TextResponse,
)
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
2023-05-17 19:07:17 +02:00
app = FastAPI()
2023-04-26 12:39:24 +02:00
def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
log.info(
(
"Created in-memory cache with unloading "
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
)
)
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
@app.on_event("startup")
async def startup_event() -> None:
init_state()
@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
2023-04-26 12:39:24 +02:00
return {"message": "Immich ML"}
@app.get("/ping", response_model=TextResponse)
def ping() -> str:
return "pong"
@app.post("/predict")
async def predict(
model_name: str = Form(alias="modelName"),
model_type: ModelType = Form(alias="modelType"),
options: str = Form(default="{}"),
text: str | None = Form(default=None),
image: UploadFile | None = None,
) -> Any:
if image is not None:
inputs: str | bytes = await image.read()
elif text is not None:
inputs = text
else:
raise HTTPException(400, "Either image or text must be provided")
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
outputs = await run(model, inputs)
return ORJSONResponse(outputs)
async def run(model: InferenceModel, inputs: Any) -> Any:
if app.state.thread_pool is not None:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
else:
return model.predict(inputs)