1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-01 08:31:59 +00:00

chore(ml): use strict mypy (#5001)

* improved typing

* improved export typing

* strict mypy & check export folder

* formatting

* add formatting checks for export folder

* re-added init call
This commit is contained in:
Mert 2023-11-13 11:18:46 -05:00 committed by GitHub
parent 9fa9ad05b1
commit 935f471ccb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 70 additions and 55 deletions

View file

@ -168,13 +168,13 @@ jobs:
poetry install --with dev poetry install --with dev
- name: Lint with ruff - name: Lint with ruff
run: | run: |
poetry run ruff check --format=github app poetry run ruff check --format=github app export
- name: Check black formatting - name: Check black formatting
run: | run: |
poetry run black --check app poetry run black --check app export
- name: Run mypy type checking - name: Run mypy type checking
run: | run: |
poetry run mypy --install-types --non-interactive app/ poetry run mypy --install-types --non-interactive --strict app/ export/
- name: Run tests and coverage - name: Run tests and coverage
run: | run: |
poetry run pytest --cov app poetry run pytest --cov app

View file

@ -36,7 +36,8 @@ def deployed_app() -> TestClient:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def responses() -> dict[str, Any]: def responses() -> dict[str, Any]:
return json.load(open("responses.json", "r")) responses: dict[str, Any] = json.load(open("responses.json", "r"))
return responses
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View file

@ -7,7 +7,7 @@ from zipfile import BadZipFile
import orjson import orjson
from fastapi import FastAPI, Form, HTTPException, UploadFile from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from starlette.formparsers import MultiPartParser from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel from app.models.base import InferenceModel

View file

@ -8,6 +8,7 @@ from typing import Any
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from typing_extensions import Buffer
from ..config import get_cache_dir, get_hf_model_name, log, settings from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelType from ..schemas import ModelType
@ -139,11 +140,12 @@ class InferenceModel(ABC):
# HF deep copies configs, so we need to make session options picklable # HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions): class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
def __getstate__(self) -> bytes: def __getstate__(self) -> bytes:
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))]) return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
def __setstate__(self, state: Any) -> None: def __setstate__(self, state: Buffer) -> None:
self.__init__() # type: ignore self.__init__() # type: ignore[misc]
for attr, val in pickle.loads(state): attrs: list[tuple[str, Any]] = pickle.loads(state)
for attr, val in attrs:
setattr(self, attr, val) setattr(self, attr, val)

View file

@ -6,7 +6,7 @@ from aiocache.plugins import BasePlugin, TimingPlugin
from app.models import from_model_type from app.models import from_model_type
from ..schemas import ModelType from ..schemas import ModelType, has_profiling
from .base import InferenceModel from .base import InferenceModel
@ -50,20 +50,20 @@ class ModelCache:
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}" key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
async with OptimisticLock(self.cache, key) as lock: async with OptimisticLock(self.cache, key) as lock:
model = await self.cache.get(key) model: InferenceModel | None = await self.cache.get(key)
if model is None: if model is None:
model = from_model_type(model_type, model_name, **model_kwargs) model = from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl) await lock.cas(model, ttl=self.ttl)
return model return model
async def get_profiling(self) -> dict[str, float] | None: async def get_profiling(self) -> dict[str, float] | None:
if not hasattr(self.cache, "profiling"): if not has_profiling(self.cache):
return None return None
return self.cache.profiling # type: ignore return self.cache.profiling
class RevalidationPlugin(BasePlugin): class RevalidationPlugin(BasePlugin): # type: ignore[misc]
"""Revalidates cache item's TTL after cache hit.""" """Revalidates cache item's TTL after cache hit."""
async def post_get( async def post_get(

View file

@ -51,7 +51,7 @@ class BaseCLIPEncoder(InferenceModel):
provider_options=self.provider_options, provider_options=self.provider_options,
) )
def _predict(self, image_or_text: Image.Image | str) -> list[float]: def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
if isinstance(image_or_text, bytes): if isinstance(image_or_text, bytes):
image_or_text = Image.open(BytesIO(image_or_text)) image_or_text = Image.open(BytesIO(image_or_text))
@ -60,16 +60,16 @@ class BaseCLIPEncoder(InferenceModel):
if self.mode == "text": if self.mode == "text":
raise TypeError("Cannot encode image as text-only model") raise TypeError("Cannot encode image as text-only model")
outputs = self.vision_model.run(None, self.transform(image_or_text)) outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
case str(): case str():
if self.mode == "vision": if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model") raise TypeError("Cannot encode text as vision-only model")
outputs = self.text_model.run(None, self.tokenize(image_or_text)) outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
case _: case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}") raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
return outputs[0][0].tolist() return outputs
@abstractmethod @abstractmethod
def tokenize(self, text: str) -> dict[str, ndarray_i32]: def tokenize(self, text: str) -> dict[str, ndarray_i32]:
@ -151,11 +151,13 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
@cached_property @cached_property
def model_cfg(self) -> dict[str, Any]: def model_cfg(self) -> dict[str, Any]:
return json.load(self.model_cfg_path.open()) model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
return model_cfg
@cached_property @cached_property
def preprocess_cfg(self) -> dict[str, Any]: def preprocess_cfg(self) -> dict[str, Any]:
return json.load(self.preprocess_cfg_path.open()) preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
return preprocess_cfg
class MCLIPEncoder(OpenCLIPEncoder): class MCLIPEncoder(OpenCLIPEncoder):

