1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2024-12-28 22:51:59 +00:00
immich/machine-learning/locustfile.py

93 lines
3.3 KiB
Python
Raw Normal View History

2023-06-25 18:20:45 +00:00
from io import BytesIO
import json
from typing import Any
2023-06-25 18:20:45 +00:00
from locust import HttpUser, events, task
from locust.env import Environment
2023-06-25 18:20:45 +00:00
from PIL import Image
from argparse import ArgumentParser
byte_image = BytesIO()
@events.init_command_line_parser.add_listener
def _(parser: ArgumentParser) -> None:
parser.add_argument("--tag-model", type=str, default="microsoft/resnet-50")
parser.add_argument("--clip-model", type=str, default="ViT-B-32::openai")
parser.add_argument("--face-model", type=str, default="buffalo_l")
parser.add_argument("--tag-min-score", type=int, default=0.0,
help="Returns all tags at or above this score. The default returns all tags.")
parser.add_argument("--face-min-score", type=int, default=0.034,
help=("Returns all faces at or above this score. The default returns 1 face per request; "
"setting this to 0 blows up the number of faces to the thousands."))
parser.add_argument("--image-size", type=int, default=1000)
2023-06-25 18:20:45 +00:00
@events.test_start.add_listener
def on_test_start(environment: Environment, **kwargs: Any) -> None:
2023-06-25 18:20:45 +00:00
global byte_image
assert environment.parsed_options is not None
image = Image.new("RGB", (environment.parsed_options.image_size, environment.parsed_options.image_size))
2023-06-25 18:20:45 +00:00
byte_image = BytesIO()
image.save(byte_image, format="jpeg")
class InferenceLoadTest(HttpUser):
abstract: bool = True
host = "http://127.0.0.1:3003"
data: bytes
headers: dict[str, str] = {"Content-Type": "image/jpg"}
# re-use the image across all instances in a process
def on_start(self) -> None:
2023-06-25 18:20:45 +00:00
global byte_image
self.data = byte_image.getvalue()
class ClassificationFormDataLoadTest(InferenceLoadTest):
@task
def classify(self) -> None:
data = [
("modelName", self.environment.parsed_options.clip_model),
("modelType", "clip"),
("options", json.dumps({"minScore": self.environment.parsed_options.tag_min_score})),
]
files = {"image": self.data}
self.client.post("/predict", data=data, files=files)
class CLIPTextFormDataLoadTest(InferenceLoadTest):
2023-06-25 18:20:45 +00:00
@task
def encode_text(self) -> None:
data = [
("modelName", self.environment.parsed_options.clip_model),
("modelType", "clip"),
("options", json.dumps({"mode": "text"})),
("text", "test search query")
]
self.client.post("/predict", data=data)
2023-06-25 18:20:45 +00:00
class CLIPVisionFormDataLoadTest(InferenceLoadTest):
2023-06-25 18:20:45 +00:00
@task
def encode_image(self) -> None:
data = [
("modelName", self.environment.parsed_options.clip_model),
("modelType", "clip"),
("options", json.dumps({"mode": "vision"})),
]
files = {"image": self.data}
self.client.post("/predict", data=data, files=files)
2023-06-25 18:20:45 +00:00
class RecognitionFormDataLoadTest(InferenceLoadTest):
2023-06-25 18:20:45 +00:00
@task
def recognize(self) -> None:
data = [
("modelName", self.environment.parsed_options.face_model),
("modelType", "facial-recognition"),
("options", json.dumps({"minScore": self.environment.parsed_options.face_min_score})),
]
files = {"image": self.data}
self.client.post("/predict", data=data, files=files)