1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-01 08:31:59 +00:00

feat(ml) backend takes image over HTTP (#2783)

* using pydantic BaseSetting

* ML API takes image file as input

* keeping image in memory

* reducing duplicate code

* using bytes instead of UploadFile & other small code improvements

* removed form-multipart, using HTTP body

* format code

---------

Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
Zeeshan Khan 2023-06-17 22:49:19 -05:00 committed by GitHub
parent 3e804f16df
commit 34201be74c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 116 additions and 80 deletions

View file

@ -36,7 +36,6 @@ services:
- 3003:3003 - 3003:3003
volumes: volumes:
- ../machine-learning/app:/usr/src/app - ../machine-learning/app:/usr/src/app
- ${UPLOAD_LOCATION}:/usr/src/app/upload
- model-cache:/cache - model-cache:/cache
env_file: env_file:
- .env - .env

View file

@ -33,7 +33,6 @@ services:
container_name: immich_machine_learning container_name: immich_machine_learning
image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release} image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release}
volumes: volumes:
- ${UPLOAD_LOCATION}:/usr/src/app/upload
- model-cache:/cache - model-cache:/cache
env_file: env_file:
- .env - .env

View file

@ -0,0 +1,22 @@
from pydantic import BaseSettings
class Settings(BaseSettings):
cache_folder: str = "/cache"
classification_model: str = "microsoft/resnet-50"
clip_image_model: str = "clip-ViT-B-32"
clip_text_model: str = "clip-ViT-B-32"
facial_recognition_model: str = "buffalo_l"
min_tag_score: float = 0.9
eager_startup: bool = True
model_ttl: int = 300
host: str = "0.0.0.0"
port: int = 3003
workers: int = 1
min_face_score: float = 0.7
class Config(BaseSettings.Config):
env_prefix = 'MACHINE_LEARNING_'
case_sensitive = False
settings = Settings()

View file