View file

@ -8,7 +8,7 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
from insightface.utils.face_align import norm_crop from insightface.utils.face_align import norm_crop
from app.config import clean_name from app.config import clean_name
from app.schemas import ModelType, ndarray_f32 from app.schemas import BoundingBox, Face, ModelType, ndarray_f32
from .base import InferenceModel from .base import InferenceModel
@ -52,7 +52,7 @@ class FaceRecognizer(InferenceModel):
) )
self.rec_model.prepare(ctx_id=0) self.rec_model.prepare(ctx_id=0)
def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]: def _predict(self, image: ndarray_f32 | bytes) -> list[Face]:
if isinstance(image, bytes): if isinstance(image, bytes):
image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR) image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
bboxes, kpss = self.det_model.detect(image) bboxes, kpss = self.det_model.detect(image)
@ -67,9 +67,8 @@ class FaceRecognizer(InferenceModel):
height, width, _ = image.shape height, width, _ = image.shape
for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss): for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
cropped_img = norm_crop(image, kps) cropped_img = norm_crop(image, kps)
embedding = self.rec_model.get_feat(cropped_img)[0].tolist() embedding: ndarray_f32 = self.rec_model.get_feat(cropped_img)[0]
results.append( face: Face = {
{
"imageWidth": width, "imageWidth": width,
"imageHeight": height, "imageHeight": height,
"boundingBox": { "boundingBox": {
@ -81,7 +80,7 @@ class FaceRecognizer(InferenceModel):
"score": score, "score": score,
"embedding": embedding, "embedding": embedding,
} }
) results.append(face)
return results return results
@property @property

View file

@ -66,7 +66,7 @@ class ImageClassifier(InferenceModel):
def _predict(self, image: Image.Image | bytes) -> list[str]: def _predict(self, image: Image.Image | bytes) -> list[str]:
if isinstance(image, bytes): if isinstance(image, bytes):
image = Image.open(BytesIO(image)) image = Image.open(BytesIO(image))
predictions: list[dict[str, Any]] = self.model(image) # type: ignore predictions: list[dict[str, Any]] = self.model(image)
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score] tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
return tags return tags

View file

@ -1,17 +1,12 @@
from enum import StrEnum from enum import StrEnum
from typing import TypeAlias from typing import Any, Protocol, TypeAlias, TypedDict, TypeGuard
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
def to_lower_camel(string: str) -> str: ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
tokens = [token.capitalize() if i > 0 else token for i, token in enumerate(string.split("_"))] ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
return "".join(tokens)
class TextModelRequest(BaseModel):
text: str
class TextResponse(BaseModel): class TextResponse(BaseModel):
@ -22,7 +17,7 @@ class MessageResponse(BaseModel):
message: str message: str
class BoundingBox(BaseModel): class BoundingBox(TypedDict):
x1: int x1: int
y1: int y1: int
x2: int x2: int
@ -35,6 +30,17 @@ class ModelType(StrEnum):
FACIAL_RECOGNITION = "facial-recognition" FACIAL_RECOGNITION = "facial-recognition"
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]] class HasProfiling(Protocol):
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]] profiling: dict[str, float]
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
class Face(TypedDict):
boundingBox: BoundingBox
embedding: ndarray_f32
imageWidth: int
imageHeight: int
score: float
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
return hasattr(obj, "profiling") and type(obj.profiling) == dict

View file

@ -1,6 +1,7 @@
import tempfile import tempfile
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from math import e
from pathlib import Path from pathlib import Path
import open_clip import open_clip
@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
output_path = Path(output_path) output_path = Path(output_path)
def encode_image(image: torch.Tensor) -> torch.Tensor: def encode_image(image: torch.Tensor) -> torch.Tensor:
return model.encode_image(image, normalize=True) output = model.encode_image(image, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),) args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
traced = torch.jit.trace(encode_image, args) traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) warnings.simplefilter("ignore", UserWarning)
@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
output_path = Path(output_path) output_path = Path(output_path)
def encode_text(text: torch.Tensor) -> torch.Tensor: def encode_text(text: torch.Tensor) -> torch.Tensor:
return model.encode_text(text, normalize=True) output = model.encode_text(text, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),) args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
traced = torch.jit.trace(encode_text, args) traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) warnings.simplefilter("ignore", UserWarning)