From 59300d2097e5c2f8e81b94ea6113750f3d14c9ff Mon Sep 17 00:00:00 2001 From: martabal <74269598+martabal@users.noreply.github.com> Date: Wed, 25 Sep 2024 18:22:54 +0200 Subject: [PATCH] feat: preload textual model --- machine-learning/app/main.py | 13 ++++++---- machine-learning/app/models/cache.py | 7 ++++++ machine-learning/app/schemas.py | 11 --------- open-api/immich-openapi-specs.json | 10 ++------ server/src/config.ts | 2 -- server/src/dtos/model-config.dto.ts | 9 +++---- .../interfaces/machine-learning.interface.ts | 13 +++++----- server/src/repositories/event.repository.ts | 24 +++++++++++++++++-- .../machine-learning.repository.ts | 12 +++++++--- .../machine-learning-settings.svelte | 17 ------------- 10 files changed, 59 insertions(+), 59 deletions(-) diff --git a/machine-learning/app/main.py b/machine-learning/app/main.py index da82c3a586..6d359ec2de 100644 --- a/machine-learning/app/main.py +++ b/machine-learning/app/main.py @@ -28,7 +28,6 @@ from .schemas import ( InferenceEntries, InferenceEntry, InferenceResponse, - LoadModelEntry, MessageResponse, ModelFormat, ModelIdentity, @@ -125,17 +124,16 @@ def get_entries(entries: str = Form()) -> InferenceEntries: raise HTTPException(422, "Invalid request format.") -def get_entry(entries: str = Form()) -> LoadModelEntry: +def get_entry(entries: str = Form()) -> InferenceEntry: try: request: PipelineRequest = orjson.loads(entries) for task, types in request.items(): for type, entry in types.items(): - parsed: LoadModelEntry = { + parsed: InferenceEntry = { "name": entry["modelName"], "task": task, "type": type, "options": entry.get("options", {}), - "ttl": entry["ttl"] if "ttl" in entry else settings.ttl, } return parsed except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e: @@ -163,6 +161,13 @@ async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None: return Response(status_code=200) +@app.post("/unload", response_model=TextResponse) +async def unload_model(entry: InferenceEntry = Depends(get_entry)) -> None: + await model_cache.unload(entry["name"], entry["type"], entry["task"]) + print("unload") + return Response(status_code=200) + + @app.post("/predict", dependencies=[Depends(update_state)]) async def predict( entries: InferenceEntries = Depends(get_entries), diff --git a/machine-learning/app/models/cache.py b/machine-learning/app/models/cache.py index bf8e8a6352..34c9fd2a41 100644 --- a/machine-learning/app/models/cache.py +++ b/machine-learning/app/models/cache.py @@ -58,3 +58,10 @@ class ModelCache: async def revalidate(self, key: str, ttl: int | None) -> None: if ttl is not None and key in self.cache._handlers: await self.cache.expire(key, ttl) + + async def unload(self, model_name: str, model_type: ModelType, model_task: ModelTask) -> None: + key = f"{model_name}{model_type}{model_task}" + async with OptimisticLock(self.cache, key): + value = await self.cache.get(key) + if value is not None: + await self.cache.delete(key) diff --git a/machine-learning/app/schemas.py b/machine-learning/app/schemas.py index b3cf60add9..f051db12c3 100644 --- a/machine-learning/app/schemas.py +++ b/machine-learning/app/schemas.py @@ -109,17 +109,6 @@ class InferenceEntry(TypedDict): options: dict[str, Any] -class LoadModelEntry(InferenceEntry): - ttl: int - - def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int): - super().__init__(name=name, task=task, type=type, options=options) - - if ttl <= 0: - raise ValueError("ttl must be a positive integer") - self.ttl = ttl - - InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]] diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json index 9bcfd23e83..2d8ebaae67 100644 --- a/open-api/immich-openapi-specs.json +++ b/open-api/immich-openapi-specs.json @@ -5307,8 +5307,8 @@ "name": "password", "required": false, "in": "query", + "example": "password", "schema": { - "example": "password", "type": "string" } }, @@ -9510,16 +9510,10 @@ "properties": { "enabled": { "type": "boolean" - }, - "ttl": { - "format": "int64", - "minimum": 0, - "type": "number" } }, "required": [ - "enabled", - "ttl" + "enabled" ], "type": "object" }, diff --git a/server/src/config.ts b/server/src/config.ts index bdafed40a1..cfc0fcaa32 100644 --- a/server/src/config.ts +++ b/server/src/config.ts @@ -122,7 +122,6 @@ export interface SystemConfig { modelName: string; loadTextualModelOnConnection: { enabled: boolean; - ttl: number; }; }; duplicateDetection: { @@ -276,7 +275,6 @@ export const defaults = Object.freeze({ modelName: 'ViT-B-32__openai', loadTextualModelOnConnection: { enabled: false, - ttl: 300, }, }, duplicateDetection: { diff --git a/server/src/dtos/model-config.dto.ts b/server/src/dtos/model-config.dto.ts index a1320e1a49..0c1630e531 100644 --- a/server/src/dtos/model-config.dto.ts +++ b/server/src/dtos/model-config.dto.ts @@ -14,12 +14,9 @@ export class ModelConfig extends TaskConfig { modelName!: string; } -export class LoadTextualModelOnConnection extends TaskConfig { - @IsNumber() - @Min(0) - @Type(() => Number) - @ApiProperty({ type: 'number', format: 'int64' }) - ttl!: number; +export class LoadTextualModelOnConnection { + @ValidateBoolean() + enabled!: boolean; } export class CLIPConfig extends ModelConfig { diff --git a/server/src/interfaces/machine-learning.interface.ts b/server/src/interfaces/machine-learning.interface.ts index 9c87b323a8..205d69f4f5 100644 --- a/server/src/interfaces/machine-learning.interface.ts +++ b/server/src/interfaces/machine-learning.interface.ts @@ -24,17 +24,13 @@ export type ModelPayload = { imagePath: string } | { text: string }; type ModelOptions = { modelName: string }; -export interface LoadModelOptions extends ModelOptions { - ttl: number; -} - export type FaceDetectionOptions = ModelOptions & { minScore: number }; type VisualResponse = { imageHeight: number; imageWidth: number }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse; -export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } }; +export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } }; export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] }; export type FacialRecognitionRequest = { @@ -50,6 +46,11 @@ export interface Face { score: number; } +export enum LoadTextModelActions { + LOAD, + UNLOAD, +} + export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse; export type DetectedFaces = { faces: Face[] } & VisualResponse; export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest; @@ -58,5 +59,5 @@ export interface IMachineLearningRepository { encodeImage(url: string, imagePath: string, config: ModelOptions): Promise; encodeText(url: string, text: string, config: ModelOptions): Promise; detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise; - loadTextModel(url: string, config: ModelOptions): Promise; + prepareTextModel(url: string, config: ModelOptions, action: LoadTextModelActions): Promise; } diff --git a/server/src/repositories/event.repository.ts b/server/src/repositories/event.repository.ts index 5a5c8ba338..02f1e11907 100644 --- a/server/src/repositories/event.repository.ts +++ b/server/src/repositories/event.repository.ts @@ -20,7 +20,7 @@ import { ServerEventMap, } from 'src/interfaces/event.interface'; import { ILoggerRepository } from 'src/interfaces/logger.interface'; -import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface'; +import { IMachineLearningRepository, LoadTextModelActions } from 'src/interfaces/machine-learning.interface'; import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface'; import { AuthService } from 'src/services/auth.service'; import { Instrumentation } from 'src/utils/instrumentation'; @@ -79,7 +79,12 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect const { machineLearning } = await this.configCore.getConfig({ withCache: true }); if (machineLearning.clip.loadTextualModelOnConnection.enabled) { try { - this.machineLearningRepository.loadTextModel(machineLearning.url, machineLearning.clip); + console.log(this.server); + this.machineLearningRepository.prepareTextModel( + machineLearning.url, + machineLearning.clip, + LoadTextModelActions.LOAD, + ); } catch (error) { this.logger.warn(error); } @@ -100,6 +105,21 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect async handleDisconnect(client: Socket) { this.logger.log(`Websocket Disconnect: ${client.id}`); await client.leave(client.nsp.name); + if ('background' in client.handshake.query && client.handshake.query.background === 'false') { + const { machineLearning } = await this.configCore.getConfig({ withCache: true }); + if (machineLearning.clip.loadTextualModelOnConnection.enabled && this.server?.engine.clientsCount == 0) { + try { + this.machineLearningRepository.prepareTextModel( + machineLearning.url, + machineLearning.clip, + LoadTextModelActions.UNLOAD, + ); + this.logger.debug('sent request to unload text model'); + } catch (error) { + this.logger.warn(error); + } + } + } } on(event: T, handler: EmitHandler): void { diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index a084d7c770..a65b29fa56 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -7,6 +7,7 @@ import { FaceDetectionOptions, FacialRecognitionResponse, IMachineLearningRepository, + LoadTextModelActions, MachineLearningRequest, ModelPayload, ModelTask, @@ -38,11 +39,16 @@ export class MachineLearningRepository implements IMachineLearningRepository { return res; } - async loadTextModel(url: string, { modelName, loadTextualModelOnConnection: { ttl } }: CLIPConfig) { + private prepareTextModelUrl: Record = { + [LoadTextModelActions.LOAD]: '/load', + [LoadTextModelActions.UNLOAD]: '/unload', + }; + + async prepareTextModel(url: string, { modelName }: CLIPConfig, actions: LoadTextModelActions) { try { - const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, ttl } } }; + const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } }; const formData = await this.getFormData(request); - const res = await this.fetchData(url, '/load', formData); + const res = await this.fetchData(url, this.prepareTextModelUrl[actions], formData); if (res.status >= 400) { throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`); } diff --git a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte index b94a4bd960..a188cbc67a 100644 --- a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte +++ b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte @@ -88,23 +88,6 @@ bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled} disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled} /> - -
- -