From 848ba685eb25908e755ea30e54d78eaafbae5693 Mon Sep 17 00:00:00 2001
From: Mert <101130780+mertalev@users.noreply.github.com>
Date: Tue, 11 Jul 2023 13:01:21 -0400
Subject: [PATCH] fix(ml): race condition when loading models (#3207)

* sync model loading, disabled model ttl by default

* disable revalidation if model unloading disabled

* moved lock
---
 machine-learning/app/config.py       |  2 +-
 machine-learning/app/main.py         |  2 +-
 machine-learning/app/models/cache.py | 12 ++++--------
 3 files changed, 6 insertions(+), 10 deletions(-)

diff --git a/machine-learning/app/config.py b/machine-learning/app/config.py
index 70520b27ca..f5cb835953 100644
--- a/machine-learning/app/config.py
+++ b/machine-learning/app/config.py
@@ -13,7 +13,7 @@ class Settings(BaseSettings):
     facial_recognition_model: str = "buffalo_l"
     min_tag_score: float = 0.9
     eager_startup: bool = True
-    model_ttl: int = 300
+    model_ttl: int = 0
     host: str = "0.0.0.0"
     port: int = 3003
     workers: int = 1
diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py
index 35ee27204c..264eb2ee87 100644
--- a/machine-learning/app/main.py
+++ b/machine-learning/app/main.py
@@ -25,7 +25,7 @@ app = FastAPI()
 
 
 def init_state() -> None:
-    app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
+    app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
 
 
 async def load_models() -> None:
diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py
index 086a57c5ae..b9d5f75a0d 100644
--- a/machine-learning/app/models/cache.py
+++ b/machine-learning/app/models/cache.py
@@ -1,4 +1,3 @@
-import asyncio
 from typing import Any
 
 from aiocache.backends.memory import SimpleMemoryCache
@@ -48,13 +47,10 @@ class ModelCache:
         """
 
         key = self.cache.build_key(model_name, model_type.value)
-        model = await self.cache.get(key)
-        if model is None:
-            async with OptimisticLock(self.cache, key) as lock:
-                model = await asyncio.get_running_loop().run_in_executor(
-                    None,
-                    lambda: InferenceModel.from_model_type(model_type, model_name, **model_kwargs),
-                )
+        async with OptimisticLock(self.cache, key) as lock:
+            model = await self.cache.get(key)
+            if model is None:
+                model = InferenceModel.from_model_type(model_type, model_name, **model_kwargs)
                 await lock.cas(model, ttl=self.ttl)
         return model