From 4ed75f2ac9be0d59089d029d23316f420a3b3f40 Mon Sep 17 00:00:00 2001 From: Mert <101130780+mertalev@users.noreply.github.com> Date: Sun, 4 Aug 2024 17:00:36 -0400 Subject: [PATCH] refactor(server): add config events for clip (#11575) use config events for clip, add tests formatting --- server/src/interfaces/search.interface.ts | 3 +- server/src/repositories/search.repository.ts | 55 ++--- .../src/services/smart-info.service.spec.ts | 210 ++++++++++++++++++ server/src/services/smart-info.service.ts | 70 ++++-- .../repositories/search.repository.mock.ts | 3 +- 5 files changed, 288 insertions(+), 53 deletions(-) diff --git a/server/src/interfaces/search.interface.ts b/server/src/interfaces/search.interface.ts index c17f833615..d77cd62cd1 100644 --- a/server/src/interfaces/search.interface.ts +++ b/server/src/interfaces/search.interface.ts @@ -170,7 +170,6 @@ export interface AssetDuplicateResult { } export interface ISearchRepository { - init(modelName: string): Promise<void>; searchMetadata(pagination: SearchPaginationOptions, options: AssetSearchOptions): Paginated<AssetEntity>; searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions): Paginated<AssetEntity>; searchDuplicates(options: AssetDuplicateSearch): Promise<AssetDuplicateResult[]>; @@ -179,4 +178,6 @@ export interface ISearchRepository { searchPlaces(placeName: string): Promise<GeodataPlacesEntity[]>; getAssetsByCity(userIds: string[]): Promise<AssetEntity[]>; deleteAllSearchEmbeddings(): Promise<void>; + getDimensionSize(): Promise<number>; + setDimensionSize(dimSize: number): Promise<void>; } diff --git a/server/src/repositories/search.repository.ts b/server/src/repositories/search.repository.ts index a4c7edab91..9abe62a12d 100644 --- a/server/src/repositories/search.repository.ts +++ b/server/src/repositories/search.repository.ts @@ -21,7 +21,6 @@ import { } from 'src/interfaces/search.interface'; import { asVector, searchAssetBuilder } from 'src/utils/database'; import { Instrumentation } from 'src/utils/instrumentation'; -import { getCLIPModelInfo } from 'src/utils/misc'; import { Paginated, PaginationMode, PaginationResult, paginatedBuilder } from 'src/utils/pagination'; import { isValidInteger } from 'src/validation'; import { Repository, SelectQueryBuilder } from 'typeorm'; @@ -55,17 +54,6 @@ export class SearchRepository implements ISearchRepository { ' INNER JOIN cte ON asset.id = cte."assetId" ORDER BY exif.city'; } - async init(modelName: string): Promise<void> { - const { dimSize } = getCLIPModelInfo(modelName); - const curDimSize = await this.getDimSize(); - this.logger.verbose(`Current database CLIP dimension size is ${curDimSize}`); - - if (dimSize != curDimSize) { - this.logger.log(`Dimension size of model ${modelName} is ${dimSize}, but database expects ${curDimSize}.`); - await this.updateDimSize(dimSize); - } - } - @GenerateSql({ params: [ { page: 1, size: 100 }, @@ -300,32 +288,7 @@ export class SearchRepository implements ISearchRepository { ); } - private async updateDimSize(dimSize: number): Promise<void> { - if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) { - throw new Error(`Invalid CLIP dimension size: ${dimSize}`); - } - - const curDimSize = await this.getDimSize(); - if (curDimSize === dimSize) { - return; - } - - this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`); - - await this.smartSearchRepository.manager.transaction(async (manager) => { - await manager.clear(SmartSearchEntity); - await manager.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`); - await manager.query(`REINDEX INDEX clip_index`); - }); - - this.logger.log(`Successfully updated database CLIP dimension size from ${curDimSize} to ${dimSize}.`); - } - - deleteAllSearchEmbeddings(): Promise<void> { - return this.smartSearchRepository.clear(); - } - - private async getDimSize(): Promise<number> { + async getDimensionSize(): Promise<number> { const res = await this.smartSearchRepository.manager.query(` SELECT atttypmod as dimsize FROM pg_attribute f @@ -342,6 +305,22 @@ export class SearchRepository implements ISearchRepository { return dimSize; } + setDimensionSize(dimSize: number): Promise<void> { + if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) { + throw new Error(`Invalid CLIP dimension size: ${dimSize}`); + } + + return this.smartSearchRepository.manager.transaction(async (manager) => { + await manager.clear(SmartSearchEntity); + await manager.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`); + await manager.query(`REINDEX INDEX clip_index`); + }); + } + + async deleteAllSearchEmbeddings(): Promise<void> { + return this.smartSearchRepository.clear(); + } + private getRuntimeConfig(numResults?: number): string { if (getVectorExtension() === DatabaseExtension.VECTOR) { return 'SET LOCAL hnsw.ef_search = 1000;'; // mitigate post-filter recall diff --git a/server/src/services/smart-info.service.spec.ts b/server/src/services/smart-info.service.spec.ts index 95f76edc49..f18dc91ff1 100644 --- a/server/src/services/smart-info.service.spec.ts +++ b/server/src/services/smart-info.service.spec.ts @@ -1,3 +1,4 @@ +import { SystemConfig } from 'src/config'; import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interface'; import { IDatabaseRepository } from 'src/interfaces/database.interface'; import { IJobRepository, JobName, JobStatus } from 'src/interfaces/job.interface'; @@ -45,6 +46,215 @@ describe(SmartInfoService.name, () => { expect(sut).toBeDefined(); }); + describe('onConfigValidateEvent', () => { + it('should allow a valid model', () => { + expect(() => + sut.onConfigValidateEvent({ + newConfig: { machineLearning: { clip: { modelName: 'ViT-B-16__openai' } } } as SystemConfig, + oldConfig: {} as SystemConfig, + }), + ).not.toThrow(); + }); + + it('should allow including organization', () => { + expect(() => + sut.onConfigValidateEvent({ + newConfig: { machineLearning: { clip: { modelName: 'immich-app/ViT-B-16__openai' } } } as SystemConfig, + oldConfig: {} as SystemConfig, + }), + ).not.toThrow(); + }); + + it('should fail for an unsupported model', () => { + expect(() => + sut.onConfigValidateEvent({ + newConfig: { machineLearning: { clip: { modelName: 'test-model' } } } as SystemConfig, + oldConfig: {} as SystemConfig, + }), + ).toThrow('Unknown CLIP model: test-model'); + }); + }); + + describe('onBootstrapEvent', () => { + it('should return if not microservices', async () => { + await sut.onBootstrapEvent('api'); + + expect(systemMock.get).not.toHaveBeenCalled(); + expect(searchMock.getDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.setDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); + expect(jobMock.getQueueStatus).not.toHaveBeenCalled(); + expect(jobMock.pause).not.toHaveBeenCalled(); + expect(jobMock.waitForQueueCompletion).not.toHaveBeenCalled(); + expect(jobMock.resume).not.toHaveBeenCalled(); + }); + + it('should return if machine learning is disabled', async () => { + systemMock.get.mockResolvedValue(systemConfigStub.machineLearningDisabled); + + await sut.onBootstrapEvent('microservices'); + + expect(systemMock.get).toHaveBeenCalledTimes(1); + expect(searchMock.getDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.setDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); + expect(jobMock.getQueueStatus).not.toHaveBeenCalled(); + expect(jobMock.pause).not.toHaveBeenCalled(); + expect(jobMock.waitForQueueCompletion).not.toHaveBeenCalled(); + expect(jobMock.resume).not.toHaveBeenCalled(); + }); + + it('should return if model and DB dimension size are equal', async () => { + searchMock.getDimensionSize.mockResolvedValue(512); + + await sut.onBootstrapEvent('microservices'); + + expect(systemMock.get).toHaveBeenCalledTimes(1); + expect(searchMock.getDimensionSize).toHaveBeenCalledTimes(1); + expect(searchMock.setDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); + expect(jobMock.getQueueStatus).not.toHaveBeenCalled(); + expect(jobMock.pause).not.toHaveBeenCalled(); + expect(jobMock.waitForQueueCompletion).not.toHaveBeenCalled(); + expect(jobMock.resume).not.toHaveBeenCalled(); + }); + + it('should update DB dimension size if model and DB have different values', async () => { + searchMock.getDimensionSize.mockResolvedValue(768); + jobMock.getQueueStatus.mockResolvedValue({ isActive: false, isPaused: false }); + + await sut.onBootstrapEvent('microservices'); + + expect(systemMock.get).toHaveBeenCalledTimes(1); + expect(searchMock.getDimensionSize).toHaveBeenCalledTimes(1); + expect(searchMock.setDimensionSize).toHaveBeenCalledWith(512); + expect(jobMock.getQueueStatus).toHaveBeenCalledTimes(1); + expect(jobMock.pause).toHaveBeenCalledTimes(1); + expect(jobMock.waitForQueueCompletion).toHaveBeenCalledTimes(1); + expect(jobMock.resume).toHaveBeenCalledTimes(1); + }); + + it('should skip pausing and resuming queue if already paused', async () => { + searchMock.getDimensionSize.mockResolvedValue(768); + jobMock.getQueueStatus.mockResolvedValue({ isActive: false, isPaused: true }); + + await sut.onBootstrapEvent('microservices'); + + expect(systemMock.get).toHaveBeenCalledTimes(1); + expect(searchMock.getDimensionSize).toHaveBeenCalledTimes(1); + expect(searchMock.setDimensionSize).toHaveBeenCalledWith(512); + expect(jobMock.getQueueStatus).toHaveBeenCalledTimes(1); + expect(jobMock.pause).not.toHaveBeenCalled(); + expect(jobMock.waitForQueueCompletion).toHaveBeenCalledTimes(1); + expect(jobMock.resume).not.toHaveBeenCalled(); + }); + }); + + describe('onConfigUpdateEvent', () => { + it('should return if machine learning is disabled', async () => { + systemMock.get.mockResolvedValue(systemConfigStub.machineLearningDisabled); + + await sut.onConfigUpdateEvent({ + newConfig: systemConfigStub.machineLearningDisabled as SystemConfig, + oldConfig: systemConfigStub.machineLearningDisabled as SystemConfig, + }); + + expect(systemMock.get).not.toHaveBeenCalled(); + expect(searchMock.getDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.setDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); + expect(jobMock.getQueueStatus).not.toHaveBeenCalled(); + expect(jobMock.pause).not.toHaveBeenCalled(); + expect(jobMock.waitForQueueCompletion).not.toHaveBeenCalled(); + expect(jobMock.resume).not.toHaveBeenCalled(); + }); + + it('should return if model and DB dimension size are equal', async () => { + searchMock.getDimensionSize.mockResolvedValue(512); + + await sut.onConfigUpdateEvent({ + newConfig: { + machineLearning: { clip: { modelName: 'ViT-B-16__openai', enabled: true }, enabled: true }, + } as SystemConfig, + oldConfig: { + machineLearning: { clip: { modelName: 'ViT-B-16__openai', enabled: true }, enabled: true }, + } as SystemConfig, + }); + + expect(searchMock.getDimensionSize).toHaveBeenCalledTimes(1); + expect(searchMock.setDimensionSize).not.toHaveBeenCalled(); + expect(searchMock.deleteAllSearchEmbeddings).not.toHaveBeenCalled(); + expect(jobMock.getQueueStatus).not.toHaveBeenCalled(); + expect(jobMock.pause).not.toHaveBeenCalled(); + expect(jobMock.waitForQueueCompletion).not.toHaveBeenCalled(); + expect(jobMock.resume).not.toHaveBeenCalled(); + }); + + it('should update DB dimension size if model and DB have different values', async () => { + searchMock.getDimensionSize.mockResolvedValue(512); + jobMock.getQueueStatus.mockResolvedValue({ isActive: false, isPaused: false }); + + await sut.onConfigUpdateEvent({ + newConfig: { + machineLearning: { clip: { modelName: 'ViT-L-14-quickgelu__dfn2b', enabled: true }, enabled: true }, + } as SystemConfig, + oldConfig: { + machineLearning: { clip: { modelName: 'ViT-B-16__openai', enabled: true }, enabled: true }, + } as SystemConfig, + }); + + expect(searchMock.getDimensionSize).toHaveBeenCalledTimes(1); + expect(searchMock.setDimensionSize).toHaveBeenCalledWith(768); + expect(jobMock.getQueueStatus).toHaveBeenCalledTimes(1); + expect(jobMock.pause).toHaveBeenCalledTimes(1); + expect(jobMock.waitForQueueCompletion).toHaveBeenCalledTimes(1); + expect(jobMock.resume).toHaveBeenCalledTimes(1); + }); + + it('should clear embeddings if old and new models are different', async () => { + searchMock.getDimensionSize.mockResolvedValue(512); + jobMock.getQueueStatus.mockResolvedValue({ isActive: false, isPaused: false }); + + await sut.onConfigUpdateEvent({ + newConfig: { + machineLearning: { clip: { modelName: 'ViT-B-32__openai', enabled: true }, enabled: true }, + } as SystemConfig, + oldConfig: { + machineLearning: { clip: { modelName: 'ViT-B-16__openai', enabled: true }, enabled: true }, + } as SystemConfig, + }); + + expect(searchMock.deleteAllSearchEmbeddings).toHaveBeenCalled(); + expect(searchMock.getDimensionSize).toHaveBeenCalledTimes(1); + expect(searchMock.setDimensionSize).not.toHaveBeenCalled(); + expect(jobMock.getQueueStatus).toHaveBeenCalledTimes(1); + expect(jobMock.pause).toHaveBeenCalledTimes(1); + expect(jobMock.waitForQueueCompletion).toHaveBeenCalledTimes(1); + expect(jobMock.resume).toHaveBeenCalledTimes(1); + }); + + it('should skip pausing and resuming queue if already paused', async () => { + searchMock.getDimensionSize.mockResolvedValue(512); + jobMock.getQueueStatus.mockResolvedValue({ isActive: false, isPaused: true }); + + await sut.onConfigUpdateEvent({ + newConfig: { + machineLearning: { clip: { modelName: 'ViT-B-32__openai', enabled: true }, enabled: true }, + } as SystemConfig, + oldConfig: { + machineLearning: { clip: { modelName: 'ViT-B-16__openai', enabled: true }, enabled: true }, + } as SystemConfig, + }); + + expect(searchMock.getDimensionSize).toHaveBeenCalledTimes(1); + expect(searchMock.setDimensionSize).not.toHaveBeenCalled(); + expect(jobMock.getQueueStatus).toHaveBeenCalledTimes(1); + expect(jobMock.pause).not.toHaveBeenCalled(); + expect(jobMock.waitForQueueCompletion).toHaveBeenCalledTimes(1); + expect(jobMock.resume).not.toHaveBeenCalled(); + }); + }); + describe('handleQueueEncodeClip', () => { it('should do nothing if machine learning is disabled', async () => { systemMock.get.mockResolvedValue(systemConfigStub.machineLearningDisabled); diff --git a/server/src/services/smart-info.service.ts b/server/src/services/smart-info.service.ts index 72372470de..1957f3885c 100644 --- a/server/src/services/smart-info.service.ts +++ b/server/src/services/smart-info.service.ts @@ -1,4 +1,5 @@ import { Inject, Injectable } from '@nestjs/common'; +import { SystemConfig } from 'src/config'; import { SystemConfigCore } from 'src/cores/system-config.core'; import { IAssetRepository, WithoutProperty } from 'src/interfaces/asset.interface'; import { DatabaseLock, IDatabaseRepository } from 'src/interfaces/database.interface'; @@ -16,7 +17,7 @@ import { ILoggerRepository } from 'src/interfaces/logger.interface'; import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface'; import { ISearchRepository } from 'src/interfaces/search.interface'; import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface'; -import { isSmartSearchEnabled } from 'src/utils/misc'; +import { getCLIPModelInfo, isSmartSearchEnabled } from 'src/utils/misc'; import { usePagination } from 'src/utils/pagination'; @Injectable() @@ -36,24 +37,67 @@ export class SmartInfoService implements OnEvents { this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger); } - async init() { - await this.jobRepository.pause(QueueName.SMART_SEARCH); + async onBootstrapEvent(app: 'api' | 'microservices') { + if (app !== 'microservices') { + return; + } - await this.jobRepository.waitForQueueCompletion(QueueName.SMART_SEARCH); + const config = await this.configCore.getConfig({ withCache: false }); + await this.init(config); + } - const { machineLearning } = await this.configCore.getConfig({ withCache: false }); - - await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, () => - this.repository.init(machineLearning.clip.modelName), - ); - - await this.jobRepository.resume(QueueName.SMART_SEARCH); + onConfigValidateEvent({ newConfig }: SystemConfigUpdateEvent) { + try { + getCLIPModelInfo(newConfig.machineLearning.clip.modelName); + } catch { + throw new Error( + `Unknown CLIP model: ${newConfig.machineLearning.clip.modelName}. Please check the model name for typos and confirm this is a supported model.`, + ); + } } async onConfigUpdateEvent({ oldConfig, newConfig }: SystemConfigUpdateEvent) { - if (oldConfig.machineLearning.clip.modelName !== newConfig.machineLearning.clip.modelName) { - await this.repository.init(newConfig.machineLearning.clip.modelName); + await this.init(newConfig, oldConfig); + } + + private async init(newConfig: SystemConfig, oldConfig?: SystemConfig) { + if (!isSmartSearchEnabled(newConfig.machineLearning)) { + return; } + + await this.databaseRepository.withLock(DatabaseLock.CLIPDimSize, async () => { + const { dimSize } = getCLIPModelInfo(newConfig.machineLearning.clip.modelName); + const dbDimSize = await this.repository.getDimensionSize(); + this.logger.verbose(`Current database CLIP dimension size is ${dbDimSize}`); + + const modelChange = + oldConfig && oldConfig.machineLearning.clip.modelName !== newConfig.machineLearning.clip.modelName; + const dimSizeChange = dbDimSize !== dimSize; + if (!modelChange && !dimSizeChange) { + return; + } + + const { isPaused } = await this.jobRepository.getQueueStatus(QueueName.SMART_SEARCH); + if (!isPaused) { + await this.jobRepository.pause(QueueName.SMART_SEARCH); + } + await this.jobRepository.waitForQueueCompletion(QueueName.SMART_SEARCH); + + if (dimSizeChange) { + this.logger.log( + `Dimension size of model ${newConfig.machineLearning.clip.modelName} is ${dimSize}, but database expects ${dbDimSize}.`, + ); + this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`); + await this.repository.setDimensionSize(dimSize); + this.logger.log(`Successfully updated database CLIP dimension size from ${dbDimSize} to ${dimSize}.`); + } else { + await this.repository.deleteAllSearchEmbeddings(); + } + + if (!isPaused) { + await this.jobRepository.resume(QueueName.SMART_SEARCH); + } + }); } async handleQueueEncodeClip({ force }: IBaseJob): Promise<JobStatus> { diff --git a/server/test/repositories/search.repository.mock.ts b/server/test/repositories/search.repository.mock.ts index 7da93e02af..fd244c6f5c 100644 --- a/server/test/repositories/search.repository.mock.ts +++ b/server/test/repositories/search.repository.mock.ts @@ -3,7 +3,6 @@ import { Mocked, vitest } from 'vitest'; export const newSearchRepositoryMock = (): Mocked<ISearchRepository> => { return { - init: vitest.fn(), searchMetadata: vitest.fn(), searchSmart: vitest.fn(), searchDuplicates: vitest.fn(), @@ -12,5 +11,7 @@ export const newSearchRepositoryMock = (): Mocked<ISearchRepository> => { searchPlaces: vitest.fn(), getAssetsByCity: vitest.fn(), deleteAllSearchEmbeddings: vitest.fn(), + getDimensionSize: vitest.fn(), + setDimensionSize: vitest.fn(), }; };