1
0
Fork 0
mirror of https://github.com/immich-app/immich.git synced 2025-01-27 22:22:45 +01:00

fix(ml): race condition when loading models ()

* sync model loading, disabled model ttl by default

* disable revalidation if model unloading disabled

* moved lock
This commit is contained in:
Mert 2023-07-11 13:01:21 -04:00 committed by GitHub
parent 9ad024c189
commit 848ba685eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 10 deletions
machine-learning/app

View file

@ -13,7 +13,7 @@ class Settings(BaseSettings):
facial_recognition_model: str = "buffalo_l" facial_recognition_model: str = "buffalo_l"
min_tag_score: float = 0.9 min_tag_score: float = 0.9
eager_startup: bool = True eager_startup: bool = True
model_ttl: int = 300 model_ttl: int = 0
host: str = "0.0.0.0" host: str = "0.0.0.0"
port: int = 3003 port: int = 3003
workers: int = 1 workers: int = 1

View file

@ -25,7 +25,7 @@ app = FastAPI()
def init_state() -> None: def init_state() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True) app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
async def load_models() -> None: async def load_models() -> None:

View file

@ -1,4 +1,3 @@
import asyncio
from typing import Any from typing import Any
from aiocache.backends.memory import SimpleMemoryCache from aiocache.backends.memory import SimpleMemoryCache
@ -48,13 +47,10 @@ class ModelCache:
""" """
key = self.cache.build_key(model_name, model_type.value) key = self.cache.build_key(model_name, model_type.value)
model = await self.cache.get(key) async with OptimisticLock(self.cache, key) as lock:
if model is None: model = await self.cache.get(key)
async with OptimisticLock(self.cache, key) as lock: if model is None:
model = await asyncio.get_running_loop().run_in_executor( model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
None,
lambda: InferenceModel.from_model_type(model_type, model_name, **model_kwargs),
)
await lock.cas(model, ttl=self.ttl) await lock.cas(model, ttl=self.ttl)
return model return model