2023-10-31 11:02:04 +01:00
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
from app.schemas import ModelType
|
|
|
|
|
|
|
|
from .base import InferenceModel
|
2023-11-12 02:04:49 +01:00
|
|
|
from .clip import MCLIPEncoder, OpenCLIPEncoder
|
|
|
|
from .constants import is_insightface, is_mclip, is_openclip
|
2023-06-25 05:18:09 +02:00
|
|
|
from .facial_recognition import FaceRecognizer
|
2023-10-31 11:02:04 +01:00
|
|
|
|
|
|
|
|
|
|
|
def from_model_type(model_type: ModelType, model_name: str, **model_kwargs: Any) -> InferenceModel:
|
|
|
|
match model_type:
|
|
|
|
case ModelType.CLIP:
|
|
|
|
if is_openclip(model_name):
|
|
|
|
return OpenCLIPEncoder(model_name, **model_kwargs)
|
|
|
|
elif is_mclip(model_name):
|
|
|
|
return MCLIPEncoder(model_name, **model_kwargs)
|
|
|
|
case ModelType.FACIAL_RECOGNITION:
|
2023-11-12 02:04:49 +01:00
|
|
|
if is_insightface(model_name):
|
|
|
|
return FaceRecognizer(model_name, **model_kwargs)
|
2023-10-31 11:02:04 +01:00
|
|
|
case _:
|
|
|
|
raise ValueError(f"Unknown model type {model_type}")
|
2023-11-12 02:04:49 +01:00
|
|
|
|
|
|
|
raise ValueError(f"Unknown ${model_type} model {model_name}")
|