@ -1,4 +1,5 @@
import os import os
import io
from typing import Any from typing import Any
from cache import ModelCache from cache import ModelCache
@ -9,52 +10,44 @@ from schemas import (
MessageResponse, MessageResponse,
TextModelRequest, TextModelRequest,
TextResponse, TextResponse,
VisionModelRequest,
) )
import uvicorn import uvicorn
from PIL import Image from PIL import Image
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Depends, Body
from models import get_model, run_classification, run_facial_recognition from models import get_model, run_classification, run_facial_recognition
from config import settings
classification_model = os.getenv(
"MACHINE_LEARNING_CLASSIFICATION_MODEL", "microsoft/resnet-50"
)
clip_image_model = os.getenv("MACHINE_LEARNING_CLIP_IMAGE_MODEL", "clip-ViT-B-32")
clip_text_model = os.getenv("MACHINE_LEARNING_CLIP_TEXT_MODEL", "clip-ViT-B-32")
facial_recognition_model = os.getenv(
"MACHINE_LEARNING_FACIAL_RECOGNITION_MODEL", "buffalo_l"
)
min_tag_score = float(os.getenv("MACHINE_LEARNING_MIN_TAG_SCORE", 0.9))
eager_startup = (
os.getenv("MACHINE_LEARNING_EAGER_STARTUP", "true") == "true"
) # loads all models at startup
model_ttl = int(os.getenv("MACHINE_LEARNING_MODEL_TTL", 300))
_model_cache = None _model_cache = None
app = FastAPI() app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
async def startup_event() -> None: async def startup_event() -> None:
global _model_cache global _model_cache
_model_cache = ModelCache(ttl=model_ttl, revalidate=True) _model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
models = [ models = [
(classification_model, "image-classification"), (settings.classification_model, "image-classification"),
(clip_image_model, "clip"), (settings.clip_image_model, "clip"),
(clip_text_model, "clip"), (settings.clip_text_model, "clip"),
(facial_recognition_model, "facial-recognition"), (settings.facial_recognition_model, "facial-recognition"),
] ]
# Get all models # Get all models
for model_name, model_type in models: for model_name, model_type in models:
if eager_startup: if settings.eager_startup:
await _model_cache.get_cached_model(model_name, model_type) await _model_cache.get_cached_model(model_name, model_type)
else: else:
get_model(model_name, model_type) get_model(model_name, model_type)
def dep_model_cache():
if _model_cache is None:
raise HTTPException(status_code=500, detail="Unable to load model.")
def dep_input_image(image: bytes = Body(...)) -> Image:
return Image.open(io.BytesIO(image))
@app.get("/", response_model=MessageResponse) @app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]: async def root() -> dict[str, str]:
return {"message": "Immich ML"} return {"message": "Immich ML"}
@ -65,29 +58,36 @@ def ping() -> str:
return "pong" return "pong"
@app.post("/image-classifier/tag-image", response_model=TagResponse, status_code=200) @app.post(
async def image_classification(payload: VisionModelRequest) -> list[str]: "/image-classifier/tag-image",
if _model_cache is None: response_model=TagResponse,
raise HTTPException(status_code=500, detail="Unable to load model.") status_code=200,
dependencies=[Depends(dep_model_cache)],
model = await _model_cache.get_cached_model( )
classification_model, "image-classification" async def image_classification(
) image: Image = Depends(dep_input_image)
labels = run_classification(model, payload.image_path, min_tag_score) ) -> list[str]:
return labels try:
model = await _model_cache.get_cached_model(
settings.classification_model, "image-classification"
)
labels = run_classification(model, image, settings.min_tag_score)
except Exception as ex:
raise HTTPException(status_code=500, detail=str(ex))
else:
return labels
@app.post( @app.post(
"/sentence-transformer/encode-image", "/sentence-transformer/encode-image",
response_model=EmbeddingResponse, response_model=EmbeddingResponse,
status_code=200, status_code=200,
dependencies=[Depends(dep_model_cache)],
) )
async def clip_encode_image(payload: VisionModelRequest) -> list[float]: async def clip_encode_image(
if _model_cache is None: image: Image = Depends(dep_input_image)
raise HTTPException(status_code=500, detail="Unable to load model.") ) -> list[float]:
model = await _model_cache.get_cached_model(settings.clip_image_model, "clip")
model = await _model_cache.get_cached_model(clip_image_model, "clip")
image = Image.open(payload.image_path)
embedding = model.encode(image).tolist() embedding = model.encode(image).tolist()
return embedding return embedding
@ -96,33 +96,38 @@ async def clip_encode_image(payload: VisionModelRequest) -> list[float]:
"/sentence-transformer/encode-text", "/sentence-transformer/encode-text",
response_model=EmbeddingResponse, response_model=EmbeddingResponse,
status_code=200, status_code=200,
dependencies=[Depends(dep_model_cache)],
) )
async def clip_encode_text(payload: TextModelRequest) -> list[float]: async def clip_encode_text(
if _model_cache is None: payload: TextModelRequest
raise HTTPException(status_code=500, detail="Unable to load model.") ) -> list[float]:
model = await _model_cache.get_cached_model(settings.clip_text_model, "clip")
model = await _model_cache.get_cached_model(clip_text_model, "clip")
embedding = model.encode(payload.text).tolist() embedding = model.encode(payload.text).tolist()
return embedding return embedding
@app.post( @app.post(
"/facial-recognition/detect-faces", response_model=FaceResponse, status_code=200 "/facial-recognition/detect-faces",
response_model=FaceResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
) )
async def facial_recognition(payload: VisionModelRequest) -> list[dict[str, Any]]: async def facial_recognition(
if _model_cache is None: image: bytes = Body(...),
raise HTTPException(status_code=500, detail="Unable to load model.") ) -> list[dict[str, Any]]:
model = await _model_cache.get_cached_model( model = await _model_cache.get_cached_model(
facial_recognition_model, "facial-recognition" settings.facial_recognition_model, "facial-recognition"
) )
faces = run_facial_recognition(model, payload.image_path) faces = run_facial_recognition(model, image)
return faces return faces
if __name__ == "__main__": if __name__ == "__main__":
host = os.getenv("MACHINE_LEARNING_HOST", "0.0.0.0")
port = int(os.getenv("MACHINE_LEARNING_PORT", 3003))
is_dev = os.getenv("NODE_ENV") == "development" is_dev = os.getenv("NODE_ENV") == "development"
uvicorn.run(
uvicorn.run("main:app", host=host, port=port, reload=is_dev, workers=1) "main:app",
host=settings.host,
port=settings.port,
reload=is_dev,
workers=settings.workers,
)

View file

@ -1,14 +1,15 @@
import torch import torch
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis
from pathlib import Path from pathlib import Path
import os
from transformers import pipeline, Pipeline from transformers import pipeline, Pipeline
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from typing import Any from typing import Any, BinaryIO
import cv2 as cv import cv2 as cv
import numpy as np
from PIL import Image
from config import settings
cache_folder = os.getenv("MACHINE_LEARNING_CACHE_FOLDER", "/cache")
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@ -49,9 +50,9 @@ def get_model(model_name: str, model_type: str, **model_kwargs):
def run_classification( def run_classification(
model: Pipeline, image_path: str, min_score: float | None = None model: Pipeline, image: Image, min_score: float | None = None
): ):
predictions: list[dict[str, Any]] = model(image_path) # type: ignore predictions: list[dict[str, Any]] = model(image) # type: ignore
result = { result = {
tag tag
for pred in predictions for pred in predictions
@ -63,9 +64,10 @@ def run_classification(
def run_facial_recognition( def run_facial_recognition(
model: FaceAnalysis, image_path: str model: FaceAnalysis, image: bytes
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
img = cv.imread(image_path) file_bytes = np.frombuffer(image, dtype=np.uint8)
img = cv.imdecode(file_bytes, cv.IMREAD_COLOR)
height, width, _ = img.shape height, width, _ = img.shape
results = [] results = []
faces = model.get(img) faces = model.get(img)
@ -101,7 +103,7 @@ def _load_facial_recognition(
if isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = cache_dir.as_posix() cache_dir = cache_dir.as_posix()
if min_face_score is None: if min_face_score is None:
min_face_score = float(os.getenv("MACHINE_LEARNING_MIN_FACE_SCORE", 0.7)) min_face_score = settings.min_face_score
model = FaceAnalysis( model = FaceAnalysis(
name=model_name, name=model_name,
@ -114,4 +116,4 @@ def _load_facial_recognition(
def _get_cache_dir(model_name: str, model_type: str) -> Path: def _get_cache_dir(model_name: str, model_type: str) -> Path:
return Path(cache_folder, device, model_type, model_name) return Path(settings.cache_folder, device, model_type, model_name)

View file

@ -9,14 +9,6 @@ def to_lower_camel(string: str) -> str:
return "".join(tokens) return "".join(tokens)
class VisionModelRequest(BaseModel):
image_path: str
class Config:
alias_generator = to_lower_camel
allow_population_by_field_name = True
class TextModelRequest(BaseModel): class TextModelRequest(BaseModel):
text: str text: str

View file

@ -1733,6 +1733,8 @@ files = [
{file = "scikit_image-0.21.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:c01e3ab0a1fabfd8ce30686d4401b7ed36e6126c9d4d05cb94abf6bdc46f7ac9"}, {file = "scikit_image-0.21.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:c01e3ab0a1fabfd8ce30686d4401b7ed36e6126c9d4d05cb94abf6bdc46f7ac9"},
{file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ef5d8d1099317b7b315b530348cbfa68ab8ce32459de3c074d204166951025c"}, {file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ef5d8d1099317b7b315b530348cbfa68ab8ce32459de3c074d204166951025c"},
{file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b1e96c59cab640ca5c5b22c501524cfaf34cbe0cb51ba73bd9a9ede3fb6e1d"}, {file = "scikit_image-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b1e96c59cab640ca5c5b22c501524cfaf34cbe0cb51ba73bd9a9ede3fb6e1d"},
{file = "scikit_image-0.21.0-cp39-cp39-win_amd64.whl", hash = "sha256:9cffcddd2a5594c0a06de2ae3e1e25d662745a26f94fda31520593669677c010"},
{file = "scikit_image-0.21.0.tar.gz", hash = "sha256:b33e823c54e6f11873ea390ee49ef832b82b9f70752c8759efd09d5a4e3d87f0"},
] ]
[package.dependencies] [package.dependencies]
@ -2088,9 +2090,9 @@ opt-einsum = ["opt-einsum (>=3.3)"]
[[package]] [[package]]
name = "torch" name = "torch"
version = "2.0.1+cpu" version = "2.0.1+cpu"
description = "" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false optional = false
python-versions = "*" python-versions = ">=3.8.0"
files = [ files = [
{file = "torch-2.0.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:fec257249ba014c68629a1994b0c6e7356e20e1afc77a87b9941a40e5095285d"}, {file = "torch-2.0.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:fec257249ba014c68629a1994b0c6e7356e20e1afc77a87b9941a40e5095285d"},
{file = "torch-2.0.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ca88b499973c4c027e32c4960bf20911d7e984bd0c55cda181dc643559f3d93f"}, {file = "torch-2.0.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ca88b499973c4c027e32c4960bf20911d7e984bd0c55cda181dc643559f3d93f"},
@ -2102,6 +2104,16 @@ files = [
{file = "torch-2.0.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:f263f8e908288427ae81441fef540377f61e339a27632b1bbe33cf78292fdaea"}, {file = "torch-2.0.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:f263f8e908288427ae81441fef540377f61e339a27632b1bbe33cf78292fdaea"},
] ]
[package.dependencies]
filelock = "*"
jinja2 = "*"
networkx = "*"
sympy = "*"
typing-extensions = "*"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
[package.source] [package.source]
type = "legacy" type = "legacy"
url = "https://download.pytorch.org/whl/cpu" url = "https://download.pytorch.org/whl/cpu"

View file

@ -1,21 +1,26 @@
import { DetectFaceResult, IMachineLearningRepository, MachineLearningInput, MACHINE_LEARNING_URL } from '@app/domain'; import { DetectFaceResult, IMachineLearningRepository, MachineLearningInput, MACHINE_LEARNING_URL } from '@app/domain';
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import axios from 'axios'; import axios from 'axios';
import { createReadStream } from 'fs';
const client = axios.create({ baseURL: MACHINE_LEARNING_URL }); const client = axios.create({ baseURL: MACHINE_LEARNING_URL });
@Injectable() @Injectable()
export class MachineLearningRepository implements IMachineLearningRepository { export class MachineLearningRepository implements IMachineLearningRepository {
private post<T>(input: MachineLearningInput, endpoint: string): Promise<T> {
return client.post<T>(endpoint, createReadStream(input.imagePath)).then((res) => res.data);
}
classifyImage(input: MachineLearningInput): Promise<string[]> { classifyImage(input: MachineLearningInput): Promise<string[]> {
return client.post<string[]>('/image-classifier/tag-image', input).then((res) => res.data); return this.post<string[]>(input, '/image-classifier/tag-image');
} }
detectFaces(input: MachineLearningInput): Promise<DetectFaceResult[]> { detectFaces(input: MachineLearningInput): Promise<DetectFaceResult[]> {
return client.post<DetectFaceResult[]>('/facial-recognition/detect-faces', input).then((res) => res.data); return this.post<DetectFaceResult[]>(input, '/facial-recognition/detect-faces');
} }
encodeImage(input: MachineLearningInput): Promise<number[]> { encodeImage(input: MachineLearningInput): Promise<number[]> {
return client.post<number[]>('/sentence-transformer/encode-image', input).then((res) => res.data); return this.post<number[]>(input, '/sentence-transformer/encode-image');
} }
encodeText(input: string): Promise<number[]> { encodeText(input: string): Promise<number[]> {