mirror of
https://github.com/immich-app/immich.git
synced 2024-12-29 15:11:58 +00:00
chore(ml): use strict mypy (#5001)
* improved typing * improved export typing * strict mypy & check export folder * formatting * add formatting checks for export folder * re-added init call
This commit is contained in:
parent
9fa9ad05b1
commit
935f471ccb
10 changed files with 70 additions and 55 deletions
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
|
@ -168,13 +168,13 @@ jobs:
|
||||||
poetry install --with dev
|
poetry install --with dev
|
||||||
- name: Lint with ruff
|
- name: Lint with ruff
|
||||||
run: |
|
run: |
|
||||||
poetry run ruff check --format=github app
|
poetry run ruff check --format=github app export
|
||||||
- name: Check black formatting
|
- name: Check black formatting
|
||||||
run: |
|
run: |
|
||||||
poetry run black --check app
|
poetry run black --check app export
|
||||||
- name: Run mypy type checking
|
- name: Run mypy type checking
|
||||||
run: |
|
run: |
|
||||||
poetry run mypy --install-types --non-interactive app/
|
poetry run mypy --install-types --non-interactive --strict app/ export/
|
||||||
- name: Run tests and coverage
|
- name: Run tests and coverage
|
||||||
run: |
|
run: |
|
||||||
poetry run pytest --cov app
|
poetry run pytest --cov app
|
||||||
|
|
|
@ -36,7 +36,8 @@ def deployed_app() -> TestClient:
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def responses() -> dict[str, Any]:
|
def responses() -> dict[str, Any]:
|
||||||
return json.load(open("responses.json", "r"))
|
responses: dict[str, Any] = json.load(open("responses.json", "r"))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -7,7 +7,7 @@ from zipfile import BadZipFile
|
||||||
import orjson
|
import orjson
|
||||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
|
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||||
from starlette.formparsers import MultiPartParser
|
from starlette.formparsers import MultiPartParser
|
||||||
|
|
||||||
from app.models.base import InferenceModel
|
from app.models.base import InferenceModel
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from typing_extensions import Buffer
|
||||||
|
|
||||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType
|
||||||
|
@ -139,11 +140,12 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
|
|
||||||
# HF deep copies configs, so we need to make session options picklable
|
# HF deep copies configs, so we need to make session options picklable
|
||||||
class PicklableSessionOptions(ort.SessionOptions):
|
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
||||||
def __getstate__(self) -> bytes:
|
def __getstate__(self) -> bytes:
|
||||||
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
||||||
|
|
||||||
def __setstate__(self, state: Any) -> None:
|
def __setstate__(self, state: Buffer) -> None:
|
||||||
self.__init__() # type: ignore
|
self.__init__() # type: ignore[misc]
|
||||||
for attr, val in pickle.loads(state):
|
attrs: list[tuple[str, Any]] = pickle.loads(state)
|
||||||
|
for attr, val in attrs:
|
||||||
setattr(self, attr, val)
|
setattr(self, attr, val)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from aiocache.plugins import BasePlugin, TimingPlugin
|
||||||
|
|
||||||
from app.models import from_model_type
|
from app.models import from_model_type
|
||||||
|
|
||||||
from ..schemas import ModelType
|
from ..schemas import ModelType, has_profiling
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,20 +50,20 @@ class ModelCache:
|
||||||
|
|
||||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
||||||
async with OptimisticLock(self.cache, key) as lock:
|
async with OptimisticLock(self.cache, key) as lock:
|
||||||
model = await self.cache.get(key)
|
model: InferenceModel | None = await self.cache.get(key)
|
||||||
if model is None:
|
if model is None:
|
||||||
model = from_model_type(model_type, model_name, **model_kwargs)
|
model = from_model_type(model_type, model_name, **model_kwargs)
|
||||||
await lock.cas(model, ttl=self.ttl)
|
await lock.cas(model, ttl=self.ttl)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def get_profiling(self) -> dict[str, float] | None:
|
async def get_profiling(self) -> dict[str, float] | None:
|
||||||
if not hasattr(self.cache, "profiling"):
|
if not has_profiling(self.cache):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.cache.profiling # type: ignore
|
return self.cache.profiling
|
||||||
|
|
||||||
|
|
||||||
class RevalidationPlugin(BasePlugin):
|
class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
||||||
"""Revalidates cache item's TTL after cache hit."""
|
"""Revalidates cache item's TTL after cache hit."""
|
||||||
|
|
||||||
async def post_get(
|
async def post_get(
|
||||||
|
|
|
@ -51,7 +51,7 @@ class BaseCLIPEncoder(InferenceModel):
|
||||||
provider_options=self.provider_options,
|
provider_options=self.provider_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _predict(self, image_or_text: Image.Image | str) -> list[float]:
|
def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
|
||||||
if isinstance(image_or_text, bytes):
|
if isinstance(image_or_text, bytes):
|
||||||
image_or_text = Image.open(BytesIO(image_or_text))
|
image_or_text = Image.open(BytesIO(image_or_text))
|
||||||
|
|
||||||
|
@ -60,16 +60,16 @@ class BaseCLIPEncoder(InferenceModel):
|
||||||
if self.mode == "text":
|
if self.mode == "text":
|
||||||
raise TypeError("Cannot encode image as text-only model")
|
raise TypeError("Cannot encode image as text-only model")
|
||||||
|
|
||||||
outputs = self.vision_model.run(None, self.transform(image_or_text))
|
outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
|
||||||
case str():
|
case str():
|
||||||
if self.mode == "vision":
|
if self.mode == "vision":
|
||||||
raise TypeError("Cannot encode text as vision-only model")
|
raise TypeError("Cannot encode text as vision-only model")
|
||||||
|
|
||||||
outputs = self.text_model.run(None, self.tokenize(image_or_text))
|
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
|
||||||
case _:
|
case _:
|
||||||
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")
|
||||||
|
|
||||||
return outputs[0][0].tolist()
|
return outputs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
|
def tokenize(self, text: str) -> dict[str, ndarray_i32]:
|
||||||
|
@ -151,11 +151,13 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model_cfg(self) -> dict[str, Any]:
|
def model_cfg(self) -> dict[str, Any]:
|
||||||
return json.load(self.model_cfg_path.open())
|
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||||
|
return model_cfg
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def preprocess_cfg(self) -> dict[str, Any]:
|
def preprocess_cfg(self) -> dict[str, Any]:
|
||||||
return json.load(self.preprocess_cfg_path.open())
|
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
||||||
|
return preprocess_cfg
|
||||||
|
|
||||||
|
|
||||||
class MCLIPEncoder(OpenCLIPEncoder):
|
class MCLIPEncoder(OpenCLIPEncoder):
|
||||||
|
|
|
@ -8,7 +8,7 @@ from insightface.model_zoo import ArcFaceONNX, RetinaFace
|
||||||
from insightface.utils.face_align import norm_crop
|
from insightface.utils.face_align import norm_crop
|
||||||
|
|
||||||
from app.config import clean_name
|
from app.config import clean_name
|
||||||
from app.schemas import ModelType, ndarray_f32
|
from app.schemas import BoundingBox, Face, ModelType, ndarray_f32
|
||||||
|
|
||||||
from .base import InferenceModel
|
from .base import InferenceModel
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ class FaceRecognizer(InferenceModel):
|
||||||
)
|
)
|
||||||
self.rec_model.prepare(ctx_id=0)
|
self.rec_model.prepare(ctx_id=0)
|
||||||
|
|
||||||
def _predict(self, image: ndarray_f32 | bytes) -> list[dict[str, Any]]:
|
def _predict(self, image: ndarray_f32 | bytes) -> list[Face]:
|
||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
image = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
|
||||||
bboxes, kpss = self.det_model.detect(image)
|
bboxes, kpss = self.det_model.detect(image)
|
||||||
|
@ -67,21 +67,20 @@ class FaceRecognizer(InferenceModel):
|
||||||
height, width, _ = image.shape
|
height, width, _ = image.shape
|
||||||
for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
|
for (x1, y1, x2, y2), score, kps in zip(bboxes, scores, kpss):
|
||||||
cropped_img = norm_crop(image, kps)
|
cropped_img = norm_crop(image, kps)
|
||||||
embedding = self.rec_model.get_feat(cropped_img)[0].tolist()
|
embedding: ndarray_f32 = self.rec_model.get_feat(cropped_img)[0]
|
||||||
results.append(
|
face: Face = {
|
||||||
{
|
"imageWidth": width,
|
||||||
"imageWidth": width,
|
"imageHeight": height,
|
||||||
"imageHeight": height,
|
"boundingBox": {
|
||||||
"boundingBox": {
|
"x1": x1,
|
||||||
"x1": x1,
|
"y1": y1,
|
||||||
"y1": y1,
|
"x2": x2,
|
||||||
"x2": x2,
|
"y2": y2,
|
||||||
"y2": y2,
|
},
|
||||||
},
|
"score": score,
|
||||||
"score": score,
|
"embedding": embedding,
|
||||||
"embedding": embedding,
|
}
|
||||||
}
|
results.append(face)
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -66,7 +66,7 @@ class ImageClassifier(InferenceModel):
|
||||||
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
def _predict(self, image: Image.Image | bytes) -> list[str]:
|
||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
image = Image.open(BytesIO(image))
|
image = Image.open(BytesIO(image))
|
||||||
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
|
predictions: list[dict[str, Any]] = self.model(image)
|
||||||
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
|
||||||
|
|
||||||
return tags
|
return tags
|
||||||
|
|
|
@ -1,17 +1,12 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import TypeAlias
|
from typing import Any, Protocol, TypeAlias, TypedDict, TypeGuard
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
||||||
def to_lower_camel(string: str) -> str:
|
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
|
||||||
tokens = [token.capitalize() if i > 0 else token for i, token in enumerate(string.split("_"))]
|
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
|
||||||
return "".join(tokens)
|
|
||||||
|
|
||||||
|
|
||||||
class TextModelRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class TextResponse(BaseModel):
|
class TextResponse(BaseModel):
|
||||||
|
@ -22,7 +17,7 @@ class MessageResponse(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
class BoundingBox(BaseModel):
|
class BoundingBox(TypedDict):
|
||||||
x1: int
|
x1: int
|
||||||
y1: int
|
y1: int
|
||||||
x2: int
|
x2: int
|
||||||
|
@ -35,6 +30,17 @@ class ModelType(StrEnum):
|
||||||
FACIAL_RECOGNITION = "facial-recognition"
|
FACIAL_RECOGNITION = "facial-recognition"
|
||||||
|
|
||||||
|
|
||||||
ndarray_f32: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
class HasProfiling(Protocol):
|
||||||
ndarray_i64: TypeAlias = np.ndarray[int, np.dtype[np.int64]]
|
profiling: dict[str, float]
|
||||||
ndarray_i32: TypeAlias = np.ndarray[int, np.dtype[np.int32]]
|
|
||||||
|
|
||||||
|
class Face(TypedDict):
|
||||||
|
boundingBox: BoundingBox
|
||||||
|
embedding: ndarray_f32
|
||||||
|
imageWidth: int
|
||||||
|
imageHeight: int
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
|
||||||
|
return hasattr(obj, "profiling") and type(obj.profiling) == dict
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from math import e
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import open_clip
|
import open_clip
|
||||||
|
@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
|
||||||
def encode_image(image: torch.Tensor) -> torch.Tensor:
|
def encode_image(image: torch.Tensor) -> torch.Tensor:
|
||||||
return model.encode_image(image, normalize=True)
|
output = model.encode_image(image, normalize=True)
|
||||||
|
assert isinstance(output, torch.Tensor)
|
||||||
|
return output
|
||||||
|
|
||||||
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
|
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
|
||||||
traced = torch.jit.trace(encode_image, args)
|
traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
|
@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
|
||||||
def encode_text(text: torch.Tensor) -> torch.Tensor:
|
def encode_text(text: torch.Tensor) -> torch.Tensor:
|
||||||
return model.encode_text(text, normalize=True)
|
output = model.encode_text(text, normalize=True)
|
||||||
|
assert isinstance(output, torch.Tensor)
|
||||||
|
return output
|
||||||
|
|
||||||
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
|
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
|
||||||
traced = torch.jit.trace(encode_text, args)
|
traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
|
|
Loading…
Reference in a new issue