diff --git a/machine-learning/app/test_main.py b/machine-learning/app/test_main.py index e20a3e6c81..318c3b0453 100644 --- a/machine-learning/app/test_main.py +++ b/machine-learning/app/test_main.py @@ -75,9 +75,9 @@ class TestCLIP: embedding = clip_encoder.predict(pil_image) assert clip_encoder.mode == "vision" - assert isinstance(embedding, list) - assert len(embedding) == clip_model_cfg["embed_dim"] - assert all([isinstance(num, float) for num in embedding]) + assert isinstance(embedding, np.ndarray) + assert embedding.shape[0] == clip_model_cfg["embed_dim"] + assert embedding.dtype == np.float32 clip_encoder.vision_model.run.assert_called_once() def test_basic_text( @@ -97,9 +97,9 @@ class TestCLIP: embedding = clip_encoder.predict("test search query") assert clip_encoder.mode == "text" - assert isinstance(embedding, list) - assert len(embedding) == clip_model_cfg["embed_dim"] - assert all([isinstance(num, float) for num in embedding]) + assert isinstance(embedding, np.ndarray) + assert embedding.shape[0] == clip_model_cfg["embed_dim"] + assert embedding.dtype == np.float32 clip_encoder.text_model.run.assert_called_once() @@ -133,9 +133,9 @@ class TestFaceRecognition: for face in faces: assert face["imageHeight"] == 800 assert face["imageWidth"] == 600 - assert isinstance(face["embedding"], list) - assert len(face["embedding"]) == 512 - assert all([isinstance(num, float) for num in face["embedding"]]) + assert isinstance(face["embedding"], np.ndarray) + assert face["embedding"].shape[0] == 512 + assert face["embedding"].dtype == np.float32 det_model.detect.assert_called_once() assert rec_model.get_feat.call_count == num_faces