2023-06-28 01:21:33 +02:00
|
|
|
from io import BytesIO
|
2023-08-06 04:45:13 +02:00
|
|
|
from typing import TypeAlias
|
2023-06-28 01:21:33 +02:00
|
|
|
from unittest import mock
|
|
|
|
|
|
|
|
import cv2
|
2023-08-06 04:45:13 +02:00
|
|
|
import numpy as np
|
2023-06-28 01:21:33 +02:00
|
|
|
import pytest
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from PIL import Image
|
2023-08-06 04:45:13 +02:00
|
|
|
from pytest_mock import MockerFixture
|
2023-06-28 01:21:33 +02:00
|
|
|
|
|
|
|
from .config import settings
|
|
|
|
from .models.cache import ModelCache
|
|
|
|
from .models.clip import CLIPSTEncoder
|
|
|
|
from .models.facial_recognition import FaceRecognizer
|
|
|
|
from .models.image_classification import ImageClassifier
|
|
|
|
from .schemas import ModelType
|
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
ndarray: TypeAlias = np.ndarray[int, np.dtype[np.float32]]
|
2023-06-28 01:21:33 +02:00
|
|
|
|
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
class TestImageClassifier:
|
|
|
|
classifier_preds = [
|
|
|
|
{"label": "that's an image alright", "score": 0.8},
|
|
|
|
{"label": "well it ends with .jpg", "score": 0.1},
|
|
|
|
{"label": "idk, im just seeing bytes", "score": 0.05},
|
|
|
|
{"label": "not sure", "score": 0.04},
|
|
|
|
{"label": "probably a virus", "score": 0.01},
|
|
|
|
]
|
|
|
|
|
|
|
|
def test_eager_init(self, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(ImageClassifier, "download")
|
|
|
|
mock_load = mocker.patch.object(ImageClassifier, "load")
|
|
|
|
classifier = ImageClassifier("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
|
|
|
|
|
|
|
|
assert classifier.model_name == "test_model_name"
|
|
|
|
mock_load.assert_called_once_with(test_arg="test_arg")
|
|
|
|
|
|
|
|
def test_lazy_init(self, mocker: MockerFixture) -> None:
|
|
|
|
mock_download = mocker.patch.object(ImageClassifier, "download")
|
|
|
|
mock_load = mocker.patch.object(ImageClassifier, "load")
|
|
|
|
face_model = ImageClassifier("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
|
|
|
|
|
|
|
|
assert face_model.model_name == "test_model_name"
|
|
|
|
mock_download.assert_called_once_with(test_arg="test_arg")
|
|
|
|
mock_load.assert_not_called()
|
|
|
|
|
|
|
|
def test_min_score(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(ImageClassifier, "load")
|
2023-06-28 01:21:33 +02:00
|
|
|
classifier = ImageClassifier("test_model_name", min_score=0.0)
|
2023-08-06 04:45:13 +02:00
|
|
|
assert classifier.min_score == 0.0
|
|
|
|
|
|
|
|
classifier.model = mock.Mock()
|
|
|
|
classifier.model.return_value = self.classifier_preds
|
|
|
|
|
2023-06-28 01:21:33 +02:00
|
|
|
all_labels = classifier.predict(pil_image)
|
|
|
|
classifier.min_score = 0.5
|
|
|
|
filtered_labels = classifier.predict(pil_image)
|
|
|
|
|
|
|
|
assert all_labels == [
|
|
|
|
"that's an image alright",
|
|
|
|
"well it ends with .jpg",
|
|
|
|
"idk",
|
|
|
|
"im just seeing bytes",
|
|
|
|
"not sure",
|
|
|
|
"probably a virus",
|
|
|
|
]
|
|
|
|
assert filtered_labels == ["that's an image alright"]
|
|
|
|
|
|
|
|
|
|
|
|
class TestCLIP:
|
2023-08-06 04:45:13 +02:00
|
|
|
embedding = np.random.rand(512).astype(np.float32)
|
|
|
|
|
|
|
|
def test_eager_init(self, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(CLIPSTEncoder, "download")
|
|
|
|
mock_load = mocker.patch.object(CLIPSTEncoder, "load")
|
|
|
|
clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
|
|
|
|
|
|
|
|
assert clip_model.model_name == "test_model_name"
|
|
|
|
mock_load.assert_called_once_with(test_arg="test_arg")
|
|
|
|
|
|
|
|
def test_lazy_init(self, mocker: MockerFixture) -> None:
|
|
|
|
mock_download = mocker.patch.object(CLIPSTEncoder, "download")
|
|
|
|
mock_load = mocker.patch.object(CLIPSTEncoder, "load")
|
|
|
|
clip_model = CLIPSTEncoder("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
|
2023-06-28 01:21:33 +02:00
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
assert clip_model.model_name == "test_model_name"
|
|
|
|
mock_download.assert_called_once_with(test_arg="test_arg")
|
|
|
|
mock_load.assert_not_called()
|
2023-06-28 01:21:33 +02:00
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
def test_basic_image(self, pil_image: Image.Image, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(CLIPSTEncoder, "load")
|
2023-06-28 01:21:33 +02:00
|
|
|
clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
|
2023-08-06 04:45:13 +02:00
|
|
|
clip_encoder.model = mock.Mock()
|
|
|
|
clip_encoder.model.encode.return_value = self.embedding
|
2023-06-28 01:21:33 +02:00
|
|
|
embedding = clip_encoder.predict(pil_image)
|
|
|
|
|
|
|
|
assert isinstance(embedding, list)
|
|
|
|
assert len(embedding) == 512
|
|
|
|
assert all([isinstance(num, float) for num in embedding])
|
2023-08-06 04:45:13 +02:00
|
|
|
clip_encoder.model.encode.assert_called_once()
|
2023-06-28 01:21:33 +02:00
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
def test_basic_text(self, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(CLIPSTEncoder, "load")
|
2023-06-28 01:21:33 +02:00
|
|
|
clip_encoder = CLIPSTEncoder("test_model_name", cache_dir="test_cache")
|
2023-08-06 04:45:13 +02:00
|
|
|
clip_encoder.model = mock.Mock()
|
|
|
|
clip_encoder.model.encode.return_value = self.embedding
|
2023-06-28 01:21:33 +02:00
|
|
|
embedding = clip_encoder.predict("test search query")
|
|
|
|
|
|
|
|
assert isinstance(embedding, list)
|
|
|
|
assert len(embedding) == 512
|
|
|
|
assert all([isinstance(num, float) for num in embedding])
|
2023-08-06 04:45:13 +02:00
|
|
|
clip_encoder.model.encode.assert_called_once()
|
2023-06-28 01:21:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
class TestFaceRecognition:
|
2023-08-06 04:45:13 +02:00
|
|
|
def test_eager_init(self, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(FaceRecognizer, "download")
|
|
|
|
mock_load = mocker.patch.object(FaceRecognizer, "load")
|
|
|
|
face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=True, test_arg="test_arg")
|
2023-06-28 01:21:33 +02:00
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
assert face_model.model_name == "test_model_name"
|
|
|
|
mock_load.assert_called_once_with(test_arg="test_arg")
|
|
|
|
|
|
|
|
def test_lazy_init(self, mocker: MockerFixture) -> None:
|
|
|
|
mock_download = mocker.patch.object(FaceRecognizer, "download")
|
|
|
|
mock_load = mocker.patch.object(FaceRecognizer, "load")
|
|
|
|
face_model = FaceRecognizer("test_model_name", cache_dir="test_cache", eager=False, test_arg="test_arg")
|
|
|
|
|
|
|
|
assert face_model.model_name == "test_model_name"
|
|
|
|
mock_download.assert_called_once_with(test_arg="test_arg")
|
|
|
|
mock_load.assert_not_called()
|
|
|
|
|
|
|
|
def test_set_min_score(self, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(FaceRecognizer, "load")
|
|
|
|
face_recognizer = FaceRecognizer("test_model_name", cache_dir="test_cache", min_score=0.5)
|
2023-06-28 01:21:33 +02:00
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
assert face_recognizer.min_score == 0.5
|
|
|
|
|
|
|
|
def test_basic(self, cv_image: cv2.Mat, mocker: MockerFixture) -> None:
|
|
|
|
mocker.patch.object(FaceRecognizer, "load")
|
2023-06-28 01:21:33 +02:00
|
|
|
face_recognizer = FaceRecognizer("test_model_name", min_score=0.0, cache_dir="test_cache")
|
2023-08-06 04:45:13 +02:00
|
|
|
|
|
|
|
det_model = mock.Mock()
|
|
|
|
num_faces = 2
|
|
|
|
bbox = np.random.rand(num_faces, 4).astype(np.float32)
|
|
|
|
score = np.array([[0.67]] * num_faces).astype(np.float32)
|
|
|
|
kpss = np.random.rand(num_faces, 5, 2).astype(np.float32)
|
|
|
|
det_model.detect.return_value = (np.concatenate([bbox, score], axis=-1), kpss)
|
|
|
|
face_recognizer.det_model = det_model
|
|
|
|
|
|
|
|
rec_model = mock.Mock()
|
|
|
|
embedding = np.random.rand(num_faces, 512).astype(np.float32)
|
|
|
|
rec_model.get_feat.return_value = embedding
|
|
|
|
face_recognizer.rec_model = rec_model
|
|
|
|
|
2023-06-28 01:21:33 +02:00
|
|
|
faces = face_recognizer.predict(cv_image)
|
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
assert len(faces) == num_faces
|
2023-06-28 01:21:33 +02:00
|
|
|
for face in faces:
|
|
|
|
assert face["imageHeight"] == 800
|
|
|
|
assert face["imageWidth"] == 600
|
|
|
|
assert isinstance(face["embedding"], list)
|
|
|
|
assert len(face["embedding"]) == 512
|
|
|
|
assert all([isinstance(num, float) for num in face["embedding"]])
|
|
|
|
|
2023-08-06 04:45:13 +02:00
|
|
|
det_model.detect.assert_called_once()
|
|
|
|
assert rec_model.get_feat.call_count == num_faces
|
2023-06-28 01:21:33 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
class TestCache:
|
|
|
|
async def test_caches(self, mock_get_model: mock.Mock) -> None:
|
|
|
|
model_cache = ModelCache()
|
|
|
|
await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
|
|
|
|
await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
|
|
|
|
assert len(model_cache.cache._cache) == 1
|
|
|
|
mock_get_model.assert_called_once()
|
|
|
|
|
|
|
|
async def test_kwargs_used(self, mock_get_model: mock.Mock) -> None:
|
|
|
|
model_cache = ModelCache()
|
|
|
|
await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION, cache_dir="test_cache")
|
|
|
|
mock_get_model.assert_called_once_with(
|
|
|
|
ModelType.IMAGE_CLASSIFICATION, "test_model_name", cache_dir="test_cache"
|
|
|
|
)
|
|
|
|
|
|
|
|
async def test_different_clip(self, mock_get_model: mock.Mock) -> None:
|
|
|
|
model_cache = ModelCache()
|
|
|
|
await model_cache.get("test_image_model_name", ModelType.CLIP)
|
|
|
|
await model_cache.get("test_text_model_name", ModelType.CLIP)
|
|
|
|
mock_get_model.assert_has_calls(
|
|
|
|
[
|
|
|
|
mock.call(ModelType.CLIP, "test_image_model_name"),
|
|
|
|
mock.call(ModelType.CLIP, "test_text_model_name"),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
assert len(model_cache.cache._cache) == 2
|
|
|
|
|
|
|
|
@mock.patch("app.models.cache.OptimisticLock", autospec=True)
|
|
|
|
async def test_model_ttl(self, mock_lock_cls: mock.Mock, mock_get_model: mock.Mock) -> None:
|
|
|
|
model_cache = ModelCache(ttl=100)
|
|
|
|
await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
|
|
|
|
mock_lock_cls.return_value.__aenter__.return_value.cas.assert_called_with(mock.ANY, ttl=100)
|
|
|
|
|
|
|
|
@mock.patch("app.models.cache.SimpleMemoryCache.expire")
|
|
|
|
async def test_revalidate(self, mock_cache_expire: mock.Mock, mock_get_model: mock.Mock) -> None:
|
|
|
|
model_cache = ModelCache(ttl=100, revalidate=True)
|
|
|
|
await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
|
|
|
|
await model_cache.get("test_model_name", ModelType.IMAGE_CLASSIFICATION)
|
|
|
|
mock_cache_expire.assert_called_once_with(mock.ANY, 100)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
not settings.test_full,
|
|
|
|
reason="More time-consuming since it deploys the app and loads models.",
|
|
|
|
)
|
|
|
|
class TestEndpoints:
|
|
|
|
def test_tagging_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
|
|
|
|
byte_image = BytesIO()
|
|
|
|
pil_image.save(byte_image, format="jpeg")
|
|
|
|
headers = {"Content-Type": "image/jpg"}
|
|
|
|
response = deployed_app.post(
|
|
|
|
"http://localhost:3003/image-classifier/tag-image",
|
|
|
|
content=byte_image.getvalue(),
|
|
|
|
headers=headers,
|
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
|
|
def test_clip_image_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
|
|
|
|
byte_image = BytesIO()
|
|
|
|
pil_image.save(byte_image, format="jpeg")
|
|
|
|
headers = {"Content-Type": "image/jpg"}
|
|
|
|
response = deployed_app.post(
|
|
|
|
"http://localhost:3003/sentence-transformer/encode-image",
|
|
|
|
content=byte_image.getvalue(),
|
|
|
|
headers=headers,
|
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
|
|
def test_clip_text_endpoint(self, deployed_app: TestClient) -> None:
|
|
|
|
response = deployed_app.post(
|
|
|
|
"http://localhost:3003/sentence-transformer/encode-text",
|
|
|
|
json={"text": "test search query"},
|
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
|
|
|
def test_face_endpoint(self, pil_image: Image.Image, deployed_app: TestClient) -> None:
|
|
|
|
byte_image = BytesIO()
|
|
|
|
pil_image.save(byte_image, format="jpeg")
|
|
|
|
headers = {"Content-Type": "image/jpg"}
|
|
|
|
response = deployed_app.post(
|
|
|
|
"http://localhost:3003/facial-recognition/detect-faces",
|
|
|
|
content=byte_image.getvalue(),
|
|
|
|
headers=headers,
|
|
|
|
)
|
|
|
|
assert response.status_code == 200
|