mirror of
https://github.com/immich-app/immich.git
synced 2024-12-29 15:11:58 +00:00
refactor(ml): modularization and styling (#2835)
* basic refactor and styling * removed batching * module entrypoint * removed unused imports * model superclass, model cache now in app state * fixed cache dir and enforced abstract method --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
parent
837ad24f58
commit
a2f5674bbb
12 changed files with 281 additions and 182 deletions
|
@ -21,8 +21,8 @@ ENV NODE_ENV=production \
|
||||||
PYTHONDONTWRITEBYTECODE=1 \
|
PYTHONDONTWRITEBYTECODE=1 \
|
||||||
PYTHONUNBUFFERED=1 \
|
PYTHONUNBUFFERED=1 \
|
||||||
PATH="/opt/venv/bin:$PATH" \
|
PATH="/opt/venv/bin:$PATH" \
|
||||||
PYTHONPATH=`pwd`
|
PYTHONPATH=/usr/src
|
||||||
|
|
||||||
COPY --from=builder /opt/venv /opt/venv
|
COPY --from=builder /opt/venv /opt/venv
|
||||||
COPY app .
|
COPY app .
|
||||||
ENTRYPOINT ["python", "main.py"]
|
ENTRYPOINT ["python", "-m", "app.main"]
|
||||||
|
|
0
machine-learning/app/__init__.py
Normal file
0
machine-learning/app/__init__.py
Normal file
|
@ -1,5 +1,10 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
|
from .schemas import ModelType
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
cache_folder: str = "/cache"
|
cache_folder: str = "/cache"
|
||||||
classification_model: str = "microsoft/resnet-50"
|
classification_model: str = "microsoft/resnet-50"
|
||||||
|
@ -15,8 +20,12 @@ class Settings(BaseSettings):
|
||||||
min_face_score: float = 0.7
|
min_face_score: float = 0.7
|
||||||
|
|
||||||
class Config(BaseSettings.Config):
|
class Config(BaseSettings.Config):
|
||||||
env_prefix = 'MACHINE_LEARNING_'
|
env_prefix = "MACHINE_LEARNING_"
|
||||||
case_sensitive = False
|
case_sensitive = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_dir(model_name: str, model_type: ModelType) -> Path:
|
||||||
|
return Path(settings.cache_folder, model_type.value, model_name)
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
@ -1,52 +1,58 @@
|
||||||
import os
|
import os
|
||||||
import io
|
from io import BytesIO
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cache import ModelCache
|
import cv2
|
||||||
from schemas import (
|
import numpy as np
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import Body, Depends, FastAPI
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .config import settings
|
||||||
|
from .models.base import InferenceModel
|
||||||
|
from .models.cache import ModelCache
|
||||||
|
from .schemas import (
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
FaceResponse,
|
FaceResponse,
|
||||||
TagResponse,
|
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
|
ModelType,
|
||||||
|
TagResponse,
|
||||||
TextModelRequest,
|
TextModelRequest,
|
||||||
TextResponse,
|
TextResponse,
|
||||||
)
|
)
|
||||||
import uvicorn
|
|
||||||
from PIL import Image
|
|
||||||
from fastapi import FastAPI, HTTPException, Depends, Body
|
|
||||||
from models import get_model, run_classification, run_facial_recognition
|
|
||||||
from config import settings
|
|
||||||
|
|
||||||
_model_cache = None
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event() -> None:
|
async def startup_event() -> None:
|
||||||
global _model_cache
|
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
||||||
_model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
same_clip = settings.clip_image_model == settings.clip_text_model
|
||||||
|
app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
|
||||||
|
app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
|
||||||
models = [
|
models = [
|
||||||
(settings.classification_model, "image-classification"),
|
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
|
||||||
(settings.clip_image_model, "clip"),
|
(settings.clip_image_model, app.state.clip_vision_type),
|
||||||
(settings.clip_text_model, "clip"),
|
(settings.clip_text_model, app.state.clip_text_type),
|
||||||
(settings.facial_recognition_model, "facial-recognition"),
|
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Get all models
|
# Get all models
|
||||||
for model_name, model_type in models:
|
for model_name, model_type in models:
|
||||||
if settings.eager_startup:
|
if settings.eager_startup:
|
||||||
await _model_cache.get_cached_model(model_name, model_type)
|
await app.state.model_cache.get(model_name, model_type)
|
||||||
else:
|
else:
|
||||||
get_model(model_name, model_type)
|
InferenceModel.from_model_type(model_type, model_name)
|
||||||
|
|
||||||
|
|
||||||
def dep_model_cache():
|
def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
|
||||||
if _model_cache is None:
|
return Image.open(BytesIO(byte_image))
|
||||||
raise HTTPException(status_code=500, detail="Unable to load model.")
|
|
||||||
|
|
||||||
|
def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat:
|
||||||
|
byte_image_np = np.frombuffer(byte_image, np.uint8)
|
||||||
|
return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
|
||||||
|
|
||||||
def dep_input_image(image: bytes = Body(...)) -> Image:
|
|
||||||
return Image.open(io.BytesIO(image))
|
|
||||||
|
|
||||||
@app.get("/", response_model=MessageResponse)
|
@app.get("/", response_model=MessageResponse)
|
||||||
async def root() -> dict[str, str]:
|
async def root() -> dict[str, str]:
|
||||||
|
@ -62,19 +68,14 @@ def ping() -> str:
|
||||||
"/image-classifier/tag-image",
|
"/image-classifier/tag-image",
|
||||||
response_model=TagResponse,
|
response_model=TagResponse,
|
||||||
status_code=200,
|
status_code=200,
|
||||||
dependencies=[Depends(dep_model_cache)],
|
|
||||||
)
|
)
|
||||||
async def image_classification(
|
async def image_classification(
|
||||||
image: Image = Depends(dep_input_image)
|
image: Image.Image = Depends(dep_pil_image),
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
try:
|
model = await app.state.model_cache.get(
|
||||||
model = await _model_cache.get_cached_model(
|
settings.classification_model, ModelType.IMAGE_CLASSIFICATION
|
||||||
settings.classification_model, "image-classification"
|
|
||||||
)
|
)
|
||||||
labels = run_classification(model, image, settings.min_tag_score)
|
labels = model.predict(image)
|
||||||
except Exception as ex:
|
|
||||||
raise HTTPException(status_code=500, detail=str(ex))
|
|
||||||
else:
|
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,13 +83,14 @@ async def image_classification(
|
||||||
"/sentence-transformer/encode-image",
|
"/sentence-transformer/encode-image",
|
||||||
response_model=EmbeddingResponse,
|
response_model=EmbeddingResponse,
|
||||||
status_code=200,
|
status_code=200,
|
||||||
dependencies=[Depends(dep_model_cache)],
|
|
||||||
)
|
)
|
||||||
async def clip_encode_image(
|
async def clip_encode_image(
|
||||||
image: Image = Depends(dep_input_image)
|
image: Image.Image = Depends(dep_pil_image),
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
model = await _model_cache.get_cached_model(settings.clip_image_model, "clip")
|
model = await app.state.model_cache.get(
|
||||||
embedding = model.encode(image).tolist()
|
settings.clip_image_model, app.state.clip_vision_type
|
||||||
|
)
|
||||||
|
embedding = model.predict(image)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,13 +98,12 @@ async def clip_encode_image(
|
||||||
"/sentence-transformer/encode-text",
|
"/sentence-transformer/encode-text",
|
||||||
response_model=EmbeddingResponse,
|
response_model=EmbeddingResponse,
|
||||||
status_code=200,
|
status_code=200,
|
||||||
dependencies=[Depends(dep_model_cache)],
|
|
||||||
)
|
)
|
||||||
async def clip_encode_text(
|
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
||||||
payload: TextModelRequest
|
model = await app.state.model_cache.get(
|
||||||
) -> list[float]:
|
settings.clip_text_model, app.state.clip_text_type
|
||||||
model = await _model_cache.get_cached_model(settings.clip_text_model, "clip")
|
)
|
||||||
embedding = model.encode(payload.text).tolist()
|
embedding = model.predict(payload.text)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,22 +111,21 @@ async def clip_encode_text(
|
||||||
"/facial-recognition/detect-faces",
|
"/facial-recognition/detect-faces",
|
||||||
response_model=FaceResponse,
|
response_model=FaceResponse,
|
||||||
status_code=200,
|
status_code=200,
|
||||||
dependencies=[Depends(dep_model_cache)],
|
|
||||||
)
|
)
|
||||||
async def facial_recognition(
|
async def facial_recognition(
|
||||||
image: bytes = Body(...),
|
image: cv2.Mat = Depends(dep_cv_image),
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
model = await _model_cache.get_cached_model(
|
model = await app.state.model_cache.get(
|
||||||
settings.facial_recognition_model, "facial-recognition"
|
settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION
|
||||||
)
|
)
|
||||||
faces = run_facial_recognition(model, image)
|
faces = model.predict(image)
|
||||||
return faces
|
return faces
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
is_dev = os.getenv("NODE_ENV") == "development"
|
is_dev = os.getenv("NODE_ENV") == "development"
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"main:app",
|
"app.main:app",
|
||||||
host=settings.host,
|
host=settings.host,
|
||||||
port=settings.port,
|
port=settings.port,
|
||||||
reload=is_dev,
|
reload=is_dev,
|
||||||
|
|
|
@ -1,119 +0,0 @@
|
||||||
import torch
|
|
||||||
from insightface.app import FaceAnalysis
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from transformers import pipeline, Pipeline
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from typing import Any, BinaryIO
|
|
||||||
import cv2 as cv
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from config import settings
|
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_name: str, model_type: str, **model_kwargs):
|
|
||||||
"""
|
|
||||||
Instantiates the specified model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of model in the model hub used for the task.
|
|
||||||
model_type: Model type or task, which determines which model zoo is used.
|
|
||||||
`facial-recognition` uses Insightface, while all other models use the HF Model Hub.
|
|
||||||
|
|
||||||
Options:
|
|
||||||
`image-classification`, `clip`,`facial-recognition`, `tokenizer`, `processor`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
model: The requested model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cache_dir = _get_cache_dir(model_name, model_type)
|
|
||||||
match model_type:
|
|
||||||
case "facial-recognition":
|
|
||||||
model = _load_facial_recognition(
|
|
||||||
model_name, cache_dir=cache_dir, **model_kwargs
|
|
||||||
)
|
|
||||||
case "clip":
|
|
||||||
model = SentenceTransformer(
|
|
||||||
model_name, cache_folder=cache_dir, **model_kwargs
|
|
||||||
)
|
|
||||||
case _:
|
|
||||||
model = pipeline(
|
|
||||||
model_type,
|
|
||||||
model_name,
|
|
||||||
model_kwargs={"cache_dir": cache_dir, **model_kwargs},
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def run_classification(
|
|
||||||
model: Pipeline, image: Image, min_score: float | None = None
|
|
||||||
):
|
|
||||||
predictions: list[dict[str, Any]] = model(image) # type: ignore
|
|
||||||
result = {
|
|
||||||
tag
|
|
||||||
for pred in predictions
|
|
||||||
for tag in pred["label"].split(", ")
|
|
||||||
if min_score is None or pred["score"] >= min_score
|
|
||||||
}
|
|
||||||
|
|
||||||
return list(result)
|
|
||||||
|
|
||||||
|
|
||||||
def run_facial_recognition(
|
|
||||||
model: FaceAnalysis, image: bytes
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
file_bytes = np.frombuffer(image, dtype=np.uint8)
|
|
||||||
img = cv.imdecode(file_bytes, cv.IMREAD_COLOR)
|
|
||||||
height, width, _ = img.shape
|
|
||||||
results = []
|
|
||||||
faces = model.get(img)
|
|
||||||
|
|
||||||
for face in faces:
|
|
||||||
x1, y1, x2, y2 = face.bbox
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"imageWidth": width,
|
|
||||||
"imageHeight": height,
|
|
||||||
"boundingBox": {
|
|
||||||
"x1": round(x1),
|
|
||||||
"y1": round(y1),
|
|
||||||
"x2": round(x2),
|
|
||||||
"y2": round(y2),
|
|
||||||
},
|
|
||||||
"score": face.det_score.item(),
|
|
||||||
"embedding": face.normed_embedding.tolist(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _load_facial_recognition(
|
|
||||||
model_name: str,
|
|
||||||
min_face_score: float | None = None,
|
|
||||||
cache_dir: Path | str | None = None,
|
|
||||||
**model_kwargs,
|
|
||||||
):
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = _get_cache_dir(model_name, "facial-recognition")
|
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = cache_dir.as_posix()
|
|
||||||
if min_face_score is None:
|
|
||||||
min_face_score = settings.min_face_score
|
|
||||||
|
|
||||||
model = FaceAnalysis(
|
|
||||||
name=model_name,
|
|
||||||
root=cache_dir,
|
|
||||||
allowed_modules=["detection", "recognition"],
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
model.prepare(ctx_id=0, det_thresh=min_face_score, det_size=(640, 640))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_dir(model_name: str, model_type: str) -> Path:
|
|
||||||
return Path(settings.cache_folder, device, model_type, model_name)
|
|
3
machine-learning/app/models/__init__.py
Normal file
3
machine-learning/app/models/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder
|
||||||
|
from .facial_recognition import FaceRecognizer
|
||||||
|
from .image_classification import ImageClassifier
|
52
machine-learning/app/models/base.py
Normal file
52
machine-learning/app/models/base.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import get_cache_dir
|
||||||
|
from ..schemas import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceModel(ABC):
|
||||||
|
_model_type: ModelType
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
cache_dir: Path | None = None,
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
self._cache_dir = (
|
||||||
|
cache_dir
|
||||||
|
if cache_dir is not None
|
||||||
|
else get_cache_dir(model_name, self.model_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def predict(self, inputs: Any) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_type(self) -> ModelType:
|
||||||
|
return self._model_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_dir(self) -> Path:
|
||||||
|
return self._cache_dir
|
||||||
|
|
||||||
|
@cache_dir.setter
|
||||||
|
def cache_dir(self, cache_dir: Path):
|
||||||
|
self._cache_dir = cache_dir
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model_type(
|
||||||
|
cls, model_type: ModelType, model_name, **model_kwargs
|
||||||
|
) -> InferenceModel:
|
||||||
|
subclasses = {
|
||||||
|
subclass._model_type: subclass for subclass in cls.__subclasses__()
|
||||||
|
}
|
||||||
|
if model_type not in subclasses:
|
||||||
|
raise ValueError(f"Unsupported model type: {model_type}")
|
||||||
|
|
||||||
|
return subclasses[model_type](model_name, **model_kwargs)
|
|
@ -1,8 +1,11 @@
|
||||||
from aiocache.plugins import TimingPlugin, BasePlugin
|
import asyncio
|
||||||
|
|
||||||
from aiocache.backends.memory import SimpleMemoryCache
|
from aiocache.backends.memory import SimpleMemoryCache
|
||||||
from aiocache.lock import OptimisticLock
|
from aiocache.lock import OptimisticLock
|
||||||
from typing import Any
|
from aiocache.plugins import BasePlugin, TimingPlugin
|
||||||
from models import get_model
|
|
||||||
|
from ..schemas import ModelType
|
||||||
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
class ModelCache:
|
class ModelCache:
|
||||||
|
@ -10,7 +13,7 @@ class ModelCache:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ttl: int | None = None,
|
ttl: float | None = None,
|
||||||
revalidate: bool = False,
|
revalidate: bool = False,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
profiling: bool = False,
|
profiling: bool = False,
|
||||||
|
@ -35,9 +38,9 @@ class ModelCache:
|
||||||
ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
|
ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_cached_model(
|
async def get(
|
||||||
self, model_name: str, model_type: str, **model_kwargs
|
self, model_name: str, model_type: ModelType, **model_kwargs
|
||||||
) -> Any:
|
) -> InferenceModel:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
model_name: Name of model in the model hub used for the task.
|
model_name: Name of model in the model hub used for the task.
|
||||||
|
@ -47,11 +50,16 @@ class ModelCache:
|
||||||
model: The requested model.
|
model: The requested model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key = self.cache.build_key(model_name, model_type)
|
key = self.cache.build_key(model_name, model_type.value)
|
||||||
model = await self.cache.get(key)
|
model = await self.cache.get(key)
|
||||||
if model is None:
|
if model is None:
|
||||||
async with OptimisticLock(self.cache, key) as lock:
|
async with OptimisticLock(self.cache, key) as lock:
|
||||||
model = get_model(model_name, model_type, **model_kwargs)
|
model = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: InferenceModel.from_model_type(
|
||||||
|
model_type, model_name, **model_kwargs
|
||||||
|
),
|
||||||
|
)
|
||||||
await lock.cas(model, ttl=self.ttl)
|
await lock.cas(model, ttl=self.ttl)
|
||||||
return model
|
return model
|
||||||
|
|
37
machine-learning/app/models/clip.py
Normal file
37
machine-learning/app/models/clip.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from ..schemas import ModelType
|
||||||
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPSTEncoder(InferenceModel):
|
||||||
|
_model_type = ModelType.CLIP
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
cache_dir: Path | None = None,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model_name, cache_dir)
|
||||||
|
self.model = SentenceTransformer(
|
||||||
|
self.model_name,
|
||||||
|
cache_folder=self.cache_dir.as_posix(),
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, image_or_text: Image | str) -> list[float]:
|
||||||
|
return self.model.encode(image_or_text).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
# stubs to allow different behavior between the two in the future
|
||||||
|
# and handle loading different image and text clip models
|
||||||
|
class CLIPSTVisionEncoder(CLIPSTEncoder):
|
||||||
|
_model_type = ModelType.CLIP_VISION
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPSTTextEncoder(CLIPSTEncoder):
|
||||||
|
_model_type = ModelType.CLIP_TEXT
|
59
machine-learning/app/models/facial_recognition.py
Normal file
59
machine-learning/app/models/facial_recognition.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from insightface.app import FaceAnalysis
|
||||||
|
|
||||||
|
from ..config import settings
|
||||||
|
from ..schemas import ModelType
|
||||||
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
class FaceRecognizer(InferenceModel):
|
||||||
|
_model_type = ModelType.FACIAL_RECOGNITION
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
min_score: float = settings.min_face_score,
|
||||||
|
cache_dir: Path | None = None,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model_name, cache_dir)
|
||||||
|
self.min_score = min_score
|
||||||
|
model = FaceAnalysis(
|
||||||
|
name=self.model_name,
|
||||||
|
root=self.cache_dir.as_posix(),
|
||||||
|
allowed_modules=["detection", "recognition"],
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
model.prepare(
|
||||||
|
ctx_id=0,
|
||||||
|
det_thresh=self.min_score,
|
||||||
|
det_size=(640, 640),
|
||||||
|
)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def predict(self, image: cv2.Mat) -> list[dict[str, Any]]:
|
||||||
|
height, width, _ = image.shape
|
||||||
|
results = []
|
||||||
|
faces = self.model.get(image)
|
||||||
|
|
||||||
|
for face in faces:
|
||||||
|
x1, y1, x2, y2 = face.bbox
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"imageWidth": width,
|
||||||
|
"imageHeight": height,
|
||||||
|
"boundingBox": {
|
||||||
|
"x1": round(x1),
|
||||||
|
"y1": round(y1),
|
||||||
|
"x2": round(x2),
|
||||||
|
"y2": round(y2),
|
||||||
|
},
|
||||||
|
"score": face.det_score.item(),
|
||||||
|
"embedding": face.normed_embedding.tolist(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
40
machine-learning/app/models/image_classification.py
Normal file
40
machine-learning/app/models/image_classification.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from PIL.Image import Image
|
||||||
|
from transformers.pipelines import pipeline
|
||||||
|
|
||||||
|
from ..config import settings
|
||||||
|
from ..schemas import ModelType
|
||||||
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifier(InferenceModel):
|
||||||
|
_model_type = ModelType.IMAGE_CLASSIFICATION
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
min_score: float = settings.min_tag_score,
|
||||||
|
cache_dir: Path | None = None,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(model_name, cache_dir)
|
||||||
|
self.min_score = min_score
|
||||||
|
|
||||||
|
self.model = pipeline(
|
||||||
|
self.model_type.value,
|
||||||
|
self.model_name,
|
||||||
|
model_kwargs={"cache_dir": self.cache_dir, **model_kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, image: Image) -> list[str]:
|
||||||
|
predictions = self.model(image)
|
||||||
|
tags = list(
|
||||||
|
{
|
||||||
|
tag
|
||||||
|
for pred in predictions
|
||||||
|
for tag in pred["label"].split(", ")
|
||||||
|
if pred["score"] >= self.min_score
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tags
|
|
@ -1,3 +1,5 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,3 +56,11 @@ class Face(BaseModel):
|
||||||
|
|
||||||
class FaceResponse(BaseModel):
|
class FaceResponse(BaseModel):
|
||||||
__root__: list[Face]
|
__root__: list[Face]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
IMAGE_CLASSIFICATION = "image-classification"
|
||||||
|
CLIP = "clip"
|
||||||
|
CLIP_VISION = "clip-vision"
|
||||||
|
CLIP_TEXT = "clip-text"
|
||||||
|
FACIAL_RECOGNITION = "facial-recognition"
|
||||||
|
|
Loading…
Reference in a new issue