mirror of
https://github.com/immich-app/immich.git
synced 2025-01-04 02:46:47 +01:00
fix(ml): limit load retries (#10494)
This commit is contained in:
parent
79a8ab71ef
commit
a42af06889
3 changed files with 26 additions and 11 deletions
|
@ -192,23 +192,18 @@ async def load(model: InferenceModel) -> InferenceModel:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load(model: InferenceModel) -> InferenceModel:
|
def _load(model: InferenceModel) -> InferenceModel:
|
||||||
|
if model.load_attempts > 1:
|
||||||
|
raise HTTPException(500, f"Failed to load model '{model.model_name}'")
|
||||||
with lock:
|
with lock:
|
||||||
model.load()
|
model.load()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await run(_load, model)
|
return await run(_load, model)
|
||||||
return model
|
|
||||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||||
log.warning(
|
log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.")
|
||||||
(
|
|
||||||
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
|
|
||||||
"Clearing cache and retrying."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
model.clear_cache()
|
model.clear_cache()
|
||||||
await run(_load, model)
|
return await run(_load, model)
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
async def idle_shutdown_task() -> None:
|
async def idle_shutdown_task() -> None:
|
||||||
|
|
|
@ -31,6 +31,7 @@ class InferenceModel(ABC):
|
||||||
**model_kwargs: Any,
|
**model_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.loaded = False
|
self.loaded = False
|
||||||
|
self.load_attempts = 0
|
||||||
self.model_name = clean_name(model_name)
|
self.model_name = clean_name(model_name)
|
||||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
|
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
|
||||||
self.providers = providers if providers is not None else self.providers_default
|
self.providers = providers if providers is not None else self.providers_default
|
||||||
|
@ -48,9 +49,11 @@ class InferenceModel(ABC):
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
if self.loaded:
|
if self.loaded:
|
||||||
return
|
return
|
||||||
|
self.load_attempts += 1
|
||||||
|
|
||||||
self.download()
|
self.download()
|
||||||
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
attempt = f"Attempt #{self.load_attempts + 1} to load" if self.load_attempts else "Loading"
|
||||||
|
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
||||||
self.session = self._load()
|
self.session = self._load()
|
||||||
self.loaded = True
|
self.loaded = True
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytest import MonkeyPatch
|
from pytest import MonkeyPatch
|
||||||
|
@ -627,6 +628,7 @@ class TestLoad:
|
||||||
async def test_load(self) -> None:
|
async def test_load(self) -> None:
|
||||||
mock_model = mock.Mock(spec=InferenceModel)
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
mock_model.loaded = False
|
mock_model.loaded = False
|
||||||
|
mock_model.load_attempts = 0
|
||||||
|
|
||||||
res = await load(mock_model)
|
res = await load(mock_model)
|
||||||
|
|
||||||
|
@ -650,6 +652,7 @@ class TestLoad:
|
||||||
mock_model.model_task = ModelTask.SEARCH
|
mock_model.model_task = ModelTask.SEARCH
|
||||||
mock_model.load.side_effect = [OSError, None]
|
mock_model.load.side_effect = [OSError, None]
|
||||||
mock_model.loaded = False
|
mock_model.loaded = False
|
||||||
|
mock_model.load_attempts = 0
|
||||||
|
|
||||||
res = await load(mock_model)
|
res = await load(mock_model)
|
||||||
|
|
||||||
|
@ -657,6 +660,20 @@ class TestLoad:
|
||||||
mock_model.clear_cache.assert_called_once()
|
mock_model.clear_cache.assert_called_once()
|
||||||
assert mock_model.load.call_count == 2
|
assert mock_model.load.call_count == 2
|
||||||
|
|
||||||
|
async def test_load_clears_cache_and_raises_if_os_error_and_already_retried(self) -> None:
|
||||||
|
mock_model = mock.Mock(spec=InferenceModel)
|
||||||
|
mock_model.model_name = "test_model_name"
|
||||||
|
mock_model.model_type = ModelType.VISUAL
|
||||||
|
mock_model.model_task = ModelTask.SEARCH
|
||||||
|
mock_model.loaded = False
|
||||||
|
mock_model.load_attempts = 2
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
await load(mock_model)
|
||||||
|
|
||||||
|
mock_model.clear_cache.assert_not_called()
|
||||||
|
mock_model.load.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not settings.test_full,
|
not settings.test_full,
|
||||||
|
|
Loading…
Reference in a new issue