mirror of
https://github.com/immich-app/immich.git
synced 2025-01-01 08:31:59 +00:00
feat(ml) backend takes image over HTTP (#2783)
* using pydantic BaseSetting * ML API takes image file as input * keeping image in memory * reducing duplicate code * using bytes instead of UploadFile & other small code improvements * removed form-multipart, using HTTP body * format code --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
parent
3e804f16df
commit
34201be74c
8 changed files with 116 additions and 80 deletions
|
@ -36,7 +36,6 @@ services:
|
||||||
- 3003:3003
|
- 3003:3003
|
||||||
volumes:
|
volumes:
|
||||||
- ../machine-learning/app:/usr/src/app
|
- ../machine-learning/app:/usr/src/app
|
||||||
- ${UPLOAD_LOCATION}:/usr/src/app/upload
|
|
||||||
- model-cache:/cache
|
- model-cache:/cache
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
|
|
|
@ -33,7 +33,6 @@ services:
|
||||||
container_name: immich_machine_learning
|
container_name: immich_machine_learning
|
||||||
image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release}
|
image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release}
|
||||||
volumes:
|
volumes:
|
||||||
- ${UPLOAD_LOCATION}:/usr/src/app/upload
|
|
||||||
- model-cache:/cache
|
- model-cache:/cache
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
|
|
22
machine-learning/app/config.py
Normal file
22
machine-learning/app/config.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
cache_folder: str = "/cache"
|
||||||
|
classification_model: str = "microsoft/resnet-50"
|
||||||
|
clip_image_model: str = "clip-ViT-B-32"
|
||||||
|
clip_text_model: str = "clip-ViT-B-32"
|
||||||
|
facial_recognition_model: str = "buffalo_l"
|
||||||
|
min_tag_score: float = 0.9
|
||||||
|
eager_startup: bool = True
|
||||||
|
model_ttl: int = 300
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 3003
|
||||||
|
workers: int = 1
|
||||||
|
min_face_score: float = 0.7
|
||||||
|
|
||||||
|
class Config(BaseSettings.Config):
|
||||||
|
env_prefix = 'MACHINE_LEARNING_'
|
||||||
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import io
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cache import ModelCache
|
from cache import ModelCache
|
||||||
|
@ -9,52 +10,44 @@ from schemas import (
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
TextModelRequest,
|
TextModelRequest,
|
||||||
TextResponse,
|
TextResponse,
|
||||||
VisionModelRequest,
|
|
||||||
)
|
)
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Depends, Body
|
||||||
from models import get_model, run_classification, run_facial_recognition
|
from models import get_model, run_classification, run_facial_recognition
|
||||||
|
from config import settings
|
||||||
classification_model = os.getenv(
|
|
||||||
"MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50"
|
|
||||||
)
|
|
||||||
clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32")
|
|
||||||
clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32")
|
|
||||||
facial_recognition_model = os.getenv(
|
|
||||||
"MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL", "buffalo_l"
|
|
||||||
)
|
|
||||||
|
|
||||||
min_tag_score = float(os.getenv("MACHINE_LEARNING_MIN_TAG_SCORE", 0.9))
|
|
||||||
eager_startup = (
|
|
||||||
os.getenv("MACHINE_LEARNING_EAGER_STARTUP", "true") == "true"
|
|
||||||
) # loads all models at startup
|
|
||||||
model_ttl = int(os.getenv("MACHINE_LEARNING_MODEL_TTL", 300))
|
|
||||||
|
|
||||||
_model_cache = None
|
_model_cache = None
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event() -> None:
|
async def startup_event() -> None:
|
||||||
global _model_cache
|
global _model_cache
|
||||||
_model_cache = ModelCache(ttl=model_ttl, revalidate=True)
|
_model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
||||||
models = [
|
models = [
|
||||||
(classification_model, "image-classification"),
|
(settings.classification_model, "image-classification"),
|
||||||
(clip_image_model, "clip"),
|
(settings.clip_image_model, "clip"),
|
||||||
(clip_text_model, "clip"),
|
(settings.clip_text_model, "clip"),
|
||||||
(facial_recognition_model, "facial-recognition"),
|
(settings.facial_recognition_model, "facial-recognition"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Get all models
|
# Get all models
|
||||||
for model_name, model_type in models:
|
for model_name, model_type in models:
|
||||||
if eager_startup:
|
if settings.eager_startup:
|
||||||
await _model_cache.get_cached_model(model_name, model_type)
|
await _model_cache.get_cached_model(model_name, model_type)
|
||||||
else:
|
else:
|
||||||
get_model(model_name, model_type)
|
get_model(model_name, model_type)
|
||||||
|
|
||||||
|
|
||||||
|
def dep_model_cache():
|
||||||
|
if _model_cache is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Unable to load model.")
|
||||||
|
|
||||||
|
def dep_input_image(image: bytes = Body(...)) -> Image:
|
||||||
|
return Image.open(io.BytesIO(image))
|
||||||
|
|
||||||
@app.get("/", response_model=MessageResponse)
|
@app.get("/", response_model=MessageResponse)
|
||||||
async def root() -> dict[str, str]:
|
async def root() -> dict[str, str]:
|
||||||
return {"message": "Immich ML"}
|
return {"message": "Immich ML"}
|
||||||
|
@ -65,29 +58,36 @@ def ping() -> str:
|
||||||
return "pong"
|
return "pong"
|
||||||
|
|
||||||
|
|
||||||
@app.post("/image-classifier/tag-image", response_model=TagResponse, status_code=200)
|
@app.post(
|
||||||
async def image_classification(payload: VisionModelRequest) -> list[str]:
|
"/image-classifier/tag-image",
|
||||||
if _model_cache is None:
|
response_model=TagResponse,
|
||||||
raise HTTPException(status_code=500, detail="Unable to load model.")
|
status_code=200,
|
||||||
|
dependencies=[Depends(dep_model_cache)],
|
||||||
model = await _model_cache.get_cached_model(
|
)
|
||||||
classification_model, "image-classification"
|
async def image_classification(
|
||||||
)
|
image: Image = Depends(dep_input_image)
|
||||||
labels = run_classification(model, payload.image_path, min_tag_score)
|
) -> list[str]:
|
||||||
return labels
|
try:
|
||||||
|
model = await _model_cache.get_cached_model(
|
||||||
|
settings.classification_model, "image-classification"
|
||||||
|
)
|
||||||
|
labels = run_classification(model, image, settings.min_tag_score)
|
||||||
|
except Exception as ex:
|
||||||
|
raise HTTPException(status_code=500, detail=str(ex))
|
||||||
|
else:
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/sentence-transformer/encode-image",
|
"/sentence-transformer/encode-image",
|
||||||
response_model=EmbeddingResponse,
|
response_model=EmbeddingResponse,
|
||||||
status_code=200,
|
status_code=200,
|
||||||
|
dependencies=[Depends(dep_model_cache)],
|
||||||
)
|
)
|
||||||
async def clip_encode_image(payload: VisionModelRequest) -> list[float]:
|
async def clip_encode_image(
|
||||||
if _model_cache is None:
|
image: Image = Depends(dep_input_image)
|
||||||
raise HTTPException(status_code=500, detail="Unable to load model.")
|
) -> list[float]:
|
||||||
|
model = await _model_cache.get_cached_model(settings.clip_image_model, "clip")
|
||||||
model = await _model_cache.get_cached_model(clip_image_model, "clip")
|
|
||||||
image = Image.open(payload.image_path)
|
|
||||||
embedding = model.encode(image).tolist()
|
embedding = model.encode(image).tolist()
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
@ -96,33 +96,38 @@ async def clip_encode_image(payload: VisionModelRequest) -> list[float]:
|
||||||
"/sentence-transformer/encode-text",
|
"/sentence-transformer/encode-text",
|
||||||
response_model=EmbeddingResponse,
|
response_model=EmbeddingResponse,
|
||||||
status_code=200,
|
status_code=200,
|
||||||
|
dependencies=[Depends(dep_model_cache)],
|
||||||
)
|
)
|
||||||
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
async def clip_encode_text(
|
||||||
if _model_cache is None:
|
payload: TextModelRequest
|
||||||
raise HTTPException(status_code=500, detail="Unable to load model.")
|
) -> list[float]:
|
||||||
|
model = await _model_cache.get_cached_model(settings.clip_text_model, "clip")
|
||||||
model = await _model_cache.get_cached_model(clip_text_model, "clip")
|
|
||||||
embedding = model.encode(payload.text).tolist()
|
embedding = model.encode(payload.text).tolist()
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/facial-recognition/detect-faces", response_model=FaceResponse, status_code=200
|
"/facial-recognition/detect-faces",
|
||||||
|
response_model=FaceResponse,
|
||||||
|
status_code=200,
|
||||||
|
dependencies=[Depends(dep_model_cache)],
|
||||||
)
|
)
|
||||||
async def facial_recognition(payload: VisionModelRequest) -> list[dict[str, Any]]:
|
async def facial_recognition(
|
||||||
if _model_cache is None:
|
image: bytes = Body(...),
|
||||||
raise HTTPException(status_code=500, detail="Unable to load model.")
|
) -> list[dict[str, Any]]:
|
||||||
|
|
||||||
model = await _model_cache.get_cached_model(
|
model = await _model_cache.get_cached_model(
|
||||||
facial_recognition_model, "facial-recognition"
|
settings.facial_recognition_model, "facial-recognition"
|
||||||
)
|
)
|
||||||
faces = run_facial_recognition(model, payload.image_path)
|
faces = run_facial_recognition(model, image)
|
||||||
return faces
|
return faces
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
host = os.getenv("MACHINE_LEARNING_HOST", "0.0.0.0")
|
|
||||||
port = int(os.getenv("MACHINE_LEARNING_PORT", 3003))
|
|
||||||
is_dev = os.getenv("NODE_ENV") == "development"
|
is_dev = os.getenv("NODE_ENV") == "development"
|
||||||
|
uvicorn.run(
|
||||||
uvicorn.run("main:app", host=host, port=port, reload=is_dev, workers=1)
|
"main:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=is_dev,
|
||||||
|
workers=settings.workers,
|
||||||
|
)
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
from insightface.app import FaceAnalysis
|
from insightface.app import FaceAnalysis
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
|
||||||
|
|
||||||
from transformers import pipeline, Pipeline
|
from transformers import pipeline, Pipeline
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from typing import Any
|
from typing import Any, BinaryIO
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from config import settings
|
||||||
|
|
||||||
cache_folder = os.getenv("MACHINE_LEARNING_CACHE_FOLDER", "/cache")
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,9 +50,9 @@ def get_model(model_name: str, model_type: str, **model_kwargs):
|
||||||
|
|
||||||
|
|
||||||
def run_classification(
|
def run_classification(
|
||||||
model: Pipeline, image_path: str, min_score: float | None = None
|
model: Pipeline, image: Image, min_score: float | None = None
|
||||||
):
|
):
|
||||||
predictions: list[dict[str, Any]] = model(image_path) # type: ignore
|
predictions: list[dict[str, Any]] = model(image) # type: ignore
|
||||||
result = {
|
result = {
|
||||||
tag
|
tag
|
||||||
for pred in predictions
|
for pred in predictions
|
||||||
|
@ -63,9 +64,10 @@ def run_classification(
|
||||||
|
|
||||||
|
|
||||||
def run_facial_recognition(
|
def run_facial_recognition(
|
||||||
model: FaceAnalysis, image_path: str
|
model: FaceAnalysis, image: bytes
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
img = cv.imread(image_path)
|
file_bytes = np.frombuffer(image, dtype=np.uint8)
|
||||||
|
img = cv.imdecode(file_bytes, cv.IMREAD_COLOR)
|
||||||
height, width, _ = img.shape
|
height, width, _ = img.shape
|
||||||
results = []
|
results = []
|
||||||
faces = model.get(img)
|
faces = model.get(img)
|
||||||
|
@ -101,7 +103,7 @@ def _load_facial_recognition(
|
||||||
if isinstance(cache_dir, Path):
|
if isinstance(cache_dir, Path):
|
||||||
cache_dir = cache_dir.as_posix()
|
cache_dir = cache_dir.as_posix()
|
||||||
if min_face_score is None:
|
if min_face_score is None:
|
||||||
min_face_score = float(os.getenv("MACHINE_LEARNING_MIN_FACE_SCORE", 0.7))
|
min_face_score = settings.min_face_score
|
||||||
|
|
||||||
model = FaceAnalysis(
|
model = FaceAnalysis(
|
||||||
name=model_name,
|
name=model_name,
|
||||||
|
@ -114,4 +116,4 @@ def _load_facial_recognition(
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_dir(model_name: str, model_type: str) -> Path:
|
def _get_cache_dir(model_name: str, model_type: str) -> Path:
|
||||||
return Path(cache_folder, device, model_type, model_name)
|
return Path(settings.cache_folder, device, model_type, model_name)
|
||||||
|
|
|
@ -9,14 +9,6 @@ def to_lower_camel(string: str) -> str:
|
||||||
return "".join(tokens)
|
return "".join(tokens)
|
||||||
|
|
||||||
|
|
||||||
class VisionModelRequest(BaseModel):
|
|
||||||
image_path: str
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
alias_generator = to_lower_camel
|
|
||||||
allow_population_by_field_name = True
|
|
||||||
|
|
||||||
|
|
||||||
class TextModelRequest(BaseModel):
|
class TextModelRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
16
machine-learning/poetry.lock
generated
16
machine-learning/poetry.lock
generated
|
@ -1733,6 +1733,8 @@ files = [
|
||||||
{file = "scikit_image-0.21.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:c01e3ab0a1fabfd8ce30686d4401b7ed36e6126c9d4d05cb94abf6bdc46f7ac9"},
|
{file = "scikit_image-0.21.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:c01e3ab0a1fabfd8ce30686d4401b7ed36e6126c9d4d05cb94abf6bdc46f7ac9"},
|
||||||
{file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ef5d8d1099317b7b315b530348cbfa68ab8ce32459de3c074d204166951025c"},
|
{file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ef5d8d1099317b7b315b530348cbfa68ab8ce32459de3c074d204166951025c"},
|
||||||
{file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b1e96c59cab640ca5c5b22c501524cfaf34cbe0cb51ba73bd9a9ede3fb6e1d"},
|
{file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b1e96c59cab640ca5c5b22c501524cfaf34cbe0cb51ba73bd9a9ede3fb6e1d"},
|
||||||
|
{file = "scikit_image-0.21.0-cp39-cp39-win_amd64.whl", hash = "sha256:9cffcddd2a5594c0a06de2ae3e1e25d662745a26f94fda31520593669677c010"},
|
||||||
|
{file = "scikit_image-0.21.0.tar.gz", hash = "sha256:b33e823c54e6f11873ea390ee49ef832b82b9f70752c8759efd09d5a4e3d87f0"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -2088,9 +2090,9 @@ opt-einsum = ["opt-einsum (>=3.3)"]
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torch"
|
name = "torch"
|
||||||
version = "2.0.1+cpu"
|
version = "2.0.1+cpu"
|
||||||
description = ""
|
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "torch-2.0.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:fec257249ba014c68629a1994b0c6e7356e20e1afc77a87b9941a40e5095285d"},
|
{file = "torch-2.0.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:fec257249ba014c68629a1994b0c6e7356e20e1afc77a87b9941a40e5095285d"},
|
||||||
{file = "torch-2.0.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ca88b499973c4c027e32c4960bf20911d7e984bd0c55cda181dc643559f3d93f"},
|
{file = "torch-2.0.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ca88b499973c4c027e32c4960bf20911d7e984bd0c55cda181dc643559f3d93f"},
|
||||||
|
@ -2102,6 +2104,16 @@ files = [
|
||||||
{file = "torch-2.0.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:f263f8e908288427ae81441fef540377f61e339a27632b1bbe33cf78292fdaea"},
|
{file = "torch-2.0.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:f263f8e908288427ae81441fef540377f61e339a27632b1bbe33cf78292fdaea"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
filelock = "*"
|
||||||
|
jinja2 = "*"
|
||||||
|
networkx = "*"
|
||||||
|
sympy = "*"
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "legacy"
|
type = "legacy"
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
|
@ -1,21 +1,26 @@
|
||||||
import { DetectFaceResult, IMachineLearningRepository, MachineLearningInput, MACHINE_LEARNING_URL } from '@app/domain';
|
import { DetectFaceResult, IMachineLearningRepository, MachineLearningInput, MACHINE_LEARNING_URL } from '@app/domain';
|
||||||
import { Injectable } from '@nestjs/common';
|
import { Injectable } from '@nestjs/common';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
|
import { createReadStream } from 'fs';
|
||||||
|
|
||||||
const client = axios.create({ baseURL: MACHINE_LEARNING_URL });
|
const client = axios.create({ baseURL: MACHINE_LEARNING_URL });
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||||
|
private post<T>(input: MachineLearningInput, endpoint: string): Promise<T> {
|
||||||
|
return client.post<T>(endpoint, createReadStream(input.imagePath)).then((res) => res.data);
|
||||||
|
}
|
||||||
|
|
||||||
classifyImage(input: MachineLearningInput): Promise<string[]> {
|
classifyImage(input: MachineLearningInput): Promise<string[]> {
|
||||||
return client.post<string[]>('/image-classifier/tag-image', input).then((res) => res.data);
|
return this.post<string[]>(input, '/image-classifier/tag-image');
|
||||||
}
|
}
|
||||||
|
|
||||||
detectFaces(input: MachineLearningInput): Promise<DetectFaceResult[]> {
|
detectFaces(input: MachineLearningInput): Promise<DetectFaceResult[]> {
|
||||||
return client.post<DetectFaceResult[]>('/facial-recognition/detect-faces', input).then((res) => res.data);
|
return this.post<DetectFaceResult[]>(input, '/facial-recognition/detect-faces');
|
||||||
}
|
}
|
||||||
|
|
||||||
encodeImage(input: MachineLearningInput): Promise<number[]> {
|
encodeImage(input: MachineLearningInput): Promise<number[]> {
|
||||||
return client.post<number[]>('/sentence-transformer/encode-image', input).then((res) => res.data);
|
return this.post<number[]>(input, '/sentence-transformer/encode-image');
|
||||||
}
|
}
|
||||||
|
|
||||||
encodeText(input: string): Promise<number[]> {
|
encodeText(input: string): Promise<number[]> {
|
||||||
|
|
Loading…
Reference in a new issue