1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2024-12-29 15:11:58 +00:00

fixed setting different clip, removed unused stubs (#2987)

This commit is contained in:
Mert 2023-06-27 13:21:50 -04:00 committed by GitHub
parent b3e97a1a0c
commit 4d3ce0a65e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 6 additions and 25 deletions

View file

@ -27,13 +27,10 @@ app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
async def startup_event() -> None: async def startup_event() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True) app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
same_clip = settings.clip_image_model == settings.clip_text_model
app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
models = [ models = [
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION), (settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
(settings.clip_image_model, app.state.clip_vision_type), (settings.clip_image_model, ModelType.CLIP),
(settings.clip_text_model, app.state.clip_text_type), (settings.clip_text_model, ModelType.CLIP),
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION), (settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
] ]
@ -87,9 +84,7 @@ async def image_classification(
async def clip_encode_image( async def clip_encode_image(
image: Image.Image = Depends(dep_pil_image), image: Image.Image = Depends(dep_pil_image),
) -> list[float]: ) -> list[float]:
model = await app.state.model_cache.get( model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
settings.clip_image_model, app.state.clip_vision_type
)
embedding = model.predict(image) embedding = model.predict(image)
return embedding return embedding
@ -100,9 +95,7 @@ async def clip_encode_image(
status_code=200, status_code=200,
) )
async def clip_encode_text(payload: TextModelRequest) -> list[float]: async def clip_encode_text(payload: TextModelRequest) -> list[float]:
model = await app.state.model_cache.get( model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
settings.clip_text_model, app.state.clip_text_type
)
embedding = model.predict(payload.text) embedding = model.predict(payload.text)
return embedding return embedding

View file

@ -1,3 +1,3 @@
from .clip import CLIPSTTextEncoder, CLIPSTVisionEncoder from .clip import CLIPSTEncoder
from .facial_recognition import FaceRecognizer from .facial_recognition import FaceRecognizer
from .image_classification import ImageClassifier from .image_classification import ImageClassifier

View file

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod, ABC from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any

View file

@ -25,13 +25,3 @@ class CLIPSTEncoder(InferenceModel):
def predict(self, image_or_text: Image | str) -> list[float]: def predict(self, image_or_text: Image | str) -> list[float]:
return self.model.encode(image_or_text).tolist() return self.model.encode(image_or_text).tolist()
# stubs to allow different behavior between the two in the future
# and handle loading different image and text clip models
class CLIPSTVisionEncoder(CLIPSTEncoder):
_model_type = ModelType.CLIP_VISION
class CLIPSTTextEncoder(CLIPSTEncoder):
_model_type = ModelType.CLIP_TEXT

View file

@ -61,6 +61,4 @@ class FaceResponse(BaseModel):
class ModelType(Enum): class ModelType(Enum):
IMAGE_CLASSIFICATION = "image-classification" IMAGE_CLASSIFICATION = "image-classification"
CLIP = "clip" CLIP = "clip"
CLIP_VISION = "clip-vision"
CLIP_TEXT = "clip-text"
FACIAL_RECOGNITION = "facial-recognition" FACIAL_RECOGNITION = "facial-recognition"