mirror of
https://github.com/immich-app/immich.git
synced 2025-01-16 16:56:46 +01:00
chore(ml): use strict mypy (#5001)
* improved typing * improved export typing * strict mypy & check export folder * formatting * add formatting checks for export folder * re-added init call
This commit is contained in:
parent
9fa9ad05b1
commit
935f471ccb
10 changed files with 70 additions and 55 deletions
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
|
@ -168,13 +168,13 @@ jobs:
|
|||
poetry install --with dev
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
poetry run ruff check --format=github app
|
||||
poetry run ruff check --format=github app export
|
||||
- name: Check black formatting
|
||||
run: |
|
||||
poetry run black --check app
|
||||
poetry run black --check app export
|
||||
- name: Run mypy type checking
|
||||
run: |
|
||||
poetry run mypy --install-types --non-interactive app/
|
||||
poetry run mypy --install-types --non-interactive --strict app/ export/
|
||||
- name: Run tests and coverage
|
||||
run: |
|
||||
poetry run pytest --cov app
|
||||
|
|
|
@ -36,7 +36,8 @@ def deployed_app() -> TestClient:
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def responses() -> dict[str, Any]:
|
||||
return json.load(open("responses.json", "r"))
|
||||
responses: dict[str, Any] = json.load(open("responses.json", "r"))
|
||||
return responses
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -7,7 +7,7 @@ from zipfile import BadZipFile
|
|||
import orjson
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||
from starlette.formparsers import MultiPartParser
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
|||
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
from typing_extensions import Buffer
|
||||
|
||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||
from ..schemas import ModelType
|
||||
|
@ -139,11 +140,12 @@ class InferenceModel(ABC):
|
|||
|
||||
|
||||
# HF deep copies configs, so we need to make session options picklable
|
||||
class PicklableSessionOptions(ort.SessionOptions):
|
||||
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
||||
def __getstate__(self) -> bytes:
|
||||
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
||||
|
||||
def __setstate__(self, state: Any) -> None:
|
||||
self.__init__() # type: ignore
|
||||
for attr, val in pickle.loads(state):
|
||||
def __setstate__(self, state: Buffer) -> None:
|
||||
self.__init__() # type: ignore[misc]
|
||||
attrs: list[tuple[str, Any]] = pickle.loads(state)
|
||||
for attr, val in attrs:
|
||||
setattr(self, attr, val)
|
||||
|
|
|
@ -6,7 +6,7 @@ from aiocache.plugins import BasePlugin, TimingPlugin
|
|||
|
||||
from app.models import from_model_type
|
||||
|
||||
from ..schemas import ModelType
|
||||
from ..schemas import ModelType, has_profiling
|
||||
from .base import InferenceModel
|
||||
|
||||
|
||||
|
@ -50,20 +50,20 @@ class ModelCache:
|
|||
|
||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model = await self.cache.get(key)
|
||||
model: InferenceModel | None = await self.cache.get(key)
|
||||
if model is None:
|
||||
model = from_model_type(model_type, model_name, **model_kwargs)
|
||||
await lock.cas(model, ttl=self.ttl)
|
||||
return model
|
||||
|
||||
async def get_profiling(self) -> dict[str, float] | None:
|
||||
if not hasattr(self.cache, "profiling"):
|
||||
if not has_profiling(self.cache):
|
||||
return None
|
||||
|
||||
return self.cache.profiling # type: ignore
|
||||
return self.cache.profiling
|
||||
|
||||
|
||||
class RevalidationPlugin(BasePlugin):
|
||||
class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
||||
"""Revalidates cache item's TTL after cache hit."""
|
||||
|
||||
async def post_get(
|
||||
|
|
|
@ -51,7 +51,7 @@ class BaseCLIPEncoder(InferenceModel):
|
|||
provider_options=self.provider_options,
|
||||
)
|
||||
|
||||
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
|
||||
def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
|
||||
if isinstance(image_or_text, bytes):
|
||||
image_or_text = Image.open(BytesIO(image_or_text))
|
||||
|
||||
|
@ -60,16 +60,16 @@ class BaseCLIPEncoder(InferenceModel):
|
|||
if self.mode == "text":
|
||||
raise TypeError("Cannot encode image as text-only model")
|
||||
|
||||
outputs = self.vision_model.run(None, self.transform(image_or_text))
|
||||
outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
|
||||
case str():
|
||||
if self.mode == "vision":
|
||||
raise TypeError("Cannot encode text as vision-only model")
|
||||
|
||||
outputs = self.text_model.run(None, self.tokenize(image_or_text))
|
||||
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
|
||||
case _:
|
||||
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
||||
|
||||
return outputs[0][0].tolist()
|
||||
return outputs
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
|
||||
|
@ -151,11 +151,13 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
|||
|
||||
@cached_property
|
||||
def model_cfg(self) -> dict[str, Any]:
|
||||
return json.load(self.model_cfg_path.open())
|
||||
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||
return model_cfg
|
||||
|
||||
@cached_property
|
||||
def preprocess_cfg(self) -> dict[str, Any]:
|
||||
return json.load(self.preprocess_cfg_path.open())
|
||||
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
||||
return preprocess_cfg
|
||||
|
||||
|
||||
class MCLIPEncoder(OpenCLIPEncoder):
|
||||
|
|
|
@ -8,7 +8,7 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
|||
from insightface.utils.face_align import norm_crop
|
||||
|
||||
from app.config import clean_name
|
||||
from app.schemas import ModelType, ndarray_f32
|
||||
from app.schemas import BoundingBox, Face, ModelType, ndarray_f32
|
||||
|
||||
from .base import InferenceModel
|
||||
|
||||
|
@ -52,7 +52,7 @@ class FaceRecognizer(InferenceModel):
|
|||
)
|
||||
self.rec_model.prepare(ctx_id=0)
|
||||
|
||||
def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]:
|
||||
def _predict(self, image: ndarray_f32 | bytes) -> list[Face]:
|
||||
if isinstance(image, bytes):
|
||||
image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
||||
bboxes, kpss = self.det_model.detect(image)
|
||||
|
@ -67,9 +67,8 @@ class FaceRecognizer(InferenceModel):
|
|||
height, width, _ = image.shape
|
||||
for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
|
||||
cropped_img = norm_crop(image, kps)
|
||||
embedding = self.rec_model.get_feat(cropped_img)[0].tolist()
|
||||
results.append(
|
||||
{
|
||||
embedding: ndarray_f32 = self.rec_model.get_feat(cropped_img)[0]
|
||||
face: Face = {
|
||||
"imageWidth": width,
|
||||
"imageHeight": height,
|
||||
"boundingBox": {
|
||||
|
@ -81,7 +80,7 @@ class FaceRecognizer(InferenceModel):
|
|||
"score": score,
|
||||
"embedding": embedding,
|
||||
}
|
||||
)
|
||||
results.append(face)
|
||||
return results
|
||||
|
||||
@property
|
||||
|
|
|
@ -66,7 +66,7 @@ class ImageClassifier(InferenceModel):
|
|||
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
||||
if isinstance(image, bytes):
|
||||
image = Image.open(BytesIO(image))
|
||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
||||
predictions: list[dict[str, Any]] = self.model(image)
|
||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||
|
||||
return tags
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
from enum import StrEnum
|
||||
from typing import TypeAlias
|
||||
from typing import Any, Protocol, TypeAlias, TypedDict, TypeGuard
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def to_lower_camel(string: str) -> str:
|
||||
tokens = [token.capitalize() if i > 0 else token for i, token in enumerate(string.split("_"))]
|
||||
return "".join(tokens)
|
||||
|
||||
|
||||
class TextModelRequest(BaseModel):
|
||||
text: str
|
||||
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
||||
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
|
||||
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
|
||||
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
|
@ -22,7 +17,7 @@ class MessageResponse(BaseModel):
|
|||
message: str
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
class BoundingBox(TypedDict):
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
|
@ -35,6 +30,17 @@ class ModelType(StrEnum):
|
|||
FACIAL_RECOGNITION = "facial-recognition"
|
||||
|
||||
|
||||
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
||||
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
|
||||
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
|
||||
class HasProfiling(Protocol):
|
||||
profiling: dict[str, float]
|
||||
|
||||
|
||||
class Face(TypedDict):
|
||||
boundingBox: BoundingBox
|
||||
embedding: ndarray_f32
|
||||
imageWidth: int
|
||||
imageHeight: int
|
||||
score: float
|
||||
|
||||
|
||||
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
|
||||
return hasattr(obj, "profiling") and type(obj.profiling) == dict
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from math import e
|
||||
from pathlib import Path
|
||||
|
||||
import open_clip
|
||||
|
@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
|
|||
output_path = Path(output_path)
|
||||
|
||||
def encode_image(image: torch.Tensor) -> torch.Tensor:
|
||||
return model.encode_image(image, normalize=True)
|
||||
output = model.encode_image(image, normalize=True)
|
||||
assert isinstance(output, torch.Tensor)
|
||||
return output
|
||||
|
||||
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
|
||||
traced = torch.jit.trace(encode_image, args)
|
||||
traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
|
@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
|
|||
output_path = Path(output_path)
|
||||
|
||||
def encode_text(text: torch.Tensor) -> torch.Tensor:
|
||||
return model.encode_text(text, normalize=True)
|
||||
output = model.encode_text(text, normalize=True)
|
||||
assert isinstance(output, torch.Tensor)
|
||||
return output
|
||||
|
||||
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
|
||||
traced = torch.jit.trace(encode_text, args)
|
||||
traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
|
|
Loading…
Reference in a new issue