mirror of
https://github.com/immich-app/immich.git
synced 2024-12-29 15:11:58 +00:00
fixed setting different clip, removed unused stubs (#2987)
This commit is contained in:
parent
b3e97a1a0c
commit
4d3ce0a65e
5 changed files with 6 additions and 25 deletions
|
@ -27,13 +27,10 @@ app = FastAPI()
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event() -> None:
|
async def startup_event() -> None:
|
||||||
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
|
||||||
same_clip = settings.clip_image_model == settings.clip_text_model
|
|
||||||
app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
|
|
||||||
app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
|
|
||||||
models = [
|
models = [
|
||||||
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
|
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
|
||||||
(settings.clip_image_model, app.state.clip_vision_type),
|
(settings.clip_image_model, ModelType.CLIP),
|
||||||
(settings.clip_text_model, app.state.clip_text_type),
|
(settings.clip_text_model, ModelType.CLIP),
|
||||||
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
|
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -87,9 +84,7 @@ async def image_classification(
|
||||||
async def clip_encode_image(
|
async def clip_encode_image(
|
||||||
image: Image.Image = Depends(dep_pil_image),
|
image: Image.Image = Depends(dep_pil_image),
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
model = await app.state.model_cache.get(
|
model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
|
||||||
settings.clip_image_model, app.state.clip_vision_type
|
|
||||||
)
|
|
||||||
embedding = model.predict(image)
|
embedding = model.predict(image)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
@ -100,9 +95,7 @@ async def clip_encode_image(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
|
||||||
model = await app.state.model_cache.get(
|
model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
|
||||||
settings.clip_text_model, app.state.clip_text_type
|
|
||||||
)
|
|
||||||
embedding = model.predict(payload.text)
|
embedding = model.predict(payload.text)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder
|
from .clip import CLIPSTEncoder
|
||||||
from .facial_recognition import FaceRecognizer
|
from .facial_recognition import FaceRecognizer
|
||||||
from .image_classification import ImageClassifier
|
from .image_classification import ImageClassifier
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod, ABC
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
|
@ -25,13 +25,3 @@ class CLIPSTEncoder(InferenceModel):
|
||||||
|
|
||||||
def predict(self, image_or_text: Image | str) -> list[float]:
|
def predict(self, image_or_text: Image | str) -> list[float]:
|
||||||
return self.model.encode(image_or_text).tolist()
|
return self.model.encode(image_or_text).tolist()
|
||||||
|
|
||||||
|
|
||||||
# stubs to allow different behavior between the two in the future
|
|
||||||
# and handle loading different image and text clip models
|
|
||||||
class CLIPSTVisionEncoder(CLIPSTEncoder):
|
|
||||||
_model_type = ModelType.CLIP_VISION
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPSTTextEncoder(CLIPSTEncoder):
|
|
||||||
_model_type = ModelType.CLIP_TEXT
|
|
||||||
|
|
|
@ -61,6 +61,4 @@ class FaceResponse(BaseModel):
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
IMAGE_CLASSIFICATION = "image-classification"
|
IMAGE_CLASSIFICATION = "image-classification"
|
||||||
CLIP = "clip"
|
CLIP = "clip"
|
||||||
CLIP_VISION = "clip-vision"
|
|
||||||
CLIP_TEXT = "clip-text"
|
|
||||||
FACIAL_RECOGNITION = "facial-recognition"
|
FACIAL_RECOGNITION = "facial-recognition"
|
||||||
|
|
Loading…
Reference in a new issue