mirror of
https://github.com/immich-app/immich.git
synced 2025-01-23 04:02:45 +01:00
feat(ml): improve test coverage (#7041)
* update e2e * tokenizer tests * more tests, remove unnecessary code * fix e2e setting * add tests for loading model * update workflow * fixed test
This commit is contained in:
parent
6e853e2a9d
commit
0c4df216d7
8 changed files with 501 additions and 1636 deletions
.github/workflows
machine-learning
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
|
@ -247,7 +247,7 @@ jobs:
|
||||||
poetry run mypy --install-types --non-interactive --strict app/
|
poetry run mypy --install-types --non-interactive --strict app/
|
||||||
- name: Run tests and coverage
|
- name: Run tests and coverage
|
||||||
run: |
|
run: |
|
||||||
poetry run pytest --cov app
|
poetry run pytest app --cov=app --cov-report term-missing
|
||||||
|
|
||||||
generated-api-up-to-date:
|
generated-api-up-to-date:
|
||||||
name: OpenAPI Clients
|
name: OpenAPI Clients
|
||||||
|
|
|
@ -119,16 +119,12 @@ async def load(model: InferenceModel) -> InferenceModel:
|
||||||
if model.loaded:
|
if model.loaded:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load() -> None:
|
def _load(model: InferenceModel) -> None:
|
||||||
with lock:
|
with lock:
|
||||||
model.load()
|
model.load()
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
try:
|
try:
|
||||||
if thread_pool is None:
|
await run(_load, model)
|
||||||
model.load()
|
|
||||||
else:
|
|
||||||
await loop.run_in_executor(thread_pool, _load)
|
|
||||||
return model
|
return model
|
||||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||||
log.warning(
|
log.warning(
|
||||||
|
@ -138,10 +134,7 @@ async def load(model: InferenceModel) -> InferenceModel:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
model.clear_cache()
|
model.clear_cache()
|
||||||
if thread_pool is None:
|
await run(_load, model)
|
||||||
model.load()
|
|
||||||
else:
|
|
||||||
await loop.run_in_executor(thread_pool, _load)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,4 +21,4 @@ def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown model type {model_type}")
|
raise ValueError(f"Unknown model type {model_type}")
|
||||||
|
|
||||||
raise ValueError(f"Unknown ${model_type} model {model_name}")
|
raise ValueError(f"Unknown {model_type} model {model_name}")
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pickle
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
@ -11,7 +10,6 @@ import onnxruntime as ort
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||||
from typing_extensions import Buffer
|
|
||||||
|
|
||||||
import ann.ann
|
import ann.ann
|
||||||
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
||||||
|
@ -200,7 +198,7 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
@providers.setter
|
@providers.setter
|
||||||
def providers(self, providers: list[str]) -> None:
|
def providers(self, providers: list[str]) -> None:
|
||||||
log.debug(
|
log.info(
|
||||||
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
|
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
|
||||||
)
|
)
|
||||||
self._providers = providers
|
self._providers = providers
|
||||||
|
@ -217,7 +215,7 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
@provider_options.setter
|
@provider_options.setter
|
||||||
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
||||||
log.info(f"Setting execution provider options to {provider_options}")
|
log.debug(f"Setting execution provider options to {provider_options}")
|
||||||
self._provider_options = provider_options
|
self._provider_options = provider_options
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -255,7 +253,7 @@ class InferenceModel(ABC):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sess_options_default(self) -> ort.SessionOptions:
|
def sess_options_default(self) -> ort.SessionOptions:
|
||||||
sess_options = PicklableSessionOptions()
|
sess_options = ort.SessionOptions()
|
||||||
sess_options.enable_cpu_mem_arena = False
|
sess_options.enable_cpu_mem_arena = False
|
||||||
|
|
||||||
# avoid thread contention between models
|
# avoid thread contention between models
|
||||||
|
@ -287,15 +285,3 @@ class InferenceModel(ABC):
|
||||||
@property
|
@property
|
||||||
def preferred_runtime_default(self) -> ModelRuntime:
|
def preferred_runtime_default(self) -> ModelRuntime:
|
||||||
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
||||||
|
|
||||||
|
|
||||||
# HF deep copies configs, so we need to make session options picklable
|
|
||||||
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
|
||||||
def __getstate__(self) -> bytes:
|
|
||||||
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
|
||||||
|
|
||||||
def __setstate__(self, state: Buffer) -> None:
|
|
||||||
self.__init__() # type: ignore[misc]
|
|
||||||
attrs: list[tuple[str, Any]] = pickle.loads(state)
|
|
||||||
for attr, val in attrs:
|
|
||||||
setattr(self, attr, val)
|
|
||||||
|
|
|
@ -80,20 +80,3 @@ class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
||||||
key = client.build_key(key, namespace)
|
key = client.build_key(key, namespace)
|
||||||
if key in client._handlers:
|
if key in client._handlers:
|
||||||
await client.expire(key, client.ttl)
|
await client.expire(key, client.ttl)
|
||||||
|
|
||||||
async def post_multi_get(
|
|
||||||
self,
|
|
||||||
client: SimpleMemoryCache,
|
|
||||||
keys: list[str],
|
|
||||||
ret: list[Any] | None = None,
|
|
||||||
namespace: str | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
if ret is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
for key, val in zip(keys, ret):
|
|
||||||
if namespace is not None:
|
|
||||||
key = client.build_key(key, namespace)
|
|
||||||
if val is not None and key in client._handlers:
|
|
||||||
await client.expire(key, client.ttl)
|
|
||||||
|
|
|
@ -144,9 +144,7 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
||||||
|
|
||||||
def _load(self) -> None:
|
def _load(self) -> None:
|
||||||
super()._load()
|
super()._load()
|
||||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
self._load_tokenizer()
|
||||||
context_length: int = text_cfg.get("context_length", 77)
|
|
||||||
pad_token: int = self.tokenizer_cfg["pad_token"]
|
|
||||||
|
|
||||||
size: list[int] | int = self.preprocess_cfg["size"]
|
size: list[int] | int = self.preprocess_cfg["size"]
|
||||||
self.size = size[0] if isinstance(size, list) else size
|
self.size = size[0] if isinstance(size, list) else size
|
||||||
|
@ -155,11 +153,19 @@ class OpenCLIPEncoder(BaseCLIPEncoder):
|
||||||
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
||||||
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
||||||
|
|
||||||
|
def _load_tokenizer(self) -> Tokenizer:
|
||||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||||
|
|
||||||
|
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||||
|
context_length: int = text_cfg.get("context_length", 77)
|
||||||
|
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||||
|
|
||||||
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
self.tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||||
|
|
||||||
pad_id: int = self.tokenizer.token_to_id(pad_token)
|
pad_id: int = self.tokenizer.token_to_id(pad_token)
|
||||||
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
self.tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
||||||
self.tokenizer.enable_truncation(max_length=context_length)
|
self.tokenizer.enable_truncation(max_length=context_length)
|
||||||
|
|
||||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||||
|
|
||||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import json
|
import json
|
||||||
import pickle
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from random import randint
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
@ -13,10 +14,12 @@ from fastapi.testclient import TestClient
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from app.main import load
|
||||||
|
|
||||||
from .config import log, settings
|
from .config import log, settings
|
||||||
from .models.base import InferenceModel, PicklableSessionOptions
|
from .models.base import InferenceModel
|
||||||
from .models.cache import ModelCache
|
from .models.cache import ModelCache
|
||||||
from .models.clip import OpenCLIPEncoder
|
from .models.clip import MCLIPEncoder, OpenCLIPEncoder
|
||||||
from .models.facial_recognition import FaceRecognizer
|
from .models.facial_recognition import FaceRecognizer
|
||||||
from .schemas import ModelRuntime, ModelType
|
from .schemas import ModelRuntime, ModelType
|
||||||
|
|
||||||
|
@ -72,6 +75,17 @@ class TestBase:
|
||||||
{"arena_extend_strategy": "kSameAsRequested"},
|
{"arena_extend_strategy": "kSameAsRequested"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_sets_openvino_device_id_if_possible(self, mocker: MockerFixture) -> None:
|
||||||
|
mocked = mocker.patch("app.models.base.ort.capi._pybind_state")
|
||||||
|
mocked.get_available_openvino_device_ids.return_value = ["GPU.0", "CPU"]
|
||||||
|
|
||||||
|
encoder = OpenCLIPEncoder("ViT-B-32__openai", providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"])
|
||||||
|
|
||||||
|
assert encoder.provider_options == [
|
||||||
|
{"device_id": "GPU.0"},
|
||||||
|
{"arena_extend_strategy": "kSameAsRequested"},
|
||||||
|
]
|
||||||
|
|
||||||
def test_sets_provider_options_kwarg(self) -> None:
|
def test_sets_provider_options_kwarg(self) -> None:
|
||||||
encoder = OpenCLIPEncoder(
|
encoder = OpenCLIPEncoder(
|
||||||
"ViT-B-32__openai",
|
"ViT-B-32__openai",
|
||||||
|
@ -119,7 +133,7 @@ class TestBase:
|
||||||
def test_sets_default_cache_dir(self) -> None:
|
def test_sets_default_cache_dir(self) -> None:
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
||||||
|
|
||||||
assert encoder.cache_dir == Path("/cache/clip/ViT-B-32__openai")
|
assert encoder.cache_dir == Path(settings.cache_folder) / "clip" / "ViT-B-32__openai"
|
||||||
|
|
||||||
def test_sets_cache_dir_kwarg(self) -> None:
|
def test_sets_cache_dir_kwarg(self) -> None:
|
||||||
cache_dir = Path("/test_cache")
|
cache_dir = Path("/test_cache")
|
||||||
|
@ -170,7 +184,7 @@ class TestBase:
|
||||||
encoder.clear_cache()
|
encoder.clear_cache()
|
||||||
|
|
||||||
mock_rmtree.assert_called_once_with(encoder.cache_dir)
|
mock_rmtree.assert_called_once_with(encoder.cache_dir)
|
||||||
assert info.call_count == 2
|
info.assert_called_with(f"Cleared cache directory for model '{encoder.model_name}'.")
|
||||||
|
|
||||||
def test_clear_cache_warns_if_path_does_not_exist(self, mocker: MockerFixture) -> None:
|
def test_clear_cache_warns_if_path_does_not_exist(self, mocker: MockerFixture) -> None:
|
||||||
mock_rmtree = mocker.patch("app.models.base.rmtree", autospec=True)
|
mock_rmtree = mocker.patch("app.models.base.rmtree", autospec=True)
|
||||||
|
@ -267,7 +281,7 @@ class TestBase:
|
||||||
def test_download(self, mocker: MockerFixture) -> None:
|
def test_download(self, mocker: MockerFixture) -> None:
|
||||||
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
mock_snapshot_download = mocker.patch("app.models.base.snapshot_download")
|
||||||
|
|
||||||
encoder = OpenCLIPEncoder("ViT-B-32__openai")
|
encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="/path/to/cache")
|
||||||
encoder.download()
|
encoder.download()
|
||||||
|
|
||||||
mock_snapshot_download.assert_called_once_with(
|
mock_snapshot_download.assert_called_once_with(
|
||||||
|
@ -348,6 +362,60 @@ class TestCLIP:
|
||||||
assert embedding.dtype == np.float32
|
assert embedding.dtype == np.float32
|
||||||
mocked.run.assert_called_once()
|
mocked.run.assert_called_once()
|
||||||
|
|
||||||
|
def test_openclip_tokenizer(
|
||||||
|
self,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
clip_model_cfg: dict[str, Any],
|
||||||
|
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
||||||
|
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "download")
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||||
|
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
|
||||||
|
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||||
|
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids)
|
||||||
|
|
||||||
|
clip_encoder = OpenCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
|
||||||
|
clip_encoder._load_tokenizer()
|
||||||
|
tokens = clip_encoder.tokenize("test search query")
|
||||||
|
|
||||||
|
assert "text" in tokens
|
||||||
|
assert isinstance(tokens["text"], np.ndarray)
|
||||||
|
assert tokens["text"].shape == (1, 77)
|
||||||
|
assert tokens["text"].dtype == np.int32
|
||||||
|
assert np.allclose(tokens["text"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||||
|
|
||||||
|
def test_mclip_tokenizer(
|
||||||
|
self,
|
||||||
|
mocker: MockerFixture,
|
||||||
|
clip_model_cfg: dict[str, Any],
|
||||||
|
clip_preprocess_cfg: Callable[[Path], dict[str, Any]],
|
||||||
|
clip_tokenizer_cfg: Callable[[Path], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "download")
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
|
||||||
|
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
|
||||||
|
mock_tokenizer = mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True).return_value
|
||||||
|
mock_ids = [randint(0, 50000) for _ in range(77)]
|
||||||
|
mock_attention_mask = [randint(0, 1) for _ in range(77)]
|
||||||
|
mock_tokenizer.encode.return_value = SimpleNamespace(ids=mock_ids, attention_mask=mock_attention_mask)
|
||||||
|
|
||||||
|
clip_encoder = MCLIPEncoder("ViT-B-32__openai", cache_dir="test_cache", mode="text")
|
||||||
|
clip_encoder._load_tokenizer()
|
||||||
|
tokens = clip_encoder.tokenize("test search query")
|
||||||
|
|
||||||
|
assert "input_ids" in tokens
|
||||||
|
assert "attention_mask" in tokens
|
||||||
|
assert isinstance(tokens["input_ids"], np.ndarray)
|
||||||
|
assert isinstance(tokens["attention_mask"], np.ndarray)
|
||||||
|
assert tokens["input_ids"].shape == (1, 77)
|
||||||
|
assert tokens["attention_mask"].shape == (1, 77)
|
||||||
|
assert np.allclose(tokens["input_ids"], np.array([mock_ids], dtype=np.int32), atol=0)
|
||||||
|
assert np.allclose(tokens["attention_mask"], np.array([mock_attention_mask], dtype=np.int32), atol=0)
|
||||||
|
|
||||||
|
|
||||||
class TestFaceRecognition:
|
class TestFaceRecognition:
|
||||||
def test_set_min_score(self, mocker: MockerFixture) -> None:
|
def test_set_min_score(self, mocker: MockerFixture) -> None:
|
||||||
|
@ -420,12 +488,75 @@ class TestCache:
|
||||||
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
||||||
|
|
||||||
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
||||||
async def test_revalidate(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
async def test_revalidate_get(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
||||||
model_cache = ModelCache(ttl=100, revalidate=True)
|
model_cache = ModelCache(ttl=100, revalidate=True)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||||
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||||
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
||||||
|
|
||||||
|
async def test_profiling(self, mock_get_model: mock.Mock) -> None:
|
||||||
|
model_cache = ModelCache(ttl=100, profiling=True)
|
||||||
|
await model_cache.get("test_model_name", ModelType.FACIAL_RECOGNITION)
|
||||||
|
profiling = await model_cache.get_profiling()
|
||||||
|
assert isinstance(profiling, dict)
|
||||||
|
assert profiling == model_cache.cache.profiling
|
||||||
|
|
||||||
|
async def test_loads_mclip(self) -> None:
|
||||||
|
model_cache = ModelCache()
|
||||||
|
|
||||||
|
model = await model_cache.get("XLM-Roberta-Large-Vit-B-32", ModelType.CLIP, mode="text")
|
||||||
|
|
||||||
|
assert isinstance(model, MCLIPEncoder)
|
||||||
|
assert model.model_name == "XLM-Roberta-Large-Vit-B-32"
|
||||||
|
|
||||||
|
async def test_raises_exception_if_invalid_model_type(self) -> None:
|
||||||
|
invalid: Any = SimpleNamespace(value="invalid")
|
||||||
|
model_cache = ModelCache()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await model_cache.get("XLM-Roberta-Large-Vit-B-32", invalid, mode="text")
|
||||||
|
|
||||||
|
async def test_raises_exception_if_unknown_model_name(self) -> None:
|
||||||
|
model_cache = ModelCache()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await model_cache.get("test_model_name", ModelType.CLIP, mode="text")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestLoad:
|
||||||
|
async def test_load(self) -> None:
|
||||||
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
|
mock_model.loaded = False
|
||||||
|
|
||||||
|
res = await load(mock_model)
|
||||||
|
|
||||||
|
assert res is mock_model
|
||||||
|
mock_model.load.assert_called_once()
|
||||||
|
mock_model.clear_cache.assert_not_called()
|
||||||
|
|
||||||
|
async def test_load_returns_model_if_loaded(self) -> None:
|
||||||
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
|
mock_model.loaded = True
|
||||||
|
|
||||||
|
res = await load(mock_model)
|
||||||
|
|
||||||
|
assert res is mock_model
|
||||||
|
mock_model.load.assert_not_called()
|
||||||
|
|
||||||
|
async def test_load_clears_cache_and_retries_if_os_error(self) -> None:
|
||||||
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
|
mock_model.model_name = "test_model_name"
|
||||||
|
mock_model.model_type = ModelType.CLIP
|
||||||
|
mock_model.load.side_effect = [OSError, None]
|
||||||
|
mock_model.loaded = False
|
||||||
|
|
||||||
|
res = await load(mock_model)
|
||||||
|
|
||||||
|
assert res is mock_model
|
||||||
|
mock_model.clear_cache.assert_called_once()
|
||||||
|
assert mock_model.load.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not settings.test_full,
|
not settings.test_full,
|
||||||
|
@ -437,15 +568,21 @@ class TestEndpoints:
|
||||||
) -> None:
|
) -> None:
|
||||||
byte_image = BytesIO()
|
byte_image = BytesIO()
|
||||||
pil_image.save(byte_image, format="jpeg")
|
pil_image.save(byte_image, format="jpeg")
|
||||||
|
expected = responses["clip"]["image"]
|
||||||
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/predict",
|
"http://localhost:3003/predict",
|
||||||
data={"modelName": "ViT-B-32__openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})},
|
data={"modelName": "ViT-B-32__openai", "modelType": "clip", "options": json.dumps({"mode": "vision"})},
|
||||||
files={"image": byte_image.getvalue()},
|
files={"image": byte_image.getvalue()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual = response.json()
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == responses["clip"]["image"]
|
assert np.allclose(expected, actual)
|
||||||
|
|
||||||
def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
def test_clip_text_endpoint(self, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||||
|
expected = responses["clip"]["text"]
|
||||||
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/predict",
|
"http://localhost:3003/predict",
|
||||||
data={
|
data={
|
||||||
|
@ -455,12 +592,15 @@ class TestEndpoints:
|
||||||
"options": json.dumps({"mode": "text"}),
|
"options": json.dumps({"mode": "text"}),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual = response.json()
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == responses["clip"]["text"]
|
assert np.allclose(expected, actual)
|
||||||
|
|
||||||
def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
def test_face_endpoint(self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient) -> None:
|
||||||
byte_image = BytesIO()
|
byte_image = BytesIO()
|
||||||
pil_image.save(byte_image, format="jpeg")
|
pil_image.save(byte_image, format="jpeg")
|
||||||
|
expected = responses["facial-recognition"]
|
||||||
|
|
||||||
response = deployed_app.post(
|
response = deployed_app.post(
|
||||||
"http://localhost:3003/predict",
|
"http://localhost:3003/predict",
|
||||||
|
@ -471,15 +611,13 @@ class TestEndpoints:
|
||||||
},
|
},
|
||||||
files={"image": byte_image.getvalue()},
|
files={"image": byte_image.getvalue()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual = response.json()
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == responses["facial-recognition"]
|
assert len(expected) == len(actual)
|
||||||
|
for expected_face, actual_face in zip(expected, actual):
|
||||||
|
assert expected_face["imageHeight"] == actual_face["imageHeight"]
|
||||||
def test_sess_options() -> None:
|
assert expected_face["imageWidth"] == actual_face["imageWidth"]
|
||||||
sess_options = PicklableSessionOptions()
|
assert expected_face["boundingBox"] == actual_face["boundingBox"]
|
||||||
sess_options.intra_op_num_threads = 1
|
assert np.allclose(expected_face["embedding"], actual_face["embedding"])
|
||||||
sess_options.inter_op_num_threads = 1
|
assert np.allclose(expected_face["score"], actual_face["score"])
|
||||||
pickled = pickle.dumps(sess_options)
|
|
||||||
unpickled = pickle.loads(pickled)
|
|
||||||
assert unpickled.intra_op_num_threads == 1
|
|
||||||
assert unpickled.inter_op_num_threads == 1
|
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue